Skip to content

Commit

Permalink
Merge pull request #1 from bacpop/phyml_test
Browse files Browse the repository at this point in the history
Add likelihood comparison against phyml
  • Loading branch information
jhellewell14 authored Feb 23, 2024
2 parents 6378788 + f3ae9b2 commit b1e4584
Show file tree
Hide file tree
Showing 9 changed files with 388 additions and 121 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ jobs:
- nightly
steps:
- uses: actions/checkout@v3
- name: Install phyml
run: sudo apt-get install -y phyml
- run: rustup update ${{ matrix.toolchain }} && rustup default ${{ matrix.toolchain }}
- run: cargo build --verbose
- run: cargo test --verbose
14 changes: 11 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
[package]
name = "maple"
name = "bactrees"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
ndarray = "0.15.6"
rand = "0.8.5"
needletail = "0.5.1"
nalgebra = "0.32.3"
clap = { version = "4.5", features = ["derive"]}

[dev-dependencies]
# testing
regex = "1.10"
snapbox = "0.4"
predicates = "2.1"
assert_fs = "1.0"
pretty_assertions = "1.3"
float-cmp = "0.9"

[profile.release]
debug = 1
19 changes: 19 additions & 0 deletions src/cli.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

use clap::Parser;

#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
pub struct Args {
/// Alignment file in FASTA format
#[arg(short, long, default_value = "listeria0.aln")]
pub alignment: String,

/// Write the likelihood of the tree and alignment, do not optimise
#[arg(long, default_value_t = false)]
pub no_optimise: bool,
}

/// Function to parse command line args into [`Args`] struct
pub fn cli_args() -> Args {
Args::parse()
}
124 changes: 124 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
mod gen_list;
mod import;
mod likelihoods;
mod node;
mod phylo2vec;
mod tests;
mod tree;
mod dspsa;

use rand::Rng;
use rand::random;

use crate::dspsa::peturbation_vec;
use crate::gen_list::*;
use crate::phylo2vec::*;
use crate::tree::Tree;
use crate::likelihoods::logse;
use crate::node::Node;
use crate::dspsa::phi;
use crate::dspsa::piv;
use std::collections::HashSet;
use std::thread::current;
use std::time::Instant;
extern crate nalgebra as na;

pub mod cli;
use crate::cli::*;

