Skip to content

Vaxflux

vaxflux

Model seasonal vaccination uptake curves.

Covariate

Bases: ABC, BaseModel

Abstract base class for covariates in vaxflux.

Subclasses must implement sample() and adhere to the following shape contract:

  • Seasonal covariates (covariate is None): sample() must register and return a numpyro.deterministic site named f"covariate_values_{self.prefix}" with shape (num_seasons,).
  • Categorical covariates (covariate is not None): sample() must register and return a numpyro.deterministic site named f"covariate_values_{self.prefix}" with shape (num_seasons, num_categories).

Attributes:

Name Type Description
parameter str

The name of the model parameter this covariate affects.

covariate str | None

The name of the covariate variable.

prefix property

Get the prefix for the covariate based on the parameter name.

Returns:

Type Description
str

The prefix string derived from the parameter name.

extra_dims(_season_coord, _category_short_coord)

Return ArviZ dimension mappings for any additional numpyro sites.

Returns dimension mappings for any additional numpyro sites this covariate registers beyond the standard covariate_values_* deterministic.

The covariate_values_* site and the raw {prefix} sample site are handled automatically by VaxfluxModel; only override this method when sample() registers extra named sites (e.g. hyperparameter draws) that should carry labelled coordinates in the ArviZ output.

Parameters:

Name Type Description Default
_season_coord str

The name of the season coordinate ("season").

required
_category_short_coord str | None

The name of the non-baseline-category coordinate for this covariate, or None for seasonal covariates.

required

Returns:

Type Description
dict[str, list[str]]

A mapping from numpyro site name to a list of coordinate dimension

dict[str, list[str]]

names, in the same format as the dims argument accepted by

dict[str, list[str]]

arviz.from_numpyro. Return an empty dict (the default) if there

dict[str, list[str]]

are no additional sites to label.

Source code in src/vaxflux/_covariates.py
def extra_dims(
    self,
    _season_coord: str,
    _category_short_coord: str | None,
) -> dict[str, list[str]]:
    """
    Return ArviZ dimension mappings for any additional numpyro sites.

    Returns dimension mappings for any additional numpyro sites this
    covariate registers beyond the standard `covariate_values_*`
    deterministic.

    The `covariate_values_*` site and the raw `{prefix}` sample site are
    handled automatically by `VaxfluxModel`; only override this method when
    `sample()` registers *extra* named sites (e.g. hyperparameter draws) that
    should carry labelled coordinates in the ArviZ output.

    Args:
        _season_coord: The name of the season coordinate (`"season"`).
        _category_short_coord: The name of the non-baseline-category
            coordinate for this covariate, or `None` for seasonal
            covariates.

    Returns:
        A mapping from numpyro site name to a list of coordinate dimension
        names, in the same format as the `dims` argument accepted by
        `arviz.from_numpyro`.  Return an empty dict (the default) if there
        are no additional sites to label.
    """
    return {}

sample(*, season_names, covariate_categories) abstractmethod

Abstract method to sample covariate values using model context.

Parameters:

Name Type Description Default
season_names list[str]

Season names defined in the model.

required
covariate_categories list[str] | None

All categories for this covariate including the baseline at index 0, or None for seasonal covariates.

required

Returns:

Type Description
NumericalArrayLike

A numerical array-like structure containing the sampled covariate values.

Source code in src/vaxflux/_covariates.py
@abstractmethod
def sample(
    self,
    *,
    season_names: list[str],
    covariate_categories: list[str] | None,
) -> NumericalArrayLike:
    """
    Abstract method to sample covariate values using model context.

    Args:
        season_names: Season names defined in the model.
        covariate_categories: All categories for this covariate including the
            baseline at index 0, or `None` for seasonal covariates.

    Returns:
        A numerical array-like structure containing the sampled covariate values.
    """
    raise NotImplementedError

CovariateCategories

Bases: BaseModel

A representation of the categories for a covariate.

Curve()

Bases: ABC

Abstract class for implementations of uptake curves.

Attributes:

Name Type Description
parameters tuple[str, ...]

A tuple of parameter names required by the prevalence function.

Examples:

>>> import jax.numpy as jnp
>>> from vaxflux import Curve
>>> class ClampedSlopeCurve(Curve):
...     def prevalence(
...         self, t: NumericalArrayLike, m: NumericalArrayLike
...     ) -> jax.Array:
...         return jnp.clip(t * m, 0.0, 1.0)
>>> t = jnp.linspace(-1.0, 3.0, num=8)
>>> curve = ClampedSlopeCurve()
>>> curve.parameters
('m',)
>>> curve.prevalence(t, m=0.5)
Array([0.        , 0.        , 0.0714286 , 0.35714293, 0.6428572 ,
       0.9285715 , 1.        , 1.        ], dtype=float32)
>>> curve.incidence(t, m=0.5)
Array([0. , 0. , 0.5, 0.5, 0.5, 0.5, 0. , 0. ], dtype=float32)
Source code in src/vaxflux/_curves.py
def __init__(self) -> None:
    self.parameters = tuple(inspect.signature(self.prevalence).parameters)[1:]
    in_axes = (0,) + (None,) * len(self.parameters)
    self._grad_prevalence = jax.vmap(jax.grad(self.prevalence), in_axes=in_axes)

incidence(t, **kwargs)

Compute the incidence at time t.

Parameters:

Name Type Description Default
t NumericalArrayLike

The time steps to evaluate the incidence curve at.

required
**kwargs NumericalArrayLike

Additional parameters required by the incidence model.

{}

Returns:

Type Description
Array

The incidence curve evaluated at time t.

Source code in src/vaxflux/_curves.py
def incidence(
    self, t: NumericalArrayLike, **kwargs: NumericalArrayLike
) -> jax.Array:
    """
    Compute the incidence at time `t`.

    Args:
        t: The time steps to evaluate the incidence curve at.
        **kwargs: Additional parameters required by the incidence model.

    Returns:
        The incidence curve evaluated at time `t`.
    """
    return cast("jax.Array", self._grad_prevalence(*((t, *tuple(kwargs.values())))))

plot(t, *, parameter_sets=None, labels=None, plot_incidence=True, plot_prevalence=True, figsize=None, title=None, **kwargs)

Plot the curve's incidence and/or prevalence over time.

Parameters:

Name Type Description Default
t NumericalArrayLike

The time steps to evaluate and plot.

required
parameter_sets Sequence[dict[str, NumericalArrayLike]] | None

Optional collection of parameter dictionaries to overlay on the same axes. When omitted, kwargs are used as a single set.

None
labels Sequence[str] | None

Optional labels for each parameter set. When omitted, labels are generated from self.parameters.

None
plot_incidence bool

Whether to plot the incidence curve.

True
plot_prevalence bool

Whether to plot the prevalence curve.

True
figsize tuple[float, float] | None

Optional figure size override.

None
title str | None

Optional figure title.

None
**kwargs NumericalArrayLike

Parameters forwarded to the curve methods when plotting a single parameter set.

{}

Returns:

Type Description
Figure

The Matplotlib figure.

Raises:

Type Description
ValueError

If neither incidence nor prevalence is selected.

ValueError

If both parameter_sets and kwargs are provided.

ValueError

If labels do not match parameter_sets.

ValueError

If t is not one-dimensional.

Source code in src/vaxflux/_curves.py
def plot(
    self,
    t: NumericalArrayLike,
    *,
    parameter_sets: Sequence[dict[str, NumericalArrayLike]] | None = None,
    labels: Sequence[str] | None = None,
    plot_incidence: bool = True,
    plot_prevalence: bool = True,
    figsize: tuple[float, float] | None = None,
    title: str | None = None,
    **kwargs: NumericalArrayLike,
) -> Figure:
    """
    Plot the curve's incidence and/or prevalence over time.

    Args:
        t: The time steps to evaluate and plot.
        parameter_sets: Optional collection of parameter dictionaries to overlay on
            the same axes. When omitted, `kwargs` are used as a single set.
        labels: Optional labels for each parameter set. When omitted, labels are
            generated from `self.parameters`.
        plot_incidence: Whether to plot the incidence curve.
        plot_prevalence: Whether to plot the prevalence curve.
        figsize: Optional figure size override.
        title: Optional figure title.
        **kwargs: Parameters forwarded to the curve methods when plotting a single
            parameter set.

    Returns:
        The Matplotlib figure.

    Raises:
        ValueError: If neither incidence nor prevalence is selected.
        ValueError: If both `parameter_sets` and `kwargs` are provided.
        ValueError: If `labels` do not match `parameter_sets`.
        ValueError: If `t` is not one-dimensional.

    """
    if not plot_incidence and not plot_prevalence:
        msg = (
            "At least one of `plot_incidence` or `plot_prevalence` must be `True`."
        )
        raise ValueError(msg)

    t_values = _numerical_array_like_to_1d_jax_array(t)
    curve_parameter_sets = _normalize_parameter_sets(
        parameters=self.parameters,
        parameter_sets=parameter_sets,
        kwargs=kwargs,
    )
    incidence_series = (
        [
            jnp.asarray(self.incidence(t_values, **params))
            for params in curve_parameter_sets
        ]
        if plot_incidence
        else None
    )
    prevalence_series = (
        [
            jnp.asarray(self.prevalence(t_values, **params))
            for params in curve_parameter_sets
        ]
        if plot_prevalence
        else None
    )
    return _plot_curve_figure(
        t=t_values,
        parameter_sets=curve_parameter_sets,
        labels=labels,
        incidence_series=incidence_series,
        prevalence_series=prevalence_series,
        parameter_names=self.parameters,
        title=title,
        figsize=figsize,
    )

