Skip to contents

0) Set parameters

2) Get FluSight data repo

The FluSight Github repository stores forecast data for the FluSight Influenza Forecasting Hub, run by the US CDC. This project collects forecasts for weekly new hospitalizations due to confirmed influenza. More information can be found in the ReadMe of the repository: https://github.com/cdcepi/FluSight-forecast-hub.

We will copy a set of forecasts from this repository, and use them to build a simple ensemble, visualize the forecasts, and score them against observed data.

To switch to a different forecasting hub, change the disease argument in the clone_hub_repos() function below.

repo_dir <- clone_hub_repos(disease = forecast_disease,
                            clone_dir = getwd())
## Cloning repository...
## Cloning into 'FluSight-forecast-hub'...
## Updating files:  36% (1004/2714)Updating files:  37% (1005/2714)Updating files:  38% (1032/2714)Updating files:  39% (1059/2714)Updating files:  40% (1086/2714)Updating files:  41% (1113/2714)Updating files:  42% (1140/2714)Updating files:  43% (1168/2714)Updating files:  44% (1195/2714)Updating files:  45% (1222/2714)Updating files:  46% (1249/2714)Updating files:  47% (1276/2714)Updating files:  48% (1303/2714)Updating files:  49% (1330/2714)Updating files:  50% (1357/2714)Updating files:  51% (1385/2714)Updating files:  52% (1412/2714)Updating files:  53% (1439/2714)Updating files:  54% (1466/2714)Updating files:  55% (1493/2714)Updating files:  56% (1520/2714)Updating files:  57% (1547/2714)Updating files:  58% (1575/2714)Updating files:  59% (1602/2714)Updating files:  60% (1629/2714)Updating files:  61% (1656/2714)Updating files:  62% (1683/2714)Updating files:  63% (1710/2714)Updating files:  64% (1737/2714)Updating files:  65% (1765/2714)Updating files:  66% (1792/2714)Updating files:  67% (1819/2714)Updating files:  68% (1846/2714)Updating files:  69% (1873/2714)Updating files:  69% (1897/2714)Updating files:  70% (1900/2714)Updating files:  71% (1927/2714)Updating files:  72% (1955/2714)Updating files:  73% (1982/2714)Updating files:  74% (2009/2714)Updating files:  75% (2036/2714)Updating files:  76% (2063/2714)Updating files:  77% (2090/2714)Updating files:  78% (2117/2714)Updating files:  79% (2145/2714)Updating files:  80% (2172/2714)Updating files:  81% (2199/2714)Updating files:  82% (2226/2714)Updating files:  83% (2253/2714)Updating files:  84% (2280/2714)Updating files:  85% (2307/2714)Updating files:  86% (2335/2714)Updating files:  87% (2362/2714)Updating files:  88% (2389/2714)Updating files:  89% (2416/2714)Updating files:  90% (2443/2714)Updating files:  91% (2470/2714)Updating files:  92% (2497/2714)Updating files:  93% (2525/2714)Updating files:  94% (2552/2714)Updating files:  95% (2579/2714)Updating files:  96% (2606/2714)Updating files:  97% (2633/2714)Updating files:  98% (2660/2714)Updating files:  98% (2672/2714)Updating files:  99% (2687/2714)Updating files: 100% (2714/2714)Updating files: 100% (2714/2714), done.
## Using repo_dir: /home/runner/work/AMPH_Forecast_Suite/AMPH_Forecast_Suite/vignettes/FluSight-forecast-hub

Copy specific forecast round to the model-output folder

We will copy forecasts from a set of models from FluSight. These include: # - FluSight-baseline # - MOBS-GLEAM_FLUH # - FluSight-ensemble

# models from flusight
models_to_copy <- c(
  "FluSight-baseline",
  "MOBS-GLEAM_FLUH",
  "FluSight-ensemble")

#models from AMPH
models_created_in_AMPH <- list.dirs("model-output", 
                                    full.names = FALSE, 
                                    recursive = FALSE)
# copy Forecast Hub forecasts to model-output folder
copy_fch_outputs(repo_dir,
                 reference_date,
                 models_to_copy)
## Copied files for date 2024-12-07 to /home/runner/work/AMPH_Forecast_Suite/AMPH_Forecast_Suite/vignettes/model-output

4) Load model output (hub forecasts & your forecasts)

output_path <- file.path("model-output")

# Retrieve parquet/csv model output files and keep those matching the reference date
file_paths <- list.files(output_path, pattern = "\\.(parquet|csv)$",
                         full.names = TRUE, recursive = TRUE)
file_paths <- file_paths[grepl(reference_date, file_paths)]

if (!length(file_paths)) {
  stop("No model-output files found for reference_date = ", reference_date,
       ". Try a different date.")
}


