Added disk-level training
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
## 0.1.1 (unreleased)
|
||||
|
||||
- Added disk-level training
|
||||
- Added more metrics
|
||||
|
||||
## 0.1.0 (2021-07-26)
|
||||
|
||||
18
README.md
18
README.md
@@ -160,6 +160,24 @@ Calculate AUC (for one-class MF)
|
||||
model.auc(&data, transpose);
|
||||
```
|
||||
|
||||
## Disk-Level Training
|
||||
|
||||
Train directly from files
|
||||
|
||||
```ruby
|
||||
model.fit_disk("train.txt")
|
||||
model.fit_eval_disk("train.txt", "validate.txt")
|
||||
model.cv_disk("train.txt")
|
||||
```
|
||||
|
||||
Data should be in the format `row_index column_index value`:
|
||||
|
||||
```txt
|
||||
0 0 5.0
|
||||
0 2 3.5
|
||||
1 1 4.0
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
Specify the initial capacity for a matrix
|
||||
|
||||
@@ -55,8 +55,11 @@ extern "C" {
|
||||
pub fn mf_save_model(model: *const MfModel, path: *const c_char) -> c_int;
|
||||
pub fn mf_load_model(path: *const c_char) -> *mut MfModel;
|
||||
pub fn mf_train(prob: *const MfProblem, param: MfParameter) -> *mut MfModel;
|
||||
pub fn mf_train_on_disk(tr_path: *const c_char, param: MfParameter) -> *mut MfModel;
|
||||
pub fn mf_train_with_validation(tr: *const MfProblem, va: *const MfProblem, param: MfParameter) -> *mut MfModel;
|
||||
pub fn mf_train_with_validation_on_disk(tr_path: *const c_char, va_path: *const c_char, param: MfParameter) -> *mut MfModel;
|
||||
pub fn mf_cross_validation(prob: *const MfProblem, nr_folds: c_int, param: MfParameter) -> c_double;
|
||||
pub fn mf_cross_validation_on_disk(prob: *const c_char, nr_folds: c_int, param: MfParameter) -> c_double;
|
||||
pub fn mf_predict(model: *const MfModel, u: c_int, v: c_int) -> c_float;
|
||||
pub fn calc_rmse(prob: *const MfProblem, model: *const MfModel) -> c_double;
|
||||
pub fn calc_mae(prob: *const MfProblem, model: *const MfModel) -> c_double;
|
||||
|
||||
40
src/model.rs
40
src/model.rs
@@ -35,17 +35,33 @@ impl Model {
|
||||
self.model = unsafe { mf_train(&prob, self.param()) };
|
||||
}
|
||||
|
||||
pub fn fit_disk(&mut self, path: &str) {
|
||||
let cpath = CString::new(path).expect("CString::new failed");
|
||||
self.model = unsafe { mf_train_on_disk(cpath.as_ptr(), self.param()) };
|
||||
}
|
||||
|
||||
pub fn fit_eval(&mut self, train_set: &Matrix, eval_set: &Matrix) {
|
||||
let tr = train_set.to_problem();
|
||||
let va = eval_set.to_problem();
|
||||
self.model = unsafe { mf_train_with_validation(&tr, &va, self.param()) };
|
||||
}
|
||||
|
||||
pub fn fit_eval_disk(&mut self, train_path: &str, eval_path: &str) {
|
||||
let trpath = CString::new(train_path).expect("CString::new failed");
|
||||
let vapath = CString::new(eval_path).expect("CString::new failed");
|
||||
self.model = unsafe { mf_train_with_validation_on_disk(trpath.as_ptr(), vapath.as_ptr(), self.param()) };
|
||||
}
|
||||
|
||||
pub fn cv(&mut self, data: &Matrix, folds: i32) {
|
||||
let prob = data.to_problem();
|
||||
unsafe { mf_cross_validation(&prob, folds, self.param()); }
|
||||
}
|
||||
|
||||
pub fn cv_disk(&mut self, path: &str, folds: i32) {
|
||||
let cpath = CString::new(path).expect("CString::new failed");
|
||||
unsafe { mf_cross_validation_on_disk(cpath.as_ptr(), folds, self.param()); }
|
||||
}
|
||||
|
||||
pub fn predict(&self, row_index: i32, column_index: i32) -> f32 {
|
||||
assert!(self.is_fit());
|
||||
unsafe { mf_predict(self.model, row_index, column_index) }
|
||||
@@ -196,6 +212,9 @@ impl Model {
|
||||
mod tests {
|
||||
use crate::{Matrix, Model};
|
||||
|
||||
const TRAIN_PATH: &str = "vendor/libmf/demo/real_matrix.tr.txt";
|
||||
const EVAL_PATH: &str = "vendor/libmf/demo/real_matrix.te.txt";
|
||||
|
||||
fn generate_data() -> Matrix {
|
||||
let mut data = Matrix::new();
|
||||
data.push(0, 0, 1.0);
|
||||
@@ -217,6 +236,13 @@ mod tests {
|
||||
model.bias();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fit_disk() {
|
||||
let mut model = Model::new();
|
||||
model.quiet = true;
|
||||
model.fit_disk(TRAIN_PATH);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fit_eval() {
|
||||
let data = generate_data();
|
||||
@@ -225,6 +251,13 @@ mod tests {
|
||||
model.fit_eval(&data, &data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fit_eval_disk() {
|
||||
let mut model = Model::new();
|
||||
model.quiet = true;
|
||||
model.fit_eval_disk(TRAIN_PATH, EVAL_PATH);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cv() {
|
||||
let data = generate_data();
|
||||
@@ -233,6 +266,13 @@ mod tests {
|
||||
model.cv(&data, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cv_disk() {
|
||||
let mut model = Model::new();
|
||||
model.quiet = true;
|
||||
model.cv_disk(TRAIN_PATH, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_load() {
|
||||
let data = generate_data();
|
||||
|
||||
Reference in New Issue
Block a user