prevalence(t, **kwargs) abstractmethod

Compute the prevalence at time t.

Parameters:

Name Type Description Default
t NumericalArrayLike

The time steps to evaluate the prevalence curve at.

required
**kwargs NumericalArrayLike

Additional parameters required by the prevalence model.

{}

Returns:

Type Description
Array

The prevalence curve evaluated at time t bounded between 0 and 1.

Source code in src/vaxflux/_curves.py
@abstractmethod
def prevalence(
    self, t: NumericalArrayLike, **kwargs: NumericalArrayLike
) -> jax.Array:
    """
    Compute the prevalence at time `t`.

    Args:
        t: The time steps to evaluate the prevalence curve at.
        **kwargs: Additional parameters required by the prevalence model.

    Returns:
        The prevalence curve evaluated at time `t` bounded between 0 and 1.
    """
    raise NotImplementedError

prevalence_difference(t0, t1, **kwargs)

Compute prevalence differences between two time arrays.

Parameters:

Name Type Description Default
t0 NumericalArrayLike

The start time steps for the interval.

required
t1 NumericalArrayLike

The end time steps for the interval.

required
**kwargs NumericalArrayLike

Additional parameters required by the prevalence model.

{}

Returns:

Type Description
Array

The prevalence difference for each interval.

Source code in src/vaxflux/_curves.py
def prevalence_difference(
    self,
    t0: NumericalArrayLike,
    t1: NumericalArrayLike,
    **kwargs: NumericalArrayLike,
) -> jax.Array:
    """
    Compute prevalence differences between two time arrays.

    Args:
        t0: The start time steps for the interval.
        t1: The end time steps for the interval.
        **kwargs: Additional parameters required by the prevalence model.

    Returns:
        The prevalence difference for each interval.
    """
    return self.prevalence(jnp.asarray(t1), **kwargs) - self.prevalence(
        jnp.asarray(t0), **kwargs
    )

DateRange

Bases: BaseModel

A representation of a date range for uptake scenarios.

Examples:

>>> from vaxflux import DateRange
>>> date_range = DateRange(
...     season="2023/2024",
...     start_date="2023-12-01",
...     end_date="2023-12-31",
...     report_date="2024-01-01",
... )
>>> date_range.season
'2023/2024'
>>> date_range.start_date
datetime.date(2023, 12, 1)
>>> date_range.end_date
datetime.date(2023, 12, 31)
>>> date_range.report_date
datetime.date(2024, 1, 1)
>>> DateRange(
...     season="2023/2024",
...     start_date="2023-12-01",
...     end_date="2023-11-30",
...     report_date="2023-12-01",
... )
Traceback (most recent call last):
    ...
pydantic_core._pydantic_core.ValidationError: 1 validation error for DateRange
  Value error, The end date, 2023-11-30, must be after or the same as the start date 2023-12-01. [...]
    For further information visit ...
>>> DateRange(
...     season="2023/2024",
...     start_date="2023-12-01",
...     end_date="2023-12-31",
...     report_date="2023-12-30",
... )
Traceback (most recent call last):
    ...
pydantic_core._pydantic_core.ValidationError: 1 validation error for DateRange
  Value error, The report date, 2023-12-30, must be after or the same as the end date 2023-12-31. [...]
    For further information visit ...

Implementation

Bases: BaseModel

A representation of an implementation of an intervention.

name property

Get the name of the implementation.

Returns:

Type Description
str

The name of the implementation, formatted as a string.

Intervention

Bases: BaseModel

A representation for the affect of an intervention on the uptake of a vaccine.

numpyro_distribution()

Return a numpyro distribution for the intervention.

Source code in src/vaxflux/_interventions.py
def numpyro_distribution(self) -> dist.Distribution:
    """Return a numpyro distribution for the intervention."""
    return getattr(dist, self.distribution)(**self.distribution_kwargs)

sample(name)

Sample an intervention effect from its distribution.

Source code in src/vaxflux/_interventions.py
def sample(self, name: str) -> JaxArray:
    """Sample an intervention effect from its distribution."""
    return cast("JaxArray", numpyro.sample(name, self.numpyro_distribution()))

LogisticCurve()

Bases: Curve

Logistic uptake curve implementation.

The logistic curve is defined by the following prevalence function:

This class implements a logistic curve with parameters \(m\), \(r\), and \(s\) which is given by:

\[ f(t\vert m,r,s)=\mathrm{invlogit}\left(m\right)\mathrm{logit}\left(e^r\left(t-s\right)\right) \]

Examples:

>>> import jax.numpy as jnp
>>> from vaxflux import LogisticCurve
>>> t = jnp.array([0.0, 1.0, 2.0, 3.0])
>>> curve = LogisticCurve()
>>> prevalence = curve.prevalence(t, m=0.0, r=1.0, s=1.0)
>>> prevalence
Array([0.03095159, 0.25      , 0.4690484 , 0.4978322 ], dtype=float32)
>>> incidence = curve.incidence(t, m=0.0, r=1.0, s=1.0)
>>> incidence
Array([0.0789269 , 0.33978522, 0.07892691, 0.00586712], dtype=float32)
Source code in src/vaxflux/_curves.py
def __init__(self) -> None:
    self.parameters = tuple(inspect.signature(self.prevalence).parameters)[1:]
    in_axes = (0,) + (None,) * len(self.parameters)
    self._grad_prevalence = jax.vmap(jax.grad(self.prevalence), in_axes=in_axes)

prevalence(t, m, r, s)

Compute the logistic prevalence at time t.

Parameters:

Name Type Description Default
t NumericalArrayLike

The time steps to evaluate the prevalence curve at.

required
m NumericalArrayLike

The curve's maximum value.

required
r NumericalArrayLike

The steepness of the curve.

required
s NumericalArrayLike

The x-value of the sigmoid's midpoint.

required

Returns:

Type Description
Array

The logistic prevalence curve evaluated at time t.

Source code in src/vaxflux/_curves.py
def prevalence(  # type: ignore[override]
    self,
    t: NumericalArrayLike,
    m: NumericalArrayLike,
    r: NumericalArrayLike,
    s: NumericalArrayLike,
) -> jax.Array:
    """
    Compute the logistic prevalence at time `t`.

    Args:
        t: The time steps to evaluate the prevalence curve at.
        m: The curve's maximum value.
        r: The steepness of the curve.
        s: The x-value of the sigmoid's midpoint.

    Returns:
        The logistic prevalence curve evaluated at time `t`.
    """
    return jsp.expit(m) * jsp.expit(jnp.exp(r) * (t - s))

PartiallyPooledGaussianCovariate

Bases: Covariate

Covariate model using a partially pooled Gaussian approach.

\[ \begin{aligned} \mu_k &\sim \mathrm{Normal}(\mu_{\mu}, \mu_{\sigma}) \\\\ \sigma_k &\sim \mathrm{HalfNormal}(\sigma) \\\\ x_{s,k} &\sim \mathrm{Normal}(\mu_k, \sigma_k) \end{aligned} \]

For seasonal covariates (covariate is None) this samples one value per season using shared scalar hyperpriors, producing shape (num_seasons,).

For categorical covariates (covariate is not None) this samples one value per non-baseline category using shared scalar hyperpriors, then broadcasts that constant effect across all seasons to produce shape (num_seasons, num_categories). The baseline category (index 0) is always zero.

Attributes:

Name Type Description
mu_mu float

The location parameter of the partially pooled mean.

mu_sigma float

The scale parameter of the partially pooled mean.

sigma float

The scale of the half-normal distribution for the standard deviation.

sample(*, season_names, covariate_categories)

Sample values from a Gaussian distribution defined by the mean and stddev.

Returns:

Type Description
NumericalArrayLike

A numerical array-like structure containing the sampled covariate values.

