diff --git a/CHANGELOG.md b/CHANGELOG.md index 1cb2e35..7420687 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ ## 0.1.1 (unreleased) +- Added disk-level training - Added more metrics ## 0.1.0 (2021-07-26) diff --git a/README.md b/README.md index 7df5124..38d77cb 100644 --- a/README.md +++ b/README.md @@ -160,6 +160,24 @@ Calculate AUC (for one-class MF) model.auc(&data, transpose); ``` +## Disk-Level Training + +Train directly from files + +```ruby +model.fit_disk("train.txt") +model.fit_eval_disk("train.txt", "validate.txt") +model.cv_disk("train.txt") +``` + +Data should be in the format `row_index column_index value`: + +```txt +0 0 5.0 +0 2 3.5 +1 1 4.0 +``` + ## Reference Specify the initial capacity for a matrix diff --git a/src/bindings.rs b/src/bindings.rs index a5df017..438a6ec 100644 --- a/src/bindings.rs +++ b/src/bindings.rs @@ -55,8 +55,11 @@ extern "C" { pub fn mf_save_model(model: *const MfModel, path: *const c_char) -> c_int; pub fn mf_load_model(path: *const c_char) -> *mut MfModel; pub fn mf_train(prob: *const MfProblem, param: MfParameter) -> *mut MfModel; + pub fn mf_train_on_disk(tr_path: *const c_char, param: MfParameter) -> *mut MfModel; pub fn mf_train_with_validation(tr: *const MfProblem, va: *const MfProblem, param: MfParameter) -> *mut MfModel; + pub fn mf_train_with_validation_on_disk(tr_path: *const c_char, va_path: *const c_char, param: MfParameter) -> *mut MfModel; pub fn mf_cross_validation(prob: *const MfProblem, nr_folds: c_int, param: MfParameter) -> c_double; + pub fn mf_cross_validation_on_disk(prob: *const c_char, nr_folds: c_int, param: MfParameter) -> c_double; pub fn mf_predict(model: *const MfModel, u: c_int, v: c_int) -> c_float; pub fn calc_rmse(prob: *const MfProblem, model: *const MfModel) -> c_double; pub fn calc_mae(prob: *const MfProblem, model: *const MfModel) -> c_double; diff --git a/src/model.rs b/src/model.rs index 495e7d4..3fe9190 100644 --- a/src/model.rs +++ b/src/model.rs @@ -35,17 +35,33 @@ impl Model { self.model = unsafe { mf_train(&prob, self.param()) }; } + pub fn fit_disk(&mut self, path: &str) { + let cpath = CString::new(path).expect("CString::new failed"); + self.model = unsafe { mf_train_on_disk(cpath.as_ptr(), self.param()) }; + } + pub fn fit_eval(&mut self, train_set: &Matrix, eval_set: &Matrix) { let tr = train_set.to_problem(); let va = eval_set.to_problem(); self.model = unsafe { mf_train_with_validation(&tr, &va, self.param()) }; } + pub fn fit_eval_disk(&mut self, train_path: &str, eval_path: &str) { + let trpath = CString::new(train_path).expect("CString::new failed"); + let vapath = CString::new(eval_path).expect("CString::new failed"); + self.model = unsafe { mf_train_with_validation_on_disk(trpath.as_ptr(), vapath.as_ptr(), self.param()) }; + } + pub fn cv(&mut self, data: &Matrix, folds: i32) { let prob = data.to_problem(); unsafe { mf_cross_validation(&prob, folds, self.param()); } } + pub fn cv_disk(&mut self, path: &str, folds: i32) { + let cpath = CString::new(path).expect("CString::new failed"); + unsafe { mf_cross_validation_on_disk(cpath.as_ptr(), folds, self.param()); } + } + pub fn predict(&self, row_index: i32, column_index: i32) -> f32 { assert!(self.is_fit()); unsafe { mf_predict(self.model, row_index, column_index) } @@ -196,6 +212,9 @@ impl Model { mod tests { use crate::{Matrix, Model}; + const TRAIN_PATH: &str = "vendor/libmf/demo/real_matrix.tr.txt"; + const EVAL_PATH: &str = "vendor/libmf/demo/real_matrix.te.txt"; + fn generate_data() -> Matrix { let mut data = Matrix::new(); data.push(0, 0, 1.0); @@ -217,6 +236,13 @@ mod tests { model.bias(); } + #[test] + fn test_fit_disk() { + let mut model = Model::new(); + model.quiet = true; + model.fit_disk(TRAIN_PATH); + } + #[test] fn test_fit_eval() { let data = generate_data(); @@ -225,6 +251,13 @@ mod tests { model.fit_eval(&data, &data); } + #[test] + fn test_fit_eval_disk() { + let mut model = Model::new(); + model.quiet = true; + model.fit_eval_disk(TRAIN_PATH, EVAL_PATH); + } + #[test] fn test_cv() { let data = generate_data(); @@ -233,6 +266,13 @@ mod tests { model.cv(&data, 5); } + #[test] + fn test_cv_disk() { + let mut model = Model::new(); + model.quiet = true; + model.cv_disk(TRAIN_PATH, 5); + } + #[test] fn test_save_load() { let data = generate_data();