Simplified RNA 1-layer model

1 Imports

Code
import os
import scanpy as sc
import pandas as pd
import numpy as np
import celldynamics as cd
import anndata

from celldynamics.models.RNA_1layer_simple import model, prepare_model_input, sample_prior
from celldynamics.guides import AmortizedNormal
from numpyro.infer.initialization import init_to_mean

import scipy.stats
import seaborn as sns
import matplotlib.pyplot as plt

2 Test fits on prior predictive samples

# Load training data:
data_directory = "/home/tchari/CellDynamicsData/mouse_brain/10X_multiome_mouse_brain_1Lin.h5ad"
adata_rna = sc.read_h5ad(data_directory)
adata_rna.var_names_make_unique()

# Load TF file:
repo_path = '/home/tchari/CellDynamics/'
tfs_path = repo_path + '/' + 'Mouse_TFs.txt'
tf_list = list(pd.read_csv(tfs_path, header = None).iloc[:,0])
sc.pp.log1p(adata_rna)
adata_rna = cd.filter_genes(adata_rna, tf_list, n_top_genes = 1000, count_threshold = 103) #Changed from 10^6
adata_rna
guide = AmortizedNormal(model, init_loc_fn=init_to_mean, predict_detection_l_c=False)
model_input = prepare_model_input(
    adata_rna,
    tf_list,
    n_cells_col = "n_cells")

model_input["return_alpha"] = True
# guide, svi, svi_state, losses, model_input = cd.train_svi(
#     model,
#     guide,
#     model_input = model_input,
#     max_iterations = 1000,
#     min_lr = 0.001,
#     max_lr = 0.01,
#     ramp_up_fraction = 0.1,
#     cell_batch_size = 0,
#     log_interval = 100,
# )
#model_input["return_alpha"] = True
prior_samples = sample_prior(model, model_input, num_samples = 1)
prior_samples['data_target'].shape
prior_samples
plt.scatter(prior_samples['mu'][0,:,0,1],np.array(prior_samples['data_target'])[0,:,0,1]) #for one gene across all cells, unspliced (0) or spliced (1)
plt.loglog()
np.corrcoef(prior_samples['mu'][0,:,0,1],np.array(prior_samples['data_target'])[0,:,0,1])
#prior_samples['predictions_rearranged'][0,:,0,:,:].shape
plt.scatter(prior_samples['predictions_rearranged'][0,:,0,0,1],prior_samples['mu'][0,:,0,1]) #for one gene across all cells, unspliced (0) or spliced (1)
plt.plot(prior_samples['predictions_rearranged'][0,:,0,0,1],prior_samples['predictions_rearranged'][0,:,0,0,1],color='black')
plt.loglog()
np.corrcoef(prior_samples['mu'][0,:,0,1],prior_samples['predictions_rearranged'][0,:,0,0,1])

plt.xlabel('Bio. mu (prior)')
plt.ylabel('Sampled mu (prior)')
T_c = np.mean(prior_samples['T_c'], axis = 0)

adata_rna.obs['Time'] =  T_c - np.min(T_c)

sc.pp.pca(adata_rna)
sc.pl.pca(
adata_rna,
color=["Time", 'celltype'],
ncols = 2,
size = 100*adata_rna.obs['n_cells'],
cmap = 'inferno',
title = 'Time since initial condition in hours'
)

2.1 Make new adata with prior samples and new model_input

new_adata = anndata.AnnData(X=np.array(prior_samples['data_target'])[0,:,:,1])
new_adata
new_adata.obs['n_cells'] = list(adata_rna.obs['n_cells'])
new_adata.layers['unspliced'] = np.array(prior_samples['data_target'])[0,:,:,0]
new_adata.layers['spliced'] = np.array(prior_samples['data_target'])[0,:,:,1]

new_adata.obs_names = adata_rna.obs_names
new_adata.var_names = adata_rna.var_names

