Skip to content

Commit

Permalink
add custom errors (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
maksym-arutyunyan authored Oct 21, 2022
1 parent 5fb288d commit 9d8eeaa
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "big_o"
description = "Infers asymptotic computational complexity"
version = "0.1.3"
version = "0.1.4"
edition = "2021"
authors = ["Maksym Arutyunyan"]
repository = "https://github.com/maksym-arutyunyan/big_o"
Expand Down
17 changes: 7 additions & 10 deletions src/complexity.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::error::Error;
use crate::linalg;
use crate::name;
use crate::name::Name;
Expand All @@ -20,7 +21,7 @@ pub struct Complexity {
}

/// Returns a calculated approximation function `f(x)`
fn get_function(name: Name, params: Params) -> Result<Box<dyn Fn(f64) -> f64>, &'static str> {
fn get_function(name: Name, params: Params) -> Result<Box<dyn Fn(f64) -> f64>, Error> {
if let (Some(a), Some(b)) = match name {
Name::Polynomial => (params.gain, params.power),
Name::Exponential => (params.gain, params.base),
Expand All @@ -38,13 +39,13 @@ fn get_function(name: Name, params: Params) -> Result<Box<dyn Fn(f64) -> f64>, &
};
Ok(f)
} else {
Err("No cofficients to compute f(x)")
Err(Error::MissingFunctionCoeffsError)
}
}

/// Computes values of `f(x)` given `x`
#[allow(dead_code)]
fn compute_f(name: Name, params: Params, x: Vec<f64>) -> Result<Vec<f64>, &'static str> {
fn compute_f(name: Name, params: Params, x: Vec<f64>) -> Result<Vec<f64>, Error> {
let f = get_function(name, params)?;
let y = x.into_iter().map(f).collect();
Ok(y)
Expand Down Expand Up @@ -111,11 +112,7 @@ fn delinearize(name: Name, gain: f64, offset: f64) -> Params {
}
}

fn calculate_residuals(
name: Name,
params: Params,
data: Vec<(f64, f64)>,
) -> Result<f64, &'static str> {
fn calculate_residuals(name: Name, params: Params, data: Vec<(f64, f64)>) -> Result<f64, Error> {
let f = get_function(name, params)?;
let residuals = data.into_iter().map(|(x, y)| (y - f(x)).abs()).sum();

Expand Down Expand Up @@ -150,7 +147,7 @@ fn rank(name: Name, params: Params) -> u32 {
}

/// Fits a function of given complexity into input data.
pub fn fit(name: Name, data: Vec<(f64, f64)>) -> Result<Complexity, &'static str> {
pub fn fit(name: Name, data: Vec<(f64, f64)>) -> Result<Complexity, Error> {
let linearized = data
.clone()
.into_iter()
Expand Down Expand Up @@ -189,7 +186,7 @@ pub fn fit(name: Name, data: Vec<(f64, f64)>) -> Result<Complexity, &'static str
///
/// assert!(linear.rank < cubic.rank);
/// ```
pub fn complexity(string: &str) -> Result<Complexity, &'static str> {
pub fn complexity(string: &str) -> Result<Complexity, Error> {
let name: Name = string.try_into()?;
Ok(crate::complexity::ComplexityBuilder::new(name).build())
}
Expand Down
18 changes: 18 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use std::fmt;

#[derive(Debug)]
pub enum Error {
LSTSQError(String),
ParseNotationError,
MissingFunctionCoeffsError,
}

impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::LSTSQError(msg) => write!(f, "LSTQS failed: {msg}"),
Error::ParseNotationError => write!(f, "Can't convert string to Name"),
Error::MissingFunctionCoeffsError => write!(f, "No cofficients to compute f(x)"),
}
}
}
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
//! ```
mod complexity;
mod error;
mod linalg;
mod name;
mod params;
mod validate;

pub use crate::complexity::complexity;
pub use crate::complexity::Complexity;
pub use crate::error::Error;
pub use crate::name::Name;
pub use crate::params::Params;

