Changed loss to use enum

This commit is contained in:
Andrew Kane
2021-10-16 18:50:06 -07:00
parent f11e49a272
commit 40accfa870
6 changed files with 57 additions and 35 deletions

View File

@@ -3,6 +3,7 @@
- Changed pattern for fitting models - use `Model::params()` instead of `Model::new()`
- Changed `fit`, `fit_eval`, `cv`, `save`, and `load` to return `Result`
- Changed `cv` to return average error
- Changed `loss` to use enum
## 0.1.1 (2021-07-27)

View File

@@ -80,41 +80,41 @@ Set parameters - default values below
```rust
libmf::Model::params()
.loss(0) // loss function
.factors(8) // number of latent factors
.threads(12) // number of threads used
.bins(25) // number of bins
.iterations(20) // number of iterations
.lambda_p1(0.0) // coefficient of L1-norm regularization on P
.lambda_p2(0.1) // coefficient of L2-norm regularization on P
.lambda_q1(0.0) // coefficient of L1-norm regularization on Q
.lambda_q2(0.1) // coefficient of L2-norm regularization on Q
.learning_rate(0.1) // learning rate
.alpha(0.1) // importance of negative entries
.c(0.0001) // desired value of negative entries
.nmf(false) // perform non-negative MF (NMF)
.quiet(false); // no outputs to stdout
.loss(libmf::Loss::RealL2) // loss function
.factors(8) // number of latent factors
.threads(12) // number of threads used
.bins(25) // number of bins
.iterations(20) // number of iterations
.lambda_p1(0.0) // coefficient of L1-norm regularization on P
.lambda_p2(0.1) // coefficient of L2-norm regularization on P
.lambda_q1(0.0) // coefficient of L1-norm regularization on Q
.lambda_q2(0.1) // coefficient of L2-norm regularization on Q
.learning_rate(0.1) // learning rate
.alpha(0.1) // importance of negative entries
.c(0.0001) // desired value of negative entries
.nmf(false) // perform non-negative MF (NMF)
.quiet(false); // no outputs to stdout
```
### Loss Functions
For real-valued matrix factorization
- 0 - squared error (L2-norm)
- 1 - absolute error (L1-norm)
- 2 - generalized KL-divergence
- `Loss::RealL2` - squared error (L2-norm)
- `Loss::RealL1` - absolute error (L1-norm)
- `Loss::RealKL` - generalized KL-divergence
For binary matrix factorization
- 5 - logarithmic error
- 6 - squared hinge loss
- 7 - hinge loss
- `Loss::BinaryLog` - logarithmic error
- `Loss::BinaryL2` - squared hinge loss
- `Loss::BinaryL1` - hinge loss
For one-class matrix factorization
- 10 - row-oriented pair-wise logarithmic loss
- 11 - column-oriented pair-wise logarithmic loss
- 12 - squared error (L2-norm)
- `Loss::OneClassRow` - row-oriented pair-wise logarithmic loss
- `Loss::OneClassCol` - column-oriented pair-wise logarithmic loss
- `Loss::OneClassL2` - squared error (L2-norm)
## Metrics

View File

@@ -20,7 +20,7 @@ pub struct MfProblem {
#[repr(C)]
#[derive(Clone, Copy)]
pub struct MfParameter {
pub fun: c_int,
pub fun: Loss,
pub k: c_int,
pub nr_threads: c_int,
pub nr_bins: c_int,
@@ -39,7 +39,7 @@ pub struct MfParameter {
#[repr(C)]
pub struct MfModel {
pub fun: c_int,
pub fun: Loss,
pub m: c_int,
pub n: c_int,
pub k: c_int,
@@ -48,6 +48,20 @@ pub struct MfModel {
pub q: *const c_float
}
#[repr(C)]
#[derive(Clone, Copy)]
pub enum Loss {
RealL2 = 0,
RealL1 = 1,
RealKL = 2,
BinaryLog = 5,
BinaryL2 = 6,
BinaryL1 = 7,
OneClassRow = 10,
OneClassCol = 11,
OneClassL2 = 12
}
extern "C" {
pub fn mf_get_default_param() -> MfParameter;
pub fn mf_save_model(model: *const MfModel, path: *const c_char) -> c_int;

View File

@@ -8,6 +8,7 @@ mod matrix;
mod model;
mod params;
pub use bindings::Loss;
pub use error::Error;
pub use matrix::Matrix;
pub use model::Model;

View File

@@ -104,7 +104,7 @@ impl Drop for Model {
#[cfg(test)]
mod tests {
use crate::{Matrix, Model};
use crate::{Loss, Matrix, Model};
fn generate_data() -> Matrix {
let mut data = Matrix::new();
@@ -139,6 +139,13 @@ mod tests {
assert!(avg_error.is_nan());
}
#[test]
fn test_loss() {
let data = generate_data();
let model = Model::params().loss(Loss::OneClassL2).quiet(true).fit(&data).unwrap();
assert_eq!(model.bias(), 0.0);
}
#[test]
fn test_save_load() {
let data = generate_data();
@@ -206,23 +213,23 @@ mod tests {
}
#[test]
fn test_fit_bad_loss() {
fn test_fit_bad_params() {
let data = generate_data();
let result = Model::params().loss(13).fit(&data);
let result = Model::params().factors(0).fit(&data);
assert!(result.is_err());
}
#[test]
fn test_fit_eval_bad_loss() {
fn test_fit_eval_bad_params() {
let data = generate_data();
let result = Model::params().loss(13).fit_eval(&data, &data);
let result = Model::params().factors(0).fit_eval(&data, &data);
assert!(result.is_err());
}
#[test]
fn test_cv_bad_loss() {
fn test_cv_bad_params() {
let data = generate_data();
let result = Model::params().loss(13).cv(&data, 5);
let result = Model::params().factors(0).cv(&data, 5);
assert!(result.is_err());
}
}

View File

@@ -1,5 +1,5 @@
use crate::bindings::*;
use crate::{Error, Matrix, Model};
use crate::{Error, Loss, Matrix, Model};
pub struct Params {
param: MfParameter
@@ -14,8 +14,7 @@ impl Params {
}
}
// TODO use enum
pub fn loss(&mut self, value: i32) -> &mut Self {
pub fn loss(&mut self, value: Loss) -> &mut Self {
self.param.fun = value;
self
}