Skip to content

Config

influpaint.batch.config

Configuration libraries for InfluPaint research.

copaint_config_library(timesteps)

CoPaint inpainting configurations

Source code in influpaint/batch/config.py
def copaint_config_library(timesteps):
    """CoPaint inpainting configurations"""
    # Friendly names map:
    #  - celebahq_try3   -> short-jump (TT)     : jump_length=5, use_timetravel=True,  num_iteration_optimize_xt=5
    #  - celebahq_noTTJ5 -> short-jump (no TT)  : jump_length=5, use_timetravel=False, num_iteration_optimize_xt=2
    #  - celebahq        -> long-jump (TT)      : jump_length=10, use_timetravel=True, num_iteration_optimize_xt=2
    config_lib = {
        "celebahq_try1": config.Config(default_config_dict={
            "respace_interpolate": False,
            "ddim": {
                "ddim_sigma": 0.0,
                "schedule_params": {
                    "ddpm_num_steps": timesteps,
                    "jump_length": 20,  # 10,
                    "jump_n_sample": 4,  # 2,
                    "num_inference_steps": timesteps,
                    "schedule_type": "linear",
                    "time_travel_filter_type": "none",
                    "use_timetravel": True
                }
            },
            "optimize_xt": {
                "coef_xt_reg": 0.00001,  # 0.0001,
                "coef_xt_reg_decay": 1.05,  # 1.01,
                "filter_xT": False,
                "lr_xt": 0.02,
                "lr_xt_decay": 1.012,
                "mid_interval_num": 1,
                "num_iteration_optimize_xt": 5,
                "optimize_before_time_travel": True,
                "optimize_xt": True,
                "use_adaptive_lr_xt": True,
                "use_smart_lr_xt_decay": True
            },
            "debug": False
        }, use_argparse=False),

        "celebahq_noTTJ5": config.Config(default_config_dict={ # short-jump (no TT)
            "respace_interpolate": False,
            "ddim": {
                "ddim_sigma": 0.0,
                "schedule_params": {
                    "ddpm_num_steps": timesteps,
                    "jump_length": 5, # was 10 before 2025-07
                    "jump_n_sample": 2,
                    "num_inference_steps": timesteps,
                    "schedule_type": "linear",
                    "time_travel_filter_type": "none",
                    "use_timetravel": False,
                }
            },
            "optimize_xt": {
                "coef_xt_reg": 0.0001,
                "coef_xt_reg_decay": 1.01,
                "filter_xT": False,
                "lr_xt": 0.02,
                "lr_xt_decay": 1.012,
                "mid_interval_num": 1,
                "num_iteration_optimize_xt": 2,
                "optimize_before_time_travel": True,
                "optimize_xt": True,
                "use_adaptive_lr_xt": True,
                "use_smart_lr_xt_decay": True
            },
            "debug": False
        }, use_argparse=False),

        "celebahq_noTT2": config.Config(default_config_dict={
            "respace_interpolate": False,
            "ddim": {
                "ddim_sigma": 0.0,
                "schedule_params": {
                    "ddpm_num_steps": timesteps,
                    "jump_length": 10,
                    "jump_n_sample": 2,
                    "num_inference_steps": timesteps,
                    "schedule_type": "linear",
                    "time_travel_filter_type": "none",
                    "use_timetravel": False,
                }
            },
            "optimize_xt": {
                "coef_xt_reg": 0.0001,
                "coef_xt_reg_decay": 1.01,
                "filter_xT": False,
                "lr_xt": 0.02,
                "lr_xt_decay": 1.012,
                "mid_interval_num": 1,
                "num_iteration_optimize_xt": 5,
                "optimize_before_time_travel": True,
                "optimize_xt": True,
                "use_adaptive_lr_xt": True,
                "use_smart_lr_xt_decay": True
            },
            "debug": False
        }, use_argparse=False),

        "celebahq_try3": config.Config(default_config_dict={ # short-jump (TT)
            "respace_interpolate": False,
            "ddim": {
                "ddim_sigma": 0.0,
                "schedule_params": {
                    "ddpm_num_steps": timesteps,
                    "jump_length": 5,  # 10
                    "jump_n_sample": 2,
                    "num_inference_steps": timesteps,
                    "schedule_type": "linear",
                    "time_travel_filter_type": "none",
                    "use_timetravel": True
                }
            },
            "optimize_xt": {
                "coef_xt_reg": 0.0001,
                "coef_xt_reg_decay": 1.01,
                "filter_xT": False,
                "lr_xt": 0.02,
                "lr_xt_decay": 1.012,
                "mid_interval_num": 1,
                "num_iteration_optimize_xt": 5,  # 2,
                "optimize_before_time_travel": True,
                "optimize_xt": True,
                "use_adaptive_lr_xt": True,
                "use_smart_lr_xt_decay": True
            },
            "debug": False
        }, use_argparse=False),

        "celebahq": config.Config(default_config_dict={  # long-jump (TT)
            "respace_interpolate": False,
            "ddim": {
                "ddim_sigma": 0.0,
                "schedule_params": {
                    "ddpm_num_steps": timesteps,
                    "jump_length": 10,
                    "jump_n_sample": 2,
                    "num_inference_steps": timesteps,
                    "schedule_type": "linear",
                    "time_travel_filter_type": "none",
                    "use_timetravel": True
                }
            },
            "optimize_xt": {
                "coef_xt_reg": 0.0001,
                "coef_xt_reg_decay": 1.01,
                "filter_xT": False,
                "lr_xt": 0.02,
                "lr_xt_decay": 1.012,
                "mid_interval_num": 1,
                "num_iteration_optimize_xt": 2,
                "optimize_before_time_travel": True,
                "optimize_xt": True,
                "use_adaptive_lr_xt": True,
                "use_smart_lr_xt_decay": True
            },
            "debug": False
        }, use_argparse=False),

        # "imagenet": config.Config(default_config_dict={
        #     "respace_interpolate": False,
        #     "ddim": {
        #         "ddim_sigma": 0.0,
        #         "schedule_params": {
        #             "ddpm_num_steps": timesteps,
        #             "jump_length": 10,
        #             "jump_n_sample": 2,
        #             "num_inference_steps": 200,
        #             "schedule_type": "linear",
        #             "time_travel_filter_type": "none",
        #             "use_timetravel": True
        #         }
        #     },
        #     "optimize_xt": {
        #         "coef_xt_reg": 0.01,
        #         "coef_xt_reg_decay": 1.0,
        #         "filter_xT": False,
        #         "lr_xt": 0.02,
        #         "lr_xt_decay": 1.012,
        #         "mid_interval_num": 1,
        #         "num_iteration_optimize_xt": 2,
        #         "optimize_before_time_travel": True,
        #         "optimize_xt": True,
        #         "use_adaptive_lr_xt": True,
        #         "use_smart_lr_xt_decay": True
        #     },
        #     "debug": False
        # }, use_argparse=False),
    }
    return config_lib