Expand All @@ -46,9 +48,7 @@ pub use crate::params::Params;
/// assert_approx_eq!(complexity.params.offset.unwrap(), 0.0, 1e-6);
/// assert!(complexity.rank < big_o::complexity("O(n^3)").unwrap().rank);
/// ```
pub fn infer_complexity(
data: Vec<(f64, f64)>,
) -> Result<(Complexity, Vec<Complexity>), &'static str> {
pub fn infer_complexity(data: Vec<(f64, f64)>) -> Result<(Complexity, Vec<Complexity>), Error> {
let mut all_fitted: Vec<Complexity> = Vec::new();
for name in name::all_names() {
let complexity = complexity::fit(name, data.clone())?;
Expand Down
7 changes: 5 additions & 2 deletions src/linalg.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::error::Error;

/// Fits a line `f(x) = gain * x + offset` into input `data` points.
///
/// Returns linear coeffs `gain`, `offset` and `residuals`.
pub fn fit_line(data: Vec<(f64, f64)>) -> Result<(f64, f64, f64), &'static str> {
pub fn fit_line(data: Vec<(f64, f64)>) -> Result<(f64, f64, f64), Error> {
use nalgebra::{Dynamic, OMatrix, OVector, U2};

let (xs, ys): (Vec<f64>, Vec<f64>) = data.iter().cloned().unzip();
Expand All @@ -14,7 +16,8 @@ pub fn fit_line(data: Vec<(f64, f64)>) -> Result<(f64, f64, f64), &'static str>
let b = OVector::<f64, Dynamic>::from_row_slice(&ys);

let epsilon = 1e-10;
let results = lstsq::lstsq(&a, &b, epsilon)?;
let results =
lstsq::lstsq(&a, &b, epsilon).map_err(|msg| Error::LSTSQError(msg.to_string()))?;

let gain = results.solution[0];
let offset = results.solution[1];
Expand Down
10 changes: 5 additions & 5 deletions src/name.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::error::Error;
use std::fmt;
use std::str::FromStr;

Expand Down Expand Up @@ -49,7 +50,7 @@ impl From<Name> for &str {
}

impl TryFrom<&str> for Name {
type Error = &'static str;
type Error = Error;

fn try_from(string: &str) -> Result<Self, Self::Error> {
match &string.to_lowercase()[..] {
Expand All @@ -61,13 +62,13 @@ impl TryFrom<&str> for Name {
"o(n^3)" | "cubic" => Ok(Name::Cubic),
"o(n^m)" | "polynomial" => Ok(Name::Polynomial),
"o(c^n)" | "exponential" => Ok(Name::Exponential),
_ => Err("Can't convert string to Name"),
_ => Err(Error::ParseNotationError),
}
}
}

impl FromStr for Name {
type Err = &'static str;
type Err = Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
s.try_into()
Expand All @@ -80,9 +81,8 @@ impl fmt::Display for Name {
}
}


#[cfg(test)]
mod tests{
mod tests {
use super::*;

const NOTATION_TEST_CASES: [(&str, Name); 8] = [
Expand Down
14 changes: 14 additions & 0 deletions tests/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,17 @@ fn infer_exponential() {
assert_approx_eq!(complexity.params.base.unwrap(), base, EPSILON);
assert!(complexity.rank <= big_o::complexity("O(c^n)").unwrap().rank);
}

#[test]
#[should_panic]
fn empty_input_failure() {
let data: Vec<(f64, f64)> = vec![];
let (_complexity, _all) = big_o::infer_complexity(data).unwrap();
}

#[test]
#[should_panic]
fn zero_input_failure() {
let data: Vec<(f64, f64)> = vec![(0.0, 0.0), (0.0, 0.0), (0.0, 0.0)];
let (_complexity, _all) = big_o::infer_complexity(data).unwrap();
}
4 changes: 2 additions & 2 deletions tests/execution_time.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fmt::Write as _;
use std::time::Instant;
use std::{thread, time};
use std::fmt::Write as _; // import without risk of name clashing
use std::{thread, time}; // import without risk of name clashing

#[allow(dead_code)]
fn write_csv(data: &Vec<(f64, f64)>, path: &str) {
Expand Down

0 comments on commit 9d8eeaa

Please sign in to comment.