# Read & bind; keep quantile forecasts; add model_id from folder name

projection_data_all <- file_paths %>%
  purrr::map_dfr(function(.x) {
    df <- read_model_file(.x)
    
    # standardize expected columns just in case
    if (!"output_type" %in% names(df))   stop("Missing 'output_type' in: ", .x)
    if (!"output_type_id" %in% names(df)) stop("Missing 'output_type_id' in: ", .x)
    
    df %>%
      dplyr::filter(.data$output_type == "quantile") %>%
      dplyr::mutate(
        output_type_id = suppressWarnings(as.numeric(.data$output_type_id)),
        model_id = basename(dirname(.x)),
        location = as.character(location)
      )
  })

prep_proj_data <- projection_data_all %>%
  dplyr::mutate(
    target_end_date = dplyr::coalesce(target_end_date, reference_date + 7 * as.integer(horizon))
  ) %>%
  dplyr::select(-tidyselect::any_of(c("model", "origin_date")))

# Convert to hubverse model_out_tbl format
projection_data_tbl <- hubUtils::as_model_out_tbl(prep_proj_data) %>%
  dplyr::filter(model_id %in% c(
    models_created_in_AMPH,
    models_to_copy
  ))


# Read and join location metadata (for names/abbreviations)

# loc_data <- readr::read_csv(file.path(dir_path, "auxiliary-data", "locations.csv"),
#                             show_col_types = FALSE)
data(loc_data, package = "AMPHForecastSuite")

projection_data_tbl2 <- projection_data_tbl %>%
  dplyr::left_join(
    loc_data %>%
        mutate(location = tolower(abbreviation)) %>%
            dplyr::select(location, location_name) %>%
      bind_rows(
        loc_data %>%
          dplyr::select(location, location_name)),
    by = "location"
  ) %>%
  dplyr::mutate(location_name = dplyr::coalesce(location_name, location)) 

dplyr::distinct(projection_data_tbl, model_id) %>% knitr::kable()
model_id
AMPH-epipredict-arx
AMPH-epipredict-climate
AMPH-neuralnetwork
AMPH-sarima
FluSight-baseline
FluSight-ensemble
MOBS-GLEAM_FLUH

5) Pick location, start date, and uncertainty bands

# Location can be "US" or a full state name (must match location_name in target_data)
loc <- state_name
start_date <- lubridate::as_date(reference_date) - lubridate::weeks(12)
# Middle 80% interval:
uncertainty <- c(0.1, 0.9)

6) Build a simple equal-weight ensemble

We will use the simple_ensemble() function from the hubEnsembles package to build a simple, equal-weight ensemble across a set of models. In this example, we exclude the FluSight-baseline and FluSight-ensemble, and only ensemble for a single location and forecast date.

# Filter to the location of interest and the chosen forecast round
round_dat <- projection_data_tbl2 %>%
  dplyr::filter(.data$location_name == loc,
                target == target,
                output_type == "quantile",
                horizon >= 0) %>%
  dplyr::collect()

# Generate a simple (equal-weight) ensemble across contributing models
round_ens <- hubEnsembles::simple_ensemble(
  round_dat %>%
    dplyr::filter(model_id %in% c("AMPH-SARIMA","AMPH-neuralnetwork", "MOBS-GLEAM_FLUH"))) %>% 
    # dplyr::filter(!(model_id %in% c("FluSight-baseline", "AMPH-SARIMA","AMPH-neuralnetwork",
    #                                 "FluSight-ensemble",
    #                                 "AMPH-epipredict-climate")))) %>%
  mutate(model_id = "AMPH-ensemble")

# Combine ensemble with individual models for plotting
plot_df <- dplyr::bind_rows(round_dat, round_ens)

# lapply(unique(round_dat$model_id), function(x){round_dat %>% filter(model_id==x) %>% pull(output_type_id) %>% unique() %>% length()})

unique(plot_df$model_id)
## [1] "AMPH-epipredict-arx"     "AMPH-epipredict-climate"
## [3] "AMPH-neuralnetwork"      "AMPH-sarima"            
## [5] "FluSight-baseline"       "FluSight-ensemble"      
## [7] "MOBS-GLEAM_FLUH"         "AMPH-ensemble"

7) Prepare data for visualization

# pull updated target data
new_target_data_date <- lubridate::as_date(reference_date) + lubridate::weeks(5)
target_data_plot <- get_nhsn_data(
  disease = forecast_disease,
  geo_values = geo_ids,
  forecast_date = new_target_data_date,
  save_data = TRUE
)
## Important: forecast_date is more than 1 week ago. Pulling data issued prior to forecast_date.
## Pulling data issued on or before 2025-01-11
## Warning: No API key found. You will be limited to non-complex queries and encounter rate
## limits if you proceed.
##  See `?save_api_key()` for details on obtaining and setting API keys.
## This warning is displayed once every 8 hours.
target_data_plot <- readr::read_csv(
  file.path("target-data", paste0("target-hospital-admissions-", new_target_data_date, ".csv")),
  show_col_types = FALSE)