dataset_library(season_setup, channels)

Dataset configurations

Source code in influpaint/batch/config.py
def dataset_library(season_setup, channels):
    """Dataset configurations"""
    day = "2025-07-17"

    dataset_spec = {
        # Legacy datasets
        # "Fv": training_datasets.FluDataset.from_fluview(season_setup=season_setup, download=False),
        #"R1Fv": training_datasets.FluDataset.from_SMHR1_fluview(season_setup=season_setup, download=False),
        #"R1": training_datasets.FluDataset.from_csp_SMHR1('Flusight/flu-datasets/synthetic/CSP_FluSMHR1_weekly_padded_4scn.nc', channels=channels),

        # New DATASET_GRIDS - just comment/uncomment to enable/disable
        "100S": lambda: training_datasets.FluDataset.from_xarray(f"training_datasets/TS_100S_{day}.nc", channels=channels),
        "70S30M": lambda: training_datasets.FluDataset.from_xarray(f"training_datasets/TS_70S30M_{day}.nc", channels=channels),
        "30S70M": lambda: training_datasets.FluDataset.from_xarray(f"training_datasets/TS_30S70M_{day}.nc", channels=channels),
        "100M": lambda: training_datasets.FluDataset.from_xarray(f"training_datasets/TS_100M_{day}.nc", channels=channels),
    }
    return dataset_spec

ddpm_library(image_size, channels, epoch, device, batch_size, unet)

Model configurations

Source code in influpaint/batch/config.py
def ddpm_library(image_size, channels, epoch, device, batch_size, unet):
    """Model configurations"""

    ddpm_spec = {
        "U200l": ddpm.DDPM(
            model=unet, 
            image_size=image_size, 
            channels=channels, 
            batch_size=batch_size, 
            epochs=epoch, 
            timesteps=200,
            beta_schedule="linear",
            device=device,
            loss_type="l2"
        ),
        "U500l": ddpm.DDPM(
            model=unet, 
            image_size=image_size, 
            channels=channels, 
            batch_size=batch_size, 
            epochs=epoch, 
            timesteps=500,
            beta_schedule="linear",
            device=device,
            loss_type="l2"
        ),
        "U200c": ddpm.DDPM(
            model=unet, 
            image_size=image_size, 
            channels=channels, 
            batch_size=batch_size, 
            epochs=epoch, 
            timesteps=200,
            beta_schedule="cosine",
            device=device,
            loss_type="l2"
        ),
        "U500c": ddpm.DDPM(
            model=unet, 
            image_size=image_size, 
            channels=channels, 
            batch_size=batch_size, 
            epochs=epoch, 
            timesteps=500,
            beta_schedule="cosine",
            device=device,
            loss_type="l2"
        ),
        "U800c": ddpm.DDPM(
            model=unet, 
            image_size=image_size, 
            channels=channels, 
            batch_size=batch_size, 
            epochs=epoch, 
            timesteps=800,
            beta_schedule="cosine",
            device=device,
            loss_type="l2"
        ),
    }
    return ddpm_spec

