Skip to content

Commit

Permalink
Merge pull request #58 from bacpop/i57-delete-fix
Browse files Browse the repository at this point in the history
Fix for ska delete
  • Loading branch information
johnlees authored Oct 4, 2023
2 parents 4eb1446 + 247b24d commit dd2c90a
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 29 deletions.
4 changes: 2 additions & 2 deletions src/generic_modes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub fn align<IntT: for<'a> UInt<'a>>(
/// Convert array to dictionary representation, runs map against a reference,
/// prints out in requested format.
pub fn map<IntT: for<'a> UInt<'a>>(
ska_array: &mut MergeSkaArray<IntT>,
ska_array: &MergeSkaArray<IntT>,
ska_ref: &mut RefSka<IntT>,
output: &Option<String>,
format: &FileType,
Expand Down Expand Up @@ -75,7 +75,7 @@ pub fn map<IntT: for<'a> UInt<'a>>(
///
/// Subsequent files are most easily passed as a slice with `[1..]`
pub fn merge<IntT: for<'a> UInt<'a>>(
first_array: &mut MergeSkaArray<IntT>,
first_array: &MergeSkaArray<IntT>,
skf_files: &[String],
output: &str,
) {
Expand Down
28 changes: 18 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ pub fn main() {
check_threads(*threads);

log::info!("Loading skf as dictionary");
if let Ok(mut ska_array) = load_array::<u64>(input, *threads) {
if let Ok(ska_array) = load_array::<u64>(input, *threads) {
log::info!(
"Making skf of reference k={} rc={}",
ska_array.kmer_len(),
Expand All @@ -552,8 +552,8 @@ pub fn main() {
*ambig_mask,
*repeat_mask,
);
map(&mut ska_array, &mut ska_ref, output, format, *threads);
} else if let Ok(mut ska_array) = load_array::<u128>(input, *threads) {
map(&ska_array, &mut ska_ref, output, format, *threads);
} else if let Ok(ska_array) = load_array::<u128>(input, *threads) {
log::info!(
"Making skf of reference k={} rc={}",
ska_array.kmer_len(),
Expand All @@ -566,7 +566,7 @@ pub fn main() {
*ambig_mask,
*repeat_mask,
);
map(&mut ska_array, &mut ska_ref, output, format, *threads);
map(&ska_array, &mut ska_ref, output, format, *threads);
} else {
panic!("Could not read input file(s): {input:?}");
}
Expand Down Expand Up @@ -609,10 +609,10 @@ pub fn main() {
}

log::info!("Loading first alignment");
if let Ok(mut first_array) = MergeSkaArray::<u64>::load(&skf_files[0]) {
merge(&mut first_array, &skf_files[1..], output);
} else if let Ok(mut first_array) = MergeSkaArray::<u128>::load(&skf_files[0]) {
merge(&mut first_array, &skf_files[1..], output);
if let Ok(first_array) = MergeSkaArray::<u64>::load(&skf_files[0]) {
merge(&first_array, &skf_files[1..], output);
} else if let Ok(first_array) = MergeSkaArray::<u128>::load(&skf_files[0]) {
merge(&first_array, &skf_files[1..], output);
} else {
panic!("Could not read input file: {}", skf_files[0]);
}
Expand Down Expand Up @@ -651,7 +651,11 @@ pub fn main() {
*min_freq,
filter,
*ambig_mask,
if output.is_none() { skf_file } else { output.as_ref().unwrap().as_str() },
if output.is_none() {
skf_file
} else {
output.as_ref().unwrap().as_str()
},
);
} else if let Ok(mut ska_array) = MergeSkaArray::<u128>::load(skf_file) {
weed(
Expand All @@ -661,7 +665,11 @@ pub fn main() {
*min_freq,
filter,
*ambig_mask,
if output.is_none() { skf_file } else { output.as_ref().unwrap().as_str() },
if output.is_none() {
skf_file
} else {
output.as_ref().unwrap().as_str()
},
);
} else {
panic!("Could not read input file: {skf_file}");
Expand Down
80 changes: 65 additions & 15 deletions src/merge_ska_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,8 @@ where
///
/// Recalculates counts, and removes any totally empty rows.
fn update_counts(&mut self) {
let mut new_counts = Vec::new();
new_counts.reserve(self.variant_count.len());

let mut new_sk = Vec::new();
new_sk.reserve(self.split_kmers.len());
let mut new_counts = Vec::with_capacity(self.variant_count.len());
let mut new_sk = Vec::with_capacity(self.split_kmers.len());

let mut new_variants = Array2::zeros((0, self.names.len()));
for (var_row, sk) in self.variants.outer_iter().zip(self.split_kmers.iter()) {
Expand All @@ -120,8 +117,7 @@ where
/// Convert a dynamic [`MergeSkaDict`] to static array representation.
pub fn new(dynamic: &MergeSkaDict<IntT>) -> Self {
let mut variants = Array2::zeros((0, dynamic.nsamples()));
let mut split_kmers: Vec<IntT> = Vec::new();
split_kmers.reserve(dynamic.ksize());
let mut split_kmers: Vec<IntT> = Vec::with_capacity(dynamic.ksize());
let mut variant_count: Vec<usize> = Vec::new();
for (kmer, bases) in dynamic.kmer_dict() {
split_kmers.push(*kmer);
Expand Down Expand Up @@ -209,20 +205,21 @@ where
if !del_name_set.is_empty() {
panic!("Could not find sample(s): {:?}", del_name_set);
}
self.names = new_names;

let mut idx_it = idx_list.iter();
let mut next_idx = idx_it.next();
let new_size = self.names.len() - idx_list.len();
let mut filtered_variants = Array2::zeros((self.ksize(), 0));
for (sample_idx, sample_variants) in self.variants.t().outer_iter().enumerate() {
if *next_idx.unwrap_or(&new_size) == sample_idx {
next_idx = idx_it.next();
} else {
filtered_variants.push_column(sample_variants).unwrap();
if let Some(next_idx_val) = next_idx {
if *next_idx_val == sample_idx {
next_idx = idx_it.next();
continue;
}
}
filtered_variants.push_column(sample_variants).unwrap();
}
self.variants = filtered_variants;
self.names = new_names;
self.update_counts();
}

Expand Down Expand Up @@ -356,8 +353,8 @@ where
.progress_count(self.variants.ncols() as u64)
.enumerate()
.map(|(i, row)| {
let mut partial_dists: Vec<(f64, f64)> = Vec::new();
partial_dists.reserve(self.variants.ncols() - (i + 1));
let mut partial_dists: Vec<(f64, f64)> =
Vec::with_capacity(self.variants.ncols() - (i + 1));
for j in (i + 1)..self.variants.ncols() {
partial_dists.push(Self::variant_dist(
&row,
Expand Down Expand Up @@ -571,3 +568,56 @@ where
})
}
}

#[cfg(test)]
mod tests {
use super::*; // Import functions and types from the parent module
use ndarray::array;

fn setup_struct() -> MergeSkaArray<u64> {
let example_array = MergeSkaArray::<u64> {
k: 31,
rc: true,
split_kmers: vec![0, 1],
variants: array![[1, 2, 3], [4, 5, 6]],
variant_count: vec![3, 3],
ska_version: "NA".to_string(),
k_bits: 64,
names: vec![
"Sample1".to_string(),
"Sample2".to_string(),
"Sample3".to_string(),
],
};
example_array
}

#[test]
fn test_delete_samples_normal() {
let mut my_struct = setup_struct();

my_struct.delete_samples(&["Sample1", "Sample2"]);

// Check that the samples were deleted
assert_eq!(my_struct.names, vec!["Sample3"]);
assert_eq!(my_struct.variants, array![[3], [6]]);
}

#[test]
#[should_panic(expected = "Invalid number of samples to remove")]
fn test_delete_samples_empty_or_all() {
let mut my_struct = setup_struct();

// This should panic
my_struct.delete_samples(&[]);
}

#[test]
#[should_panic(expected = "Could not find sample(s): ")]
fn test_delete_samples_non_existent() {
let mut my_struct = setup_struct();

// This should panic because "Sample4" does not exist
my_struct.delete_samples(&["Sample4"]);
}
}
3 changes: 1 addition & 2 deletions src/ska_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,7 @@ where
let ref_base = self.seq[map_chrom][map_pos];
let ref_allele = u8_to_base(ref_base);

let mut genotype_vec = Vec::new();
genotype_vec.reserve(var_t_owned.ncols());
let mut genotype_vec = Vec::with_capacity(var_t_owned.ncols());
let mut alt_bases: Vec<Base> = Vec::new();
let mut variant = false;
for mapped_base in sample_variants {
Expand Down

0 comments on commit dd2c90a

Please sign in to comment.