sc.pp.log1p(new_adata)
sc.pp.highly_variable_genes(new_adata, n_top_genes=1000)
new_adata
guide = AmortizedNormal(model, init_loc_fn=init_to_mean, predict_detection_l_c=False)

model_input = prepare_model_input(
    new_adata,
    tf_list,
    n_cells_col = "n_cells")
# #Check that data remained the same
# plt.scatter(prior_samples['mu'][0,:,0,1],model_input['data'][:,0,1]) #for one gene across all cells, unspliced (0) or spliced (1)
# plt.loglog()

2.2 Run inference for new model_input

model_input["return_alpha"] = True

guide, svi, svi_state, losses, model_input = cd.train_svi(
    model,
    guide,
    model_input = model_input,
    max_iterations = 1000,
    min_lr = 0.001,
    max_lr = 0.01,
    ramp_up_fraction = 0.1,
    cell_batch_size = 0,
    log_interval = 100,
)
cd.plot_elbo_loss(losses)
posterior = cd.extract_posterior_estimates(
    model,
    guide,
    svi,
    svi_state,
    quantiles = [],
    num_samples = 0,
    modes = [1],
    model_input = model_input,
    deterministic_sites = ["predictions_rearranged", "T_c", "mu", "alpha_cg"]
)
#Check if prior_samples had mu, alpha_cg, T_c , predictions_rearranged
#Low correlation between mu and data generated
plt.scatter(prior_samples['mu'][0,:,0,1],model_input['data'][:,0,1]) #for one gene across all cells, unspliced (0) or spliced (1)
plt.loglog()
x = prior_samples['alpha_cg'][0,:,:].flatten()
y = np.mean(posterior['deterministic']['alpha_cg'],axis=0).flatten()

spearman,_ = scipy.stats.spearmanr(x,y)

pearson,_=scipy.stats.pearsonr(x,y)



sns.scatterplot(x=x,y=y)
plt.annotate(r"$\rho$:"+f"{spearman:.2f}", xy=(0.1, 0.9), xycoords='axes fraction')
plt.annotate(r"$r$:"+f"{pearson:.2f}", xy=(0.1, 0.85), xycoords='axes fraction')

plt.xlabel('Prior alpha_cg')
plt.ylabel('Inf. alpha_cg')

#_=plt.loglog()
x = np.mean(prior_samples['alpha_cg'][0,:,:],axis=0).flatten()
y = np.mean(np.mean(posterior['deterministic']['alpha_cg'],axis=0),axis=0).flatten()

spearman,_ = scipy.stats.spearmanr(x,y)

pearson,_=scipy.stats.pearsonr(x,y)



sns.scatterplot(x=x,y=y)
plt.annotate(r"$\rho$:"+f"{spearman:.2f}", xy=(0.1, 0.9), xycoords='axes fraction')
plt.annotate(r"$r$:"+f"{pearson:.2f}", xy=(0.1, 0.85), xycoords='axes fraction')

plt.xlabel('(Mean) Prior alpha_g')
plt.ylabel('(Mean) Inf. alpha_g')

_=plt.loglog()
x = prior_samples['mu'][0,:,:].flatten()
y = np.mean(posterior['deterministic']['mu'],axis=0).flatten()

spearman,_ = scipy.stats.spearmanr(x,y)
pearson,_=scipy.stats.pearsonr(x,y)



sns.scatterplot(x=x,y=y)
plt.annotate(r"$\rho$:"+f"{spearman:.2f}", xy=(0.1, 0.9), xycoords='axes fraction')
plt.annotate(r"$r$:"+f"{pearson:.2f}", xy=(0.1, 0.85), xycoords='axes fraction')

plt.xlabel('Prior mu')
plt.ylabel('Inf. mu')

_=plt.loglog()
#Corr for a single gene
ind = 50

x = prior_samples['alpha_cg'][0,:,ind].flatten()
y = np.mean(posterior['deterministic']['alpha_cg'],axis=0)[:,ind].flatten()

spearman,_ = scipy.stats.spearmanr(x,y)

pearson,_=scipy.stats.pearsonr(x,y)