get_dataset(dataset_name, season_setup, channels)

Get dataset by name, handling lambda functions

Source code in influpaint/batch/config.py
def get_dataset(dataset_name, season_setup, channels):
    """Get dataset by name, handling lambda functions"""
    dataset_spec = dataset_library(season_setup, channels)
    dataset_factory = dataset_spec[dataset_name]

    # Handle lambda functions (for new training datasets)
    if callable(dataset_factory):
        return dataset_factory()
    else:
        return dataset_factory

transform_library(scaling_per_channel, data_mean, data_std)

Transform configuration

Source code in influpaint/batch/config.py
def transform_library(scaling_per_channel, data_mean, data_std):
    """Transform configuration"""
    from torchvision import transforms
    from influpaint.datasets import transforms as epitransforms


    transform_enrich = {
        "No": transforms.Compose([]),
        "PoisPadScale": transforms.Compose([
            transforms.Lambda(lambda t: epitransforms.transform_poisson(t)),
            transforms.Lambda(lambda t: epitransforms.transform_random_padintime(t, min_shift=-15, max_shift=15)),
            transforms.Lambda(lambda t: epitransforms.transform_randomscale(t, min=.1, max=1.9)),
        ]),
        "PoisPadScaleSmall": transforms.Compose([
            transforms.Lambda(lambda t: epitransforms.transform_poisson(t)),
            transforms.Lambda(lambda t: epitransforms.transform_random_padintime(t, min_shift=-4, max_shift=4)),
            transforms.Lambda(lambda t: epitransforms.transform_randomscale(t, min=.7, max=1.3)),
        ]),
        "Pois": transforms.Compose([
            transforms.Lambda(lambda t: epitransforms.transform_poisson(t)),
        ])
    }

    transforms_spec = {
        # No scaling (linear scale)
        "Lins": {
            "reg": transforms.Compose([
                transforms.Lambda(lambda t: epitransforms.transform_channelwisescale(t, scale=1/scaling_per_channel)),
                transforms.Lambda(lambda t: epitransforms.transform_channelwisescale(t, scale=2)),
            ]),
            "inv": transforms.Compose([
                transforms.Lambda(lambda t: epitransforms.transform_channelwisescale_inv(t, scale=1/scaling_per_channel)),
                transforms.Lambda(lambda t: epitransforms.transform_channelwisescale_inv(t, scale=2)),
            ][::-1])  
        },
        # sqrt scale
        "Sqrt": {
            "reg": transforms.Compose([
                transforms.Lambda(lambda t: epitransforms.transform_channelwisescale(t, scale=1/scaling_per_channel)),
                epitransforms.transform_sqrt,
                transforms.Lambda(lambda t: epitransforms.transform_channelwisescale(t, scale=2)),
            ]),
            "inv": transforms.Compose([
                transforms.Lambda(lambda t: epitransforms.transform_channelwisescale_inv(t, scale=1/scaling_per_channel)),
                epitransforms.transform_sqrt_inv,
                transforms.Lambda(lambda t: epitransforms.transform_channelwisescale_inv(t, scale=2)),
            ][::-1])  
        },
        # Log-transform followed by Z-score
        "LogZs": {
            "reg": transforms.Compose([
                # Use log1p for numerical stability, calculates log(1+t)
                transforms.Lambda(lambda t: torch.log1p(t)),
                # Standardize the log-transformed data
                transforms.Lambda(lambda t: (t - np.log(data_mean)) / np.log(data_std)),
            ]),
            "inv": transforms.Compose([
                # Inverse of log1p is expm1
                transforms.Lambda(lambda t: torch.expm1(t)),
                # Inverse of standardization
                transforms.Lambda(lambda t: t * np.log(data_std) + np.log(data_mean)),
            ][::-1]) # Reverses to apply inv_zscore, then inv_log
        },
        "LinsZs": {
            "reg": transforms.Compose([
                transforms.Lambda(lambda t: (t - data_mean) / data_std),
            ]),
            "inv": transforms.Compose([
                transforms.Lambda(lambda t: t * data_std + data_mean),
            ])
        }
    }

    return transforms_spec, transform_enrich