Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
theosanderson committed Oct 14, 2021
1 parent 93d05e0 commit deca517
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 31 deletions.
43 changes: 28 additions & 15 deletions src/chronumental/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ def main():
type=float)

parser.add_argument(
'-vd',
'-variance_dates',
default=0.3,
type=float,
help=
"Scale factor for date distribution. Essentially a measure of how uncertain we think the measured dates are."
)

parser.add_argument('-vb',
parser.add_argument('-variance_branch_length',
default=1,
type=float,
help="Scale factor for branch length distribution")
help="Scale factor for branch length distribution. Essentially how close we want to match the expectation of the Poisson.")

parser.add_argument('--steps',
default=1000,
Expand Down Expand Up @@ -100,10 +100,10 @@ def main():
args = parser.parse_args()

if args.dates_out is None:
args.dates_out = prepend_to_file_name(args.dates, "dates")+".tsv"
args.dates_out = prepend_to_file_name(args.dates, "chronumental_dates")+".tsv"

if args.tree_out is None:
args.tree_out = prepend_to_file_name(args.tree, "tree")
args.tree_out = prepend_to_file_name(args.tree, "chronumental_timetree")

metadata = input.get_metadata(args.dates)

Expand Down Expand Up @@ -150,7 +150,7 @@ def main():

print(branch_distances_array)

my_model = models.FixedClock(rows, cols, branch_distances_array, args.clock, args.vb ,args.vd, terminal_target_dates_array)
my_model = models.FixedClock(rows, cols, branch_distances_array, args.clock, args.variance_branch_length ,args.variance_dates, terminal_target_dates_array)

print("Performing SVI:")
svi = SVI(my_model.model, my_model.guide, optim.Adam(args.lr), Trace_ELBO())
Expand All @@ -175,31 +175,44 @@ def main():
length_cor = np.corrcoef(
branch_distances_array,
times)[0, 1] # This correlation should be relatively high
print(step, loss, date_cor, date_error, max_date_error, length_cor,
svi.get_params(state)['latent_mutation_rate_auto_loc'])
print(f"Step:{step}\tLoss:{loss}\tDate correlation:{date_cor:10.4f}\tMean date error:{date_error:10.4f}\tMax date error:{max_date_error:10.4f}\tLength correlation:{length_cor:10.4f}\tInferred mutation rate:{svi.get_params(state)['latent_mutation_rate_auto_loc']:10.4f}")

tree2 = input.read_tree(args.tree)

branch_length_lookup = dict(
zip(names_init,
svi.get_params(state)['latent_time_length_auto_loc'].tolist()))
for i, node in enumerate(tree2.traverse_postorder()):
my_model.get_branch_times(svi.get_params(state)).tolist()))

total_lengths_in_time = {}

total_lengths= dict()

for i, node in enumerate(tree2.traverse_preorder()):
if not node.label:
node_name = f"internal_node_{i}"
if args.name_all_nodes:
node.label = node_name
else:
node_name = node.label.replace("'", "")
node.branch_length = branch_length_lookup[node_name]
if not node.parent:
total_lengths[node] = node.branch_length
else:
total_lengths[node] = node.branch_length + total_lengths[node.parent]

if node.label:
total_lengths_in_time[node.label.replace("'","")] = total_lengths[node]




tree2.write_tree_newick(args.tree_out)
print(f"Wrote tree to {args.tree_out}")

new_dates_absolute = [lookup[reference_point] + datetime.timedelta(days=x) for x in new_dates.tolist()]

output_meta = pd.DataFrame({'strain': terminal_names,
'predicted_date': new_dates_absolute})
origin_date = lookup[reference_point]
output_dates = {name: origin_date + datetime.timedelta(days=x) for name,x in total_lengths_in_time.items()}

names, values = zip(*output_dates.items())
output_meta = pd.DataFrame({"strain": names, "predicted_date": values})


output_meta.to_csv(args.dates_out, sep="\t", index=False)
Expand Down
26 changes: 15 additions & 11 deletions src/chronumental/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,25 @@ def get_oldest(full):
return oldest_date, reference_point

def get_target_dates(tree, lookup, reference_point):
terminal_targets = {}
for terminal in tqdm.tqdm(tree.traverse_leaves(),
"Creating target date array"):

terminal.label = terminal.label.replace("'", "")
if terminal.label in lookup:
date = lookup[terminal.label]
diff = (date - lookup[reference_point]).days
terminal_targets[terminal.label] = diff
return terminal_targets
"""
Returns a list of dictionary mapping names to integer dates being targeted.
Dates are relative to the date of the reference point, which forms an arbitary origin.
"""
terminal_targets = {}
for terminal in tqdm.tqdm(tree.traverse_leaves(),
"Creating target date array"):

terminal.label = terminal.label.replace("'", "")
if terminal.label in lookup:
date = lookup[terminal.label]
diff = (date - lookup[reference_point]).days
terminal_targets[terminal.label] = diff
return terminal_targets


def get_initial_branch_lengths_and_name_all_nodes(tree):
initial_branch_lengths = {}
for i, node in tqdm.tqdm(enumerate(tree.traverse_postorder()),
for i, node in tqdm.tqdm(enumerate(tree.traverse_preorder()),
"finding initial branch_lengths"):
if not node.label:
name = f"internal_node_{i}"
Expand Down
10 changes: 5 additions & 5 deletions src/chronumental/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from . import helpers

class FixedClock(object):
def __init__(self, rows, cols, branch_distances_array, clock_rate, vb ,vd, terminal_target_dates_array):
def __init__(self, rows, cols, branch_distances_array, clock_rate, variance_branch_length ,variance_dates, terminal_target_dates_array):
self.rows = rows
self.cols = cols
self.branch_distances_array = branch_distances_array
self.clock_rate = clock_rate
self.terminal_target_dates_array = terminal_target_dates_array
self.vb = vb
self.vd = vd
self.variance_branch_length = variance_branch_length
self.variance_dates = variance_dates

self.initial_time = 365 * (
branch_distances_array
Expand All @@ -37,7 +37,7 @@ def model(self):
"latent_time_length",
dist.TruncatedNormal(low=0,
loc=self.initial_time,
scale=self.vb,
scale=self.variance_branch_length,
validate_args=True))

mutation_rate = numpyro.sample(
Expand All @@ -57,7 +57,7 @@ def model(self):
final_dates = numpyro.sample(
f"final_dates",
dist.Normal(calced_dates,
self.vd * jnp.ones(calced_dates.shape[0])),
self.variance_dates * jnp.ones(calced_dates.shape[0])),
obs=self.terminal_target_dates_array)

def get_branch_times(self , params):
Expand Down

0 comments on commit deca517

Please sign in to comment.