Source code in src/vaxflux/_covariates.py
def sample(
    self,
    *,
    season_names: list[str],
    covariate_categories: list[str] | None,
) -> NumericalArrayLike:
    """
    Sample values from a Gaussian distribution defined by the mean and stddev.

    Returns:
        A numerical array-like structure containing the sampled covariate values.
    """
    mu_sample = numpyro.sample(
        f"{self.prefix}_mu", dist.Normal(self.mu_mu, self.mu_sigma)
    )
    sigma_sample = numpyro.sample(
        f"{self.prefix}_sigma", dist.HalfNormal(self.sigma)
    )
    if covariate_categories is None:
        with numpyro.plate(f"covariate_{self.prefix}", len(season_names)):
            sampled_values = cast(
                "NumericalArrayLike",
                numpyro.sample(self.prefix, dist.Normal(mu_sample, sigma_sample)),
            )
        return cast(
            "NumericalArrayLike",
            numpyro.deterministic(
                f"covariate_values_{self.prefix}", sampled_values
            ),
        )
    with numpyro.plate(f"covariate_{self.prefix}", len(covariate_categories) - 1):
        sampled_values = cast(
            "NumericalArrayLike",
            numpyro.sample(self.prefix, dist.Normal(mu_sample, sigma_sample)),
        )
    padded = jnp.pad(sampled_values, (1, 0), mode="constant")
    inflated = jnp.broadcast_to(
        padded[jnp.newaxis, :],
        (len(season_names), len(covariate_categories)),
    )
    return cast(
        "NumericalArrayLike",
        numpyro.deterministic(f"covariate_values_{self.prefix}", inflated),
    )

PooledGaussianCovariate

Bases: Covariate

Covariate model using a pooled Gaussian approach.

\[ \begin{aligned} x_k &\sim \mathrm{Normal}(\mu, \sigma) \end{aligned} \]

For seasonal covariates (covariate is None) this samples one value per season, producing shape (num_seasons,).

For categorical covariates (covariate is not None) this samples one value per non-baseline category, then broadcasts that constant effect across all seasons to produce shape (num_seasons, num_categories). The baseline category (index 0) is always zero.

Attributes:

Name Type Description
mu float

The mean of the Gaussian distribution.

sigma float

The standard deviation of the Gaussian distribution.

sample(*, season_names, covariate_categories)

Sample values from a Gaussian distribution defined by the mean and stddev.

Returns:

Type Description
NumericalArrayLike

A numerical array-like structure containing the sampled covariate values.

Source code in src/vaxflux/_covariates.py
def sample(
    self,
    *,
    season_names: list[str],
    covariate_categories: list[str] | None,
) -> NumericalArrayLike:
    """
    Sample values from a Gaussian distribution defined by the mean and stddev.

    Returns:
        A numerical array-like structure containing the sampled covariate values.
    """
    if covariate_categories is None:
        with numpyro.plate(f"covariate_{self.prefix}", len(season_names)):
            sampled_values = cast(
                "NumericalArrayLike",
                numpyro.sample(self.prefix, dist.Normal(self.mu, self.sigma)),
            )
        return cast(
            "NumericalArrayLike",
            numpyro.deterministic(
                f"covariate_values_{self.prefix}", sampled_values
            ),
        )
    with numpyro.plate(f"covariate_{self.prefix}", len(covariate_categories) - 1):
        sampled_values = cast(
            "NumericalArrayLike",
            numpyro.sample(self.prefix, dist.Normal(self.mu, self.sigma)),
        )
    padded = jnp.pad(sampled_values, (1, 0), mode="constant")
    inflated = jnp.broadcast_to(
        padded[jnp.newaxis, :],
        (len(season_names), len(covariate_categories)),
    )
    return cast(
        "NumericalArrayLike",
        numpyro.deterministic(f"covariate_values_{self.prefix}", inflated),
    )

SeasonRange

Bases: BaseModel

A representation of a season range for uptake scenarios.

Examples:

>>> from vaxflux import SeasonRange
>>> season_range = SeasonRange(
...     season="2023/2024",
...     start_date="2023-12-01",
...     end_date="2024-03-31",
... )
>>> season_range.season
'2023/2024'
>>> season_range.start_date
datetime.date(2023, 12, 1)
>>> season_range.end_date
datetime.date(2024, 3, 31)
>>> SeasonRange(
...     season="2023/2024",
...     start_date="2024-03-31",
...     end_date="2023-12-01",
... )
Traceback (most recent call last):
    ...
pydantic_core._pydantic_core.ValidationError: 1 validation error for SeasonRange
  Value error, The end date, 2023-12-01, must be after or the same as the start date 2024-03-31. [...]
    For further information visit ...

SeasonVaryingPartiallyPooledGaussianCovariate

Bases: Covariate

Season-varying covariate model using partial pooling across seasons.

\[ \begin{aligned} \mu_k &\sim \mathrm{Normal}(\mu_{\mu}, \mu_{\sigma}) \\\\ \sigma_k &\sim \mathrm{HalfNormal}(\sigma) \\\\ x_{s,k} &\sim \mathrm{Normal}(\mu_k, \sigma_k) \end{aligned} \]

This covariate samples category effects that vary by season, while shrinking them toward category-level means. It is only valid for categorical covariates (covariate is not None) and always returns shape (num_seasons, num_categories).

Attributes:

Name Type Description
mu_mu float

The location parameter of the partially pooled mean.

mu_sigma float

The scale parameter of the partially pooled mean.

sigma float

The scale of the half-normal distribution for the standard deviation.

Notes

Unlike covariates generally, this class does not support seasonal covariates since the effects are explicitly season-varying. Attempting to use this class with covariate is None will raise an error at model definition time due to the covariate name validation.

extra_dims(season_coord, category_short_coord)

Return dims for the hyperparameter and raw-draw sites.

Registers {prefix}_mu, {prefix}_sigma, and {prefix} with the appropriate coordinate names.

Source code in src/vaxflux/_covariates.py
def extra_dims(
    self,
    season_coord: str,
    category_short_coord: str | None,
) -> dict[str, list[str]]:
    """
    Return dims for the hyperparameter and raw-draw sites.

    Registers `{prefix}_mu`, `{prefix}_sigma`, and `{prefix}` with
    the appropriate coordinate names.
    """
    short = category_short_coord or season_coord
    return {
        f"{self.prefix}_mu": [short],
        f"{self.prefix}_sigma": [short],
        self.prefix: [season_coord, short],
    }

sample(*, season_names, covariate_categories)

Sample season-varying covariate values.

Parameters:

Name Type Description Default
season_names list[str]

Season names in the model.

required
covariate_categories list[str] | None

Covariate categories including the baseline. Always non-None at runtime because covariate is required to be a non-empty string on this class, so the model always resolves and passes categories.

required

Returns:

Type Description
NumericalArrayLike

A numerical array-like structure containing the sampled covariate values.

Source code in src/vaxflux/_covariates.py
def sample(
    self,
    *,
    season_names: list[str],
    covariate_categories: list[str] | None,
) -> NumericalArrayLike:
    """
    Sample season-varying covariate values.

    Args:
        season_names: Season names in the model.
        covariate_categories: Covariate categories including the baseline.
            Always non-`None` at runtime because `covariate` is required
            to be a non-empty string on this class, so the model always
            resolves and passes categories.

    Returns:
        A numerical array-like structure containing the sampled covariate values.
    """
    # covariate_categories is always non-None here: the model only resolves
    # and passes categories when covariate is not None, and covariate is
    # required to be a str on this class.
    categories = cast("list[str]", covariate_categories)
    num_seasons = len(season_names)
    num_categories = len(categories)
    # Draw category-level parameters from hyper priors.
    with numpyro.plate(f"covariate_{self.prefix}", num_categories - 1):
        mu_sample = numpyro.sample(
            f"{self.prefix}_mu",
            dist.Normal(self.mu_mu, self.mu_sigma),
        )
        sigma_sample = numpyro.sample(
            f"{self.prefix}_sigma",
            dist.HalfNormal(self.sigma),
        )
    # Draw season-level effects per category from the category-level priors.
    with numpyro.plate(f"{self.prefix}_season", num_seasons):
        sampled_values = numpyro.sample(
            self.prefix,
            dist.Normal(mu_sample, sigma_sample).to_event(1),
        )
    padded = jnp.pad(sampled_values, ((0, 0), (1, 0)), mode="constant")
    return cast(
        "NumericalArrayLike",
        numpyro.deterministic(f"covariate_values_{self.prefix}", padded),
    )

VaxfluxInferenceData

Bases: InferenceData

Container for inference data specific to vaxflux.

This class extends ArviZ's InferenceData to include specialized functionality for vaxflux models. This allows for users to easily use this object with functions that expect ArviZ InferenceData while also providing vaxflux-specific methods and attributes.

coords cached property

Return model coordinates without chain/draw dimensions.

Returns:

Type Description
dict[str, list[str]]

The coordinate mapping for the merged prior or posterior dataset without

dict[str, list[str]]

chain and draw dimensions (which differ between prior/posterior).

covariate_categories cached property

A mapping of covariate names to their category labels.

Returns:

Type Description
dict[str, dict[str, str]]

A dictionary mapping covariate names to dictionaries of category labels

dict[str, dict[str, str]]

where each inner dictionary maps coordinate-safe category names to the

dict[str, dict[str, str]]

original category labels.

merged_posterior cached property

Return a merged posterior dataset combining posterior and posterior predictive.

Returns:

Type Description
Dataset

An xarray Dataset with posterior and posterior predictive variables.

merged_prior cached property

Return a merged prior dataset combining prior and prior predictive.

Returns:

Type Description
Dataset