pub fn main() {
let args = cli_args();

let start = Instant::now();

// Define rate matrix
let q: na::Matrix4<f64> = na::Matrix4::new(
-1.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0,
1.0 / 3.0, -1.0, 1.0 / 3.0, 1.0 / 3.0,
1.0 / 3.0, 1.0 / 3.0, -1.0, 1.0 / 3.0,
1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0, -1.0,
);

// let mut tr = phylo2vec_quad(vec![0, 0, 0]);
let mut tr = phylo2vec_quad(random_tree(27));
// let filename = "listeria0.aln";
tr.add_genetic_data(&args.alignment);

tr.update_likelihood_postorder(&q);

// println!("{:?}", tr.mutation_lists);
println!("{}", tr.get_tree_likelihood());
println!("{:?}", tr.newick());
println!("{:?}", tr.tree_vec);

if !args.no_optimise {
let mut theta: Vec<f64> = tr.tree_vec.iter().map(|x| *x as f64).collect();
let n = theta.len();

let a = 2.0;
let A = 2.0;
let alpha = 0.75;
// let k = 0;

let mut llvec: Vec<f64> = Vec::new();

let start = Instant::now();
for k in 0..=200 {
println!("k: {:?}", k);
// println!("theta: {:?}", theta);

// // Peturbation vector
let delta = peturbation_vec(n);
// println!("delta: {:?}", delta);

// // Pi vector
let pivec: Vec<f64> = piv(&theta);
// // println!("pivec: {:?}", pivec);

// // theta+/-
let thetaplus: Vec<usize> = pivec.iter().zip(delta.iter()).map(|(x, y)| (x + (y / 2.0)).round() as usize).collect();
let thetaminus: Vec<usize> = pivec.iter().zip(delta.iter()).map(|(x, y)| (x - (y / 2.0)).round() as usize).collect();

// // println!("thetaplus: {:?}", thetaplus);
// // println!("thetaminus: {:?}", thetaminus);

// // Calculate likelihood at theta trees
tr.update_tree(Some(thetaplus), false);
// // println!("tree changes: {:?}", tr.changes);
tr.update_likelihood(&q);
let x1 = tr.get_tree_likelihood();
// // println!("thetaplus ll: {:?}", x1);

tr.update_tree(Some(thetaminus), false);
// // println!("tree changes: {:?}", tr.changes);
tr.update_likelihood(&q);
let x2 = tr.get_tree_likelihood();
// // println!("thetaminus ll: {:?}", x2);

// // Calculations to work out new theta
let ldiff = x1 - x2;
let ghat: Vec<f64> = delta.iter().map(|el| if !el.eq(&0.0) {el * ldiff} else {0.0}).collect();

let ak = a / (1.0 + A + k as f64).powf(alpha);

theta = theta.iter().zip(ghat.iter()).map(|(theta, g)| *theta as f64 - ak * g).collect();

llvec.push(x1);

// // println!("ghat: {:?}", ghat);

}

let out: Vec<f64> = phi(&theta).iter().map(|x| x.round()).collect();
println!("final theta: {:?}", out);

println!("{:?}", &llvec);
// println!("{:?}", &llvec[95..100]);
}
let end = Instant::now();

eprintln!("Done in {}s", end.duration_since(start).as_secs());
eprintln!("Done in {}ms", end.duration_since(start).as_millis());
eprintln!("Done in {}ns", end.duration_since(start).as_nanos());

}
119 changes: 1 addition & 118 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,120 +1,3 @@
mod gen_list;
mod import;
mod likelihoods;
mod node;
mod phylo2vec;
mod tests;
mod tree;
mod dspsa;

use rand::Rng;
use rand::random;

use crate::dspsa::peturbation_vec;
use crate::gen_list::*;
use crate::phylo2vec::*;
use crate::tree::Tree;
use crate::likelihoods::logse;
use crate::node::Node;
use crate::dspsa::phi;
use crate::dspsa::piv;
use std::collections::HashSet;
use std::thread::current;
use std::time::Instant;
extern crate nalgebra as na;