sns.scatterplot(x=x,y=y)
plt.annotate(r"$\rho$:"+f"{spearman:.2f}", xy=(0.1, 0.9), xycoords='axes fraction')
plt.annotate(r"$r$:"+f"{pearson:.2f}", xy=(0.1, 0.85), xycoords='axes fraction')

plt.xlabel('Prior alpha_c')
plt.ylabel('Inf. alpha_c')

#_=plt.loglog()
fig = cd.predictions_vs_data(
    model_input['data'],
    posterior['deterministic']['mu'],
    ylabel = 'Posterior Predictions',
    title = 'Posterior vs Data',
)
fig = cd.predictions_vs_data(
    model_input['data'],
    posterior['deterministic']['predictions_rearranged'][:,:,0,:,:],
    ylabel = 'Posterior Predictions',
    title = 'Posterior vs Data',
)

3 Test fits on real data

# Load training data:
data_directory = "/home/tchari/CellDynamicsData/mouse_brain/10X_multiome_mouse_brain_1Lin.h5ad"
adata_rna = sc.read_h5ad(data_directory)
adata_rna.var_names_make_unique()

# Load TF file:
repo_path = '/home/tchari/CellDynamics/'
tfs_path = repo_path + '/' + 'Mouse_TFs.txt'
tf_list = list(pd.read_csv(tfs_path, header = None).iloc[:,0])
sc.pp.log1p(adata_rna)
adata_rna = cd.filter_genes(adata_rna, tf_list, n_top_genes = 1000, count_threshold = 103) #Changed from 10^6
guide = AmortizedNormal(model, init_loc_fn=init_to_mean, predict_detection_l_c=False)
model_input = prepare_model_input(
    adata_rna,
    tf_list,
    n_cells_col = "n_cells")
model_input["return_alpha"] = True
prior_samples = sample_prior(model, model_input, num_samples = 1)
#Hopx > Stem Cell, Eomes > intermediate progenitor, Satb2 > Upper Layer Neuron, Stmn2 > General Neuron
fig = cd.prior_data_geneset(prior_samples, model_input, adata_rna, geneset=["Eomes", "Hopx", "Satb2", "Stmn2"], subplot_size=(18, 12), 
                           plot_alpha = True)
fig = cd.predictions_vs_data(
    model_input['data'],
    prior_samples['mu'],
    ylabel = 'Prior Predictions',
    title = 'Prior vs Data'
)
T_c = np.mean(prior_samples['T_c'], axis = 0)

adata_rna.obs['Time'] =  T_c - np.min(T_c)

sc.pp.pca(adata_rna)
sc.pl.pca(
adata_rna,
color=["Time", 'celltype'],
ncols = 2,
size = 100*adata_rna.obs['n_cells'],
cmap = 'inferno',
title = 'Time since initial condition in hours'
)
sc.pl.umap(
adata_rna,
color=["Time", 'celltype'],
ncols = 2,
size = 100*adata_rna.obs['n_cells'],
cmap = 'inferno',
title = 'Time since initial condition in hours'
)
guide, svi, svi_state, losses, model_input = cd.train_svi(
    model,
    guide,
    model_input = model_input,
    max_iterations = 1000,
    min_lr = 0.001,
    max_lr = 0.01,
    ramp_up_fraction = 0.1,
    cell_batch_size = 0,
    log_interval = 100,
)
cd.plot_elbo_loss(losses)
posterior = cd.extract_posterior_estimates(
    model,
    guide,
    svi,
    svi_state,
    quantiles = [],
    num_samples = 0,
    modes = [1],
    model_input = model_input,
    deterministic_sites = ["predictions_rearranged", "T_c", "mu", "alpha_cg"]
)
# import pickle

