Improved docs [skip ci]

This commit is contained in:
Andrew Kane
2024-06-30 22:09:13 -07:00
parent 762ec863e2
commit 9066cb543d
4 changed files with 42 additions and 2 deletions

View File

@@ -59,7 +59,7 @@ Save the model to a file
model.save("model.txt").unwrap(); model.save("model.txt").unwrap();
``` ```
Load the model from a file Load a model from a file
```rust ```rust
let model = libmf::Model::load("model.txt").unwrap(); let model = libmf::Model::load("model.txt").unwrap();
@@ -87,7 +87,7 @@ Set parameters - default values below
libmf::Model::params() libmf::Model::params()
.loss(libmf::Loss::RealL2) // loss function .loss(libmf::Loss::RealL2) // loss function
.factors(8) // number of latent factors .factors(8) // number of latent factors
.threads(12) // number of threads used .threads(12) // number of threads
.bins(25) // number of bins .bins(25) // number of bins
.iterations(20) // number of iterations .iterations(20) // number of iterations
.lambda_p1(0.0) // coefficient of L1-norm regularization on P .lambda_p1(0.0) // coefficient of L1-norm regularization on P

View File

@@ -50,14 +50,23 @@ pub struct MfModel {
#[repr(C)] #[repr(C)]
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
pub enum Loss { pub enum Loss {
/// Squared error (L2-norm).
RealL2 = 0, RealL2 = 0,
/// Absolute error (L1-norm).
RealL1 = 1, RealL1 = 1,
/// Generalized KL-divergence.
RealKL = 2, RealKL = 2,
/// Logarithmic error.
BinaryLog = 5, BinaryLog = 5,
/// Squared hinge loss.
BinaryL2 = 6, BinaryL2 = 6,
/// Hinge loss.
BinaryL1 = 7, BinaryL1 = 7,
/// Row-oriented pair-wise logarithmic loss.
OneClassRow = 10, OneClassRow = 10,
/// Column-oriented pair-wise logarithmic loss.
OneClassCol = 11, OneClassCol = 11,
/// Squared error (L2-norm).
OneClassL2 = 12, OneClassL2 = 12,
} }

View File

@@ -15,6 +15,7 @@ impl Model {
Params::new() Params::new()
} }
/// Loads a model from a file.
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, Error> { pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
// TODO better conversion // TODO better conversion
let cpath = CString::new(path.as_ref().to_str().unwrap())?; let cpath = CString::new(path.as_ref().to_str().unwrap())?;
@@ -25,10 +26,12 @@ impl Model {
Ok(Model { model }) Ok(Model { model })
} }
/// Returns the predicted value for a specific row and column.
pub fn predict(&self, row_index: i32, column_index: i32) -> f32 { pub fn predict(&self, row_index: i32, column_index: i32) -> f32 {
unsafe { mf_predict(self.model, row_index, column_index) } unsafe { mf_predict(self.model, row_index, column_index) }
} }
/// Saves the model to a file.
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Error> { pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Error> {
// TODO better conversion // TODO better conversion
let cpath = CString::new(path.as_ref().to_str().unwrap())?; let cpath = CString::new(path.as_ref().to_str().unwrap())?;
@@ -39,18 +42,22 @@ impl Model {
Ok(()) Ok(())
} }
/// Returns the number of rows.
pub fn rows(&self) -> i32 { pub fn rows(&self) -> i32 {
unsafe { (*self.model).m } unsafe { (*self.model).m }
} }
/// Returns the number of columns.
pub fn columns(&self) -> i32 { pub fn columns(&self) -> i32 {
unsafe { (*self.model).n } unsafe { (*self.model).n }
} }
/// Returns the number of factors.
pub fn factors(&self) -> i32 { pub fn factors(&self) -> i32 {
unsafe { (*self.model).k } unsafe { (*self.model).k }
} }
/// Returns the bias.
pub fn bias(&self) -> f32 { pub fn bias(&self) -> f32 {
unsafe { (*self.model).b } unsafe { (*self.model).b }
} }
@@ -91,36 +98,43 @@ impl Model {
self.q_factors().chunks(self.factors() as usize) self.q_factors().chunks(self.factors() as usize)
} }
/// Calculates RMSE (for real-valued MF).
pub fn rmse(&self, data: &Matrix) -> f64 { pub fn rmse(&self, data: &Matrix) -> f64 {
let prob = data.to_problem(); let prob = data.to_problem();
unsafe { calc_rmse(&prob, self.model) } unsafe { calc_rmse(&prob, self.model) }
} }
/// Calculates MAE (for real-valued MF).
pub fn mae(&self, data: &Matrix) -> f64 { pub fn mae(&self, data: &Matrix) -> f64 {
let prob = data.to_problem(); let prob = data.to_problem();
unsafe { calc_mae(&prob, self.model) } unsafe { calc_mae(&prob, self.model) }
} }
/// Calculates generalized KL-divergence (for non-negative real-valued MF).
pub fn gkl(&self, data: &Matrix) -> f64 { pub fn gkl(&self, data: &Matrix) -> f64 {
let prob = data.to_problem(); let prob = data.to_problem();
unsafe { calc_gkl(&prob, self.model) } unsafe { calc_gkl(&prob, self.model) }
} }
/// Calculates logarithmic loss (for binary MF).
pub fn logloss(&self, data: &Matrix) -> f64 { pub fn logloss(&self, data: &Matrix) -> f64 {
let prob = data.to_problem(); let prob = data.to_problem();
unsafe { calc_logloss(&prob, self.model) } unsafe { calc_logloss(&prob, self.model) }
} }
/// Calculates accuracy (for binary MF).
pub fn accuracy(&self, data: &Matrix) -> f64 { pub fn accuracy(&self, data: &Matrix) -> f64 {
let prob = data.to_problem(); let prob = data.to_problem();
unsafe { calc_accuracy(&prob, self.model) } unsafe { calc_accuracy(&prob, self.model) }
} }
/// Calculates MPR (for one-class MF).
pub fn mpr(&self, data: &Matrix, transpose: bool) -> f64 { pub fn mpr(&self, data: &Matrix, transpose: bool) -> f64 {
let prob = data.to_problem(); let prob = data.to_problem();
unsafe { calc_mpr(&prob, self.model, transpose) } unsafe { calc_mpr(&prob, self.model, transpose) }
} }
/// Calculates AUC (for one-class MF).
pub fn auc(&self, data: &Matrix, transpose: bool) -> f64 { pub fn auc(&self, data: &Matrix, transpose: bool) -> f64 {
let prob = data.to_problem(); let prob = data.to_problem();
unsafe { calc_auc(&prob, self.model, transpose) } unsafe { calc_auc(&prob, self.model, transpose) }

View File

@@ -13,76 +13,91 @@ impl Params {
Self { param } Self { param }
} }
/// Sets the loss function.
pub fn loss(&mut self, value: Loss) -> &mut Self { pub fn loss(&mut self, value: Loss) -> &mut Self {
self.param.fun = value; self.param.fun = value;
self self
} }
/// Sets the number of latent factors.
pub fn factors(&mut self, value: i32) -> &mut Self { pub fn factors(&mut self, value: i32) -> &mut Self {
self.param.k = value; self.param.k = value;
self self
} }
/// Sets the number of threads.
pub fn threads(&mut self, value: i32) -> &mut Self { pub fn threads(&mut self, value: i32) -> &mut Self {
self.param.nr_threads = value; self.param.nr_threads = value;
self self
} }
/// Sets the number of bins.
pub fn bins(&mut self, value: i32) -> &mut Self { pub fn bins(&mut self, value: i32) -> &mut Self {
self.param.nr_bins = value; self.param.nr_bins = value;
self self
} }
/// Sets the number of iterations.
pub fn iterations(&mut self, value: i32) -> &mut Self { pub fn iterations(&mut self, value: i32) -> &mut Self {
self.param.nr_iters = value; self.param.nr_iters = value;
self self
} }
/// Sets the coefficient of L1-norm regularization on P.
pub fn lambda_p1(&mut self, value: f32) -> &mut Self { pub fn lambda_p1(&mut self, value: f32) -> &mut Self {
self.param.lambda_p1 = value; self.param.lambda_p1 = value;
self self
} }
/// Sets the coefficient of L2-norm regularization on P.
pub fn lambda_p2(&mut self, value: f32) -> &mut Self { pub fn lambda_p2(&mut self, value: f32) -> &mut Self {
self.param.lambda_p2 = value; self.param.lambda_p2 = value;
self self
} }
/// Sets the coefficient of L1-norm regularization on Q.
pub fn lambda_q1(&mut self, value: f32) -> &mut Self { pub fn lambda_q1(&mut self, value: f32) -> &mut Self {
self.param.lambda_q1 = value; self.param.lambda_q1 = value;
self self
} }
/// Sets the coefficient of L2-norm regularization on Q.
pub fn lambda_q2(&mut self, value: f32) -> &mut Self { pub fn lambda_q2(&mut self, value: f32) -> &mut Self {
self.param.lambda_q2 = value; self.param.lambda_q2 = value;
self self
} }
/// Sets the learning rate.
pub fn learning_rate(&mut self, value: f32) -> &mut Self { pub fn learning_rate(&mut self, value: f32) -> &mut Self {
self.param.eta = value; self.param.eta = value;
self self
} }
/// Sets the importance of negative entries.
pub fn alpha(&mut self, value: f32) -> &mut Self { pub fn alpha(&mut self, value: f32) -> &mut Self {
self.param.alpha = value; self.param.alpha = value;
self self
} }
/// Sets the desired value of negative entries.
pub fn c(&mut self, value: f32) -> &mut Self { pub fn c(&mut self, value: f32) -> &mut Self {
self.param.c = value; self.param.c = value;
self self
} }
/// Sets whether to perform non-negative MF (NMF).
pub fn nmf(&mut self, value: bool) -> &mut Self { pub fn nmf(&mut self, value: bool) -> &mut Self {
self.param.do_nmf = value; self.param.do_nmf = value;
self self
} }
/// Sets whether to output to stdout.
pub fn quiet(&mut self, value: bool) -> &mut Self { pub fn quiet(&mut self, value: bool) -> &mut Self {
self.param.quiet = value; self.param.quiet = value;
self self
} }
/// Fits a model.
pub fn fit(&mut self, data: &Matrix) -> Result<Model, Error> { pub fn fit(&mut self, data: &Matrix) -> Result<Model, Error> {
let prob = data.to_problem(); let prob = data.to_problem();
let param = self.build_param()?; let param = self.build_param()?;
@@ -93,6 +108,7 @@ impl Params {
Ok(Model { model }) Ok(Model { model })
} }
/// Fits a model and performs cross-validation.
pub fn fit_eval(&mut self, train_set: &Matrix, eval_set: &Matrix) -> Result<Model, Error> { pub fn fit_eval(&mut self, train_set: &Matrix, eval_set: &Matrix) -> Result<Model, Error> {
let tr = train_set.to_problem(); let tr = train_set.to_problem();
let va = eval_set.to_problem(); let va = eval_set.to_problem();
@@ -104,6 +120,7 @@ impl Params {
Ok(Model { model }) Ok(Model { model })
} }
/// Performs cross-validation.
pub fn cv(&mut self, data: &Matrix, folds: i32) -> Result<f64, Error> { pub fn cv(&mut self, data: &Matrix, folds: i32) -> Result<f64, Error> {
let prob = data.to_problem(); let prob = data.to_problem();
let param = self.build_param()?; let param = self.build_param()?;