From 9d8eeaacda87f692f912ed9c0701653442aad835 Mon Sep 17 00:00:00 2001 From: Maksym Arutyunyan <107496615+maksym-arutyunyan@users.noreply.github.com> Date: Fri, 21 Oct 2022 16:23:50 +0200 Subject: [PATCH] add custom errors (#8) --- Cargo.toml | 2 +- src/complexity.rs | 17 +++++++---------- src/error.rs | 18 ++++++++++++++++++ src/lib.rs | 6 +++--- src/linalg.rs | 7 +++++-- src/name.rs | 10 +++++----- tests/api.rs | 14 ++++++++++++++ tests/execution_time.rs | 4 ++-- 8 files changed, 55 insertions(+), 23 deletions(-) create mode 100644 src/error.rs diff --git a/Cargo.toml b/Cargo.toml index 61df850..1f4449d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "big_o" description = "Infers asymptotic computational complexity" -version = "0.1.3" +version = "0.1.4" edition = "2021" authors = ["Maksym Arutyunyan"] repository = "https://github.com/maksym-arutyunyan/big_o" diff --git a/src/complexity.rs b/src/complexity.rs index 9e07999..16bc788 100644 --- a/src/complexity.rs +++ b/src/complexity.rs @@ -1,3 +1,4 @@ +use crate::error::Error; use crate::linalg; use crate::name; use crate::name::Name; @@ -20,7 +21,7 @@ pub struct Complexity { } /// Returns a calculated approximation function `f(x)` -fn get_function(name: Name, params: Params) -> Result f64>, &'static str> { +fn get_function(name: Name, params: Params) -> Result f64>, Error> { if let (Some(a), Some(b)) = match name { Name::Polynomial => (params.gain, params.power), Name::Exponential => (params.gain, params.base), @@ -38,13 +39,13 @@ fn get_function(name: Name, params: Params) -> Result f64>, & }; Ok(f) } else { - Err("No cofficients to compute f(x)") + Err(Error::MissingFunctionCoeffsError) } } /// Computes values of `f(x)` given `x` #[allow(dead_code)] -fn compute_f(name: Name, params: Params, x: Vec) -> Result, &'static str> { +fn compute_f(name: Name, params: Params, x: Vec) -> Result, Error> { let f = get_function(name, params)?; let y = x.into_iter().map(f).collect(); Ok(y) @@ -111,11 +112,7 @@ fn delinearize(name: Name, gain: f64, offset: f64) -> Params { } } -fn calculate_residuals( - name: Name, - params: Params, - data: Vec<(f64, f64)>, -) -> Result { +fn calculate_residuals(name: Name, params: Params, data: Vec<(f64, f64)>) -> Result { let f = get_function(name, params)?; let residuals = data.into_iter().map(|(x, y)| (y - f(x)).abs()).sum(); @@ -150,7 +147,7 @@ fn rank(name: Name, params: Params) -> u32 { } /// Fits a function of given complexity into input data. -pub fn fit(name: Name, data: Vec<(f64, f64)>) -> Result { +pub fn fit(name: Name, data: Vec<(f64, f64)>) -> Result { let linearized = data .clone() .into_iter() @@ -189,7 +186,7 @@ pub fn fit(name: Name, data: Vec<(f64, f64)>) -> Result Result { +pub fn complexity(string: &str) -> Result { let name: Name = string.try_into()?; Ok(crate::complexity::ComplexityBuilder::new(name).build()) } diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..dbb5aa1 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,18 @@ +use std::fmt; + +#[derive(Debug)] +pub enum Error { + LSTSQError(String), + ParseNotationError, + MissingFunctionCoeffsError, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::LSTSQError(msg) => write!(f, "LSTQS failed: {msg}"), + Error::ParseNotationError => write!(f, "Can't convert string to Name"), + Error::MissingFunctionCoeffsError => write!(f, "No cofficients to compute f(x)"), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index bff3cc4..1f68ece 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,7 @@ //! ``` mod complexity; +mod error; mod linalg; mod name; mod params; @@ -27,6 +28,7 @@ mod validate; pub use crate::complexity::complexity; pub use crate::complexity::Complexity; +pub use crate::error::Error; pub use crate::name::Name; pub use crate::params::Params; @@ -46,9 +48,7 @@ pub use crate::params::Params; /// assert_approx_eq!(complexity.params.offset.unwrap(), 0.0, 1e-6); /// assert!(complexity.rank < big_o::complexity("O(n^3)").unwrap().rank); /// ``` -pub fn infer_complexity( - data: Vec<(f64, f64)>, -) -> Result<(Complexity, Vec), &'static str> { +pub fn infer_complexity(data: Vec<(f64, f64)>) -> Result<(Complexity, Vec), Error> { let mut all_fitted: Vec = Vec::new(); for name in name::all_names() { let complexity = complexity::fit(name, data.clone())?; diff --git a/src/linalg.rs b/src/linalg.rs index 2526a1b..61b67fd 100644 --- a/src/linalg.rs +++ b/src/linalg.rs @@ -1,7 +1,9 @@ +use crate::error::Error; + /// Fits a line `f(x) = gain * x + offset` into input `data` points. /// /// Returns linear coeffs `gain`, `offset` and `residuals`. -pub fn fit_line(data: Vec<(f64, f64)>) -> Result<(f64, f64, f64), &'static str> { +pub fn fit_line(data: Vec<(f64, f64)>) -> Result<(f64, f64, f64), Error> { use nalgebra::{Dynamic, OMatrix, OVector, U2}; let (xs, ys): (Vec, Vec) = data.iter().cloned().unzip(); @@ -14,7 +16,8 @@ pub fn fit_line(data: Vec<(f64, f64)>) -> Result<(f64, f64, f64), &'static str> let b = OVector::::from_row_slice(&ys); let epsilon = 1e-10; - let results = lstsq::lstsq(&a, &b, epsilon)?; + let results = + lstsq::lstsq(&a, &b, epsilon).map_err(|msg| Error::LSTSQError(msg.to_string()))?; let gain = results.solution[0]; let offset = results.solution[1]; diff --git a/src/name.rs b/src/name.rs index 9db6aee..250380c 100644 --- a/src/name.rs +++ b/src/name.rs @@ -1,3 +1,4 @@ +use crate::error::Error; use std::fmt; use std::str::FromStr; @@ -49,7 +50,7 @@ impl From for &str { } impl TryFrom<&str> for Name { - type Error = &'static str; + type Error = Error; fn try_from(string: &str) -> Result { match &string.to_lowercase()[..] { @@ -61,13 +62,13 @@ impl TryFrom<&str> for Name { "o(n^3)" | "cubic" => Ok(Name::Cubic), "o(n^m)" | "polynomial" => Ok(Name::Polynomial), "o(c^n)" | "exponential" => Ok(Name::Exponential), - _ => Err("Can't convert string to Name"), + _ => Err(Error::ParseNotationError), } } } impl FromStr for Name { - type Err = &'static str; + type Err = Error; fn from_str(s: &str) -> Result { s.try_into() @@ -80,9 +81,8 @@ impl fmt::Display for Name { } } - #[cfg(test)] -mod tests{ +mod tests { use super::*; const NOTATION_TEST_CASES: [(&str, Name); 8] = [ diff --git a/tests/api.rs b/tests/api.rs index b754f6e..66995f8 100644 --- a/tests/api.rs +++ b/tests/api.rs @@ -222,3 +222,17 @@ fn infer_exponential() { assert_approx_eq!(complexity.params.base.unwrap(), base, EPSILON); assert!(complexity.rank <= big_o::complexity("O(c^n)").unwrap().rank); } + +#[test] +#[should_panic] +fn empty_input_failure() { + let data: Vec<(f64, f64)> = vec![]; + let (_complexity, _all) = big_o::infer_complexity(data).unwrap(); +} + +#[test] +#[should_panic] +fn zero_input_failure() { + let data: Vec<(f64, f64)> = vec![(0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]; + let (_complexity, _all) = big_o::infer_complexity(data).unwrap(); +} diff --git a/tests/execution_time.rs b/tests/execution_time.rs index a3e65ae..c4a64f8 100644 --- a/tests/execution_time.rs +++ b/tests/execution_time.rs @@ -1,6 +1,6 @@ +use std::fmt::Write as _; use std::time::Instant; -use std::{thread, time}; -use std::fmt::Write as _; // import without risk of name clashing +use std::{thread, time}; // import without risk of name clashing #[allow(dead_code)] fn write_csv(data: &Vec<(f64, f64)>, path: &str) {