An xarray Dataset with prior and prior predictive variables.

posterior_observations cached property

Return posterior observations as a formatted DataFrame.

Returns:

Type Description
DataFrame

A DataFrame with observation metadata and values. Will contain the columns

DataFrame

'chain', 'draw', 'season', 'season_start_date', 'season_end_date',

DataFrame

'start_date', 'end_date', 'report_date', any covariate columns,

DataFrame

'type', and 'value'.

prior_observations cached property

Return prior observations as a formatted DataFrame.

Returns:

Type Description
DataFrame

A DataFrame with observation metadata and values. Will contain the columns

DataFrame

'chain', 'draw', 'season', 'season_start_date', 'season_end_date',

DataFrame

'start_date', 'end_date', 'report_date', any covariate columns,

DataFrame

'type', and 'value'.

from_numpyro(*args, **kwargs) classmethod

Create inference data from NumPyro outputs via ArviZ.

For more details, see ArviZ's from_numpyro function.

Parameters:

Name Type Description Default
*args Any

Positional arguments forwarded to arviz.from_numpyro.

()
**kwargs Any

Keyword arguments forwarded to arviz.from_numpyro.

{}

Returns:

Type Description
VaxfluxInferenceData

A VaxfluxInferenceData instance populated from NumPyro results.

Source code in src/vaxflux/_vaxflux_inference_data.py
@classmethod
def from_numpyro(cls, *args: Any, **kwargs: Any) -> "VaxfluxInferenceData":
    """
    Create inference data from NumPyro outputs via ArviZ.

    For more details, see
    [ArviZ's `from_numpyro` function](https://python.arviz.org/en/stable/api/generated/arviz.from_numpyro.html).

    Args:
        *args: Positional arguments forwarded to `arviz.from_numpyro`.
        **kwargs: Keyword arguments forwarded to `arviz.from_numpyro`.

    Returns:
        A `VaxfluxInferenceData` instance populated from NumPyro results.

    """
    observations = cast("pd.DataFrame | None", kwargs.pop("observations", None))
    idata = cast(
        "az.InferenceData",
        az.from_numpyro(*args, **kwargs),  # type: ignore[no-untyped-call]
    )
    idata.__class__ = cls
    idata.observations = observations.copy() if observations is not None else None  # type: ignore[attr-defined]
    return cast("VaxfluxInferenceData", idata)

plot_posterior_predictive(*, intervals=(0.5, 0.8, 0.95), plot_incidence=True, plot_prevalence=True, plot_observations=None, predictive='latent', figsize=None)

Plot posterior predictive checks with shaded intervals.

Parameters:

Name Type Description Default
intervals Sequence[float]

Central predictive intervals to plot.

(0.5, 0.8, 0.95)
plot_incidence bool

Whether to plot incidence panels.

True
plot_prevalence bool

Whether to plot prevalence panels.

True
plot_observations bool | None

Whether to plot observed data points. When None, uses self.observations availability.

None
predictive Literal['latent', 'observation']

Whether to plot latent curve intervals or observation-level predictive intervals.

'latent'
figsize tuple[float, float] | None

Optional figure size override.

None

Returns:

Type Description
Figure

The Matplotlib figure.

Source code in src/vaxflux/_vaxflux_inference_data.py
def plot_posterior_predictive(
    self,
    *,
    intervals: Sequence[float] = (0.5, 0.8, 0.95),
    plot_incidence: bool = True,
    plot_prevalence: bool = True,
    plot_observations: bool | None = None,
    predictive: Literal["latent", "observation"] = "latent",
    figsize: tuple[float, float] | None = None,
) -> Figure:
    """
    Plot posterior predictive checks with shaded intervals.

    Args:
        intervals: Central predictive intervals to plot.
        plot_incidence: Whether to plot incidence panels.
        plot_prevalence: Whether to plot prevalence panels.
        plot_observations: Whether to plot observed data points. When `None`,
            uses `self.observations` availability.
        predictive: Whether to plot latent curve intervals or observation-level
            predictive intervals.
        figsize: Optional figure size override.

    Returns:
        The Matplotlib figure.

    """
    predictive_label = (
        "Observation Predictive" if predictive == "observation" else "Predictive"
    )
    return self._plot_predictive(
        data=self._predictive_observations(
            predictive=predictive,
            fallback=self.posterior_observations,
            dataset=(
                self.merged_posterior
                if predictive == "observation"
                else getattr(self, "posterior_predictive", None)
            ),
        ),
        title=f"Posterior {predictive_label} Checks",
        intervals=intervals,
        plot_incidence=plot_incidence,
        plot_prevalence=plot_prevalence,
        plot_observations=plot_observations,
        figsize=figsize,
    )

plot_prior_predictive(*, intervals=(0.5, 0.8, 0.95), plot_incidence=True, plot_prevalence=True, plot_observations=None, predictive='latent', figsize=None)

Plot prior predictive checks with shaded intervals.

Parameters:

Name Type Description Default
intervals Sequence[float]

Central predictive intervals to plot.

(0.5, 0.8, 0.95)
plot_incidence bool

Whether to plot incidence panels.

True
plot_prevalence bool

Whether to plot prevalence panels.

True
plot_observations bool | None

Whether to plot observed data points. When None, uses self.observations availability.

None
predictive Literal['latent', 'observation']

Whether to plot latent curve intervals or observation-level predictive intervals.

'latent'
figsize tuple[float, float] | None

Optional figure size override.

None

Returns:

Type Description
Figure

The Matplotlib figure.

Source code in src/vaxflux/_vaxflux_inference_data.py
def plot_prior_predictive(
    self,
    *,
    intervals: Sequence[float] = (0.5, 0.8, 0.95),
    plot_incidence: bool = True,
    plot_prevalence: bool = True,
    plot_observations: bool | None = None,
    predictive: Literal["latent", "observation"] = "latent",
    figsize: tuple[float, float] | None = None,
) -> Figure:
    """
    Plot prior predictive checks with shaded intervals.

    Args:
        intervals: Central predictive intervals to plot.
        plot_incidence: Whether to plot incidence panels.
        plot_prevalence: Whether to plot prevalence panels.
        plot_observations: Whether to plot observed data points. When `None`,
            uses `self.observations` availability.
        predictive: Whether to plot latent curve intervals or observation-level
            predictive intervals.
        figsize: Optional figure size override.

    Returns:
        The Matplotlib figure.

    """
    predictive_label = (
        "Observation Predictive" if predictive == "observation" else "Predictive"
    )
    return self._plot_predictive(
        data=self._predictive_observations(
            predictive=predictive,
            fallback=self.prior_observations,
            dataset=(
                self.merged_prior
                if predictive == "observation"
                else getattr(self, "prior_predictive", None)
            ),
        ),
        title=f"Prior {predictive_label} Checks",
        intervals=intervals,
        plot_incidence=plot_incidence,
        plot_prevalence=plot_prevalence,
        plot_observations=plot_observations,
        figsize=figsize,
    )

posterior_prevalence_scenarios(scenarios)

Summarize posterior end-of-season prevalence scenarios.

Parameters:

Name Type Description Default
scenarios dict[str, tuple[float, float]]

Mapping of scenario name to (low, high) quantile bounds.

required

Returns:

Type Description
DataFrame

A DataFrame with scenario quantiles and prevalence bounds.

Source code in src/vaxflux/_vaxflux_inference_data.py
def posterior_prevalence_scenarios(
    self,
    scenarios: dict[str, tuple[float, float]],
) -> pd.DataFrame:
    """
    Summarize posterior end-of-season prevalence scenarios.

    Args:
        scenarios: Mapping of scenario name to (low, high) quantile bounds.

    Returns:
        A DataFrame with scenario quantiles and prevalence bounds.

    """
    return self._prevalence_scenarios_from_observations(
        observations=self.posterior_observations,
        scenarios=scenarios,
    )

prior_prevalence_scenarios(scenarios)

Summarize prior end-of-season prevalence scenarios.

Parameters:

Name Type Description Default
scenarios dict[str, tuple[float, float]]

Mapping of scenario name to (low, high) quantile bounds.

required

Returns:

Type Description
DataFrame

A DataFrame with scenario quantiles and prevalence bounds.

Source code in src/vaxflux/_vaxflux_inference_data.py
def prior_prevalence_scenarios(
    self,
    scenarios: dict[str, tuple[float, float]],
) -> pd.DataFrame:
    """
    Summarize prior end-of-season prevalence scenarios.

    Args:
        scenarios: Mapping of scenario name to (low, high) quantile bounds.

    Returns:
        A DataFrame with scenario quantiles and prevalence bounds.

    """
    return self._prevalence_scenarios_from_observations(
        observations=self.prior_observations,
        scenarios=scenarios,
    )

VaxfluxModel(curve)

Vaccine uptake model builder.

VaxfluxModel stores the curve, time ranges, covariates, interventions, and observations that define a vaccination uptake model. Configure the model by incrementally adding these components, then use rendering or sampling methods to inspect or estimate the resulting probabilistic model.

Examples:

