Improved docs [skip ci]

This commit is contained in:
Andrew Kane
2024-06-30 22:09:13 -07:00
parent 762ec863e2
commit 9066cb543d
4 changed files with 42 additions and 2 deletions

View File

@@ -50,14 +50,23 @@ pub struct MfModel {
#[repr(C)]
#[derive(Clone, Copy)]
pub enum Loss {
/// Squared error (L2-norm).
RealL2 = 0,
/// Absolute error (L1-norm).
RealL1 = 1,
/// Generalized KL-divergence.
RealKL = 2,
/// Logarithmic error.
BinaryLog = 5,
/// Squared hinge loss.
BinaryL2 = 6,
/// Hinge loss.
BinaryL1 = 7,
/// Row-oriented pair-wise logarithmic loss.
OneClassRow = 10,
/// Column-oriented pair-wise logarithmic loss.
OneClassCol = 11,
/// Squared error (L2-norm).
OneClassL2 = 12,
}

View File

@@ -15,6 +15,7 @@ impl Model {
Params::new()
}
/// Loads a model from a file.
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
// TODO better conversion
let cpath = CString::new(path.as_ref().to_str().unwrap())?;
@@ -25,10 +26,12 @@ impl Model {
Ok(Model { model })
}
/// Returns the predicted value for a specific row and column.
pub fn predict(&self, row_index: i32, column_index: i32) -> f32 {
unsafe { mf_predict(self.model, row_index, column_index) }
}
/// Saves the model to a file.
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Error> {
// TODO better conversion
let cpath = CString::new(path.as_ref().to_str().unwrap())?;
@@ -39,18 +42,22 @@ impl Model {
Ok(())
}
/// Returns the number of rows.
pub fn rows(&self) -> i32 {
unsafe { (*self.model).m }
}
/// Returns the number of columns.
pub fn columns(&self) -> i32 {
unsafe { (*self.model).n }
}
/// Returns the number of factors.
pub fn factors(&self) -> i32 {
unsafe { (*self.model).k }
}
/// Returns the bias.
pub fn bias(&self) -> f32 {
unsafe { (*self.model).b }
}
@@ -91,36 +98,43 @@ impl Model {
self.q_factors().chunks(self.factors() as usize)
}
/// Calculates RMSE (for real-valued MF).
pub fn rmse(&self, data: &Matrix) -> f64 {
let prob = data.to_problem();
unsafe { calc_rmse(&prob, self.model) }
}
/// Calculates MAE (for real-valued MF).
pub fn mae(&self, data: &Matrix) -> f64 {
let prob = data.to_problem();
unsafe { calc_mae(&prob, self.model) }
}
/// Calculates generalized KL-divergence (for non-negative real-valued MF).
pub fn gkl(&self, data: &Matrix) -> f64 {
let prob = data.to_problem();
unsafe { calc_gkl(&prob, self.model) }
}
/// Calculates logarithmic loss (for binary MF).
pub fn logloss(&self, data: &Matrix) -> f64 {
let prob = data.to_problem();
unsafe { calc_logloss(&prob, self.model) }
}
/// Calculates accuracy (for binary MF).
pub fn accuracy(&self, data: &Matrix) -> f64 {
let prob = data.to_problem();
unsafe { calc_accuracy(&prob, self.model) }
}
/// Calculates MPR (for one-class MF).
pub fn mpr(&self, data: &Matrix, transpose: bool) -> f64 {
let prob = data.to_problem();
unsafe { calc_mpr(&prob, self.model, transpose) }
}
/// Calculates AUC (for one-class MF).
pub fn auc(&self, data: &Matrix, transpose: bool) -> f64 {
let prob = data.to_problem();
unsafe { calc_auc(&prob, self.model, transpose) }

View File

@@ -13,76 +13,91 @@ impl Params {
Self { param }
}
/// Sets the loss function.
pub fn loss(&mut self, value: Loss) -> &mut Self {
self.param.fun = value;
self
}
/// Sets the number of latent factors.
pub fn factors(&mut self, value: i32) -> &mut Self {
self.param.k = value;
self
}
/// Sets the number of threads.
pub fn threads(&mut self, value: i32) -> &mut Self {
self.param.nr_threads = value;
self
}
/// Sets the number of bins.
pub fn bins(&mut self, value: i32) -> &mut Self {
self.param.nr_bins = value;
self
}
/// Sets the number of iterations.
pub fn iterations(&mut self, value: i32) -> &mut Self {
self.param.nr_iters = value;
self
}
/// Sets the coefficient of L1-norm regularization on P.
pub fn lambda_p1(&mut self, value: f32) -> &mut Self {
self.param.lambda_p1 = value;
self
}
/// Sets the coefficient of L2-norm regularization on P.
pub fn lambda_p2(&mut self, value: f32) -> &mut Self {
self.param.lambda_p2 = value;
self
}
/// Sets the coefficient of L1-norm regularization on Q.
pub fn lambda_q1(&mut self, value: f32) -> &mut Self {
self.param.lambda_q1 = value;
self
}
/// Sets the coefficient of L2-norm regularization on Q.
pub fn lambda_q2(&mut self, value: f32) -> &mut Self {
self.param.lambda_q2 = value;
self
}
/// Sets the learning rate.
pub fn learning_rate(&mut self, value: f32) -> &mut Self {
self.param.eta = value;
self
}
/// Sets the importance of negative entries.
pub fn alpha(&mut self, value: f32) -> &mut Self {
self.param.alpha = value;
self
}
/// Sets the desired value of negative entries.
pub fn c(&mut self, value: f32) -> &mut Self {
self.param.c = value;
self
}
/// Sets whether to perform non-negative MF (NMF).
pub fn nmf(&mut self, value: bool) -> &mut Self {
self.param.do_nmf = value;
self
}
/// Sets whether to output to stdout.
pub fn quiet(&mut self, value: bool) -> &mut Self {
self.param.quiet = value;
self
}
/// Fits a model.
pub fn fit(&mut self, data: &Matrix) -> Result<Model, Error> {
let prob = data.to_problem();
let param = self.build_param()?;
@@ -93,6 +108,7 @@ impl Params {
Ok(Model { model })
}
/// Fits a model and performs cross-validation.
pub fn fit_eval(&mut self, train_set: &Matrix, eval_set: &Matrix) -> Result<Model, Error> {
let tr = train_set.to_problem();
let va = eval_set.to_problem();
@@ -104,6 +120,7 @@ impl Params {
Ok(Model { model })
}
/// Performs cross-validation.
pub fn cv(&mut self, data: &Matrix, folds: i32) -> Result<f64, Error> {
let prob = data.to_problem();
let param = self.build_param()?;