Changed fit and fit_eval to return Result

This commit is contained in:
Andrew Kane
2021-10-16 17:43:37 -07:00
parent af64e749cd
commit 26160cfb9a
6 changed files with 37 additions and 25 deletions

View File

@@ -1,7 +1,7 @@
## 0.2.0 (unreleased)
- Changed pattern for fitting models - use `Model::params()` instead of `Model::new()`
- Changed `Model::load` to return `Result`
- Changed `fit`, `fit_eval`, and `load` to return `Result`
## 0.1.1 (2021-07-27)

View File

@@ -26,7 +26,7 @@ data.push(1, 1, 4.0);
Fit a model
```rust
let model = libmf::Model::params().fit(&data);
let model = libmf::Model::params().fit(&data).unwrap();
```
Make predictions
@@ -63,7 +63,7 @@ let model = libmf::Model::load("model.txt").unwrap();
Pass a validation set
```rust
let model = libmf::Model::params().fit_eval(&train_set, &eval_set);
let model = libmf::Model::params().fit_eval(&train_set, &eval_set).unwrap();
```
## Cross-Validation
@@ -179,7 +179,7 @@ let mut data = libmf::Matrix::with_capacity(3);
Use
```rust
let model = libmf::Model::params().factors(20).fit(&data);
let model = libmf::Model::params().factors(20).fit(&data).unwrap();
```
instead of

2
src/error.rs Normal file
View File

@@ -0,0 +1,2 @@
#[derive(Debug)]
pub struct Error(pub(crate) String);

View File

@@ -3,10 +3,12 @@
//! [View the docs](https://github.com/ankane/libmf-rust)
mod bindings;
mod error;
mod matrix;
mod model;
mod params;
pub use error::Error;
pub use matrix::Matrix;
pub use model::Model;
pub use params::Params;

View File

@@ -1,10 +1,7 @@
use crate::bindings::*;
use crate::{Matrix, Params};
use crate::{Error, Matrix, Params};
use std::ffi::CString;
#[derive(Debug)]
pub struct Error(String);
pub struct Model {
pub(crate) model: *mut MfModel,
}
@@ -20,9 +17,7 @@ impl Model {
if model.is_null() {
Err(Error("Cannot open model".to_string()))
} else {
Ok(Model {
model: model
})
Ok(Model { model })
}
}
@@ -117,7 +112,7 @@ mod tests {
#[test]
fn test_fit() {
let data = generate_data();
let model = Model::params().quiet(true).fit(&data);
let model = Model::params().quiet(true).fit(&data).unwrap();
model.predict(0, 1);
model.p_factors();
@@ -128,7 +123,7 @@ mod tests {
#[test]
fn test_fit_eval() {
let data = generate_data();
Model::params().quiet(true).fit_eval(&data, &data);
Model::params().quiet(true).fit_eval(&data, &data).unwrap();
}
#[test]
@@ -140,7 +135,7 @@ mod tests {
#[test]
fn test_save_load() {
let data = generate_data();
let model = Model::params().quiet(true).fit(&data);
let model = Model::params().quiet(true).fit(&data).unwrap();
model.save("/tmp/model.txt");
let model = Model::load("/tmp/model.txt").unwrap();
@@ -159,7 +154,7 @@ mod tests {
#[test]
fn test_metrics() {
let data = generate_data();
let model = Model::params().quiet(true).fit(&data);
let model = Model::params().quiet(true).fit(&data).unwrap();
assert!(model.rmse(&data) < 0.15);
assert!(model.mae(&data) < 0.15);
@@ -173,14 +168,14 @@ mod tests {
#[test]
fn test_predict_out_of_range() {
let data = generate_data();
let model = Model::params().quiet(true).fit(&data);
let model = Model::params().quiet(true).fit(&data).unwrap();
assert_eq!(model.bias(), model.predict(1000, 1000));
}
#[test]
fn test_fit_empty() {
let data = Matrix::new();
let model = Model::params().quiet(true).fit(&data);
let model = Model::params().quiet(true).fit(&data).unwrap();
assert!(model.p_factors().is_empty());
assert!(model.q_factors().is_empty());
assert!(model.bias().is_nan());
@@ -189,9 +184,16 @@ mod tests {
#[test]
fn test_fit_eval_empty() {
let data = Matrix::new();
let model = Model::params().quiet(true).fit_eval(&data, &data);
let model = Model::params().quiet(true).fit_eval(&data, &data).unwrap();
assert!(model.p_factors().is_empty());
assert!(model.q_factors().is_empty());
assert!(model.bias().is_nan());
}
#[test]
fn test_bad_loss() {
let data = generate_data();
let result = Model::params().loss(13).fit(&data);
assert!(result.is_err());
}
}

View File

@@ -1,5 +1,5 @@
use crate::bindings::*;
use crate::{Matrix, Model};
use crate::{Error, Matrix, Model};
pub struct Params {
param: MfParameter
@@ -85,18 +85,24 @@ impl Params {
self
}
pub fn fit(&mut self, data: &Matrix) -> Model {
pub fn fit(&mut self, data: &Matrix) -> Result<Model, Error> {
let prob = data.to_problem();
Model {
model: unsafe { mf_train(&prob, self.param) }
let model = unsafe { mf_train(&prob, self.param) };
if model.is_null() {
Err(Error("Bad parameters".to_string()))
} else {
Ok(Model { model })
}
}
pub fn fit_eval(&mut self, train_set: &Matrix, eval_set: &Matrix) -> Model {
pub fn fit_eval(&mut self, train_set: &Matrix, eval_set: &Matrix) -> Result<Model, Error> {
let tr = train_set.to_problem();
let va = eval_set.to_problem();
Model {
model: unsafe { mf_train_with_validation(&tr, &va, self.param) }
let model = unsafe { mf_train_with_validation(&tr, &va, self.param) };
if model.is_null() {
Err(Error("Bad parameters".to_string()))
} else {
Ok(Model { model })
}
}