Guides

The guides package hosts guides for using stormi.

AmortizedNormal

stormi.guides.AmortizedNormal(
    self,
    model,
    model_input,
    *,
    init_net_params=None,
    init_loc_fn=init_to_mean,
    init_seed=0,
    nn_width=32,
    nn_depth=2,
    props_t=(0.7, 0.1, 0.1, 0.1),
    props_y=(0.5, 0.3, 0.1, 0.1),
    props_l=(0.5, 0.3, 0.1, 0.1),
    props_pw=(0.7, 0.1, 0.1, 0.1),
    hvg_n_top=2000,
    rna_embeddings=None,
    gene_nn_width=32,
    gene_nn_depth=2,
    share_gene_trunk=True,
)

Compose an AutoNormal with an amortized neural guide for cell-specific (‘local’) parameters.

This wrapper: - infers which amortized heads to enable from model_input (RNA/ATAC presence, number of paths), - builds a two-part AutoGuideList (AutoNormal for globals + MLP for locals), - optionally warms up the MLP with a prior-only SVI and rebinds those weights, - provides convenience utilities to get warm predictions, save/load warm params, and extract posterior means (global + local, batched if needed).

Methods

Name Description
extract_all_medians Get (global_medians, local_medians) in one call, with batched local extraction.
load_warm Load the cached warm parameters from a pickle file.
plot_warm Plot UMAPs of warm-start predictions directly from the object.
save_warm Save the cached warm parameters to a pickle file.
warm_predictions Compute warm-start predictions using the cached warm parameters.
warm_up Run prior-only warm-up and rebind the amortized network with learned weights.

extract_all_medians

stormi.guides.AmortizedNormal.extract_all_medians(
    model_input,
    training_output,
    *,
    batch_size=1000,
    **kw,
)

Get (global_medians, local_medians) in one call, with batched local extraction.

Parameters

model_input : dict Input data dictionary used for inference. Must contain at least ‘data’, ‘obs2sample’, and ‘M_c’. Optional: ‘data_atac’. training_output : dict Dictionary returned by training, expected to include {‘guide’: AutoGuideList, ‘svi’: SVI, ‘svi_state’: Any}. batch_size : int, optional Batch size for local extraction. Default is 1000. **kw : Additional keyword arguments forwarded to extract_local_means_full (e.g., num_paths, etc.).

Returns

tuple[dict, dict] (global_medians, local_medians) dictionaries.

load_warm

stormi.guides.AmortizedNormal.load_warm(path)

Load the cached warm parameters from a pickle file.

Parameters

path : str | pathlib.Path File path to read from.

Returns

dict The loaded parameter dict. Also rebinds the guide so that subsequent training starts from these weights.

plot_warm

stormi.guides.AmortizedNormal.plot_warm(
    adata_rna,
    model_input,
    *,
    data_atac=None,
    day_key=None,
    size=100,
    ncols=4,
    cmap='inferno',
    return_axes=False,
)

Plot UMAPs of warm-start predictions directly from the object.

Thin wrapper around warm_predictions + plot_warm_params.

Parameters

adata_rna : anndata.AnnData model_input : dict data_atac : jnp.ndarray or None day_key : str or None size, ncols, cmap, return_axes : see plot_warm_params.

Returns

list[matplotlib.axes.Axes] or None

save_warm

stormi.guides.AmortizedNormal.save_warm(path)

Save the cached warm parameters to a pickle file.

Parameters

path : str | pathlib.Path File path to write to.

warm_predictions

stormi.guides.AmortizedNormal.warm_predictions(model_input, data_atac=None)

Compute warm-start predictions using the cached warm parameters.

Parameters

model_input : dict Must contain arrays needed by warm_forward_predictions. data_atac : jnp.ndarray or None Required when the cached params include an l head.

Returns

dict See warm_forward_predictions.

warm_up

stormi.guides.AmortizedNormal.warm_up(model_input, n_steps=10000, seed=0)

Run prior-only warm-up and rebind the amortized network with learned weights.

Parameters

model_input : dict Inputs required by warm_up_guide. n_steps : int Number of warm-up SVI steps. seed : int Random seed for SVI init.

Returns

dict The learned amortized-network parameters. Also cached on the object and used to rebuild the guide with these as initial values.