Skip to content

Mixer

influpaint.datasets.mixer

dataset_mixer.py - Epidemic Data Augmentation and Frame Construction

Combines multiple epidemic surveillance datasets into a unified training corpus for diffusion models. Addresses common challenges in epidemic modeling:

  • Dataset Rebalancing: Uses multipliers to weight data sources
  • Temporal Completeness: Ensures all frames have complete weekly coverage (1-53)
  • Spatial Completeness: Fills missing location-season combinations
  • Gap Handling: Fills missing weeks and locations
  • Peak Scaling: Scales epidemic curves to target peak intensities

Key Components:

  1. Multiplier Calculation: Compute dataset weights for target proportions
  2. Frame Construction: Build complete epidemic frames for training
  3. Gap Filling: Handle missing data intelligently
  4. Peak Scaling: Scale frames to realistic epidemic intensities

Typical Usage:

Step 1: Combine datasets into hierarchical structure

all_datasets_df = pd.concat([fluview_df, nc_df, smh_traj_df])

Step 2: Configure dataset inclusion, weighting, and scaling

config = { "fluview": {"proportion": 0.7, "total": 1000, "to_scale": True}, # 70% + scaling "smh_traj": {"proportion": 0.3, "total": 1000, "to_scale": False} # 30% + no scaling }

Step 3: Define scaling distribution for peak intensities

scaling_dist = np.array([1000, 2000, 3000, 5000, 8000, 12000]) # US peak values

Step 4: Build complete frames with configurable location handling and scaling

frames = build_frames(all_datasets_df, config, season_axis, fill_missing_locations="zeros", scaling_distribution=scaling_dist)

Alternative: Use explicit multipliers instead of proportions

config = { "fluview": {"multiplier": 2, "to_scale": True}, # Include twice + scaling "smh_traj": {"multiplier": 1, "to_scale": False} # Include once + no scaling } frames = build_frames(all_datasets_df, config, season_axis, scaling_distribution=scaling_dist)

Peak Scaling:

When "to_scale": True is specified for a dataset: - Each frame gets independently scaled to a random peak from scaling_distribution - Scaling preserves epidemic curve shape while adjusting intensity - US peak = max(weekly_sum_across_all_locations) is used as scaling reference - Origin tracking includes scaling target: "[scaled_to_X.X]" - Provides realistic intensity variation for training data augmentation

Output Format:

Each frame is a complete epidemic season with: - All weeks (1-53) represented - All locations covered
- Consistent data structure for array conversion - Optional peak scaling applied - Full provenance tracking in 'origin' column

Enables training on heterogeneous surveillance data while maintaining epidemiological structure and realistic intensity distributions.

build_frames(all_datasets_df, config, season_axis, fill_missing_locations='error', scaling_distribution=None)

Build complete epidemic frames from hierarchical dataset structure.

Handles 4-level hierarchy: H1 → H2 → Season → Sample and creates complete frames while preserving dataset origins.

Parameters:

Name Type Description Default
all_datasets_df DataFrame

Combined dataset with required columns: - datasetH1: Top-level dataset category (e.g., 'fluview', 'smh_traj')
- datasetH2: Sub-dataset within H1 (e.g., 'round4_CADPH-FluCAT_A-2024-08-01') - fluseason: Flu season year - sample: Sample identifier within each H2/season combination - location_code, season_week, value, week_enddate: Epidemic data

required
config dict

Configuration for dataset inclusion and weighting: - Keys: H1 dataset names (must exist in datasetH1 column) - Values: Either {"multiplier": int} or {"proportion": float, "total": int} - Optional: {"to_scale": bool} to enable frame scaling

required
season_axis SeasonAxis

Season axis object providing location definitions

required
fill_missing_locations str

Strategy for handling missing locations: - "error": Fail if any expected location is missing (default) - "zeros": Fill missing locations with zeros - "random": Fill missing locations with random other season data - "skip": Skip frames with missing locations

'error'
scaling_distribution ndarray

Array of values to draw from for scaling. Required if any dataset in config has "to_scale": True

None

Returns:

Name Type Description
list list

Complete epidemic frames, where each frame contains: - All weeks (1-53) for all expected locations (based on season_axis) - Origin column tracking source: "H1/H2/season/sample" - Replicated datasets as specified by config

Example

config = { "fluview": {"multiplier": 1, "to_scale": True}, "smh_traj": {"proportion": 0.7, "total": 1000, "to_scale": False} } scaling_dist = np.array([1000, 2000, 3000, 5000, 8000]) # Peak values to scale to frames = build_frames(all_datasets_df, config, season_axis, fill_missing_locations="zeros", scaling_distribution=scaling_dist)

Notes
  • If H1 dataset is included, ALL H2s and seasons within it are included
  • Minimum frames = sum(n_H2 * n_seasons) for each included H1
  • Samples are replicated, not created (e.g., sample_1_copy1, sample_1_copy2)
  • Location completeness is enforced based on season_axis.locations
Source code in influpaint/datasets/mixer.py
def build_frames(all_datasets_df: pd.DataFrame, config: dict, season_axis: SeasonAxis, 
                 fill_missing_locations: str = "error", scaling_distribution: np.ndarray = None) -> list:
    """
    Build complete epidemic frames from hierarchical dataset structure.

    Handles 4-level hierarchy: H1 → H2 → Season → Sample and creates complete 
    frames while preserving dataset origins.

    Args:
        all_datasets_df (pd.DataFrame): Combined dataset with required columns:
            - datasetH1: Top-level dataset category (e.g., 'fluview', 'smh_traj')  
            - datasetH2: Sub-dataset within H1 (e.g., 'round4_CADPH-FluCAT_A-2024-08-01')
            - fluseason: Flu season year
            - sample: Sample identifier within each H2/season combination
            - location_code, season_week, value, week_enddate: Epidemic data

        config (dict): Configuration for dataset inclusion and weighting:
            - Keys: H1 dataset names (must exist in datasetH1 column)
            - Values: Either {"multiplier": int} or {"proportion": float, "total": int}
            - Optional: {"to_scale": bool} to enable frame scaling

        season_axis (SeasonAxis): Season axis object providing location definitions

        fill_missing_locations (str): Strategy for handling missing locations:
            - "error": Fail if any expected location is missing (default)
            - "zeros": Fill missing locations with zeros
            - "random": Fill missing locations with random other season data
            - "skip": Skip frames with missing locations

        scaling_distribution (np.ndarray, optional): Array of values to draw from for scaling.
            Required if any dataset in config has "to_scale": True

    Returns:
        list: Complete epidemic frames, where each frame contains:
            - All weeks (1-53) for all expected locations (based on season_axis)
            - Origin column tracking source: "H1/H2/season/sample"
            - Replicated datasets as specified by config

    Example:
        config = {
            "fluview": {"multiplier": 1, "to_scale": True},
            "smh_traj": {"proportion": 0.7, "total": 1000, "to_scale": False}
        }
        scaling_dist = np.array([1000, 2000, 3000, 5000, 8000])  # Peak values to scale to
        frames = build_frames(all_datasets_df, config, season_axis, 
                             fill_missing_locations="zeros", scaling_distribution=scaling_dist)

    Notes:
        - If H1 dataset is included, ALL H2s and seasons within it are included
        - Minimum frames = sum(n_H2 * n_seasons) for each included H1
        - Samples are replicated, not created (e.g., sample_1_copy1, sample_1_copy2)
        - Location completeness is enforced based on season_axis.locations
    """
    # Validate input dataframe
    required_columns = ['datasetH1', 'datasetH2', 'fluseason', 'sample', 
                       'location_code', 'season_week', 'value', 'week_enddate']
    _validate_required_columns(all_datasets_df, required_columns, "Input dataframe")

    # Validate config references existing H1 datasets
    available_h1 = set(all_datasets_df['datasetH1'].unique())
    config_h1 = set(config.keys())
    missing_h1 = config_h1 - available_h1
    if missing_h1:
        raise ValueError(f"Config references non-existent H1 datasets: {missing_h1}")

    # Validate fill_missing_locations parameter
    valid_strategies = {"error", "zeros", "random", "skip"}
    if fill_missing_locations not in valid_strategies:
        raise ValueError(f"fill_missing_locations must be one of: {valid_strategies}")

    # Validate scaling parameters
    needs_scaling = any(cfg.get("to_scale", False) for cfg in config.values())
    if needs_scaling and scaling_distribution is None:
        raise ValueError("scaling_distribution must be provided when any dataset has to_scale=True")
    if scaling_distribution is not None and len(scaling_distribution) == 0:
        raise ValueError("scaling_distribution cannot be empty")

    # Calculate multipliers for each H1 dataset
    h1_multipliers = _calculate_h1_multipliers(all_datasets_df, config)

    # Pre-compute global lookup table for intelligent filling (do this once)
    global_lookup = None
    if fill_missing_locations == "random":
        print("Pre-computing intelligent fill lookup table...")
        global_lookup = _build_global_lookup_table(all_datasets_df)

    # Build frames for each included H1 dataset  
    all_frames = []
    frame_summary = {}

    print("Building frames...")
    for h1_name, multiplier in h1_multipliers.items():
        h1_config = config[h1_name]
        should_scale = h1_config.get("to_scale", False)
        scale_info = " with scaling" if should_scale else ""
        print(f"Processing {h1_name} (multiplier={multiplier}{scale_info})...")
        h1_data = all_datasets_df[all_datasets_df['datasetH1'] == h1_name].copy()

        h1_frames = _build_h1_frames(h1_data, h1_name, multiplier, season_axis, fill_missing_locations, 
                                   all_datasets_df, global_lookup, should_scale, scaling_distribution)
        all_frames.extend(h1_frames)

        # Build summary for this H1 dataset
        h2_counts = {}
        for frame in h1_frames:
            if 'datasetH2' in frame.columns:
                h2 = frame['datasetH2'].iloc[0]
                # Handle NaN values properly
                if pd.isna(h2):
                    h2 = "<missing_datasetH2>"
                h2_counts[h2] = h2_counts.get(h2, 0) + 1

        frame_summary[h1_name] = {
            'total_frames': len(h1_frames),
            'multiplier': multiplier,
            'h2_breakdown': h2_counts
        }

    # Print summary
    print(f"Created {len(all_frames)} total frames:")
    for h1_name, summary in frame_summary.items():
        print(f"  {h1_name}: {summary['total_frames']} frames (multiplier={summary['multiplier']})")
        #for h2, count in summary['h2_breakdown'].items():
        #    print(f"    {h2}: {count} frames")

    # Clean up frames by removing unnecessary metadata columns
    cleaned_frames = []
    essential_columns = ['location_code', 'season_week', 'value', 'week_enddate', 'origin']

    for frame in all_frames:
        # Keep only essential columns that have actual data
        available_essential = [col for col in essential_columns if col in frame.columns]
        cleaned_frame = frame[available_essential].copy()
        cleaned_frames.append(cleaned_frame)

    return cleaned_frames