Skip to content

Commit

Permalink
Retry with single core in geneformer (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
lazappi authored Jan 21, 2025
1 parent 01bdb9d commit 0211654
Showing 1 changed file with 38 additions and 12 deletions.
50 changes: 38 additions & 12 deletions src/methods/geneformer/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,24 @@

print(">>> Reading input...", flush=True)
sys.path.append(meta["resources_dir"])
from read_anndata_partial import read_anndata
from exit_codes import exit_non_applicable
from read_anndata_partial import read_anndata

adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns")

if adata.uns["dataset_organism"] != "homo_sapiens":
exit_non_applicable(
f"Geneformer can only be used with human data "
f"(dataset_organism == \"{adata.uns['dataset_organism']}\")"
f'(dataset_organism == "{adata.uns["dataset_organism"]}")'
)

# Set adata.var_names to gene IDs
adata.var_names = adata.var["feature_id"]
is_ensembl = all(var_name.startswith("ENSG") for var_name in adata.var_names)
if not is_ensembl:
raise ValueError(f"Geneformer requires adata.var_names to contain ENSEMBL gene ids")
exit_non_applicable(
"Geneformer requires adata.var_names to contain ENSEMBL gene ids"
)

print(f">>> Getting settings for model '{par['model']}'...", flush=True)
model_split = par["model"].split("-")
Expand Down Expand Up @@ -97,18 +99,42 @@
adata.write_h5ad(os.path.join(input_dir, "input.h5ad"))
print(adata)


# Function to try parallel execution and fall batch to a single processor if it fails
def tryParallelFunction(fun, label):
try:
fun(nproc=n_processors)
except RuntimeError as e:
# Retry with nproc=1 if error message contains "One of the subprocesses has abruptly died"
if "subprocess" in str(e) and "died" in str(e):
print(f"{label} failed. Error message: {e}", flush=True)
print("Retrying with nproc=1", flush=True)
fun(nproc=1)
else:
raise e


print(">>> Tokenizing data...", flush=True)
special_token = model_details["dataset"] == "95M"
print(f"Input size: {model_details['input_size']}, Special token: {special_token}")
tokenizer = TranscriptomeTokenizer(
nproc=n_processors,
model_input_size=model_details["input_size"],
special_token=special_token,
gene_median_file=dictionary_files["gene_median"],
token_dictionary_file=dictionary_files["token"],
gene_mapping_file=dictionary_files["ensembl_mapping"],
)
tokenizer.tokenize_data(input_dir, tokenized_dir, "tokenized", file_format="h5ad")


def tokenize_data(nproc):
tokenizer = TranscriptomeTokenizer(
nproc=nproc,
model_input_size=model_details["input_size"],
special_token=special_token,
gene_median_file=dictionary_files["gene_median"],
token_dictionary_file=dictionary_files["token"],
gene_mapping_file=dictionary_files["ensembl_mapping"],
)

tokenizer.tokenize_data(input_dir, tokenized_dir, "tokenized", file_format="h5ad")

return tokenizer


tokenizer = tryParallelFunction(tokenize_data, "Tokenizing data")

print(f">>> Getting model files for model '{par['model']}'...", flush=True)
model_files = {
Expand Down

0 comments on commit 0211654

Please sign in to comment.