From be110b7413f3ff2031c7bc456222ec60a481eef3 Mon Sep 17 00:00:00 2001 From: Joel Hellewell Date: Mon, 26 Feb 2024 11:26:12 +0000 Subject: [PATCH] Moving optimisation code into Tree impl --- listeria_simple.fasta | 7 --- src/cli.rs | 2 +- src/dspsa.rs | 86 ++++++++++++++++++++++++++++++++++ src/lib.rs | 104 +++++++++++++++++++++--------------------- 4 files changed, 140 insertions(+), 59 deletions(-) delete mode 100644 listeria_simple.fasta diff --git a/listeria_simple.fasta b/listeria_simple.fasta deleted file mode 100644 index b75f677..0000000 --- a/listeria_simple.fasta +++ /dev/null @@ -1,7 +0,0 @@ ->1 -C ->2 -C ->3 -T - diff --git a/src/cli.rs b/src/cli.rs index 3a2e74f..5883de7 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -5,7 +5,7 @@ use clap::Parser; #[command(version, about, long_about = None)] pub struct Args { /// Alignment file in FASTA format - #[arg(short, long, default_value = "listeria0.aln")] + #[arg(short, long, default_value = "tests/test_files_in/listeria0.aln")] pub alignment: String, /// Write the likelihood of the tree and alignment, do not optimise diff --git a/src/dspsa.rs b/src/dspsa.rs index 8debcea..d52ebe3 100644 --- a/src/dspsa.rs +++ b/src/dspsa.rs @@ -1,4 +1,5 @@ use rand::Rng; +use crate::Tree; pub fn phi(v: &[f64]) -> Vec { v.iter().enumerate().map(|(i, value)| { @@ -27,4 +28,89 @@ pub fn peturbation_vec(n: usize) -> Vec { }).collect(); delta[0] = 0.0; delta +} + +pub fn theta_change(pivec: &Vec, delta: &Vec, plus: bool) -> Vec { + + let zip = pivec.iter().zip(delta.iter()); + + match plus { + true => { + zip + .map(|(x, y)| (x + (y / 2.0)).round() as usize) + .collect() + }, + false => { + zip + .map(|(x, y)| (x - (y / 2.0)).round() as usize) + .collect() + } + } +} + +impl Tree { + pub fn optimise(&mut self, q: &na::Matrix4, iterations: usize) { + // Convert tree vector to Vec + let mut theta: Vec = self.tree_vec.iter().map(|x| *x as f64).collect(); + println!("Current tree vector is: {:?}", theta); + println!("Current likelihood is: {}", self.get_tree_likelihood()); + let n: usize = theta.len(); + + // Tuning parameters for optimisation, will + // eventually have defaults or be passed in + let a: f64 = 2.0; + let cap_a: f64 = 2.0; + let alpha: f64 = 0.75; + + // Pre-allocate vectors + let mut delta: Vec = Vec::with_capacity(n); + let mut pivec: Vec = Vec::with_capacity(n); + let mut thetaplus: Vec = Vec::with_capacity(n); + let mut thetaminus: Vec = Vec::with_capacity(n); + + // Optimisation loop + for k in 0..=iterations { + println!("Optimisation step {} out of {}", k, iterations); + // Generate peturbation vector + delta = peturbation_vec(n); + + // Generate pi vector + pivec = piv(&theta); + + // Calculate theta+ and theta-, + // New tree vectors based on peturbation + thetaplus = theta_change(&pivec, &delta, true); + thetaminus = theta_change(&pivec, &delta, false); + + // Update tree and calculate likelihoods + self.update_tree(Some(thetaplus), false); + self.update_likelihood(&q); + let lplus: f64 = self.get_tree_likelihood(); + + self.update_tree(Some(thetaminus), false); + self.update_likelihood(&q); + let lminus: f64 = self.get_tree_likelihood(); + + // Update theta based on likelihoods of theta+/- + let ldiff = lplus - lminus; + + let ghat: Vec = delta.iter() + .map(|el| if !el.eq(&0.0) {el * ldiff} else {0.0}).collect(); + + let ak: f64 = a / (1.0 + cap_a + k as f64).powf(alpha); + + // Set new theta + theta = theta.iter().zip(ghat.iter()) + .map(|(theta, g)| *theta - ak * g).collect(); + println!("New tree vector is: {:?}", theta); + } + + // Update final tree after finishing optimisation + println!("New tree vector is: {:?}", theta); + let new_tree_vec: Vec = theta.iter().map(|x| *x as usize).collect(); + println!("New tree vector is: {:?}", new_tree_vec); + self.update_tree(Some(new_tree_vec), false); + self.update_likelihood(&q); + println!("New tree likelihood is {}", self.get_tree_likelihood()); + } } \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 607bf09..a54d743 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,74 +51,76 @@ pub fn main() { println!("{:?}", tr.newick()); println!("{:?}", tr.tree_vec); - if !args.no_optimise { - let mut theta: Vec = tr.tree_vec.iter().map(|x| *x as f64).collect(); - let n = theta.len(); + tr.optimise(&q, 5); - let a = 2.0; - let A = 2.0; - let alpha = 0.75; - // let k = 0; + // if !args.no_optimise { + // let mut theta: Vec = tr.tree_vec.iter().map(|x| *x as f64).collect(); + // let n = theta.len(); - let mut llvec: Vec = Vec::new(); + // let a = 2.0; + // let A = 2.0; + // let alpha = 0.75; + // // let k = 0; - let start = Instant::now(); - for k in 0..=200 { - println!("k: {:?}", k); - // println!("theta: {:?}", theta); + // let mut llvec: Vec = Vec::new(); - // // Peturbation vector - let delta = peturbation_vec(n); - // println!("delta: {:?}", delta); + // let start = Instant::now(); + // for k in 0..=200 { + // println!("k: {:?}", k); + // // println!("theta: {:?}", theta); - // // Pi vector - let pivec: Vec = piv(&theta); - // // println!("pivec: {:?}", pivec); + // // // Peturbation vector + // let delta = peturbation_vec(n); + // // println!("delta: {:?}", delta); - // // theta+/- - let thetaplus: Vec = pivec.iter().zip(delta.iter()).map(|(x, y)| (x + (y / 2.0)).round() as usize).collect(); - let thetaminus: Vec = pivec.iter().zip(delta.iter()).map(|(x, y)| (x - (y / 2.0)).round() as usize).collect(); + // // // Pi vector + // let pivec: Vec = piv(&theta); + // // // println!("pivec: {:?}", pivec); - // // println!("thetaplus: {:?}", thetaplus); - // // println!("thetaminus: {:?}", thetaminus); + // // // theta+/- + // let thetaplus: Vec = pivec.iter().zip(delta.iter()).map(|(x, y)| (x + (y / 2.0)).round() as usize).collect(); + // let thetaminus: Vec = pivec.iter().zip(delta.iter()).map(|(x, y)| (x - (y / 2.0)).round() as usize).collect(); - // // Calculate likelihood at theta trees - tr.update_tree(Some(thetaplus), false); - // // println!("tree changes: {:?}", tr.changes); - tr.update_likelihood(&q); - let x1 = tr.get_tree_likelihood(); - // // println!("thetaplus ll: {:?}", x1); + // // // println!("thetaplus: {:?}", thetaplus); + // // // println!("thetaminus: {:?}", thetaminus); - tr.update_tree(Some(thetaminus), false); - // // println!("tree changes: {:?}", tr.changes); - tr.update_likelihood(&q); - let x2 = tr.get_tree_likelihood(); - // // println!("thetaminus ll: {:?}", x2); + // // // Calculate likelihood at theta trees + // tr.update_tree(Some(thetaplus), false); + // // // println!("tree changes: {:?}", tr.changes); + // tr.update_likelihood(&q); + // let x1 = tr.get_tree_likelihood(); + // // // println!("thetaplus ll: {:?}", x1); - // // Calculations to work out new theta - let ldiff = x1 - x2; - let ghat: Vec = delta.iter().map(|el| if !el.eq(&0.0) {el * ldiff} else {0.0}).collect(); + // tr.update_tree(Some(thetaminus), false); + // // // println!("tree changes: {:?}", tr.changes); + // tr.update_likelihood(&q); + // let x2 = tr.get_tree_likelihood(); + // // // println!("thetaminus ll: {:?}", x2); - let ak = a / (1.0 + A + k as f64).powf(alpha); + // // // Calculations to work out new theta + // let ldiff = x1 - x2; + // let ghat: Vec = delta.iter().map(|el| if !el.eq(&0.0) {el * ldiff} else {0.0}).collect(); - theta = theta.iter().zip(ghat.iter()).map(|(theta, g)| *theta - ak * g).collect(); + // let ak = a / (1.0 + A + k as f64).powf(alpha); - llvec.push(x1); + // theta = theta.iter().zip(ghat.iter()).map(|(theta, g)| *theta - ak * g).collect(); - // // println!("ghat: {:?}", ghat); + // llvec.push(x1); - } + // // // println!("ghat: {:?}", ghat); - let out: Vec = phi(&theta).iter().map(|x| x.round()).collect(); - println!("final theta: {:?}", out); + // } - println!("{:?}", &llvec); - // println!("{:?}", &llvec[95..100]); - } - let end = Instant::now(); + // let out: Vec = phi(&theta).iter().map(|x| x.round()).collect(); + // println!("final theta: {:?}", out); - eprintln!("Done in {}s", end.duration_since(start).as_secs()); - eprintln!("Done in {}ms", end.duration_since(start).as_millis()); - eprintln!("Done in {}ns", end.duration_since(start).as_nanos()); + // println!("{:?}", &llvec); + // // println!("{:?}", &llvec[95..100]); + // } + // let end = Instant::now(); + + // eprintln!("Done in {}s", end.duration_since(start).as_secs()); + // eprintln!("Done in {}ms", end.duration_since(start).as_millis()); + // eprintln!("Done in {}ns", end.duration_since(start).as_nanos()); }