Models

The models package hosts the suite of probabilistic models supported by stormi.

RNA_1layer

stormi.models.RNA_1layer

Functions

Name Description
dstate_dt Compute the derivative of the state vector for the coupled system using a neural network.
mlp One or Multilayer Perceptron (MLP) with residual connections.
model NumPyro model for coupled transcription and splicing dynamics.

dstate_dt

stormi.models.RNA_1layer.dstate_dt(t, state, args)

Compute the derivative of the state vector for the coupled system using a neural network.

Parameters

Name Type Description Default
t Time scalar. required
state State vector [u_1, …, u_G, s_1, …, s_G]. required
args Tuple containing parameters (G, beta_g, gamma_g, nn_params, T_ON). required

Returns

Name Type Description
Derivative of the state vector

mlp

stormi.models.RNA_1layer.mlp(params, x)

One or Multilayer Perceptron (MLP) with residual connections.

Parameters

Name Type Description Default
params Dict Dictionary containing neural network parameters (weights and biases). required
x Any Input data array. required

Returns

Name Type Description
Any Output array after passing through the MLP.

model

stormi.models.RNA_1layer.model(
    data,
    M_c,
    obs2sample,
    batch_index,
    tf_indices,
    total_num_cells,
    n_batch,
    prior_time,
    prior_timespan,
    unknown_idx,
    T_limits,
    return_alpha=False,
    Tmax_alpha=50.0,
    Tmax_beta=1.0,
    splicing_rate_alpha_hyp_prior_alpha=20.0,
    splicing_rate_alpha_hyp_prior_mean=5.0,
    splicing_rate_mean_hyp_prior_alpha=10.0,
    splicing_rate_mean_hyp_prior_mean=1.0,
    degradation_rate_alpha_hyp_prior_alpha=20.0,
    degradation_rate_alpha_hyp_prior_mean=5.0,
    degradation_rate_mean_hyp_prior_alpha=10.0,
    degradation_rate_mean_hyp_prior_mean=1.0,
    transcription_rate_alpha_hyp_prior_alpha=20.0,
    transcription_rate_alpha_hyp_prior_mean=2.0,
    transcription_rate_mean_hyp_prior_alpha=10.0,
    transcription_rate_mean_hyp_prior_mean=5.0,
    lambda_alpha=1.0,
    lambda_mean=1.0,
    kappa_alpha=1.0,
    kappa_mean=1.0,
    detection_mean_hyp_prior_alpha=1.0,
    detection_mean_hyp_prior_beta=1.0,
    detection_hyp_prior_alpha=10.0,
    detection_i_prior_alpha=100.0,
    detection_gi_prior_alpha=200.0,
    gene_add_alpha_hyp_prior_alpha=9.0,
    gene_add_alpha_hyp_prior_beta=3.0,
    gene_add_mean_hyp_prior_alpha=1.0,
    gene_add_mean_hyp_prior_beta=100.0,
    stochastic_v_ag_hyp_prior_alpha=9.0,
    stochastic_v_ag_hyp_prior_beta=3.0,
    sde_rng_key=0,
    **kwargs,
)

NumPyro model for coupled transcription and splicing dynamics.

Parameters

Name Type Description Default
data Any Observed data array of shape (num_cells, num_genes, num_modalities). required
M_c Any Number of cells in each metacell. required
batch_index Any Array indicating batch assignments for each cell. required
tf_indices Any Indices of genes that are TFs. required
total_num_cells int Number of cells in the full dataset. required
n_batch int Number of batches. required
Tmax_alpha float Alpha parameter for Tmax prior. 50.0
Tmax_beta float Beta parameter for Tmax prior. 1.0
key Random number generator key. required

Returns

Name Type Description
None. Defines the probabilistic model for inference.