Free model

This commit is contained in:
Andrew Kane
2021-07-26 16:57:40 -07:00
parent c6fddcc303
commit 2041988959
2 changed files with 19 additions and 3 deletions

View File

@@ -54,6 +54,7 @@ extern "C" {
pub fn mf_get_default_param() -> MfParameter;
pub fn mf_save_model(model: *const MfModel, path: *const c_char) -> c_int;
pub fn mf_load_model(path: *const c_char) -> *mut MfModel;
pub fn mf_destroy_model(model: *mut *mut MfModel);
pub fn mf_train(prob: *const MfProblem, param: MfParameter) -> *mut MfModel;
pub fn mf_train_with_validation(tr: *const MfProblem, va: *const MfProblem, param: MfParameter) -> *mut MfModel;
pub fn mf_cross_validation(prob: *const MfProblem, nr_folds: c_int, param: MfParameter) -> c_double;

View File

@@ -3,7 +3,7 @@ use crate::Matrix;
use std::ffi::CString;
pub struct Model {
model: *const MfModel,
model: *mut MfModel,
pub loss: i32,
pub factors: i32,
pub threads: i32,
@@ -22,7 +22,7 @@ pub struct Model {
impl Model {
pub fn new() -> Self {
Self::with_model(std::ptr::null())
Self::with_model(std::ptr::null_mut())
}
pub fn load(path: &str) -> Self {
@@ -32,12 +32,14 @@ impl Model {
pub fn fit(&mut self, data: &Matrix) {
let prob = data.to_problem();
self.destroy_model();
self.model = unsafe { mf_train(&prob, self.param()) };
}
pub fn fit_eval(&mut self, train_set: &Matrix, eval_set: &Matrix) {
let tr = train_set.to_problem();
let va = eval_set.to_problem();
self.destroy_model();
self.model = unsafe { mf_train_with_validation(&tr, &va, self.param()) };
}
@@ -147,7 +149,7 @@ impl Model {
unsafe { calc_auc(&prob, self.model, transpose) }
}
fn with_model(model: *const MfModel) -> Self {
fn with_model(model: *mut MfModel) -> Self {
let param = unsafe { mf_get_default_param() };
Self {
model: model,
@@ -190,6 +192,19 @@ impl Model {
fn is_fit(&self) -> bool {
!self.model.is_null()
}
fn destroy_model(&mut self) {
if !self.model.is_null() {
unsafe { mf_destroy_model(&mut self.model) };
self.model = std::ptr::null_mut();
}
}
}
impl Drop for Model {
fn drop(&mut self) {
self.destroy_model();
}
}
#[cfg(test)]