fn main() {
let start = Instant::now();

// Define rate matrix
let q: na::Matrix4<f64> = na::Matrix4::new(
-1.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0,
1.0 / 3.0, -1.0, 1.0 / 3.0, 1.0 / 3.0,
1.0 / 3.0, 1.0 / 3.0, -1.0, 1.0 / 3.0,
1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0, -1.0,
);

// let mut tr = phylo2vec_quad(vec![0, 0, 0]);
let mut tr = phylo2vec_quad(random_tree(27));
// let filename = "listeria0.aln";
let filename = "listeria0.aln";
tr.add_genetic_data(filename);

tr.update_likelihood_postorder(&q);

// println!("{:?}", tr.mutation_lists);
println!("{}", tr.get_tree_likelihood());
println!("{:?}", tr.newick());
println!("{:?}", tr.tree_vec);
// println!("{:?}", tr.nodes);

let mut theta: Vec<f64> = tr.tree_vec.iter().map(|x| *x as f64).collect();
let n = theta.len();

let a = 2.0;
let A = 2.0;
let alpha = 0.75;
// let k = 0;

let mut llvec: Vec<f64> = Vec::new();

let start = Instant::now();
for k in 0..=200 {
println!("k: {:?}", k);
// println!("theta: {:?}", theta);

// // Peturbation vector
let delta = peturbation_vec(n);
// println!("delta: {:?}", delta);

// // Pi vector
let pivec: Vec<f64> = piv(&theta);
// // println!("pivec: {:?}", pivec);

// // theta+/-
let thetaplus: Vec<usize> = pivec.iter().zip(delta.iter()).map(|(x, y)| (x + (y / 2.0)).round() as usize).collect();
let thetaminus: Vec<usize> = pivec.iter().zip(delta.iter()).map(|(x, y)| (x - (y / 2.0)).round() as usize).collect();

// // println!("thetaplus: {:?}", thetaplus);
// // println!("thetaminus: {:?}", thetaminus);

// // Calculate likelihood at theta trees
tr.update_tree(Some(thetaplus), false);
// // println!("tree changes: {:?}", tr.changes);
tr.update_likelihood(&q);
let x1 = tr.get_tree_likelihood();
// // println!("thetaplus ll: {:?}", x1);

tr.update_tree(Some(thetaminus), false);
// // println!("tree changes: {:?}", tr.changes);
tr.update_likelihood(&q);
let x2 = tr.get_tree_likelihood();
// // println!("thetaminus ll: {:?}", x2);

// // Calculations to work out new theta
let ldiff = x1 - x2;
let ghat: Vec<f64> = delta.iter().map(|el| if !el.eq(&0.0) {el * ldiff} else {0.0}).collect();

let ak = a / (1.0 + A + k as f64).powf(alpha);

theta = theta.iter().zip(ghat.iter()).map(|(theta, g)| *theta as f64 - ak * g).collect();

llvec.push(x1);

// // println!("ghat: {:?}", ghat);

}

let out: Vec<f64> = phi(&theta).iter().map(|x| x.round()).collect();
println!("final theta: {:?}", out);
let end = Instant::now();


println!("{:?}", &llvec);
// println!("{:?}", &llvec[95..100]);

eprintln!("Done in {}s", end.duration_since(start).as_secs());
eprintln!("Done in {}ms", end.duration_since(start).as_millis());
eprintln!("Done in {}ns", end.duration_since(start).as_nanos());

bactrees::main();
}
78 changes: 78 additions & 0 deletions tests/common/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use std::path::{Path, PathBuf};
use std::fs::File;
use std::io::{Write, BufReader, BufRead, Error};

use assert_fs::{prelude::*, TempDir};
use predicates::prelude::*;

#[cfg(test)]

// Creates correct path for input/output files
static FILE_IN: &'static str = "tests/test_files_in";
static SYM_IN: &'static str = "input";
static SYM_TEST: &'static str = "correct";

#[derive(Debug, PartialEq, Copy, Clone)]
pub enum TestDir {
Input,
Correct,
}

pub struct TestSetup {
wd: TempDir,
}

impl TestSetup {
pub fn setup() -> Self {
let wd = assert_fs::TempDir::new().unwrap();
wd.child(SYM_IN)
.symlink_to_dir(
Path::new(FILE_IN)
.canonicalize()
.expect("Could not link expected files"),
)
.unwrap();
Self { wd }
}

pub fn get_wd(&self) -> String {
self.wd.path().display().to_string()
}

pub fn create_file(&self, filename: &str) -> Result<(File, String), Error> {
let path = self.wd.with_file_name(filename);
let output = File::create(path.clone())?;
Ok((output, path.canonicalize().unwrap().to_str().unwrap().to_string()))
}

pub fn file_path(&self, name: &str, file_type: TestDir) -> PathBuf {
match file_type {
TestDir::Input => {
PathBuf::from(&format!("{}/{}/{}", self.wd.path().display(), SYM_IN, name))
}
TestDir::Correct => PathBuf::from(&format!(
"{}/{}/{}",
self.wd.path().display(),
SYM_TEST,
name
)),
}
}

pub fn file_string(&self, name: &str, file_type: TestDir) -> String {
self.file_path(name, file_type)
.to_str()
.expect("Could not unpack file path")
.to_owned()
}

pub fn file_check(&self, name_out: &str, name_correct: &str) -> bool {
let predicate_file = predicate::path::eq_file(self.wd.child(name_out).path());
predicate_file.eval(self.file_path(name_correct, TestDir::Correct).as_path())
}

pub fn file_exists(&self, name_out: &str) -> bool {
let predicate_fn = predicate::path::is_file();
predicate_fn.eval(self.wd.child(name_out).path())
}
}
Loading

0 comments on commit b1e4584

Please sign in to comment.