Changed loss to use enum
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
- Changed pattern for fitting models - use `Model::params()` instead of `Model::new()`
|
||||
- Changed `fit`, `fit_eval`, `cv`, `save`, and `load` to return `Result`
|
||||
- Changed `cv` to return average error
|
||||
- Changed `loss` to use enum
|
||||
|
||||
## 0.1.1 (2021-07-27)
|
||||
|
||||
|
||||
46
README.md
46
README.md
@@ -80,41 +80,41 @@ Set parameters - default values below
|
||||
|
||||
```rust
|
||||
libmf::Model::params()
|
||||
.loss(0) // loss function
|
||||
.factors(8) // number of latent factors
|
||||
.threads(12) // number of threads used
|
||||
.bins(25) // number of bins
|
||||
.iterations(20) // number of iterations
|
||||
.lambda_p1(0.0) // coefficient of L1-norm regularization on P
|
||||
.lambda_p2(0.1) // coefficient of L2-norm regularization on P
|
||||
.lambda_q1(0.0) // coefficient of L1-norm regularization on Q
|
||||
.lambda_q2(0.1) // coefficient of L2-norm regularization on Q
|
||||
.learning_rate(0.1) // learning rate
|
||||
.alpha(0.1) // importance of negative entries
|
||||
.c(0.0001) // desired value of negative entries
|
||||
.nmf(false) // perform non-negative MF (NMF)
|
||||
.quiet(false); // no outputs to stdout
|
||||
.loss(libmf::Loss::RealL2) // loss function
|
||||
.factors(8) // number of latent factors
|
||||
.threads(12) // number of threads used
|
||||
.bins(25) // number of bins
|
||||
.iterations(20) // number of iterations
|
||||
.lambda_p1(0.0) // coefficient of L1-norm regularization on P
|
||||
.lambda_p2(0.1) // coefficient of L2-norm regularization on P
|
||||
.lambda_q1(0.0) // coefficient of L1-norm regularization on Q
|
||||
.lambda_q2(0.1) // coefficient of L2-norm regularization on Q
|
||||
.learning_rate(0.1) // learning rate
|
||||
.alpha(0.1) // importance of negative entries
|
||||
.c(0.0001) // desired value of negative entries
|
||||
.nmf(false) // perform non-negative MF (NMF)
|
||||
.quiet(false); // no outputs to stdout
|
||||
```
|
||||
|
||||
### Loss Functions
|
||||
|
||||
For real-valued matrix factorization
|
||||
|
||||
- 0 - squared error (L2-norm)
|
||||
- 1 - absolute error (L1-norm)
|
||||
- 2 - generalized KL-divergence
|
||||
- `Loss::RealL2` - squared error (L2-norm)
|
||||
- `Loss::RealL1` - absolute error (L1-norm)
|
||||
- `Loss::RealKL` - generalized KL-divergence
|
||||
|
||||
For binary matrix factorization
|
||||
|
||||
- 5 - logarithmic error
|
||||
- 6 - squared hinge loss
|
||||
- 7 - hinge loss
|
||||
- `Loss::BinaryLog` - logarithmic error
|
||||
- `Loss::BinaryL2` - squared hinge loss
|
||||
- `Loss::BinaryL1` - hinge loss
|
||||
|
||||
For one-class matrix factorization
|
||||
|
||||
- 10 - row-oriented pair-wise logarithmic loss
|
||||
- 11 - column-oriented pair-wise logarithmic loss
|
||||
- 12 - squared error (L2-norm)
|
||||
- `Loss::OneClassRow` - row-oriented pair-wise logarithmic loss
|
||||
- `Loss::OneClassCol` - column-oriented pair-wise logarithmic loss
|
||||
- `Loss::OneClassL2` - squared error (L2-norm)
|
||||
|
||||
## Metrics
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ pub struct MfProblem {
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct MfParameter {
|
||||
pub fun: c_int,
|
||||
pub fun: Loss,
|
||||
pub k: c_int,
|
||||
pub nr_threads: c_int,
|
||||
pub nr_bins: c_int,
|
||||
@@ -39,7 +39,7 @@ pub struct MfParameter {
|
||||
|
||||
#[repr(C)]
|
||||
pub struct MfModel {
|
||||
pub fun: c_int,
|
||||
pub fun: Loss,
|
||||
pub m: c_int,
|
||||
pub n: c_int,
|
||||
pub k: c_int,
|
||||
@@ -48,6 +48,20 @@ pub struct MfModel {
|
||||
pub q: *const c_float
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum Loss {
|
||||
RealL2 = 0,
|
||||
RealL1 = 1,
|
||||
RealKL = 2,
|
||||
BinaryLog = 5,
|
||||
BinaryL2 = 6,
|
||||
BinaryL1 = 7,
|
||||
OneClassRow = 10,
|
||||
OneClassCol = 11,
|
||||
OneClassL2 = 12
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
pub fn mf_get_default_param() -> MfParameter;
|
||||
pub fn mf_save_model(model: *const MfModel, path: *const c_char) -> c_int;
|
||||
|
||||
@@ -8,6 +8,7 @@ mod matrix;
|
||||
mod model;
|
||||
mod params;
|
||||
|
||||
pub use bindings::Loss;
|
||||
pub use error::Error;
|
||||
pub use matrix::Matrix;
|
||||
pub use model::Model;
|
||||
|
||||
21
src/model.rs
21
src/model.rs
@@ -104,7 +104,7 @@ impl Drop for Model {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{Matrix, Model};
|
||||
use crate::{Loss, Matrix, Model};
|
||||
|
||||
fn generate_data() -> Matrix {
|
||||
let mut data = Matrix::new();
|
||||
@@ -139,6 +139,13 @@ mod tests {
|
||||
assert!(avg_error.is_nan());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_loss() {
|
||||
let data = generate_data();
|
||||
let model = Model::params().loss(Loss::OneClassL2).quiet(true).fit(&data).unwrap();
|
||||
assert_eq!(model.bias(), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_load() {
|
||||
let data = generate_data();
|
||||
@@ -206,23 +213,23 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fit_bad_loss() {
|
||||
fn test_fit_bad_params() {
|
||||
let data = generate_data();
|
||||
let result = Model::params().loss(13).fit(&data);
|
||||
let result = Model::params().factors(0).fit(&data);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fit_eval_bad_loss() {
|
||||
fn test_fit_eval_bad_params() {
|
||||
let data = generate_data();
|
||||
let result = Model::params().loss(13).fit_eval(&data, &data);
|
||||
let result = Model::params().factors(0).fit_eval(&data, &data);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cv_bad_loss() {
|
||||
fn test_cv_bad_params() {
|
||||
let data = generate_data();
|
||||
let result = Model::params().loss(13).cv(&data, 5);
|
||||
let result = Model::params().factors(0).cv(&data, 5);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::bindings::*;
|
||||
use crate::{Error, Matrix, Model};
|
||||
use crate::{Error, Loss, Matrix, Model};
|
||||
|
||||
pub struct Params {
|
||||
param: MfParameter
|
||||
@@ -14,8 +14,7 @@ impl Params {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO use enum
|
||||
pub fn loss(&mut self, value: i32) -> &mut Self {
|
||||
pub fn loss(&mut self, value: Loss) -> &mut Self {
|
||||
self.param.fun = value;
|
||||
self
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user