# # Load the dictionary from the file
# with open("/g/stegle/aivazidis/data/posterior_RNA1layer_2500epochs_1.pkl", "rb") as f:
#     posterior = pickle.load(f)
fig = cd.predictions_vs_data(
    model_input['data'],
    posterior['deterministic']['mu'],
    ylabel = 'Posterior Predictions',
    title = 'Posterior vs Data',
)
fig = cd.predictions_vs_data(
    model_input['data'],
    posterior['deterministic']['predictions_rearranged'][:,:,0,:,:],
    ylabel = 'Posterior Predictions',
    title = 'Posterior vs Data',
)
T_c = np.mean(posterior['deterministic']['T_c'], axis = 0)

adata_rna.obs['Time'] =  T_c - np.min(T_c)

sc.pp.pca(adata_rna)
sc.pl.pca(
adata_rna,
color=["Time", 'celltype'],
ncols = 2,
size = 100*adata_rna.obs['n_cells'],
cmap = 'inferno',
title = 'Time since initial condition in hours'
)
sc.pl.umap(
adata_rna,
color=["Time", 'celltype'],
ncols = 2,
size = 100*adata_rna.obs['n_cells'],
cmap = 'inferno',
title = 'Time since initial condition in hours'
)
fig = cd.posterior_data_geneset(
    posterior,
    model_input,
    adata_rna,
    ['Satb2', 'Eomes', 'Stmn2', 'Hopx'],
    plot_alpha=True
)

#Is this single-nucleus seq?
#Update to have all rates
fig = cd.posterior_data_geneset(
    posterior,
    model_input,
    adata_rna,
    list(adata_rna.var_names)[0:10],
    plot_alpha=True
)
# import pickle

# # Save the dictionary to a file
# with open("/g/stegle/aivazidis/data/posterior_RNA1layer_2500epochs_1.pkl", "wb") as f:
#     pickle.dump(posterior, f)

4 Test of LARRY Data

# Load training data:
data_directory = "/home/tchari/CellDynamicsData/LARRY/mono_lin_well1.h5ad"
adata_rna = sc.read_h5ad(data_directory)
adata_rna.var_names_make_unique()

# Load TF file:
repo_path = '/home/tchari/CellDynamics/'
tfs_path = repo_path + '/' + 'Mouse_TFs.txt'
tf_list = list(pd.read_csv(tfs_path, header = None).iloc[:,0])
#sc.pp.log1p(adata_rna)
adata_rna = cd.filter_genes(adata_rna, tf_list, n_top_genes = 1000, count_threshold = 102) #Changed from 10^6
guide = AmortizedNormal(model, init_loc_fn=init_to_mean, predict_detection_l_c=False)
adata_rna
adata_rna.var_names
model_input = prepare_model_input(
    adata_rna,
    tf_list,
    n_cells_col = "n_cells")
model_input["return_alpha"] = True
prior_samples = sample_prior(model, model_input, num_samples = 1)
# #Hopx > Stem Cell, Eomes > intermediate progenitor, Satb2 > Upper Layer Neuron, Stmn2 > General Neuron
# fig = cd.prior_data_geneset(prior_samples, model_input, adata_rna, geneset=["Eomes", "Hopx", "Satb2", "Stmn2"], subplot_size=(18, 12), 
#                            plot_alpha = True)
fig = cd.predictions_vs_data(
    model_input['data'],
    prior_samples['mu'],
    ylabel = 'Prior Predictions',
    title = 'Prior vs Data'
)
T_c = np.mean(prior_samples['T_c'], axis = 0)

adata_rna.obs['Time'] =  T_c - np.min(T_c)

sc.pp.pca(adata_rna)
sc.pl.pca(
adata_rna,
color=["Time", 'time_info'],
ncols = 2,
size = 100*adata_rna.obs['n_cells'],
cmap = 'inferno',
title = 'Time since initial condition in hours'
)
sc.pl.pca(
adata_rna,
color=["Time", 'state_info'],
ncols = 2,
size = 100*adata_rna.obs['n_cells'],
cmap = 'inferno',
title = 'Time since initial condition in hours'
)
guide, svi, svi_state, losses, model_input = cd.train_svi(
    model,
    guide,
    model_input = model_input,
    max_iterations = 1000,
    min_lr = 0.001,
    max_lr = 0.01,
    ramp_up_fraction = 0.1,
    cell_batch_size = 0,
    log_interval = 100,
)
cd.plot_elbo_loss(losses)
posterior = cd.extract_posterior_estimates(
    model,
    guide,
    svi,
    svi_state,
    quantiles = [],
    num_samples = 0,
    modes = [1],
    model_input = model_input,
    deterministic_sites = ["predictions_rearranged", "T_c", "mu", "alpha_cg"]
)
# import pickle