>>> import pandas as pd
>>> from vaxflux import LogisticCurve, SeasonRange, VaxfluxModel
>>> observations = pd.DataFrame(
...     {
...         "season": ["2023/2024"],
...         "start_date": ["2023-12-01"],
...         "end_date": ["2023-12-07"],
...         "type": ["incidence"],
...         "value": [0.18],
...     }
... )
>>> model = VaxfluxModel(curve=LogisticCurve())
>>> model = (
...     model.add_seasons(
...         SeasonRange(
...             season="2023/2024",
...             start_date="2023-12-01",
...             end_date="2024-03-31",
...         )
...     )
...     .add_observations(observations)
...     .add_observation_process(
...         kind="normal",
...         noise=0.05,
...     )
... )
>>> model
<vaxflux.VaxfluxModel object at ...>

Initialize the VaxfluxModel with a given uptake curve.

Parameters:

Name Type Description Default
curve Curve

An instance of a Curve subclass representing the uptake curve.

required
Source code in src/vaxflux/_vaxflux_model.py
def __init__(self, curve: Curve) -> None:
    """
    Initialize the `VaxfluxModel` with a given uptake curve.

    Args:
        curve: An instance of a `Curve` subclass representing the uptake curve.

    """
    self._curve = curve
    self._seasons: list[SeasonRange] = []
    self._dates: list[DateRange] = []
    self._covariate_categories: list[CovariateCategories] = []
    self._covariates: list[Covariate] = []
    self._observations: VaxfluxObservations | None = None
    self._interventions: list[Intervention] = []
    self._implementations: list[Implementation] = []
    self._observation_process_kind: None | Literal["normal"] = None
    self._observation_process_noise: float = 0.05
    self._observation_process_partially_pool_by_season: bool = True
    self._observation_process_partially_pool_by_covariate: bool = False
    self._observation_process_prevalence_penalty: float = 0.0
    self._observation_labels: list[str] = []

__repr__()

Return a string representation of the model.

Returns:

Type Description
str

A string representation of the VaxfluxModel instance.

Source code in src/vaxflux/_vaxflux_model.py
def __repr__(self) -> str:
    """
    Return a string representation of the model.

    Returns:
        A string representation of the `VaxfluxModel` instance.
    """
    return f"<vaxflux.VaxfluxModel object at {hex(id(self))}>"

add_covariate_categories(*args)

Add one or more covariate categories to the model.

Parameters:

Name Type Description Default
*args CovariateCategories | list[CovariateCategories]

One or more CovariateCategories objects or sequences of CovariateCategories objects.

()

Returns:

Type Description
Self

The model instance for method chaining.

Raises:

Type Description
TypeError

If any argument is not a CovariateCategories or a sequence of CovariateCategories objects.

Source code in src/vaxflux/_vaxflux_model.py
def add_covariate_categories(
    self, *args: CovariateCategories | list[CovariateCategories]
) -> Self:
    """
    Add one or more covariate categories to the model.

    Args:
        *args: One or more `CovariateCategories` objects or sequences of
            `CovariateCategories` objects.

    Returns:
        The model instance for method chaining.

    Raises:
        TypeError: If any argument is not a `CovariateCategories` or a sequence of
            `CovariateCategories` objects.

    """
    covariate_categories = _collect_args(
        args, CovariateCategories, "CovariateCategories"
    )
    self._covariate_categories.extend(covariate_categories)
    return self

add_covariates(*args)

Add one or more covariates to the model.

Parameters:

Name Type Description Default
*args Covariate | list[Covariate]

One or more Covariate objects or sequences of Covariate objects.

()

Returns:

Type Description
Self

The model instance for method chaining.

Raises:

Type Description
TypeError

If any argument is not a Covariate or a sequence of Covariate objects.

ValueError

If duplicate parameter/covariate pairs are provided.

Examples:

>>> from vaxflux._covariates import PartiallyPooledGaussianCovariate
>>> from vaxflux._curves import LogisticCurve
>>> model = VaxfluxModel(curve=LogisticCurve())
>>> model.add_covariates(
...     PartiallyPooledGaussianCovariate(
...         parameter="m",
...         covariate="age",
...         mu_mu=0.5,
...         mu_sigma=0.1,
...         sigma=0.2,
...     )
... )
<vaxflux.VaxfluxModel object at ...>
>>> model.add_covariates(
...     [
...         PartiallyPooledGaussianCovariate(
...             parameter="sigma",
...             mu_mu=0.2,
...             mu_sigma=0.03,
...             sigma=0.1,
...         ),
...     ]
... )
<vaxflux.VaxfluxModel object at ...>
>>> model.add_covariates("invalid_argument")
Traceback (most recent call last):
    ...
TypeError: Arguments must be Covariate objects or sequences of Covariate objects, got str.
Source code in src/vaxflux/_vaxflux_model.py
def add_covariates(self, *args: Covariate | list[Covariate]) -> Self:
    """
    Add one or more covariates to the model.

    Args:
        *args: One or more `Covariate` objects or sequences of `Covariate` objects.

    Returns:
        The model instance for method chaining.

    Raises:
        TypeError: If any argument is not a `Covariate` or a sequence of
            `Covariate` objects.
        ValueError: If duplicate parameter/covariate pairs are provided.

    Examples:
        >>> from vaxflux._covariates import PartiallyPooledGaussianCovariate
        >>> from vaxflux._curves import LogisticCurve
        >>> model = VaxfluxModel(curve=LogisticCurve())
        >>> model.add_covariates(
        ...     PartiallyPooledGaussianCovariate(
        ...         parameter="m",
        ...         covariate="age",
        ...         mu_mu=0.5,
        ...         mu_sigma=0.1,
        ...         sigma=0.2,
        ...     )
        ... )
        <vaxflux.VaxfluxModel object at ...>
        >>> model.add_covariates(
        ...     [
        ...         PartiallyPooledGaussianCovariate(
        ...             parameter="sigma",
        ...             mu_mu=0.2,
        ...             mu_sigma=0.03,
        ...             sigma=0.1,
        ...         ),
        ...     ]
        ... )
        <vaxflux.VaxfluxModel object at ...>
        >>> model.add_covariates("invalid_argument")
        Traceback (most recent call last):
            ...
        TypeError: Arguments must be Covariate objects or sequences of Covariate objects, got str.

    """  # noqa: E501
    covariates = _collect_args(args, Covariate, "Covariate")  # type: ignore[type-abstract]
    existing_pairs = Counter(
        (cov.parameter, cov.covariate) for cov in self._covariates
    )
    new_pairs = Counter((cov.parameter, cov.covariate) for cov in covariates)
    duplicate_pairs = [
        pair
        for pair, count in new_pairs.items()
        if count + existing_pairs.get(pair, 0) > 1
    ]
    if duplicate_pairs:
        msg = (
            "Duplicate covariate parameter/covariate pairs found: "
            f"{duplicate_pairs}."
        )
        raise ValueError(msg)
    self._covariates.extend(covariates)
    return self

add_dates(*args)

Add one or more date ranges to the model.

Parameters:

Name Type Description Default
*args DateRange | list[DateRange]

One or more DateRange objects or sequences of DateRange objects.

()

Returns:

Type Description
Self

The model instance for method chaining.

Raises:

Type Description
TypeError

If any argument is not a DateRange or a sequence of DateRange objects.

ValueError

If duplicate date ranges are found.

ValueError

If overlapping date ranges are found.

Examples:

>>> from vaxflux import LogisticCurve, DateRange, SeasonRange
>>> model = VaxfluxModel(curve=LogisticCurve())
>>> model.add_seasons(
...     SeasonRange(
...         season="2023/2024",
...         start_date="2023-12-01",
...         end_date="2024-03-31",
...     )
... )
<vaxflux.VaxfluxModel object at ...>
>>> model.add_dates(
...     DateRange(
...         season="2023/2024",
...         start_date="2023-12-01",
...         end_date="2023-12-07",
...         report_date="2023-12-08",
...     )
... )
<vaxflux.VaxfluxModel object at ...>
>>> model.add_dates("invalid_argument")
Traceback (most recent call last):
    ...