# Forecasts to tidy plot
proj_data <- hubUtils::as_model_out_tbl(plot_df) %>%
  dplyr::rename(target_date = target_end_date) %>%
  dplyr::mutate(output_type_id = suppressWarnings(as.numeric(output_type_id))) %>%
  dplyr::arrange(model_id, horizon, target_date, output_type_id) %>%
  dplyr::distinct(model_id, horizon, target_date, output_type_id, .keep_all = TRUE)

# Observed data for the same location and time window
target_data_plot <- target_data_plot %>%
  dplyr::filter(target_end_date > start_date) %>%
  dplyr::rename(date = target_end_date)

head(proj_data)
## # A tibble: 6 × 10
##   model_id      target_date reference_date target horizon location location_name
##   <chr>         <date>      <date>         <chr>    <dbl> <chr>    <chr>        
## 1 AMPH-ensemble 2024-12-07  2024-12-07     wk in…       0 24       Maryland     
## 2 AMPH-ensemble 2024-12-07  2024-12-07     wk in…       0 24       Maryland     
## 3 AMPH-ensemble 2024-12-07  2024-12-07     wk in…       0 24       Maryland     
## 4 AMPH-ensemble 2024-12-07  2024-12-07     wk in…       0 24       Maryland     
## 5 AMPH-ensemble 2024-12-07  2024-12-07     wk in…       0 24       Maryland     
## 6 AMPH-ensemble 2024-12-07  2024-12-07     wk in…       0 24       Maryland     
## # ℹ 3 more variables: output_type <chr>, output_type_id <dbl>, value <dbl>
head(target_data_plot)
## # A tibble: 6 × 10
##   location abbreviation location_name target    source disease signal date      
##      <dbl> <chr>        <chr>         <chr>     <chr>  <chr>   <chr>  <date>    
## 1       24 MD           Maryland      wk inc i… nhsn   influe… confi… 2024-09-21
## 2       24 MD           Maryland      wk inc i… nhsn   influe… confi… 2024-09-28
## 3       24 MD           Maryland      wk inc i… nhsn   influe… confi… 2024-10-05
## 4       24 MD           Maryland      wk inc i… nhsn   influe… confi… 2024-10-12
## 5       24 MD           Maryland      wk inc i… nhsn   influe… confi… 2024-10-19
## 6       24 MD           Maryland      wk inc i… nhsn   influe… confi… 2024-10-26
## # ℹ 2 more variables: issue_date <date>, observation <dbl>

8) Plot forecasts vs. truth

We use two different approaches to visualize the forecasts vs. observed data. The hubVis package provides a convenient function to plot forecasts, plot_step_ahead_model_output().

We can also plot ggplot2 directly, though this requires more code.

Option 1: with hubVis

# This is having issues

hubVis::plot_step_ahead_model_output(
  proj_data,
  target_data = target_data_plot,
  use_median_as_point = TRUE,
  show_legend = TRUE,
  intervals = 0.8,        
  ens_name = "AMPH-ensemble",
  ens_color = "black"
)

Option 2: with ggplot2

library(dplyr)
library(tidyr)
library(ggplot2)
library(scales)

# target_data_plot <- readr::read_csv(
#   file.path("target-data", paste0("target-hospital-admissions-", new_target_data_date, ".csv")),
#   show_col_types = FALSE)

# Identify ensemble id from the object you created earlier
ens_id <- unique(round_ens$model_id)[1]  # e.g., "hub-ensemble"

# Build the 80% ribbon (0.1 / 0.9) for all models
ribbon_80 <- proj_data %>%
  filter(output_type == "quantile", output_type_id %in% c(0.1, 0.9)) %>%
  mutate(output_type_id = as.numeric(output_type_id)) %>%
  select(model_id, horizon, target_date, output_type_id, value) %>%
  pivot_wider(names_from = output_type_id, values_from = value, names_prefix = "q") %>%
  rename(ymin = q0.1, ymax = q0.9)

# Median (0.5) for all models
med_50 <- proj_data %>%
  filter(output_type == "quantile", output_type_id == 0.5) %>%
  select(model_id, horizon, target_date, value) %>%
  mutate(line_width = if_else(model_id == ens_id, 1.1, 0.8))

# Legend order: others first, ensemble last
model_levels <- proj_data %>%
  distinct(model_id) %>%
  pull(model_id) %>%
  setdiff(ens_id) %>%
  c(ens_id)

