Skip to content

Commit

Permalink
pivot matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
attack68 committed Feb 12, 2024
1 parent d085e75 commit 431005d
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 72 deletions.
61 changes: 19 additions & 42 deletions src/bin/scratch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,46 +9,23 @@
// pub b: f64,
// }
//
// #[derive(Debug, Clone)]
// enum TestE {
// A(f64),
// B(Test),
// }
//
// impl Add for TestE {
// type Output = Self;
// fn add(self, other: Self) -> Self {
// use TestE::*;
// match self {
// A(MT) => match other {
// A(MTO) => A(MT + MTO),
// B(MTO) => B(Test {a: 10.0, b: 0.0})
// },
// B(MT) => match other {
// A(MTO) => B(Test{a: 10.0, b: 20.0}),
// B(MTO) => B(Test{a: 12.0, b: 30.0})
// }
// }
// }
// }
//
// fn add(a: TestE, b: TestE) -> TestE {
// a + b
// }
#[derive(Debug, Clone)]
enum Fs {
F64(f64),
F32(f32),
}

// fn main() {
// let x = Duals::Float(2.0);
// let y = Duals::Float(3.0);
// let z = x + y;
// println!("{:?}", z);
// let x_1 = NaiveDate::from_ymd_opt(2003, 12, 2).unwrap();
// let x_2 = NaiveDate::from_ymd_opt(2003, 12, 22).unwrap();
// let x = NaiveDate::from_ymd_opt(2003, 12, 18).unwrap();
//
// let y_1 = Duals::Float(100.0);
// let y_2 = Duals::Float(200.0);
//
// let z = interpolate_with_method(&x, &x_1, y_1, &x_2, y_2, "linear", None);
// println!("{:?}", z);
// println!("{}", x_1)
// }
fn add_one(x: Fs) -> Fs {
match x {
Fs::F64(v) => v + 1.0_f64,
Fs::F32(v) => v + 1.0_f32
}
}

fn main() {
// let x_1 = NaiveDate::from_ymd_opt(2003, 12, 2).unwrap();

let x = 32.5;
let y = add_one(Fs(x));
println!("{:?}", y)
}
28 changes: 16 additions & 12 deletions src/dual/linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,31 @@ enum Pivoting {
OnUnderlying,
}

fn argabsmax(a: ArrayView1<DualOrF64>) -> usize {
let a: (usize, DualOrF64) = a.iter().enumerate().fold((0, DualOrF64::F64(0.0)), |acc, (i, elem)| {
fn argabsmax(a: ArrayView1<i32>) -> usize {
let a: (usize, i32) = a.iter().enumerate().fold((0, 0), |acc, (i, elem)| {
if elem.abs() > acc.1 { (i, elem.clone()) } else { acc }
});
a.0
}

pub fn pivot_matrix(A: &Array2<DualOrF64>) -> (Array2<i32>, Array2<DualOrF64>) {
pub fn pivot_matrix(A: &Array2<T>) -> (Array2<i32>, Array2<T>) {
// pivot square matrix
let n = A.len_of(Axis(0));
let mut P: Array2<i32> = Array::eye(n);
let PA = A.clone();
let O = A.clone();
let mut Pa = A.to_owned(); // initialise PA and Original (or)
// let Or = A.to_owned();
for j in 0..n {
let i = argabsmax(O.slice(s![j.., j]));
if j != i {
// define row swaps i <-> j
let mut row_j = P.slice(s![j, ..]).clone();
let mut row_i = P.slice_mut(s![i, ..]);
Zip::from(row_i).and(row_j).apply(|x: &mut i32, y: &i32| std::mem::swap(x, &mut *y));
let k = argabsmax(Pa.slice(s![j.., j])) + j;
if j != k {
// define row swaps j <-> k (note that k > j by definition)
let (mut Pt, mut Pb) = P.slice_mut(s![.., ..]).split_at(Axis(0), k);
let (r1, r2) = (Pt.row_mut(j), Pb.row_mut(0));
Zip::from(r1).and(r2).apply(std::mem::swap);

let (mut Pt, mut Pb) = Pa.slice_mut(s![.., ..]).split_at(Axis(0), k);
let (r1, r2) = (Pt.row_mut(j), Pb.row_mut(0));
Zip::from(r1).and(r2).apply(std::mem::swap);
}
}
(P, PA)
(P, Pa)
}
53 changes: 35 additions & 18 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
pub mod dual;
use dual::dual1::{Dual, DualOrF64};
use dual::linalg::pivot_matrix;
use ndarray::{Array1, Array};
use ndarray::{Array2, arr2, s};
use ndarray::{Array1, Array, Dimension};
use ndarray::{Array2, arr2, s, Zip, Axis};

fn main() {
// let d1 = Dual::new(
Expand All @@ -25,21 +25,38 @@ fn main() {
//
// let elapsed = now.elapsed();
// println!("Elapsed: {:.2?}", elapsed / 100000);
let A: Array2<DualOrF64> = arr2(&[
[DualOrF64::F64(1.),DualOrF64::Dual(Dual::new(2.0, vec![], vec![]))],
[DualOrF64::F64(4.),DualOrF64::Dual(Dual::new(5.0, vec![], vec![]))],
]);
let a: (usize, DualOrF64) = A.slice(s![.., 0]).iter().enumerate().fold((0, DualOrF64::F64(0.0)), |acc, (i, elem)| {
if elem.abs() > acc.1 { (i, elem.clone()) } else { acc }
});
// let a = [1, 2, 3, 4, 5];
// let b = a.into_iter().enumerate().fold((0, 0), |s, (i, j)| (s.0 + i, s.1 + i * j));
// println!("{:?}", b); // Prints 40

let (x, y) = pivot_matrix(&A);

println!("{:?}", A);
println!("{:?}", A.slice(s![.., 0]));
println!("{:?}", a);


// let A: Array2<DualOrF64> = arr2(&[
// [DualOrF64::F64(1.),DualOrF64::Dual(Dual::new(2.0, vec![], vec![]))],
// [DualOrF64::F64(4.),DualOrF64::Dual(Dual::new(5.0, vec![], vec![]))],
// ]);
// let a: (usize, DualOrF64) = A.slice(s![.., 0]).iter().enumerate().fold((0, DualOrF64::F64(0.0)), |acc, (i, elem)| {
// if elem.abs() > acc.1 { (i, elem.clone()) } else { acc }
// });
// // let a = [1, 2, 3, 4, 5];
// // let b = a.into_iter().enumerate().fold((0, 0), |s, (i, j)| (s.0 + i, s.1 + i * j));
// // println!("{:?}", b); // Prints 40
//
// let (x, y) = pivot_matrix(&A);
//
// println!("{:?}", A);
// println!("{:?}", A.slice(s![.., 0]));
// println!("{:?}", a);

// let mut P: Array2<i32> = Array2::eye(3);
// let (mut Pt, mut Pb) = P.slice_mut(s![.., ..]).split_at(Axis(0), 1);
// let (r1, r2) = (Pt.row_mut(0), Pb.row_mut(0));
// Zip::from(r1).and(r2).apply(std::mem::swap);


let P: Array2<i32> = arr2(
&[[1, 2, 3, 4],
[10, 2, 5, 6],
[7, 8, 1, 1],
[2, 2, 2, 9]]
);
let (A, B) = pivot_matrix(&P);

println!("{:?}", B);
}

0 comments on commit 431005d

Please sign in to comment.