Vaxflux¶
vaxflux
¶
Model seasonal vaccination uptake curves.
Covariate
¶
Bases: ABC, BaseModel
Abstract base class for covariates in vaxflux.
Subclasses must implement sample() and adhere to the following shape contract:
- Seasonal covariates (
covariate is None):sample()must register and return anumpyro.deterministicsite namedf"covariate_values_{self.prefix}"with shape(num_seasons,). - Categorical covariates (
covariate is not None):sample()must register and return anumpyro.deterministicsite namedf"covariate_values_{self.prefix}"with shape(num_seasons, num_categories).
Attributes:
| Name | Type | Description |
|---|---|---|
parameter |
str
|
The name of the model parameter this covariate affects. |
covariate |
str | None
|
The name of the covariate variable. |
prefix
property
¶
Get the prefix for the covariate based on the parameter name.
Returns:
| Type | Description |
|---|---|
str
|
The prefix string derived from the parameter name. |
extra_dims(_season_coord, _category_short_coord)
¶
Return ArviZ dimension mappings for any additional numpyro sites.
Returns dimension mappings for any additional numpyro sites this
covariate registers beyond the standard covariate_values_*
deterministic.
The covariate_values_* site and the raw {prefix} sample site are
handled automatically by VaxfluxModel; only override this method when
sample() registers extra named sites (e.g. hyperparameter draws) that
should carry labelled coordinates in the ArviZ output.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
_season_coord
|
str
|
The name of the season coordinate ( |
required |
_category_short_coord
|
str | None
|
The name of the non-baseline-category
coordinate for this covariate, or |
required |
Returns:
| Type | Description |
|---|---|
dict[str, list[str]]
|
A mapping from numpyro site name to a list of coordinate dimension |
dict[str, list[str]]
|
names, in the same format as the |
dict[str, list[str]]
|
|
dict[str, list[str]]
|
are no additional sites to label. |
Source code in src/vaxflux/_covariates.py
sample(*, season_names, covariate_categories)
abstractmethod
¶
Abstract method to sample covariate values using model context.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
season_names
|
list[str]
|
Season names defined in the model. |
required |
covariate_categories
|
list[str] | None
|
All categories for this covariate including the
baseline at index 0, or |
required |
Returns:
| Type | Description |
|---|---|
NumericalArrayLike
|
A numerical array-like structure containing the sampled covariate values. |
Source code in src/vaxflux/_covariates.py
CovariateCategories
¶
Bases: BaseModel
A representation of the categories for a covariate.
Curve()
¶
Bases: ABC
Abstract class for implementations of uptake curves.
Attributes:
| Name | Type | Description |
|---|---|---|
parameters |
tuple[str, ...]
|
A tuple of parameter names required by the prevalence function. |
Examples:
>>> import jax.numpy as jnp
>>> from vaxflux import Curve
>>> class ClampedSlopeCurve(Curve):
... def prevalence(
... self, t: NumericalArrayLike, m: NumericalArrayLike
... ) -> jax.Array:
... return jnp.clip(t * m, 0.0, 1.0)
>>> t = jnp.linspace(-1.0, 3.0, num=8)
>>> curve = ClampedSlopeCurve()
>>> curve.parameters
('m',)
>>> curve.prevalence(t, m=0.5)
Array([0. , 0. , 0.0714286 , 0.35714293, 0.6428572 ,
0.9285715 , 1. , 1. ], dtype=float32)
>>> curve.incidence(t, m=0.5)
Array([0. , 0. , 0.5, 0.5, 0.5, 0.5, 0. , 0. ], dtype=float32)
Source code in src/vaxflux/_curves.py
incidence(t, **kwargs)
¶
Compute the incidence at time t.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
NumericalArrayLike
|
The time steps to evaluate the incidence curve at. |
required |
**kwargs
|
NumericalArrayLike
|
Additional parameters required by the incidence model. |
{}
|
Returns:
| Type | Description |
|---|---|
Array
|
The incidence curve evaluated at time |
Source code in src/vaxflux/_curves.py
plot(t, *, parameter_sets=None, labels=None, plot_incidence=True, plot_prevalence=True, figsize=None, title=None, **kwargs)
¶
Plot the curve's incidence and/or prevalence over time.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
NumericalArrayLike
|
The time steps to evaluate and plot. |
required |
parameter_sets
|
Sequence[dict[str, NumericalArrayLike]] | None
|
Optional collection of parameter dictionaries to overlay on
the same axes. When omitted, |
None
|
labels
|
Sequence[str] | None
|
Optional labels for each parameter set. When omitted, labels are
generated from |
None
|
plot_incidence
|
bool
|
Whether to plot the incidence curve. |
True
|
plot_prevalence
|
bool
|
Whether to plot the prevalence curve. |
True
|
figsize
|
tuple[float, float] | None
|
Optional figure size override. |
None
|
title
|
str | None
|
Optional figure title. |
None
|
**kwargs
|
NumericalArrayLike
|
Parameters forwarded to the curve methods when plotting a single parameter set. |
{}
|
Returns:
| Type | Description |
|---|---|
Figure
|
The Matplotlib figure. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If neither incidence nor prevalence is selected. |
ValueError
|
If both |
ValueError
|
If |
ValueError
|
If |
Source code in src/vaxflux/_curves.py
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 | |
prevalence(t, **kwargs)
abstractmethod
¶
Compute the prevalence at time t.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
NumericalArrayLike
|
The time steps to evaluate the prevalence curve at. |
required |
**kwargs
|
NumericalArrayLike
|
Additional parameters required by the prevalence model. |
{}
|
Returns:
| Type | Description |
|---|---|
Array
|
The prevalence curve evaluated at time |
Source code in src/vaxflux/_curves.py
prevalence_difference(t0, t1, **kwargs)
¶
Compute prevalence differences between two time arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t0
|
NumericalArrayLike
|
The start time steps for the interval. |
required |
t1
|
NumericalArrayLike
|
The end time steps for the interval. |
required |
**kwargs
|
NumericalArrayLike
|
Additional parameters required by the prevalence model. |
{}
|
Returns:
| Type | Description |
|---|---|
Array
|
The prevalence difference for each interval. |
Source code in src/vaxflux/_curves.py
DateRange
¶
Bases: BaseModel
A representation of a date range for uptake scenarios.
Examples:
>>> from vaxflux import DateRange
>>> date_range = DateRange(
... season="2023/2024",
... start_date="2023-12-01",
... end_date="2023-12-31",
... report_date="2024-01-01",
... )
>>> date_range.season
'2023/2024'
>>> date_range.start_date
datetime.date(2023, 12, 1)
>>> date_range.end_date
datetime.date(2023, 12, 31)
>>> date_range.report_date
datetime.date(2024, 1, 1)
>>> DateRange(
... season="2023/2024",
... start_date="2023-12-01",
... end_date="2023-11-30",
... report_date="2023-12-01",
... )
Traceback (most recent call last):
...
pydantic_core._pydantic_core.ValidationError: 1 validation error for DateRange
Value error, The end date, 2023-11-30, must be after or the same as the start date 2023-12-01. [...]
For further information visit ...
>>> DateRange(
... season="2023/2024",
... start_date="2023-12-01",
... end_date="2023-12-31",
... report_date="2023-12-30",
... )
Traceback (most recent call last):
...
pydantic_core._pydantic_core.ValidationError: 1 validation error for DateRange
Value error, The report date, 2023-12-30, must be after or the same as the end date 2023-12-31. [...]
For further information visit ...
Implementation
¶
Bases: BaseModel
A representation of an implementation of an intervention.
name
property
¶
Get the name of the implementation.
Returns:
| Type | Description |
|---|---|
str
|
The name of the implementation, formatted as a string. |
Intervention
¶
Bases: BaseModel
A representation for the affect of an intervention on the uptake of a vaccine.
numpyro_distribution()
¶
LogisticCurve()
¶
Bases: Curve
Logistic uptake curve implementation.
The logistic curve is defined by the following prevalence function:
This class implements a logistic curve with parameters \(m\), \(r\), and \(s\) which is given by:
Examples:
>>> import jax.numpy as jnp
>>> from vaxflux import LogisticCurve
>>> t = jnp.array([0.0, 1.0, 2.0, 3.0])
>>> curve = LogisticCurve()
>>> prevalence = curve.prevalence(t, m=0.0, r=1.0, s=1.0)
>>> prevalence
Array([0.03095159, 0.25 , 0.4690484 , 0.4978322 ], dtype=float32)
>>> incidence = curve.incidence(t, m=0.0, r=1.0, s=1.0)
>>> incidence
Array([0.0789269 , 0.33978522, 0.07892691, 0.00586712], dtype=float32)
Source code in src/vaxflux/_curves.py
prevalence(t, m, r, s)
¶
Compute the logistic prevalence at time t.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
NumericalArrayLike
|
The time steps to evaluate the prevalence curve at. |
required |
m
|
NumericalArrayLike
|
The curve's maximum value. |
required |
r
|
NumericalArrayLike
|
The steepness of the curve. |
required |
s
|
NumericalArrayLike
|
The x-value of the sigmoid's midpoint. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The logistic prevalence curve evaluated at time |
Source code in src/vaxflux/_curves.py
PartiallyPooledGaussianCovariate
¶
Bases: Covariate
Covariate model using a partially pooled Gaussian approach.
For seasonal covariates (covariate is None) this samples one value per season
using shared scalar hyperpriors, producing shape (num_seasons,).
For categorical covariates (covariate is not None) this samples one value per
non-baseline category using shared scalar hyperpriors, then broadcasts that
constant effect across all seasons to produce shape
(num_seasons, num_categories). The baseline category (index 0) is always zero.
Attributes:
| Name | Type | Description |
|---|---|---|
mu_mu |
float
|
The location parameter of the partially pooled mean. |
mu_sigma |
float
|
The scale parameter of the partially pooled mean. |
sigma |
float
|
The scale of the half-normal distribution for the standard deviation. |
sample(*, season_names, covariate_categories)
¶
Sample values from a Gaussian distribution defined by the mean and stddev.
Returns:
| Type | Description |
|---|---|
NumericalArrayLike
|
A numerical array-like structure containing the sampled covariate values. |
Source code in src/vaxflux/_covariates.py
PooledGaussianCovariate
¶
Bases: Covariate
Covariate model using a pooled Gaussian approach.
For seasonal covariates (covariate is None) this samples one value per season,
producing shape (num_seasons,).
For categorical covariates (covariate is not None) this samples one value per
non-baseline category, then broadcasts that constant effect across all seasons to
produce shape (num_seasons, num_categories). The baseline category (index 0)
is always zero.
Attributes:
| Name | Type | Description |
|---|---|---|
mu |
float
|
The mean of the Gaussian distribution. |
sigma |
float
|
The standard deviation of the Gaussian distribution. |
sample(*, season_names, covariate_categories)
¶
Sample values from a Gaussian distribution defined by the mean and stddev.
Returns:
| Type | Description |
|---|---|
NumericalArrayLike
|
A numerical array-like structure containing the sampled covariate values. |
Source code in src/vaxflux/_covariates.py
SeasonRange
¶
Bases: BaseModel
A representation of a season range for uptake scenarios.
Examples:
>>> from vaxflux import SeasonRange
>>> season_range = SeasonRange(
... season="2023/2024",
... start_date="2023-12-01",
... end_date="2024-03-31",
... )
>>> season_range.season
'2023/2024'
>>> season_range.start_date
datetime.date(2023, 12, 1)
>>> season_range.end_date
datetime.date(2024, 3, 31)
>>> SeasonRange(
... season="2023/2024",
... start_date="2024-03-31",
... end_date="2023-12-01",
... )
Traceback (most recent call last):
...
pydantic_core._pydantic_core.ValidationError: 1 validation error for SeasonRange
Value error, The end date, 2023-12-01, must be after or the same as the start date 2024-03-31. [...]
For further information visit ...
SeasonVaryingPartiallyPooledGaussianCovariate
¶
Bases: Covariate
Season-varying covariate model using partial pooling across seasons.
This covariate samples category effects that vary by season, while shrinking
them toward category-level means. It is only valid for categorical covariates
(covariate is not None) and always returns shape
(num_seasons, num_categories).
Attributes:
| Name | Type | Description |
|---|---|---|
mu_mu |
float
|
The location parameter of the partially pooled mean. |
mu_sigma |
float
|
The scale parameter of the partially pooled mean. |
sigma |
float
|
The scale of the half-normal distribution for the standard deviation. |
Notes
Unlike covariates generally, this class does not support seasonal covariates
since the effects are explicitly season-varying. Attempting to use this class
with covariate is None will raise an error at model definition time due to
the covariate name validation.
extra_dims(season_coord, category_short_coord)
¶
Return dims for the hyperparameter and raw-draw sites.
Registers {prefix}_mu, {prefix}_sigma, and {prefix} with
the appropriate coordinate names.
Source code in src/vaxflux/_covariates.py
sample(*, season_names, covariate_categories)
¶
Sample season-varying covariate values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
season_names
|
list[str]
|
Season names in the model. |
required |
covariate_categories
|
list[str] | None
|
Covariate categories including the baseline.
Always non- |
required |
Returns:
| Type | Description |
|---|---|
NumericalArrayLike
|
A numerical array-like structure containing the sampled covariate values. |
Source code in src/vaxflux/_covariates.py
VaxfluxInferenceData
¶
Bases: InferenceData
Container for inference data specific to vaxflux.
This class extends
ArviZ's InferenceData
to include specialized functionality for vaxflux models. This allows for users to
easily use this object with functions that expect ArviZ InferenceData while also
providing vaxflux-specific methods and attributes.
coords
cached
property
¶
Return model coordinates without chain/draw dimensions.
Returns:
| Type | Description |
|---|---|
dict[str, list[str]]
|
The coordinate mapping for the merged prior or posterior dataset without |
dict[str, list[str]]
|
chain and draw dimensions (which differ between prior/posterior). |
covariate_categories
cached
property
¶
A mapping of covariate names to their category labels.
Returns:
| Type | Description |
|---|---|
dict[str, dict[str, str]]
|
A dictionary mapping covariate names to dictionaries of category labels |
dict[str, dict[str, str]]
|
where each inner dictionary maps coordinate-safe category names to the |
dict[str, dict[str, str]]
|
original category labels. |
merged_posterior
cached
property
¶
Return a merged posterior dataset combining posterior and posterior predictive.
Returns:
| Type | Description |
|---|---|
Dataset
|
An xarray Dataset with posterior and posterior predictive variables. |
merged_prior
cached
property
¶
Return a merged prior dataset combining prior and prior predictive.
Returns:
| Type | Description |
|---|---|
Dataset
|
An xarray Dataset with prior and prior predictive variables. |
posterior_observations
cached
property
¶
Return posterior observations as a formatted DataFrame.
Returns:
| Type | Description |
|---|---|
DataFrame
|
A DataFrame with observation metadata and values. Will contain the columns |
DataFrame
|
'chain', 'draw', 'season', 'season_start_date', 'season_end_date', |
DataFrame
|
'start_date', 'end_date', 'report_date', any covariate columns, |
DataFrame
|
'type', and 'value'. |
prior_observations
cached
property
¶
Return prior observations as a formatted DataFrame.
Returns:
| Type | Description |
|---|---|
DataFrame
|
A DataFrame with observation metadata and values. Will contain the columns |
DataFrame
|
'chain', 'draw', 'season', 'season_start_date', 'season_end_date', |
DataFrame
|
'start_date', 'end_date', 'report_date', any covariate columns, |
DataFrame
|
'type', and 'value'. |
from_numpyro(*args, **kwargs)
classmethod
¶
Create inference data from NumPyro outputs via ArviZ.
For more details, see
ArviZ's from_numpyro function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
Any
|
Positional arguments forwarded to |
()
|
**kwargs
|
Any
|
Keyword arguments forwarded to |
{}
|
Returns:
| Type | Description |
|---|---|
VaxfluxInferenceData
|
A |
Source code in src/vaxflux/_vaxflux_inference_data.py
plot_posterior_predictive(*, intervals=(0.5, 0.8, 0.95), plot_incidence=True, plot_prevalence=True, plot_observations=None, predictive='latent', figsize=None)
¶
Plot posterior predictive checks with shaded intervals.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
intervals
|
Sequence[float]
|
Central predictive intervals to plot. |
(0.5, 0.8, 0.95)
|
plot_incidence
|
bool
|
Whether to plot incidence panels. |
True
|
plot_prevalence
|
bool
|
Whether to plot prevalence panels. |
True
|
plot_observations
|
bool | None
|
Whether to plot observed data points. When |
None
|
predictive
|
Literal['latent', 'observation']
|
Whether to plot latent curve intervals or observation-level predictive intervals. |
'latent'
|
figsize
|
tuple[float, float] | None
|
Optional figure size override. |
None
|
Returns:
| Type | Description |
|---|---|
Figure
|
The Matplotlib figure. |
Source code in src/vaxflux/_vaxflux_inference_data.py
plot_prior_predictive(*, intervals=(0.5, 0.8, 0.95), plot_incidence=True, plot_prevalence=True, plot_observations=None, predictive='latent', figsize=None)
¶
Plot prior predictive checks with shaded intervals.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
intervals
|
Sequence[float]
|
Central predictive intervals to plot. |
(0.5, 0.8, 0.95)
|
plot_incidence
|
bool
|
Whether to plot incidence panels. |
True
|
plot_prevalence
|
bool
|
Whether to plot prevalence panels. |
True
|
plot_observations
|
bool | None
|
Whether to plot observed data points. When |
None
|
predictive
|
Literal['latent', 'observation']
|
Whether to plot latent curve intervals or observation-level predictive intervals. |
'latent'
|
figsize
|
tuple[float, float] | None
|
Optional figure size override. |
None
|
Returns:
| Type | Description |
|---|---|
Figure
|
The Matplotlib figure. |
Source code in src/vaxflux/_vaxflux_inference_data.py
posterior_prevalence_scenarios(scenarios)
¶
Summarize posterior end-of-season prevalence scenarios.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
scenarios
|
dict[str, tuple[float, float]]
|
Mapping of scenario name to (low, high) quantile bounds. |
required |
Returns:
| Type | Description |
|---|---|
DataFrame
|
A DataFrame with scenario quantiles and prevalence bounds. |
Source code in src/vaxflux/_vaxflux_inference_data.py
prior_prevalence_scenarios(scenarios)
¶
Summarize prior end-of-season prevalence scenarios.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
scenarios
|
dict[str, tuple[float, float]]
|
Mapping of scenario name to (low, high) quantile bounds. |
required |
Returns:
| Type | Description |
|---|---|
DataFrame
|
A DataFrame with scenario quantiles and prevalence bounds. |
Source code in src/vaxflux/_vaxflux_inference_data.py
VaxfluxModel(curve)
¶
Vaccine uptake model builder.
VaxfluxModel stores the curve, time ranges, covariates, interventions, and
observations that define a vaccination uptake model. Configure the model by
incrementally adding these components, then use rendering or sampling methods to
inspect or estimate the resulting probabilistic model.
Examples:
>>> import pandas as pd
>>> from vaxflux import LogisticCurve, SeasonRange, VaxfluxModel
>>> observations = pd.DataFrame(
... {
... "season": ["2023/2024"],
... "start_date": ["2023-12-01"],
... "end_date": ["2023-12-07"],
... "type": ["incidence"],
... "value": [0.18],
... }
... )
>>> model = VaxfluxModel(curve=LogisticCurve())
>>> model = (
... model.add_seasons(
... SeasonRange(
... season="2023/2024",
... start_date="2023-12-01",
... end_date="2024-03-31",
... )
... )
... .add_observations(observations)
... .add_observation_process(
... kind="normal",
... noise=0.05,
... )
... )
>>> model
<vaxflux.VaxfluxModel object at ...>
Initialize the VaxfluxModel with a given uptake curve.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
curve
|
Curve
|
An instance of a |
required |
Source code in src/vaxflux/_vaxflux_model.py
__repr__()
¶
Return a string representation of the model.
Returns:
| Type | Description |
|---|---|
str
|
A string representation of the |
add_covariate_categories(*args)
¶
Add one or more covariate categories to the model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
CovariateCategories | list[CovariateCategories]
|
One or more |
()
|
Returns:
| Type | Description |
|---|---|
Self
|
The model instance for method chaining. |
Raises:
| Type | Description |
|---|---|
TypeError
|
If any argument is not a |
Source code in src/vaxflux/_vaxflux_model.py
add_covariates(*args)
¶
Add one or more covariates to the model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
Covariate | list[Covariate]
|
One or more |
()
|
Returns:
| Type | Description |
|---|---|
Self
|
The model instance for method chaining. |
Raises:
| Type | Description |
|---|---|
TypeError
|
If any argument is not a |
ValueError
|
If duplicate parameter/covariate pairs are provided. |
Examples:
>>> from vaxflux._covariates import PartiallyPooledGaussianCovariate
>>> from vaxflux._curves import LogisticCurve
>>> model = VaxfluxModel(curve=LogisticCurve())
>>> model.add_covariates(
... PartiallyPooledGaussianCovariate(
... parameter="m",
... covariate="age",
... mu_mu=0.5,
... mu_sigma=0.1,
... sigma=0.2,
... )
... )
<vaxflux.VaxfluxModel object at ...>
>>> model.add_covariates(
... [
... PartiallyPooledGaussianCovariate(
... parameter="sigma",
... mu_mu=0.2,
... mu_sigma=0.03,
... sigma=0.1,
... ),
... ]
... )
<vaxflux.VaxfluxModel object at ...>
>>> model.add_covariates("invalid_argument")
Traceback (most recent call last):
...
TypeError: Arguments must be Covariate objects or sequences of Covariate objects, got str.
Source code in src/vaxflux/_vaxflux_model.py
add_dates(*args)
¶
Add one or more date ranges to the model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
DateRange | list[DateRange]
|
One or more |
()
|
Returns:
| Type | Description |
|---|---|
Self
|
The model instance for method chaining. |
Raises:
| Type | Description |
|---|---|
TypeError
|
If any argument is not a |
ValueError
|
If duplicate date ranges are found. |
ValueError
|
If overlapping date ranges are found. |
Examples:
>>> from vaxflux import LogisticCurve, DateRange, SeasonRange
>>> model = VaxfluxModel(curve=LogisticCurve())
>>> model.add_seasons(
... SeasonRange(
... season="2023/2024",
... start_date="2023-12-01",
... end_date="2024-03-31",
... )
... )
<vaxflux.VaxfluxModel object at ...>
>>> model.add_dates(
... DateRange(
... season="2023/2024",
... start_date="2023-12-01",
... end_date="2023-12-07",
... report_date="2023-12-08",
... )
... )
<vaxflux.VaxfluxModel object at ...>
>>> model.add_dates("invalid_argument")
Traceback (most recent call last):
...
TypeError: Arguments must be DateRange objects or sequences of DateRange objects, got str.
Source code in src/vaxflux/_vaxflux_model.py
add_implementations(*args)
¶
Add one or more implementations to the model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
Implementation | list[Implementation]
|
One or more |
()
|
Returns:
| Type | Description |
|---|---|
Self
|
The model instance for method chaining. |
Raises:
| Type | Description |
|---|---|
TypeError
|
If any argument is not an |
ValueError
|
If an implementation's intervention has not been added to the model. |
Source code in src/vaxflux/_vaxflux_model.py
add_interventions(*args)
¶
Add one or more interventions to the model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
Intervention | list[Intervention]
|
One or more |
()
|
Returns:
| Type | Description |
|---|---|
Self
|
The model instance for method chaining. |
Raises:
| Type | Description |
|---|---|
TypeError
|
If any argument is not an |
Source code in src/vaxflux/_vaxflux_model.py
add_observation_process(kind, noise, *, partially_pool_by_season=True, partially_pool_by_covariate=False, prevalence_penalty=0.0)
¶
Add an observation process to the model based on the added observations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kind
|
Literal['normal']
|
The kind of observation process to add. |
required |
noise
|
float
|
The observation noise parameter. |
required |
partially_pool_by_season
|
bool
|
Whether to partially pool observation noise by season. |
True
|
partially_pool_by_covariate
|
bool
|
Whether to partially pool observation noise by covariate. |
False
|
prevalence_penalty
|
float
|
Penalty to apply when modeled prevalence exceeds 1. The
penalty is applied as |
0.0
|
Returns:
| Type | Description |
|---|---|
Self
|
The model instance for method chaining. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If no observations have been added to the model. |
ValueError
|
If an unsupported observation process kind is provided. |
ValueError
|
If the prevalence penalty is negative. |
Source code in src/vaxflux/_vaxflux_model.py
add_observations(observations)
¶
Add observations to the model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
observations
|
VaxfluxObservations | DataFrame
|
The observations to add to the model. Accepts either a
|
required |
Returns:
| Type | Description |
|---|---|
Self
|
The model instance for method chaining. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If observations were already added to the model. |
Source code in src/vaxflux/_vaxflux_model.py
add_seasons(*args)
¶
Add one or more seasons to the model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
SeasonRange | list[SeasonRange]
|
One or more |
()
|
Returns:
| Type | Description |
|---|---|
Self
|
The model instance for method chaining. |
Raises:
| Type | Description |
|---|---|
TypeError
|
If any argument is not a |
Examples:
>>> from vaxflux import LogisticCurve, SeasonRange
>>> model = VaxfluxModel(curve=LogisticCurve())
>>> model.add_seasons(
... SeasonRange(
... season="2023/2024",
... start_date="2023-12-01",
... end_date="2024-03-31",
... )
... )
<vaxflux.VaxfluxModel object at ...>
>>> model.add_seasons(
... SeasonRange(
... season="2024/2025",
... start_date="2024-12-01",
... end_date="2025-03-31",
... ),
... SeasonRange(
... season="2025/2026",
... start_date="2025-12-01",
... end_date="2026-03-31",
... ),
... )
<vaxflux.VaxfluxModel object at ...>
>>> model.add_seasons(
... [
... SeasonRange(
... season="2026/2027",
... start_date="2026-12-01",
... end_date="2027-03-31",
... ),
... SeasonRange(
... season="2027/2028",
... start_date="2027-12-01",
... end_date="2028-03-31",
... ),
... ]
... )
<vaxflux.VaxfluxModel object at ...>
>>> model.add_seasons("invalid_argument")
Traceback (most recent call last):
...
TypeError: Arguments must be SeasonRange objects or sequences of SeasonRange objects, got str.
Source code in src/vaxflux/_vaxflux_model.py
from_csv(*args, **kwargs)
classmethod
¶
Construct a default model from observations stored in a CSV file.
The CSV is read with VaxfluxObservations.from_csv, then used to build a
model with:
- a default
LogisticCurve, - inferred seasons and dates from the observations,
- one
CovariateCategoriesentry per observation covariate using the full sorted list of observed levels, - one
SeasonVaryingPartiallyPooledGaussianCovariateper curve-parameter/covariate pair with weak default hyperparameters (mu_mu=0.0,mu_sigma=1.0,sigma=1.0), - no interventions or implementations,
- a
"normal"observation process withnoise=0.001.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
Any
|
Positional arguments forwarded to
|
()
|
**kwargs
|
Any
|
Keyword arguments forwarded to
|
{}
|
Returns:
| Type | Description |
|---|---|
Self
|
A configured |
Source code in src/vaxflux/_vaxflux_model.py
render_model(**kwargs)
¶
Render the model graph using NumPyro's rendering utilities.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs
|
Any
|
Keyword arguments forwarded to |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
The rendered model object (typically a Graphviz object). |
Source code in src/vaxflux/_vaxflux_model.py
sample(prior_samples=None, warmup=None, samples=None, chains=1, random_seed=1, mcmc_args=None, nuts_args=None, prior_predictive_args=None, posterior_predictive_args=None, arviz_args=None)
¶
Sample from the prior predictive, posterior, and posterior predictive.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prior_samples
|
int | None
|
Number of prior predictive samples to draw. When |
None
|
warmup
|
int | None
|
Number of warmup (burn-in) steps. |
None
|
samples
|
int | None
|
Number of posterior samples to draw. When |
None
|
chains
|
int
|
Number of MCMC chains to run. |
1
|
random_seed
|
int
|
Seed for the random number generator. |
1
|
mcmc_args
|
dict[str, Any] | None
|
Additional arguments for the MCMC sampler. |
None
|
nuts_args
|
dict[str, Any] | None
|
Additional arguments for the NUTS kernel. |
None
|
prior_predictive_args
|
dict[str, Any] | None
|
Additional arguments for the prior predictive sampler. |
None
|
posterior_predictive_args
|
dict[str, Any] | None
|
Additional arguments for the posterior predictive sampler. |
None
|
arviz_args
|
dict[str, Any] | None
|
Additional arguments for ArviZ's |
None
|
Returns:
| Type | Description |
|---|---|
VaxfluxInferenceData
|
The inference data containing whichever sampling stages were requested. |
Source code in src/vaxflux/_vaxflux_model.py
537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 | |
VaxfluxObservations(data)
¶
Container for vaxflux observations data.
Wraps a validated pandas DataFrame and provides a stable interface for
working with observations. Use from_dataframe to construct an
instance from a raw DataFrame.
The underlying DataFrame is accessible via the data property.
Columns that are not part of the standard metadata schema are exposed via
the covariate_columns property for downstream use by the model.
Examples:
>>> import pandas as pd
>>> from vaxflux import VaxfluxObservations
>>> obs = VaxfluxObservations.from_dataframe(
... pd.DataFrame(
... {
... "season": ["2023/2024"],
... "start_date": ["2023-10-01"],
... "end_date": ["2023-10-31"],
... "type": ["incidence"],
... "value": [0.15],
... "location": ["ABC County"],
... }
... )
... )
>>> obs.data.shape
(1, 7)
>>> obs.covariate_columns
['location']
>>> len(obs)
1
>>> obs["season"]
0 2023/2024
Name: season, dtype: str
Source code in src/vaxflux/_vaxflux_observations.py
covariate_columns
property
¶
Column names that are not part of the standard metadata schema.
Returns:
| Type | Description |
|---|---|
list[str]
|
A sorted list of column names that are not in the set of known |
list[str]
|
metadata columns ( |
list[str]
|
|
data
property
¶
The validated underlying observations DataFrame.
from_csv(*args, **kwargs)
classmethod
¶
Construct a VaxfluxObservations by reading a CSV with pandas.
This is a thin wrapper around pandas.read_csv; all positional and keyword
arguments are forwarded directly to that function before the resulting
DataFrame is validated and wrapped.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
Any
|
Positional arguments forwarded to |
()
|
**kwargs
|
Any
|
Keyword arguments forwarded to |
{}
|
Returns:
| Type | Description |
|---|---|
VaxfluxObservations
|
A validated |
Raises:
| Type | Description |
|---|---|
Exception
|
Propagates any exception raised by |
ValueError
|
If the resulting DataFrame is empty or missing required columns. |
ValueError
|
If the |
ValueError
|
If the |
NotImplementedError
|
If prevalence observations are provided. |
NotImplementedError
|
If observations with differing report dates are provided. |
Source code in src/vaxflux/_vaxflux_observations.py
from_dataframe(data)
classmethod
¶
Construct a VaxfluxObservations from a DataFrame or existing instance.
If data is already a VaxfluxObservations, it is returned
unchanged. If it is a pandas.DataFrame, it is validated and
wrapped.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
VaxfluxObservations | DataFrame
|
A raw observations DataFrame or an existing |
required |
Returns:
| Type | Description |
|---|---|
VaxfluxObservations
|
A validated |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the DataFrame is empty or missing required columns. |
ValueError
|
If the |
ValueError
|
If the |
NotImplementedError
|
If prevalence observations are provided. |
NotImplementedError
|
If observations with differing report dates are provided (nowcasting is not yet supported). |
Source code in src/vaxflux/_vaxflux_observations.py
daily_date_ranges(season_ranges, range_days=0, remainder='raise')
¶
Create daily date ranges from the season ranges.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
season_ranges
|
list[SeasonRange] | SeasonRange
|
The season ranges to create the daily date ranges from. |
required |
range_days
|
int
|
The number of days for each daily date range, must be at least 0. |
0
|
remainder
|
Literal['fill', 'raise', 'skip']
|
The strategy to handle the remainder of days when the season ranges do not divide evenly into daily date ranges. Options are "fill" to fill the remainder with the last date range, "raise" to raise an error, and "skip" to skip the remainder. |
'raise'
|
Returns:
| Type | Description |
|---|---|
list[DateRange]
|
The daily date ranges for the uptake scenarios. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the number of days for each daily date range is less than 1. |
ValueError
|
If the number of days for each daily date range does not divide
evenly into the season range and |