Revert "Added disk-level training"

This reverts commit 7cbcf78a95.
This commit is contained in:
Andrew Kane
2021-07-26 16:51:08 -07:00
parent dcfbb2f867
commit c6fddcc303
4 changed files with 0 additions and 62 deletions

View File

@@ -55,11 +55,8 @@ extern "C" {
pub fn mf_save_model(model: *const MfModel, path: *const c_char) -> c_int;
pub fn mf_load_model(path: *const c_char) -> *mut MfModel;
pub fn mf_train(prob: *const MfProblem, param: MfParameter) -> *mut MfModel;
pub fn mf_train_on_disk(tr_path: *const c_char, param: MfParameter) -> *mut MfModel;
pub fn mf_train_with_validation(tr: *const MfProblem, va: *const MfProblem, param: MfParameter) -> *mut MfModel;
pub fn mf_train_with_validation_on_disk(tr_path: *const c_char, va_path: *const c_char, param: MfParameter) -> *mut MfModel;
pub fn mf_cross_validation(prob: *const MfProblem, nr_folds: c_int, param: MfParameter) -> c_double;
pub fn mf_cross_validation_on_disk(prob: *const c_char, nr_folds: c_int, param: MfParameter) -> c_double;
pub fn mf_predict(model: *const MfModel, u: c_int, v: c_int) -> c_float;
pub fn calc_rmse(prob: *const MfProblem, model: *const MfModel) -> c_double;
pub fn calc_mae(prob: *const MfProblem, model: *const MfModel) -> c_double;

View File

@@ -35,33 +35,17 @@ impl Model {
self.model = unsafe { mf_train(&prob, self.param()) };
}
pub fn fit_disk(&mut self, path: &str) {
let cpath = CString::new(path).expect("CString::new failed");
self.model = unsafe { mf_train_on_disk(cpath.as_ptr(), self.param()) };
}
pub fn fit_eval(&mut self, train_set: &Matrix, eval_set: &Matrix) {
let tr = train_set.to_problem();
let va = eval_set.to_problem();
self.model = unsafe { mf_train_with_validation(&tr, &va, self.param()) };
}
pub fn fit_eval_disk(&mut self, train_path: &str, eval_path: &str) {
let trpath = CString::new(train_path).expect("CString::new failed");
let vapath = CString::new(eval_path).expect("CString::new failed");
self.model = unsafe { mf_train_with_validation_on_disk(trpath.as_ptr(), vapath.as_ptr(), self.param()) };
}
pub fn cv(&mut self, data: &Matrix, folds: i32) {
let prob = data.to_problem();
unsafe { mf_cross_validation(&prob, folds, self.param()); }
}
pub fn cv_disk(&mut self, path: &str, folds: i32) {
let cpath = CString::new(path).expect("CString::new failed");
unsafe { mf_cross_validation_on_disk(cpath.as_ptr(), folds, self.param()); }
}
pub fn predict(&self, row_index: i32, column_index: i32) -> f32 {
assert!(self.is_fit());
unsafe { mf_predict(self.model, row_index, column_index) }
@@ -212,9 +196,6 @@ impl Model {
mod tests {
use crate::{Matrix, Model};
const TRAIN_PATH: &str = "vendor/libmf/demo/real_matrix.tr.txt";
const EVAL_PATH: &str = "vendor/libmf/demo/real_matrix.te.txt";
fn generate_data() -> Matrix {
let mut data = Matrix::new();
data.push(0, 0, 1.0);
@@ -236,13 +217,6 @@ mod tests {
model.bias();
}
#[test]
fn test_fit_disk() {
let mut model = Model::new();
model.quiet = true;
model.fit_disk(TRAIN_PATH);
}
#[test]
fn test_fit_eval() {
let data = generate_data();
@@ -251,13 +225,6 @@ mod tests {
model.fit_eval(&data, &data);
}
#[test]
fn test_fit_eval_disk() {
let mut model = Model::new();
model.quiet = true;
model.fit_eval_disk(TRAIN_PATH, EVAL_PATH);
}
#[test]
fn test_cv() {
let data = generate_data();
@@ -266,13 +233,6 @@ mod tests {
model.cv(&data, 5);
}
#[test]
fn test_cv_disk() {
let mut model = Model::new();
model.quiet = true;
model.cv_disk(TRAIN_PATH, 5);
}
#[test]
fn test_save_load() {
let data = generate_data();