Changed fit and fit_eval to return Result
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
## 0.2.0 (unreleased)
|
||||
|
||||
- Changed pattern for fitting models - use `Model::params()` instead of `Model::new()`
|
||||
- Changed `Model::load` to return `Result`
|
||||
- Changed `fit`, `fit_eval`, and `load` to return `Result`
|
||||
|
||||
## 0.1.1 (2021-07-27)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ data.push(1, 1, 4.0);
|
||||
Fit a model
|
||||
|
||||
```rust
|
||||
let model = libmf::Model::params().fit(&data);
|
||||
let model = libmf::Model::params().fit(&data).unwrap();
|
||||
```
|
||||
|
||||
Make predictions
|
||||
@@ -63,7 +63,7 @@ let model = libmf::Model::load("model.txt").unwrap();
|
||||
Pass a validation set
|
||||
|
||||
```rust
|
||||
let model = libmf::Model::params().fit_eval(&train_set, &eval_set);
|
||||
let model = libmf::Model::params().fit_eval(&train_set, &eval_set).unwrap();
|
||||
```
|
||||
|
||||
## Cross-Validation
|
||||
@@ -179,7 +179,7 @@ let mut data = libmf::Matrix::with_capacity(3);
|
||||
Use
|
||||
|
||||
```rust
|
||||
let model = libmf::Model::params().factors(20).fit(&data);
|
||||
let model = libmf::Model::params().factors(20).fit(&data).unwrap();
|
||||
```
|
||||
|
||||
instead of
|
||||
|
||||
2
src/error.rs
Normal file
2
src/error.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
#[derive(Debug)]
|
||||
pub struct Error(pub(crate) String);
|
||||
@@ -3,10 +3,12 @@
|
||||
//! [View the docs](https://github.com/ankane/libmf-rust)
|
||||
|
||||
mod bindings;
|
||||
mod error;
|
||||
mod matrix;
|
||||
mod model;
|
||||
mod params;
|
||||
|
||||
pub use error::Error;
|
||||
pub use matrix::Matrix;
|
||||
pub use model::Model;
|
||||
pub use params::Params;
|
||||
|
||||
30
src/model.rs
30
src/model.rs
@@ -1,10 +1,7 @@
|
||||
use crate::bindings::*;
|
||||
use crate::{Matrix, Params};
|
||||
use crate::{Error, Matrix, Params};
|
||||
use std::ffi::CString;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Error(String);
|
||||
|
||||
pub struct Model {
|
||||
pub(crate) model: *mut MfModel,
|
||||
}
|
||||
@@ -20,9 +17,7 @@ impl Model {
|
||||
if model.is_null() {
|
||||
Err(Error("Cannot open model".to_string()))
|
||||
} else {
|
||||
Ok(Model {
|
||||
model: model
|
||||
})
|
||||
Ok(Model { model })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,7 +112,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_fit() {
|
||||
let data = generate_data();
|
||||
let model = Model::params().quiet(true).fit(&data);
|
||||
let model = Model::params().quiet(true).fit(&data).unwrap();
|
||||
model.predict(0, 1);
|
||||
|
||||
model.p_factors();
|
||||
@@ -128,7 +123,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_fit_eval() {
|
||||
let data = generate_data();
|
||||
Model::params().quiet(true).fit_eval(&data, &data);
|
||||
Model::params().quiet(true).fit_eval(&data, &data).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -140,7 +135,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_save_load() {
|
||||
let data = generate_data();
|
||||
let model = Model::params().quiet(true).fit(&data);
|
||||
let model = Model::params().quiet(true).fit(&data).unwrap();
|
||||
|
||||
model.save("/tmp/model.txt");
|
||||
let model = Model::load("/tmp/model.txt").unwrap();
|
||||
@@ -159,7 +154,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_metrics() {
|
||||
let data = generate_data();
|
||||
let model = Model::params().quiet(true).fit(&data);
|
||||
let model = Model::params().quiet(true).fit(&data).unwrap();
|
||||
|
||||
assert!(model.rmse(&data) < 0.15);
|
||||
assert!(model.mae(&data) < 0.15);
|
||||
@@ -173,14 +168,14 @@ mod tests {
|
||||
#[test]
|
||||
fn test_predict_out_of_range() {
|
||||
let data = generate_data();
|
||||
let model = Model::params().quiet(true).fit(&data);
|
||||
let model = Model::params().quiet(true).fit(&data).unwrap();
|
||||
assert_eq!(model.bias(), model.predict(1000, 1000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fit_empty() {
|
||||
let data = Matrix::new();
|
||||
let model = Model::params().quiet(true).fit(&data);
|
||||
let model = Model::params().quiet(true).fit(&data).unwrap();
|
||||
assert!(model.p_factors().is_empty());
|
||||
assert!(model.q_factors().is_empty());
|
||||
assert!(model.bias().is_nan());
|
||||
@@ -189,9 +184,16 @@ mod tests {
|
||||
#[test]
|
||||
fn test_fit_eval_empty() {
|
||||
let data = Matrix::new();
|
||||
let model = Model::params().quiet(true).fit_eval(&data, &data);
|
||||
let model = Model::params().quiet(true).fit_eval(&data, &data).unwrap();
|
||||
assert!(model.p_factors().is_empty());
|
||||
assert!(model.q_factors().is_empty());
|
||||
assert!(model.bias().is_nan());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bad_loss() {
|
||||
let data = generate_data();
|
||||
let result = Model::params().loss(13).fit(&data);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::bindings::*;
|
||||
use crate::{Matrix, Model};
|
||||
use crate::{Error, Matrix, Model};
|
||||
|
||||
pub struct Params {
|
||||
param: MfParameter
|
||||
@@ -85,18 +85,24 @@ impl Params {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn fit(&mut self, data: &Matrix) -> Model {
|
||||
pub fn fit(&mut self, data: &Matrix) -> Result<Model, Error> {
|
||||
let prob = data.to_problem();
|
||||
Model {
|
||||
model: unsafe { mf_train(&prob, self.param) }
|
||||
let model = unsafe { mf_train(&prob, self.param) };
|
||||
if model.is_null() {
|
||||
Err(Error("Bad parameters".to_string()))
|
||||
} else {
|
||||
Ok(Model { model })
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fit_eval(&mut self, train_set: &Matrix, eval_set: &Matrix) -> Model {
|
||||
pub fn fit_eval(&mut self, train_set: &Matrix, eval_set: &Matrix) -> Result<Model, Error> {
|
||||
let tr = train_set.to_problem();
|
||||
let va = eval_set.to_problem();
|
||||
Model {
|
||||
model: unsafe { mf_train_with_validation(&tr, &va, self.param) }
|
||||
let model = unsafe { mf_train_with_validation(&tr, &va, self.param) };
|
||||
if model.is_null() {
|
||||
Err(Error("Bad parameters".to_string()))
|
||||
} else {
|
||||
Ok(Model { model })
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user