TypeError: Arguments must be DateRange objects or sequences of DateRange objects, got str.
Source code in src/vaxflux/_vaxflux_model.py
def add_dates(self, *args: DateRange | list[DateRange]) -> Self:
    """
    Add one or more date ranges to the model.

    Args:
        *args: One or more `DateRange` objects or sequences of `DateRange` objects.

    Returns:
        The model instance for method chaining.

    Raises:
        TypeError: If any argument is not a `DateRange` or a sequence of
            `DateRange` objects.
        ValueError: If duplicate date ranges are found.
        ValueError: If overlapping date ranges are found.

    Examples:
        >>> from vaxflux import LogisticCurve, DateRange, SeasonRange
        >>> model = VaxfluxModel(curve=LogisticCurve())
        >>> model.add_seasons(
        ...     SeasonRange(
        ...         season="2023/2024",
        ...         start_date="2023-12-01",
        ...         end_date="2024-03-31",
        ...     )
        ... )
        <vaxflux.VaxfluxModel object at ...>
        >>> model.add_dates(
        ...     DateRange(
        ...         season="2023/2024",
        ...         start_date="2023-12-01",
        ...         end_date="2023-12-07",
        ...         report_date="2023-12-08",
        ...     )
        ... )
        <vaxflux.VaxfluxModel object at ...>
        >>> model.add_dates("invalid_argument")
        Traceback (most recent call last):
            ...
        TypeError: Arguments must be DateRange objects or sequences of DateRange objects, got str.

    """  # noqa: E501
    dates = _collect_ranges(args, DateRange, "DateRange")
    if missing_seasons := {date.season for date in dates} - {
        season.season for season in self._seasons
    }:
        msg = (
            "The following date range seasons have not been added "
            f"to the model: {sorted(missing_seasons)}."
        )
        raise ValueError(msg)
    _validate_ranges(dates, self._dates)
    self._dates.extend(dates)
    return self

add_implementations(*args)

Add one or more implementations to the model.

Parameters:

Name Type Description Default
*args Implementation | list[Implementation]

One or more Implementation objects or sequences of Implementation objects.

()

Returns:

Type Description
Self

The model instance for method chaining.

Raises:

Type Description
TypeError

If any argument is not an Implementation or a sequence of Implementation objects.

ValueError

If an implementation's intervention has not been added to the model.

Source code in src/vaxflux/_vaxflux_model.py
def add_implementations(self, *args: Implementation | list[Implementation]) -> Self:
    """
    Add one or more implementations to the model.

    Args:
        *args: One or more `Implementation` objects or sequences of
            `Implementation` objects.

    Returns:
        The model instance for method chaining.

    Raises:
        TypeError: If any argument is not an `Implementation` or a sequence of
            `Implementation` objects.
        ValueError: If an implementation's intervention has not been added to the
            model.

    """
    implementations = _collect_args(args, Implementation, "Implementation")
    if missing_interventions := {
        implementation.intervention for implementation in implementations
    } - {intervention.name for intervention in self._interventions}:
        msg = (
            "The following implementation interventions have not been added "
            f"to the model: {sorted(missing_interventions)}."
        )
        raise ValueError(msg)
    self._implementations.extend(implementations)
    return self

add_interventions(*args)

Add one or more interventions to the model.

Parameters:

Name Type Description Default
*args Intervention | list[Intervention]

One or more Intervention objects or sequences of Intervention objects.

()

Returns:

Type Description
Self

The model instance for method chaining.

Raises:

Type Description
TypeError

If any argument is not an Intervention or a sequence of Intervention objects.

Source code in src/vaxflux/_vaxflux_model.py
def add_interventions(self, *args: Intervention | list[Intervention]) -> Self:
    """
    Add one or more interventions to the model.

    Args:
        *args: One or more `Intervention` objects or sequences of `Intervention`
            objects.

    Returns:
        The model instance for method chaining.

    Raises:
        TypeError: If any argument is not an `Intervention` or a sequence of
            `Intervention` objects.

    """
    interventions = _collect_args(args, Intervention, "Intervention")
    self._interventions.extend(interventions)
    return self

add_observation_process(kind, noise, *, partially_pool_by_season=True, partially_pool_by_covariate=False, prevalence_penalty=0.0)

Add an observation process to the model based on the added observations.

Parameters:

Name Type Description Default
kind Literal['normal']

The kind of observation process to add.

required
noise float

The observation noise parameter.

required
partially_pool_by_season bool

Whether to partially pool observation noise by season.

True
partially_pool_by_covariate bool

Whether to partially pool observation noise by covariate.

False
prevalence_penalty float

Penalty to apply when modeled prevalence exceeds 1. The penalty is applied as -prevalence_penalty * excess^2 where excess is the excess prevalence. Use math.inf for a hard constraint.

0.0

Returns:

Type Description
Self

The model instance for method chaining.

Raises:

Type Description
ValueError

If no observations have been added to the model.

ValueError

If an unsupported observation process kind is provided.

ValueError

If the prevalence penalty is negative.

Source code in src/vaxflux/_vaxflux_model.py
def add_observation_process(
    self,
    kind: Literal["normal"],
    noise: float,
    *,
    partially_pool_by_season: bool = True,
    partially_pool_by_covariate: bool = False,
    prevalence_penalty: float = 0.0,
) -> Self:
    """
    Add an observation process to the model based on the added observations.

    Args:
        kind: The kind of observation process to add.
        noise: The observation noise parameter.
        partially_pool_by_season: Whether to partially pool observation noise by
            season.
        partially_pool_by_covariate: Whether to partially pool observation noise by
            covariate.
        prevalence_penalty: Penalty to apply when modeled prevalence exceeds 1. The
            penalty is applied as `-prevalence_penalty * excess^2` where `excess`
            is the excess prevalence. Use `math.inf` for a hard constraint.

    Returns:
        The model instance for method chaining.

    Raises:
        ValueError: If no observations have been added to the model.
        ValueError: If an unsupported observation process kind is provided.
        ValueError: If the prevalence penalty is negative.

    """
    if self._observations is None:
        msg = "No observations have been added to the model, nothing to observe."
        raise ValueError(msg)
    if kind != "normal":
        msg = f"Unsupported observation process kind: {kind}."
        raise ValueError(msg)
    if prevalence_penalty < 0:
        msg = "prevalence_penalty must be non-negative."
        raise ValueError(msg)
    self._observation_process_kind = kind
    self._observation_process_noise = noise
    self._observation_process_partially_pool_by_season = partially_pool_by_season
    self._observation_process_partially_pool_by_covariate = (
        partially_pool_by_covariate
    )
    self._observation_process_prevalence_penalty = prevalence_penalty
    return self

add_observations(observations)

Add observations to the model.

Parameters:

Name Type Description Default
observations VaxfluxObservations | DataFrame

The observations to add to the model. Accepts either a vaxflux.VaxfluxObservations instance or a raw pandas.DataFrame which will be validated on construction.

required

Returns:

Type Description
Self

The model instance for method chaining.

Raises:

Type Description
ValueError

If observations were already added to the model.

Source code in src/vaxflux/_vaxflux_model.py
def add_observations(
    self, observations: VaxfluxObservations | pd.DataFrame
) -> Self:
    """
    Add observations to the model.

    Args:
        observations: The observations to add to the model. Accepts either a
            `vaxflux.VaxfluxObservations` instance or a raw `pandas.DataFrame`
            which will be validated on construction.

    Returns:
        The model instance for method chaining.

    Raises:
        ValueError: If observations were already added to the model.

    """
    if self._observations is not None:
        msg = "Observations have already been added to the model."
        raise ValueError(msg)
    self._observations = VaxfluxObservations.from_dataframe(observations)
    return self

add_seasons(*args)

Add one or more seasons to the model.

Parameters:

Name Type Description Default
*args SeasonRange | list[SeasonRange]

One or more SeasonRange objects or sequences of SeasonRange objects.

()

Returns:

Type Description
Self

The model instance for method chaining.

Raises:

Type Description
TypeError

If any argument is not a SeasonRange or a sequence of SeasonRange objects.

Examples:

>>> from vaxflux import LogisticCurve, SeasonRange
>>> model = VaxfluxModel(curve=LogisticCurve())
>>> model.add_seasons(
...     SeasonRange(
...         season="2023/2024",
...         start_date="2023-12-01",
...         end_date="2024-03-31",
...     )
... )
<vaxflux.VaxfluxModel object at ...>
>>> model.add_seasons(
...     SeasonRange(
...         season="2024/2025",
...         start_date="2024-12-01",
...         end_date="2025-03-31",
...     ),
...     SeasonRange(
...         season="2025/2026",
...         start_date="2025-12-01",
...         end_date="2026-03-31",
...     ),
... )
<vaxflux.VaxfluxModel object at ...>
>>> model.add_seasons(
...     [
...         SeasonRange(
...             season="2026/2027",
...             start_date="2026-12-01",
...             end_date="2027-03-31",
...         ),
...         SeasonRange(
...             season="2027/2028",
...             start_date="2027-12-01",
...             end_date="2028-03-31",
...         ),
...     ]
... )
<vaxflux.VaxfluxModel object at ...>
>>> model.add_seasons("invalid_argument")
Traceback (most recent call last):
    ...
