Improved error tests
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
use std::ffi::NulError;
|
use std::ffi::NulError;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, PartialEq)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
Io,
|
Io,
|
||||||
Parameter(String),
|
Parameter(String),
|
||||||
|
|||||||
22
src/model.rs
22
src/model.rs
@@ -2,6 +2,7 @@ use crate::bindings::*;
|
|||||||
use crate::{Error, Matrix, Params};
|
use crate::{Error, Matrix, Params};
|
||||||
use std::ffi::CString;
|
use std::ffi::CString;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct Model {
|
pub struct Model {
|
||||||
pub(crate) model: *mut MfModel,
|
pub(crate) model: *mut MfModel,
|
||||||
}
|
}
|
||||||
@@ -102,7 +103,7 @@ impl Drop for Model {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::{Loss, Matrix, Model};
|
use crate::{Error, Loss, Matrix, Model};
|
||||||
|
|
||||||
fn generate_data() -> Matrix {
|
fn generate_data() -> Matrix {
|
||||||
let mut data = Matrix::new();
|
let mut data = Matrix::new();
|
||||||
@@ -168,13 +169,13 @@ mod tests {
|
|||||||
let data = generate_data();
|
let data = generate_data();
|
||||||
let model = Model::params().quiet(true).fit(&data).unwrap();
|
let model = Model::params().quiet(true).fit(&data).unwrap();
|
||||||
let result = model.save("/tmp/missing/model.txt");
|
let result = model.save("/tmp/missing/model.txt");
|
||||||
assert!(result.is_err());
|
assert_eq!(result.unwrap_err(), Error::Io);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_load_missing() {
|
fn test_load_missing() {
|
||||||
let result = Model::load("/tmp/missing.txt");
|
let result = Model::load("/tmp/missing.txt");
|
||||||
assert!(result.is_err());
|
assert_eq!(result.unwrap_err(), Error::Io);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -220,20 +221,29 @@ mod tests {
|
|||||||
fn test_fit_bad_params() {
|
fn test_fit_bad_params() {
|
||||||
let data = generate_data();
|
let data = generate_data();
|
||||||
let result = Model::params().factors(0).fit(&data);
|
let result = Model::params().factors(0).fit(&data);
|
||||||
assert!(result.is_err());
|
assert_eq!(
|
||||||
|
result.unwrap_err(),
|
||||||
|
Error::Parameter("number of factors must be greater than zero".to_string())
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_fit_eval_bad_params() {
|
fn test_fit_eval_bad_params() {
|
||||||
let data = generate_data();
|
let data = generate_data();
|
||||||
let result = Model::params().factors(0).fit_eval(&data, &data);
|
let result = Model::params().factors(0).fit_eval(&data, &data);
|
||||||
assert!(result.is_err());
|
assert_eq!(
|
||||||
|
result.unwrap_err(),
|
||||||
|
Error::Parameter("number of factors must be greater than zero".to_string())
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_cv_bad_params() {
|
fn test_cv_bad_params() {
|
||||||
let data = generate_data();
|
let data = generate_data();
|
||||||
let result = Model::params().factors(0).cv(&data, 5);
|
let result = Model::params().factors(0).cv(&data, 5);
|
||||||
assert!(result.is_err());
|
assert_eq!(
|
||||||
|
result.unwrap_err(),
|
||||||
|
Error::Parameter("number of factors must be greater than zero".to_string())
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user