Skip to content

Commit

Permalink
Add CellBender v3 support to read_cellbender() function
Browse files Browse the repository at this point in the history
  • Loading branch information
cakirb committed Jan 4, 2024
1 parent 20d782f commit cd599d9
Showing 1 changed file with 116 additions and 0 deletions.
116 changes: 116 additions & 0 deletions sctk/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,44 @@ def read_cellbender(
else:
raise ValueError("The data doesn't look like cellbender output")
n_var, n_obs = tuple(mat["shape"][()])
if "metadata" in f:
ad = read_cellbender_v3(f,
feat=feat,
feat_name=feat_name,
vardict=vardict,
n_var=n_var,
n_obs=n_obs,
remove_zero=True,
remove_nan=True,
train_history=False,
latent_gene_encoding=False,
add_suffix=None,)
else:
ad = read_cellbender_v2(mat,
feat=feat,
feat_name=feat_name,
vardict=vardict,
n_var=n_var,
n_obs=n_obs,
remove_zero=True,
remove_nan=True,
train_history=False,
latent_gene_encoding=False,
add_suffix=None,)
return ad

def read_cellbender_v2(mat,
feat,
feat_name,
vardict,
n_var,
n_obs,
remove_zero=True,
remove_nan=True,
train_history=False,
latent_gene_encoding=False,
add_suffix=None,
):
cols = ["latent_cell_probability", "latent_RT_efficiency"]
if "barcode_indices_for_latents" in mat:
bidx = mat["barcode_indices_for_latents"][()]
Expand Down Expand Up @@ -194,6 +232,84 @@ def read_cellbender(

return ad1

def read_cellbender_v3(f,
feat,
feat_name,
vardict,
n_var,
n_obs,
remove_zero=True,
remove_nan=True,
train_history=False,
latent_gene_encoding=False,
add_suffix=None,
):
import numpy as np
import anndata
import scipy.sparse as sp
import pandas as pd
cols = ["cell_probability", "droplet_efficiency"]
if "barcode_indices_for_latents" in f["droplet_latents"]:
bidx = f["droplet_latents"]["barcode_indices_for_latents"][()]
obsdict = {}
for x in cols:
val = np.empty(n_obs)
val.fill(np.nan)
val[bidx] = f["droplet_latents"][x][()]
obsdict[x] = val
if latent_gene_encoding:
lge = f["droplet_latents"]["gene_expression_encoding"][()]
obsm = np.empty((n_obs, lge.shape[1]))
obsm.fill(np.nan)
obsm[bidx, :] = lge
else:
obsdict = {x: f["matrix"][x] for x in cols}
if latent_gene_encoding:
obsm = f["droplet_latents"]["gene_expression_encoding"][()]
barcodes = np.array(
[b[:-2] if b.endswith("-1") else b for b in f["matrix"]["barcodes"][()].astype(str)]
)
ad = anndata.AnnData(
X=sp.csr_matrix(
(f["matrix"]["data"][()], f["matrix"]["indices"][()], f["matrix"]["indptr"][()]),
shape=(n_obs, n_var),
),
var=pd.DataFrame(vardict, index=feat_name.astype(str)),
obs=pd.DataFrame(obsdict, index=barcodes),
uns={
"target_false_positive_rate": f['metadata']["target_false_positive_rate"][()],
"test_elbo": list(f['metadata']['learning_curve_test_elbo']),
"test_epoch": list(f['metadata']['learning_curve_test_epoch']),
"overall_change_in_train_elbo": list(f['metadata']["overall_change_in_train_elbo"]),
}
if train_history
else {},
)
ad.var_names_make_unique()
if latent_gene_encoding:
ad.obsm["X_latent_gene_encoding"] = obsm

mask_nan = np.isnan(ad.obs.cell_probability)
mask_0 = ad.X.sum(axis=1).A1 <= 0

mask_remove = np.zeros(n_obs).astype(bool)
if remove_nan:
mask_remove = mask_remove | mask_nan
if remove_zero:
mask_remove = mask_remove | mask_0
idx_remove = np.where(mask_remove)[0]

idx_sort = pd.Series(np.argsort(ad.obs_names))
idx_sort = idx_sort[~idx_sort.isin(idx_remove)]

ad1 = ad[idx_sort.values].copy()
del ad

if add_suffix:
ad1.obs_names = ad1.obs_names.astype(str) + add_suffix

return ad1


def read_h5ad(input_h5ad, component="all", **kwargs):
if component == "all":
Expand Down

0 comments on commit cd599d9

Please sign in to comment.