diff --git a/src/error.rs b/src/error.rs index 3315b0f..fafa464 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ use std::ffi::NulError; -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum Error { Io, Parameter(String), diff --git a/src/model.rs b/src/model.rs index 590debc..447008d 100644 --- a/src/model.rs +++ b/src/model.rs @@ -2,6 +2,7 @@ use crate::bindings::*; use crate::{Error, Matrix, Params}; use std::ffi::CString; +#[derive(Debug)] pub struct Model { pub(crate) model: *mut MfModel, } @@ -102,7 +103,7 @@ impl Drop for Model { #[cfg(test)] mod tests { - use crate::{Loss, Matrix, Model}; + use crate::{Error, Loss, Matrix, Model}; fn generate_data() -> Matrix { let mut data = Matrix::new(); @@ -168,13 +169,13 @@ mod tests { let data = generate_data(); let model = Model::params().quiet(true).fit(&data).unwrap(); let result = model.save("/tmp/missing/model.txt"); - assert!(result.is_err()); + assert_eq!(result.unwrap_err(), Error::Io); } #[test] fn test_load_missing() { let result = Model::load("/tmp/missing.txt"); - assert!(result.is_err()); + assert_eq!(result.unwrap_err(), Error::Io); } #[test] @@ -220,20 +221,29 @@ mod tests { fn test_fit_bad_params() { let data = generate_data(); let result = Model::params().factors(0).fit(&data); - assert!(result.is_err()); + assert_eq!( + result.unwrap_err(), + Error::Parameter("number of factors must be greater than zero".to_string()) + ); } #[test] fn test_fit_eval_bad_params() { let data = generate_data(); let result = Model::params().factors(0).fit_eval(&data, &data); - assert!(result.is_err()); + assert_eq!( + result.unwrap_err(), + Error::Parameter("number of factors must be greater than zero".to_string()) + ); } #[test] fn test_cv_bad_params() { let data = generate_data(); let result = Model::params().factors(0).cv(&data, 5); - assert!(result.is_err()); + assert_eq!( + result.unwrap_err(), + Error::Parameter("number of factors must be greater than zero".to_string()) + ); } }