From a0c0714c72cb26d5ad75fcdb08e9390061148615 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sun, 17 Oct 2021 12:08:08 -0700 Subject: [PATCH] Improved error tests --- src/error.rs | 2 +- src/model.rs | 22 ++++++++++++++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/error.rs b/src/error.rs index 3315b0f..fafa464 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ use std::ffi::NulError; -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum Error { Io, Parameter(String), diff --git a/src/model.rs b/src/model.rs index 590debc..447008d 100644 --- a/src/model.rs +++ b/src/model.rs @@ -2,6 +2,7 @@ use crate::bindings::*; use crate::{Error, Matrix, Params}; use std::ffi::CString; +#[derive(Debug)] pub struct Model { pub(crate) model: *mut MfModel, } @@ -102,7 +103,7 @@ impl Drop for Model { #[cfg(test)] mod tests { - use crate::{Loss, Matrix, Model}; + use crate::{Error, Loss, Matrix, Model}; fn generate_data() -> Matrix { let mut data = Matrix::new(); @@ -168,13 +169,13 @@ mod tests { let data = generate_data(); let model = Model::params().quiet(true).fit(&data).unwrap(); let result = model.save("/tmp/missing/model.txt"); - assert!(result.is_err()); + assert_eq!(result.unwrap_err(), Error::Io); } #[test] fn test_load_missing() { let result = Model::load("/tmp/missing.txt"); - assert!(result.is_err()); + assert_eq!(result.unwrap_err(), Error::Io); } #[test] @@ -220,20 +221,29 @@ mod tests { fn test_fit_bad_params() { let data = generate_data(); let result = Model::params().factors(0).fit(&data); - assert!(result.is_err()); + assert_eq!( + result.unwrap_err(), + Error::Parameter("number of factors must be greater than zero".to_string()) + ); } #[test] fn test_fit_eval_bad_params() { let data = generate_data(); let result = Model::params().factors(0).fit_eval(&data, &data); - assert!(result.is_err()); + assert_eq!( + result.unwrap_err(), + Error::Parameter("number of factors must be greater than zero".to_string()) + ); } #[test] fn test_cv_bad_params() { let data = generate_data(); let result = Model::params().factors(0).cv(&data, 5); - assert!(result.is_err()); + assert_eq!( + result.unwrap_err(), + Error::Parameter("number of factors must be greater than zero".to_string()) + ); } }