Improved docs [skip ci]
This commit is contained in:
@@ -50,14 +50,23 @@ pub struct MfModel {
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum Loss {
|
||||
/// Squared error (L2-norm).
|
||||
RealL2 = 0,
|
||||
/// Absolute error (L1-norm).
|
||||
RealL1 = 1,
|
||||
/// Generalized KL-divergence.
|
||||
RealKL = 2,
|
||||
/// Logarithmic error.
|
||||
BinaryLog = 5,
|
||||
/// Squared hinge loss.
|
||||
BinaryL2 = 6,
|
||||
/// Hinge loss.
|
||||
BinaryL1 = 7,
|
||||
/// Row-oriented pair-wise logarithmic loss.
|
||||
OneClassRow = 10,
|
||||
/// Column-oriented pair-wise logarithmic loss.
|
||||
OneClassCol = 11,
|
||||
/// Squared error (L2-norm).
|
||||
OneClassL2 = 12,
|
||||
}
|
||||
|
||||
|
||||
14
src/model.rs
14
src/model.rs
@@ -15,6 +15,7 @@ impl Model {
|
||||
Params::new()
|
||||
}
|
||||
|
||||
/// Loads a model from a file.
|
||||
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
|
||||
// TODO better conversion
|
||||
let cpath = CString::new(path.as_ref().to_str().unwrap())?;
|
||||
@@ -25,10 +26,12 @@ impl Model {
|
||||
Ok(Model { model })
|
||||
}
|
||||
|
||||
/// Returns the predicted value for a specific row and column.
|
||||
pub fn predict(&self, row_index: i32, column_index: i32) -> f32 {
|
||||
unsafe { mf_predict(self.model, row_index, column_index) }
|
||||
}
|
||||
|
||||
/// Saves the model to a file.
|
||||
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Error> {
|
||||
// TODO better conversion
|
||||
let cpath = CString::new(path.as_ref().to_str().unwrap())?;
|
||||
@@ -39,18 +42,22 @@ impl Model {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the number of rows.
|
||||
pub fn rows(&self) -> i32 {
|
||||
unsafe { (*self.model).m }
|
||||
}
|
||||
|
||||
/// Returns the number of columns.
|
||||
pub fn columns(&self) -> i32 {
|
||||
unsafe { (*self.model).n }
|
||||
}
|
||||
|
||||
/// Returns the number of factors.
|
||||
pub fn factors(&self) -> i32 {
|
||||
unsafe { (*self.model).k }
|
||||
}
|
||||
|
||||
/// Returns the bias.
|
||||
pub fn bias(&self) -> f32 {
|
||||
unsafe { (*self.model).b }
|
||||
}
|
||||
@@ -91,36 +98,43 @@ impl Model {
|
||||
self.q_factors().chunks(self.factors() as usize)
|
||||
}
|
||||
|
||||
/// Calculates RMSE (for real-valued MF).
|
||||
pub fn rmse(&self, data: &Matrix) -> f64 {
|
||||
let prob = data.to_problem();
|
||||
unsafe { calc_rmse(&prob, self.model) }
|
||||
}
|
||||
|
||||
/// Calculates MAE (for real-valued MF).
|
||||
pub fn mae(&self, data: &Matrix) -> f64 {
|
||||
let prob = data.to_problem();
|
||||
unsafe { calc_mae(&prob, self.model) }
|
||||
}
|
||||
|
||||
/// Calculates generalized KL-divergence (for non-negative real-valued MF).
|
||||
pub fn gkl(&self, data: &Matrix) -> f64 {
|
||||
let prob = data.to_problem();
|
||||
unsafe { calc_gkl(&prob, self.model) }
|
||||
}
|
||||
|
||||
/// Calculates logarithmic loss (for binary MF).
|
||||
pub fn logloss(&self, data: &Matrix) -> f64 {
|
||||
let prob = data.to_problem();
|
||||
unsafe { calc_logloss(&prob, self.model) }
|
||||
}
|
||||
|
||||
/// Calculates accuracy (for binary MF).
|
||||
pub fn accuracy(&self, data: &Matrix) -> f64 {
|
||||
let prob = data.to_problem();
|
||||
unsafe { calc_accuracy(&prob, self.model) }
|
||||
}
|
||||
|
||||
/// Calculates MPR (for one-class MF).
|
||||
pub fn mpr(&self, data: &Matrix, transpose: bool) -> f64 {
|
||||
let prob = data.to_problem();
|
||||
unsafe { calc_mpr(&prob, self.model, transpose) }
|
||||
}
|
||||
|
||||
/// Calculates AUC (for one-class MF).
|
||||
pub fn auc(&self, data: &Matrix, transpose: bool) -> f64 {
|
||||
let prob = data.to_problem();
|
||||
unsafe { calc_auc(&prob, self.model, transpose) }
|
||||
|
||||
@@ -13,76 +13,91 @@ impl Params {
|
||||
Self { param }
|
||||
}
|
||||
|
||||
/// Sets the loss function.
|
||||
pub fn loss(&mut self, value: Loss) -> &mut Self {
|
||||
self.param.fun = value;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the number of latent factors.
|
||||
pub fn factors(&mut self, value: i32) -> &mut Self {
|
||||
self.param.k = value;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the number of threads.
|
||||
pub fn threads(&mut self, value: i32) -> &mut Self {
|
||||
self.param.nr_threads = value;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the number of bins.
|
||||
pub fn bins(&mut self, value: i32) -> &mut Self {
|
||||
self.param.nr_bins = value;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the number of iterations.
|
||||
pub fn iterations(&mut self, value: i32) -> &mut Self {
|
||||
self.param.nr_iters = value;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the coefficient of L1-norm regularization on P.
|
||||
pub fn lambda_p1(&mut self, value: f32) -> &mut Self {
|
||||
self.param.lambda_p1 = value;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the coefficient of L2-norm regularization on P.
|
||||
pub fn lambda_p2(&mut self, value: f32) -> &mut Self {
|
||||
self.param.lambda_p2 = value;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the coefficient of L1-norm regularization on Q.
|
||||
pub fn lambda_q1(&mut self, value: f32) -> &mut Self {
|
||||
self.param.lambda_q1 = value;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the coefficient of L2-norm regularization on Q.
|
||||
pub fn lambda_q2(&mut self, value: f32) -> &mut Self {
|
||||
self.param.lambda_q2 = value;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the learning rate.
|
||||
pub fn learning_rate(&mut self, value: f32) -> &mut Self {
|
||||
self.param.eta = value;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the importance of negative entries.
|
||||
pub fn alpha(&mut self, value: f32) -> &mut Self {
|
||||
self.param.alpha = value;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the desired value of negative entries.
|
||||
pub fn c(&mut self, value: f32) -> &mut Self {
|
||||
self.param.c = value;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets whether to perform non-negative MF (NMF).
|
||||
pub fn nmf(&mut self, value: bool) -> &mut Self {
|
||||
self.param.do_nmf = value;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets whether to output to stdout.
|
||||
pub fn quiet(&mut self, value: bool) -> &mut Self {
|
||||
self.param.quiet = value;
|
||||
self
|
||||
}
|
||||
|
||||
/// Fits a model.
|
||||
pub fn fit(&mut self, data: &Matrix) -> Result<Model, Error> {
|
||||
let prob = data.to_problem();
|
||||
let param = self.build_param()?;
|
||||
@@ -93,6 +108,7 @@ impl Params {
|
||||
Ok(Model { model })
|
||||
}
|
||||
|
||||
/// Fits a model and performs cross-validation.
|
||||
pub fn fit_eval(&mut self, train_set: &Matrix, eval_set: &Matrix) -> Result<Model, Error> {
|
||||
let tr = train_set.to_problem();
|
||||
let va = eval_set.to_problem();
|
||||
@@ -104,6 +120,7 @@ impl Params {
|
||||
Ok(Model { model })
|
||||
}
|
||||
|
||||
/// Performs cross-validation.
|
||||
pub fn cv(&mut self, data: &Matrix, folds: i32) -> Result<f64, Error> {
|
||||
let prob = data.to_problem();
|
||||
let param = self.build_param()?;
|
||||
|
||||
Reference in New Issue
Block a user