From 9066cb543d02b169219184922635a54ad3729d8b Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sun, 30 Jun 2024 22:09:13 -0700 Subject: [PATCH] Improved docs [skip ci] --- README.md | 4 ++-- src/bindings.rs | 9 +++++++++ src/model.rs | 14 ++++++++++++++ src/params.rs | 17 +++++++++++++++++ 4 files changed, 42 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 28be52c..cb1af2a 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ Save the model to a file model.save("model.txt").unwrap(); ``` -Load the model from a file +Load a model from a file ```rust let model = libmf::Model::load("model.txt").unwrap(); @@ -87,7 +87,7 @@ Set parameters - default values below libmf::Model::params() .loss(libmf::Loss::RealL2) // loss function .factors(8) // number of latent factors - .threads(12) // number of threads used + .threads(12) // number of threads .bins(25) // number of bins .iterations(20) // number of iterations .lambda_p1(0.0) // coefficient of L1-norm regularization on P diff --git a/src/bindings.rs b/src/bindings.rs index abcb6f2..1b43dca 100644 --- a/src/bindings.rs +++ b/src/bindings.rs @@ -50,14 +50,23 @@ pub struct MfModel { #[repr(C)] #[derive(Clone, Copy)] pub enum Loss { + /// Squared error (L2-norm). RealL2 = 0, + /// Absolute error (L1-norm). RealL1 = 1, + /// Generalized KL-divergence. RealKL = 2, + /// Logarithmic error. BinaryLog = 5, + /// Squared hinge loss. BinaryL2 = 6, + /// Hinge loss. BinaryL1 = 7, + /// Row-oriented pair-wise logarithmic loss. OneClassRow = 10, + /// Column-oriented pair-wise logarithmic loss. OneClassCol = 11, + /// Squared error (L2-norm). OneClassL2 = 12, } diff --git a/src/model.rs b/src/model.rs index 47d5a3d..55b04ad 100644 --- a/src/model.rs +++ b/src/model.rs @@ -15,6 +15,7 @@ impl Model { Params::new() } + /// Loads a model from a file. pub fn load>(path: P) -> Result { // TODO better conversion let cpath = CString::new(path.as_ref().to_str().unwrap())?; @@ -25,10 +26,12 @@ impl Model { Ok(Model { model }) } + /// Returns the predicted value for a specific row and column. pub fn predict(&self, row_index: i32, column_index: i32) -> f32 { unsafe { mf_predict(self.model, row_index, column_index) } } + /// Saves the model to a file. pub fn save>(&self, path: P) -> Result<(), Error> { // TODO better conversion let cpath = CString::new(path.as_ref().to_str().unwrap())?; @@ -39,18 +42,22 @@ impl Model { Ok(()) } + /// Returns the number of rows. pub fn rows(&self) -> i32 { unsafe { (*self.model).m } } + /// Returns the number of columns. pub fn columns(&self) -> i32 { unsafe { (*self.model).n } } + /// Returns the number of factors. pub fn factors(&self) -> i32 { unsafe { (*self.model).k } } + /// Returns the bias. pub fn bias(&self) -> f32 { unsafe { (*self.model).b } } @@ -91,36 +98,43 @@ impl Model { self.q_factors().chunks(self.factors() as usize) } + /// Calculates RMSE (for real-valued MF). pub fn rmse(&self, data: &Matrix) -> f64 { let prob = data.to_problem(); unsafe { calc_rmse(&prob, self.model) } } + /// Calculates MAE (for real-valued MF). pub fn mae(&self, data: &Matrix) -> f64 { let prob = data.to_problem(); unsafe { calc_mae(&prob, self.model) } } + /// Calculates generalized KL-divergence (for non-negative real-valued MF). pub fn gkl(&self, data: &Matrix) -> f64 { let prob = data.to_problem(); unsafe { calc_gkl(&prob, self.model) } } + /// Calculates logarithmic loss (for binary MF). pub fn logloss(&self, data: &Matrix) -> f64 { let prob = data.to_problem(); unsafe { calc_logloss(&prob, self.model) } } + /// Calculates accuracy (for binary MF). pub fn accuracy(&self, data: &Matrix) -> f64 { let prob = data.to_problem(); unsafe { calc_accuracy(&prob, self.model) } } + /// Calculates MPR (for one-class MF). pub fn mpr(&self, data: &Matrix, transpose: bool) -> f64 { let prob = data.to_problem(); unsafe { calc_mpr(&prob, self.model, transpose) } } + /// Calculates AUC (for one-class MF). pub fn auc(&self, data: &Matrix, transpose: bool) -> f64 { let prob = data.to_problem(); unsafe { calc_auc(&prob, self.model, transpose) } diff --git a/src/params.rs b/src/params.rs index e417317..ef5f345 100644 --- a/src/params.rs +++ b/src/params.rs @@ -13,76 +13,91 @@ impl Params { Self { param } } + /// Sets the loss function. pub fn loss(&mut self, value: Loss) -> &mut Self { self.param.fun = value; self } + /// Sets the number of latent factors. pub fn factors(&mut self, value: i32) -> &mut Self { self.param.k = value; self } + /// Sets the number of threads. pub fn threads(&mut self, value: i32) -> &mut Self { self.param.nr_threads = value; self } + /// Sets the number of bins. pub fn bins(&mut self, value: i32) -> &mut Self { self.param.nr_bins = value; self } + /// Sets the number of iterations. pub fn iterations(&mut self, value: i32) -> &mut Self { self.param.nr_iters = value; self } + /// Sets the coefficient of L1-norm regularization on P. pub fn lambda_p1(&mut self, value: f32) -> &mut Self { self.param.lambda_p1 = value; self } + /// Sets the coefficient of L2-norm regularization on P. pub fn lambda_p2(&mut self, value: f32) -> &mut Self { self.param.lambda_p2 = value; self } + /// Sets the coefficient of L1-norm regularization on Q. pub fn lambda_q1(&mut self, value: f32) -> &mut Self { self.param.lambda_q1 = value; self } + /// Sets the coefficient of L2-norm regularization on Q. pub fn lambda_q2(&mut self, value: f32) -> &mut Self { self.param.lambda_q2 = value; self } + /// Sets the learning rate. pub fn learning_rate(&mut self, value: f32) -> &mut Self { self.param.eta = value; self } + /// Sets the importance of negative entries. pub fn alpha(&mut self, value: f32) -> &mut Self { self.param.alpha = value; self } + /// Sets the desired value of negative entries. pub fn c(&mut self, value: f32) -> &mut Self { self.param.c = value; self } + /// Sets whether to perform non-negative MF (NMF). pub fn nmf(&mut self, value: bool) -> &mut Self { self.param.do_nmf = value; self } + /// Sets whether to output to stdout. pub fn quiet(&mut self, value: bool) -> &mut Self { self.param.quiet = value; self } + /// Fits a model. pub fn fit(&mut self, data: &Matrix) -> Result { let prob = data.to_problem(); let param = self.build_param()?; @@ -93,6 +108,7 @@ impl Params { Ok(Model { model }) } + /// Fits a model and performs cross-validation. pub fn fit_eval(&mut self, train_set: &Matrix, eval_set: &Matrix) -> Result { let tr = train_set.to_problem(); let va = eval_set.to_problem(); @@ -104,6 +120,7 @@ impl Params { Ok(Model { model }) } + /// Performs cross-validation. pub fn cv(&mut self, data: &Matrix, folds: i32) -> Result { let prob = data.to_problem(); let param = self.build_param()?;