# Okabe–Ito palette (color-blind friendly)
okabe_ito <- c(
  "#E69F00", "#56B4E9", "#009E73", "#F0E442",
  "#0072B2", "#D55E00", "#CC79A7", "#999999"
)
n_other <- length(model_levels) - 1

other_cols <- if (n_other <= length(okabe_ito)) okabe_ito[seq_len(n_other)] else scales::hue_pal(l = 45, c = 100)(n_other)

# Lines: others = Okabe–Ito, ensemble = black
color_vals <- setNames(c(other_cols, "#000000"), model_levels)
# Ribbons: same hues; ensemble darker gray so black line pops
fill_vals  <- setNames(c(other_cols, "#3A3A3A"), model_levels)

ggplot() +
  geom_ribbon(
    data = ribbon_80,
    aes(x = target_date, ymin = ymin, ymax = ymax, fill = model_id),
    alpha = 0.22, show.legend = TRUE
  ) +
  geom_line(
    data = med_50,
    aes(x = target_date, y = value, color = model_id, linewidth = line_width),
    lineend = "round", alpha = 0.98, show.legend = TRUE
  ) +
  geom_point(
    data = target_data_plot,
    aes(x = date, y = observation),
    size = 1.2, alpha = 0.85, inherit.aes = FALSE,
    color = "grey50"
  ) +
  geom_line(
    data = target_data_plot,
    aes(x = date, y = observation),
    alpha = 0.85, inherit.aes = FALSE,
    color = "grey50"
  ) +
  scale_color_manual(values = color_vals, name = "Model") +
  scale_fill_manual(values  = fill_vals,  name = "Model") +
  scale_linewidth_identity() +
  labs(x = "Target date", y = "Weekly incident hospitalizations") +
  theme_minimal(base_size = 12) +
  theme(legend.position = "right")

Plot of forecasts vs. truth

10) Score forecasts (WIS, coverage)

Evaluation of forecasts is critical for assessing model performance and for building trust in forecasts. Proper scoring rules, such as the Weighted Interval Score (WIS), provide a way to evaluate the accuracy and calibration of probabilistic forecasts. Scoring requires observed data. We will use the scoringutils package to compute WIS for our forecasts.

scoring_target_data <- readr::read_csv(
  file.path("target-data", paste0("target-hospital-admissions-", new_target_data_date, ".csv")),
  show_col_types = FALSE)

scoring_target_data <- scoring_target_data %>%
  filter(location %in% location,
         issue_date >= target_end_date,
         target_end_date > "2022-09-01") %>%
  select(geo_value = location, target_end_date, value = observation) %>%
  drop_na(value) %>%
  epiprocess::as_epi_df(time_value = target_end_date)
# Join forecasts with observations at (target_date, location)
# and conform to scoringutils "forecast" structure.

scoring_df <- dplyr::left_join(
  proj_data,
  scoring_target_data %>%
    dplyr::rename(observation = value,
                  target_date = time_value) %>%
    mutate(location_name = loc) %>%
    select(-geo_value),
  by = c("target_date", "location_name"),
  relationship = "many-to-one"
) %>%
  dplyr::rename(
    model = model_id,
    predicted = value,
    observed = observation,
    quantile_level = output_type_id
  )

# Convert to a scoringutils forecast object
forecast <- scoringutils::as_forecast_quantile(
  scoring_df,
  observed       = "observed",
  predicted      = "predicted",
  quantile_level = "quantile_level",
  # be explicit so extra cols don't confuse the unit of a single forecast
  forecast_unit  = c("model", "location_name", "target_date")
)

# Score (WIS, coverage, etc.)
scores <- scoringutils::score(forecast)
knitr::kable(scoringutils::summarise_scores(scores, by = "model"))
model wis overprediction underprediction dispersion bias interval_coverage_50 interval_coverage_90 ae_median
AMPH-ensemble 57.72073 0 53.93871 3.782014 -0.9200 0 0.25 85.29844
AMPH-epipredict-arx 62.20304 0 58.88705 3.315985 -0.9525 0 0.25 81.89191
AMPH-epipredict-climate 39.23780 0 11.72917 27.508632 -0.3500 1 1.00 75.50000
AMPH-neuralnetwork 55.28833 0 50.30912 4.979213 -0.9000 0 0.50 88.23638
AMPH-sarima 61.90821 0 57.58205 4.326163 -0.9125 0 0.25 85.92319
FluSight-baseline 54.71011 0 48.66304 6.047065 -0.8825 0 0.50 85.75000
FluSight-ensemble 50.65141 0 46.18478 4.466630 -0.9325 0 0.25 70.75000
MOBS-GLEAM_FLUH 62.53968 0 59.95487 2.584814 -0.9700 0 0.25 82.36050