TypeError: Arguments must be SeasonRange objects or sequences of SeasonRange objects, got str.
Source code in src/vaxflux/_vaxflux_model.py
def add_seasons(self, *args: SeasonRange | list[SeasonRange]) -> Self:
    """
    Add one or more seasons to the model.

    Args:
        *args: One or more `SeasonRange` objects or sequences of `SeasonRange`
            objects.

    Returns:
        The model instance for method chaining.

    Raises:
        TypeError: If any argument is not a `SeasonRange` or a sequence of
            `SeasonRange` objects.

    Examples:
        >>> from vaxflux import LogisticCurve, SeasonRange
        >>> model = VaxfluxModel(curve=LogisticCurve())
        >>> model.add_seasons(
        ...     SeasonRange(
        ...         season="2023/2024",
        ...         start_date="2023-12-01",
        ...         end_date="2024-03-31",
        ...     )
        ... )
        <vaxflux.VaxfluxModel object at ...>
        >>> model.add_seasons(
        ...     SeasonRange(
        ...         season="2024/2025",
        ...         start_date="2024-12-01",
        ...         end_date="2025-03-31",
        ...     ),
        ...     SeasonRange(
        ...         season="2025/2026",
        ...         start_date="2025-12-01",
        ...         end_date="2026-03-31",
        ...     ),
        ... )
        <vaxflux.VaxfluxModel object at ...>
        >>> model.add_seasons(
        ...     [
        ...         SeasonRange(
        ...             season="2026/2027",
        ...             start_date="2026-12-01",
        ...             end_date="2027-03-31",
        ...         ),
        ...         SeasonRange(
        ...             season="2027/2028",
        ...             start_date="2027-12-01",
        ...             end_date="2028-03-31",
        ...         ),
        ...     ]
        ... )
        <vaxflux.VaxfluxModel object at ...>
        >>> model.add_seasons("invalid_argument")
        Traceback (most recent call last):
            ...
        TypeError: Arguments must be SeasonRange objects or sequences of SeasonRange objects, got str.

    """  # noqa: E501
    seasons = _collect_ranges(args, SeasonRange, "SeasonRange")
    _validate_ranges(seasons, self._seasons)
    self._seasons.extend(seasons)
    return self

from_csv(*args, **kwargs) classmethod

Construct a default model from observations stored in a CSV file.

The CSV is read with VaxfluxObservations.from_csv, then used to build a model with:

  • a default LogisticCurve,
  • inferred seasons and dates from the observations,
  • one CovariateCategories entry per observation covariate using the full sorted list of observed levels,
  • one SeasonVaryingPartiallyPooledGaussianCovariate per curve-parameter/covariate pair with weak default hyperparameters (mu_mu=0.0, mu_sigma=1.0, sigma=1.0),
  • no interventions or implementations,
  • a "normal" observation process with noise=0.001.

Parameters:

Name Type Description Default
*args Any

Positional arguments forwarded to VaxfluxObservations.from_csv.

()
**kwargs Any

Keyword arguments forwarded to VaxfluxObservations.from_csv.

{}

Returns:

Type Description
Self

A configured VaxfluxModel instance.

Source code in src/vaxflux/_vaxflux_model.py
@classmethod
def from_csv(cls, *args: Any, **kwargs: Any) -> Self:
    """
    Construct a default model from observations stored in a CSV file.

    The CSV is read with `VaxfluxObservations.from_csv`, then used to build a
    model with:

    - a default `LogisticCurve`,
    - inferred seasons and dates from the observations,
    - one `CovariateCategories` entry per observation covariate using the full
      sorted list of observed levels,
    - one `SeasonVaryingPartiallyPooledGaussianCovariate` per
      curve-parameter/covariate pair with weak default hyperparameters
      (`mu_mu=0.0`, `mu_sigma=1.0`, `sigma=1.0`),
    - no interventions or implementations,
    - a `"normal"` observation process with `noise=0.001`.

    Args:
        *args: Positional arguments forwarded to
            `VaxfluxObservations.from_csv`.
        **kwargs: Keyword arguments forwarded to
            `VaxfluxObservations.from_csv`.

    Returns:
        A configured `VaxfluxModel` instance.

    """
    observations = VaxfluxObservations.from_csv(*args, **kwargs)
    model = cls(curve=LogisticCurve())

    season_ranges: list[SeasonRange] = _infer_ranges_from_observations(
        observations.data, [], "season"
    )
    if season_ranges:
        model.add_seasons(season_ranges)

    date_ranges: list[DateRange] = _infer_ranges_from_observations(
        observations.data, [], "date"
    )
    if date_ranges:
        model.add_dates(date_ranges)

    covariate_categories = [
        CovariateCategories(
            covariate=covariate_name,
            categories=tuple(
                sorted(observations.data[covariate_name].astype(str).unique())
            ),
        )
        for covariate_name in observations.covariate_columns
    ]
    if covariate_categories:
        model.add_covariate_categories(covariate_categories)
        model.add_covariates(
            [
                SeasonVaryingPartiallyPooledGaussianCovariate(
                    parameter=parameter,
                    covariate=covariate_category.covariate,
                    mu_mu=0.0,
                    mu_sigma=1.0,
                    sigma=1.0,
                )
                for parameter in model._curve.parameters
                for covariate_category in covariate_categories
            ]
        )

    return model.add_observations(observations).add_observation_process(
        kind="normal", noise=0.001
    )

render_model(**kwargs)

Render the model graph using NumPyro's rendering utilities.

Parameters:

Name Type Description Default
**kwargs Any

Keyword arguments forwarded to numpyro.render_model.

{}

Returns:

Type Description
Any

The rendered model object (typically a Graphviz object).

Source code in src/vaxflux/_vaxflux_model.py
def render_model(self, **kwargs: Any) -> Any:  # noqa: ANN401
    """
    Render the model graph using NumPyro's rendering utilities.

    Args:
        **kwargs: Keyword arguments forwarded to `numpyro.render_model`.

    Returns:
        The rendered model object (typically a Graphviz object).

    """
    self._pre_model()
    return _render_model(self._model, **kwargs)

sample(prior_samples=None, warmup=None, samples=None, chains=1, random_seed=1, mcmc_args=None, nuts_args=None, prior_predictive_args=None, posterior_predictive_args=None, arviz_args=None)

Sample from the prior predictive, posterior, and posterior predictive.

Parameters:

Name Type Description Default
prior_samples int | None

Number of prior predictive samples to draw. When None, prior predictive sampling is skipped.

None
warmup int | None

Number of warmup (burn-in) steps.

None
samples int | None

Number of posterior samples to draw. When None, posterior and posterior predictive sampling are skipped.

None
chains int

Number of MCMC chains to run.

1
random_seed int

Seed for the random number generator.

1
mcmc_args dict[str, Any] | None

Additional arguments for the MCMC sampler.

None
nuts_args dict[str, Any] | None

Additional arguments for the NUTS kernel.

None
prior_predictive_args dict[str, Any] | None

Additional arguments for the prior predictive sampler.

None
posterior_predictive_args dict[str, Any] | None

Additional arguments for the posterior predictive sampler.

None
arviz_args dict[str, Any] | None

Additional arguments for ArviZ's from_numpyro.

None

Returns:

Type Description
VaxfluxInferenceData

The inference data containing whichever sampling stages were requested.

Source code in src/vaxflux/_vaxflux_model.py
def sample(  # noqa: PLR0913
    self,
    prior_samples: int | None = None,
    warmup: int | None = None,
    samples: int | None = None,
    chains: int = 1,
    random_seed: int = 1,
    mcmc_args: dict[str, Any] | None = None,
    nuts_args: dict[str, Any] | None = None,
    prior_predictive_args: dict[str, Any] | None = None,
    posterior_predictive_args: dict[str, Any] | None = None,
    arviz_args: dict[str, Any] | None = None,
) -> VaxfluxInferenceData:
    """
    Sample from the prior predictive, posterior, and posterior predictive.

    Args:
        prior_samples: Number of prior predictive samples to draw. When `None`,
            prior predictive sampling is skipped.
        warmup: Number of warmup (burn-in) steps.
        samples: Number of posterior samples to draw. When `None`, posterior and
            posterior predictive sampling are skipped.
        chains: Number of MCMC chains to run.
        random_seed: Seed for the random number generator.
        mcmc_args: Additional arguments for the MCMC sampler.
        nuts_args: Additional arguments for the NUTS kernel.
        prior_predictive_args: Additional arguments for the prior predictive
            sampler.
        posterior_predictive_args: Additional arguments for the posterior
            predictive sampler.
        arviz_args: Additional arguments for ArviZ's `from_numpyro`.

    Returns:
        The inference data containing whichever sampling stages were requested.

    """
    if prior_samples is None and samples is None:
        msg = "At least one of `prior_samples` or `samples` must be provided."
        raise ValueError(msg)

    # Prepare arguments
    mcmc_args = mcmc_args or {}
    nuts_args = nuts_args or {}
    prior_predictive_args = prior_predictive_args or {}
    posterior_predictive_args = posterior_predictive_args or {}
    from_numpyro_kwargs = dict(arviz_args or {})

    # Prepare the model
    self._pre_model()
    rng_key = key(random_seed)
    prior_key, mcmc_key, posterior_key = split(rng_key, num=3)
    coords, dims = self._coords_and_dims(
        base_coords=from_numpyro_kwargs.get("coords"),
        base_dims=from_numpyro_kwargs.get("dims"),
    )
    from_numpyro_kwargs["coords"] = coords
    from_numpyro_kwargs["dims"] = dims
    if self._observations is not None:
        from_numpyro_kwargs["observations"] = self._observations.data.copy()

    # Sample from prior predictive if requested
    if prior_samples is not None:
        prior_predictive = Predictive(
            self._model, num_samples=prior_samples, **prior_predictive_args
        )
        from_numpyro_kwargs["prior"] = prior_predictive(
            prior_key, simulate_observations=True
        )

    # Sample from posterior/posterior predictive if requested
    if samples is not None:
        kernel = NUTS(self._model, **nuts_args)
        mcmc = MCMC(
            kernel,
            num_warmup=warmup,
            num_samples=samples,
            num_chains=chains,
            **mcmc_args,
        )
        mcmc_key = split(mcmc_key, num=chains) if chains > 1 else mcmc_key
        mcmc.run(mcmc_key)

        posterior_samples = mcmc.get_samples()
        posterior_predictive = Predictive(
            self._model,
            posterior_samples=posterior_samples,
            **posterior_predictive_args,
        )
        from_numpyro_kwargs["posterior_predictive"] = posterior_predictive(
            posterior_key, simulate_observations=True
        )
        from_numpyro_kwargs["posterior"] = mcmc

    return VaxfluxInferenceData.from_numpyro(**from_numpyro_kwargs)