# # Load the dictionary from the file
# with open("/g/stegle/aivazidis/data/posterior_RNA1layer_2500epochs_1.pkl", "rb") as f:
#     posterior = pickle.load(f)
fig = cd.predictions_vs_data(
    model_input['data'],
    posterior['deterministic']['mu'],
    ylabel = 'Posterior Predictions',
    title = 'Posterior vs Data',
)
fig = cd.predictions_vs_data(
    model_input['data'],
    posterior['deterministic']['predictions_rearranged'][:,:,0,:,:],
    ylabel = 'Posterior Predictions',
    title = 'Posterior vs Data',
)


#why are 'biological' mu predictions lower than mu after sampling?
T_c = np.mean(posterior['deterministic']['T_c'], axis = 0)

adata_rna.obs['Time'] =  T_c - np.min(T_c)

sc.pp.pca(adata_rna)
sc.pl.pca(
adata_rna,
color=["Time", 'time_info','state_info'],
ncols = 3,
size = 100*adata_rna.obs['n_cells'],
cmap = 'inferno',
title = 'Time since initial condition in hours'
)
fig = cd.posterior_data_geneset(
    posterior,
    model_input,
    adata_rna,
    list(adata_rna.var_names)[0:10],
    plot_alpha=True
)

4.1 Test on other well/held-out sample

data_directory = "/home/tchari/CellDynamicsData/LARRY/mono_lin_well2.h5ad"
adata_rna_test = sc.read_h5ad(data_directory)
adata_rna_test.var_names_make_unique()
#sc.pp.log1p(adata_rna)
#adata_rna_test = cd.filter_genes(adata_rna_test, tf_list, n_top_genes = 1000, count_threshold = 102) #Changed from 10^6
adata_rna_test = adata_rna_test[:,adata_rna.var_names]
adata_rna_test
model_input_test = prepare_model_input(
    adata_rna_test,
    tf_list,
    n_cells_col = "n_cells")
model_input_test["return_alpha"] = True
posterior_test = cd.extract_posterior_estimates(
    model,
    guide,
    svi,
    svi_state,
    quantiles = [],
    num_samples = 0,
    modes = [1],
    model_input = model_input_test,
    deterministic_sites = ["predictions_rearranged", "T_c", "mu", "alpha_cg"]
)
fig = cd.predictions_vs_data(
    model_input_test['data'],
    posterior_test['deterministic']['mu'],
    ylabel = 'Posterior Predictions',
    title = 'Posterior vs Data (Test)',
)
fig = cd.predictions_vs_data(
    model_input_test['data'],
    posterior_test['deterministic']['predictions_rearranged'][:,:,0,:,:],
    ylabel = 'Posterior Predictions',
    title = 'Posterior vs Data (Test)',
)


#why are 'biological' mu predictions lower than mu after sampling?
T_c = np.mean(posterior_test['deterministic']['T_c'], axis = 0)

adata_rna_test.obs['Time'] =  T_c - np.min(T_c)

sc.pp.pca(adata_rna_test)
sc.pl.pca(
adata_rna_test,
color=["Time", 'time_info','state_info'],
ncols = 3,
size = 100*adata_rna_test.obs['n_cells'],
cmap = 'inferno',
title = 'Time since initial condition in hours'
)
fig = cd.posterior_data_geneset(
    posterior_test,
    model_input_test,
    adata_rna_test,
    list(adata_rna_test.var_names)[0:10],
    plot_alpha=True
)
model
posterior_test