Models
The models package hosts the suite of probabilistic models supported by stormi.
RNA_1layer
stormi.models.RNA_1layer
Functions
| Name | Description |
|---|---|
| dstate_dt | Compute the derivative of the state vector for the coupled system using a neural network. |
| mlp | One or Multilayer Perceptron (MLP) with residual connections. |
| model | NumPyro model for coupled transcription and splicing dynamics. |
dstate_dt
stormi.models.RNA_1layer.dstate_dt(t, state, args)Compute the derivative of the state vector for the coupled system using a neural network.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| t | Time scalar. | required | |
| state | State vector [u_1, …, u_G, s_1, …, s_G]. | required | |
| args | Tuple containing parameters (G, beta_g, gamma_g, nn_params, T_ON). | required |
Returns
| Name | Type | Description |
|---|---|---|
| Derivative of the state vector |
mlp
stormi.models.RNA_1layer.mlp(params, x)One or Multilayer Perceptron (MLP) with residual connections.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| params | Dict | Dictionary containing neural network parameters (weights and biases). | required |
| x | Any | Input data array. | required |
Returns
| Name | Type | Description |
|---|---|---|
| Any | Output array after passing through the MLP. |
model
stormi.models.RNA_1layer.model(
data,
M_c,
obs2sample,
batch_index,
tf_indices,
total_num_cells,
n_batch,
prior_time,
prior_timespan,
unknown_idx,
T_limits,
return_alpha=False,
Tmax_alpha=50.0,
Tmax_beta=1.0,
splicing_rate_alpha_hyp_prior_alpha=20.0,
splicing_rate_alpha_hyp_prior_mean=5.0,
splicing_rate_mean_hyp_prior_alpha=10.0,
splicing_rate_mean_hyp_prior_mean=1.0,
degradation_rate_alpha_hyp_prior_alpha=20.0,
degradation_rate_alpha_hyp_prior_mean=5.0,
degradation_rate_mean_hyp_prior_alpha=10.0,
degradation_rate_mean_hyp_prior_mean=1.0,
transcription_rate_alpha_hyp_prior_alpha=20.0,
transcription_rate_alpha_hyp_prior_mean=2.0,
transcription_rate_mean_hyp_prior_alpha=10.0,
transcription_rate_mean_hyp_prior_mean=5.0,
lambda_alpha=1.0,
lambda_mean=1.0,
kappa_alpha=1.0,
kappa_mean=1.0,
detection_mean_hyp_prior_alpha=1.0,
detection_mean_hyp_prior_beta=1.0,
detection_hyp_prior_alpha=10.0,
detection_i_prior_alpha=100.0,
detection_gi_prior_alpha=200.0,
gene_add_alpha_hyp_prior_alpha=9.0,
gene_add_alpha_hyp_prior_beta=3.0,
gene_add_mean_hyp_prior_alpha=1.0,
gene_add_mean_hyp_prior_beta=100.0,
stochastic_v_ag_hyp_prior_alpha=9.0,
stochastic_v_ag_hyp_prior_beta=3.0,
sde_rng_key=0,
**kwargs,
)NumPyro model for coupled transcription and splicing dynamics.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| data | Any | Observed data array of shape (num_cells, num_genes, num_modalities). | required |
| M_c | Any | Number of cells in each metacell. | required |
| batch_index | Any | Array indicating batch assignments for each cell. | required |
| tf_indices | Any | Indices of genes that are TFs. | required |
| total_num_cells | int | Number of cells in the full dataset. | required |
| n_batch | int | Number of batches. | required |
| Tmax_alpha | float | Alpha parameter for Tmax prior. | 50.0 |
| Tmax_beta | float | Beta parameter for Tmax prior. | 1.0 |
| key | Random number generator key. | required |
Returns
| Name | Type | Description |
|---|---|---|
| None. Defines the probabilistic model for inference. |