VaxfluxObservations(data)

Container for vaxflux observations data.

Wraps a validated pandas DataFrame and provides a stable interface for working with observations. Use from_dataframe to construct an instance from a raw DataFrame.

The underlying DataFrame is accessible via the data property. Columns that are not part of the standard metadata schema are exposed via the covariate_columns property for downstream use by the model.

Examples:

>>> import pandas as pd
>>> from vaxflux import VaxfluxObservations
>>> obs = VaxfluxObservations.from_dataframe(
...     pd.DataFrame(
...         {
...             "season": ["2023/2024"],
...             "start_date": ["2023-10-01"],
...             "end_date": ["2023-10-31"],
...             "type": ["incidence"],
...             "value": [0.15],
...             "location": ["ABC County"],
...         }
...     )
... )
>>> obs.data.shape
(1, 7)
>>> obs.covariate_columns
['location']
>>> len(obs)
1
>>> obs["season"]
0    2023/2024
Name: season, dtype: str
Source code in src/vaxflux/_vaxflux_observations.py
def __init__(self, data: pd.DataFrame) -> None:
    self._data = _validate_and_format_observations(data)

covariate_columns property

Column names that are not part of the standard metadata schema.

Returns:

Type Description
list[str]

A sorted list of column names that are not in the set of known

list[str]

metadata columns (season, start_date, end_date,

list[str]

report_date, type, value).

data property

The validated underlying observations DataFrame.

from_csv(*args, **kwargs) classmethod

Construct a VaxfluxObservations by reading a CSV with pandas.

This is a thin wrapper around pandas.read_csv; all positional and keyword arguments are forwarded directly to that function before the resulting DataFrame is validated and wrapped.

Parameters:

Name Type Description Default
*args Any

Positional arguments forwarded to pandas.read_csv.

()
**kwargs Any

Keyword arguments forwarded to pandas.read_csv.

{}

Returns:

Type Description
VaxfluxObservations

A validated VaxfluxObservations instance.

Raises:

Type Description
Exception

Propagates any exception raised by pandas.read_csv.

ValueError

If the resulting DataFrame is empty or missing required columns.

ValueError

If the value column contains NaN or negative values.

ValueError

If the type column contains unsupported values.

NotImplementedError

If prevalence observations are provided.

NotImplementedError

If observations with differing report dates are provided.

Source code in src/vaxflux/_vaxflux_observations.py
@classmethod
def from_csv(cls, *args: Any, **kwargs: Any) -> "VaxfluxObservations":
    """
    Construct a `VaxfluxObservations` by reading a CSV with pandas.

    This is a thin wrapper around `pandas.read_csv`; all positional and keyword
    arguments are forwarded directly to that function before the resulting
    DataFrame is validated and wrapped.

    Args:
        *args: Positional arguments forwarded to `pandas.read_csv`.
        **kwargs: Keyword arguments forwarded to `pandas.read_csv`.

    Returns:
        A validated `VaxfluxObservations` instance.

    Raises:
        Exception: Propagates any exception raised by `pandas.read_csv`.
        ValueError: If the resulting DataFrame is empty or missing required
            columns.
        ValueError: If the `value` column contains NaN or negative values.
        ValueError: If the `type` column contains unsupported values.
        NotImplementedError: If prevalence observations are provided.
        NotImplementedError: If observations with differing report dates are
            provided.
    """
    return cls.from_dataframe(pd.read_csv(*args, **kwargs))

from_dataframe(data) classmethod

Construct a VaxfluxObservations from a DataFrame or existing instance.

If data is already a VaxfluxObservations, it is returned unchanged. If it is a pandas.DataFrame, it is validated and wrapped.

Parameters:

Name Type Description Default
data VaxfluxObservations | DataFrame

A raw observations DataFrame or an existing VaxfluxObservations instance.

required

Returns:

Type Description
VaxfluxObservations

A validated VaxfluxObservations instance.

Raises:

Type Description
ValueError

If the DataFrame is empty or missing required columns.

ValueError

If the value column contains NaN or negative values.

ValueError

If the type column contains unsupported values.

NotImplementedError

If prevalence observations are provided.

NotImplementedError

If observations with differing report dates are provided (nowcasting is not yet supported).

Source code in src/vaxflux/_vaxflux_observations.py
@classmethod
def from_dataframe(
    cls,
    data: "VaxfluxObservations | pd.DataFrame",
) -> "VaxfluxObservations":
    """
    Construct a `VaxfluxObservations` from a DataFrame or existing instance.

    If `data` is already a `VaxfluxObservations`, it is returned
    unchanged. If it is a `pandas.DataFrame`, it is validated and
    wrapped.

    Args:
        data: A raw observations DataFrame or an existing `VaxfluxObservations`
            instance.

    Returns:
        A validated `VaxfluxObservations` instance.

    Raises:
        ValueError: If the DataFrame is empty or missing required columns.
        ValueError: If the `value` column contains NaN or negative values.
        ValueError: If the `type` column contains unsupported values.
        NotImplementedError: If prevalence observations are provided.
        NotImplementedError: If observations with differing report dates are
            provided (nowcasting is not yet supported).

    """
    if isinstance(data, VaxfluxObservations):
        return data
    return cls(data)

daily_date_ranges(season_ranges, range_days=0, remainder='raise')

Create daily date ranges from the season ranges.

Parameters:

Name Type Description Default
season_ranges list[SeasonRange] | SeasonRange

The season ranges to create the daily date ranges from.

required
range_days int

The number of days for each daily date range, must be at least 0.

0
remainder Literal['fill', 'raise', 'skip']

The strategy to handle the remainder of days when the season ranges do not divide evenly into daily date ranges. Options are "fill" to fill the remainder with the last date range, "raise" to raise an error, and "skip" to skip the remainder.

'raise'

Returns:

Type Description
list[DateRange]

The daily date ranges for the uptake scenarios.

Raises:

Type Description
ValueError

If the number of days for each daily date range is less than 1.

ValueError

If the number of days for each daily date range does not divide evenly into the season range and remainder is 'raise'.

Source code in src/vaxflux/_dates.py
def daily_date_ranges(
    season_ranges: list[SeasonRange] | SeasonRange,
    range_days: int = 0,
    remainder: Literal["fill", "raise", "skip"] = "raise",
) -> list[DateRange]:
    """
    Create daily date ranges from the season ranges.

    Args:
        season_ranges: The season ranges to create the daily date ranges from.
        range_days: The number of days for each daily date range, must be at least 0.
        remainder: The strategy to handle the remainder of days when the season ranges
            do not divide evenly into daily date ranges. Options are "fill" to fill the
            remainder with the last date range, "raise" to raise an error, and "skip" to
            skip the remainder.

    Returns:
        The daily date ranges for the uptake scenarios.

    Raises:
        ValueError: If the number of days for each daily date range is less than 1.
        ValueError: If the number of days for each daily date range does not divide
            evenly into the season range and `remainder` is 'raise'.

    """
    season_ranges = (
        [season_ranges] if isinstance(season_ranges, SeasonRange) else season_ranges
    )
    if range_days < 0:
        msg = "The number of days for each daily date range must be at least 0."
        raise ValueError(
            msg,
        )
    date_ranges = []
    td = timedelta(days=range_days)
    td_one_day = timedelta(days=1)
    for season_range in season_ranges:
        start_date = season_range.start_date
        while start_date <= season_range.end_date:
            end_date = start_date + td
            if end_date > season_range.end_date:
                if remainder == "raise":
                    msg = (
                        "The number of days for each daily date range does not divide "
                        f"evenly into the season range for {season_range.season}."
                    )
                    raise ValueError(
                        msg,
                    )
                if remainder == "fill":
                    end_date = season_range.end_date
                else:
                    break
            date_ranges.append(
                DateRange(
                    season=season_range.season,
                    start_date=start_date,
                    end_date=end_date,
                    report_date=end_date,
                ),
            )
            start_date = end_date + td_one_day
    return date_ranges