First commit
This commit is contained in:
11
.github/workflows/build.yml
vendored
Normal file
11
.github/workflows/build.yml
vendored
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
name: build
|
||||||
|
on: [push, pull_request]
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- uses: actions-rs/toolchain@v1
|
||||||
|
with:
|
||||||
|
toolchain: stable
|
||||||
|
- run: cargo test
|
||||||
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
/target
|
||||||
|
Cargo.lock
|
||||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[submodule "vendor/libmf"]
|
||||||
|
path = vendor/libmf
|
||||||
|
url = https://github.com/ankane/libmf-1.git
|
||||||
3
CHANGELOG.md
Normal file
3
CHANGELOG.md
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
## 0.1.0 (unreleased)
|
||||||
|
|
||||||
|
- First release
|
||||||
19
Cargo.toml
Normal file
19
Cargo.toml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
[package]
|
||||||
|
name = "libmf"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Large-scale sparse matrix factorization for Rust"
|
||||||
|
repository = "https://github.com/ankane/libmf-rust"
|
||||||
|
license = "BSD-3-Clause"
|
||||||
|
authors = ["Andrew Kane <andrew@ankane.org>"]
|
||||||
|
edition = "2018"
|
||||||
|
readme = "README.md"
|
||||||
|
exclude = ["vendor/libmf/demo/*", "vendor/libmf/windows/*"]
|
||||||
|
links = "mf"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
libc = "0.2"
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
cc = "1.0"
|
||||||
33
LICENSE.txt
Normal file
33
LICENSE.txt
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
BSD 3-Clause License
|
||||||
|
|
||||||
|
Copyright (c) 2014-2015 The LIBMF Project.
|
||||||
|
Copyright (c) 2021, Andrew Kane.
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions
|
||||||
|
are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. Neither name of copyright holders nor the names of its contributors
|
||||||
|
may be used to endorse or promote products derived from this software
|
||||||
|
without specific prior written permission.
|
||||||
|
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR
|
||||||
|
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||||
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||||
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
150
README.md
Normal file
150
README.md
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
# LIBMF Rust
|
||||||
|
|
||||||
|
[LIBMF](https://github.com/cjlin1/libmf) - large-scale sparse matrix factorization - for Rust
|
||||||
|
|
||||||
|
[](https://github.com/ankane/libmf-rust/actions)
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Add this line to your application’s `Cargo.toml` under `[dependencies]`:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
libmf = { version = "0.1" }
|
||||||
|
```
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
Prep your data in the format `row_index, column_index, value`
|
||||||
|
|
||||||
|
```rust
|
||||||
|
let mut data = libmf::Matrix::new();
|
||||||
|
data.push(0, 0, 5.0);
|
||||||
|
data.push(0, 2, 3.5);
|
||||||
|
data.push(1, 1, 4.0);
|
||||||
|
```
|
||||||
|
|
||||||
|
Create a model
|
||||||
|
|
||||||
|
```rust
|
||||||
|
let mut model = libmf::Model::new();
|
||||||
|
model.fit(&data);
|
||||||
|
```
|
||||||
|
|
||||||
|
Make predictions
|
||||||
|
|
||||||
|
```rust
|
||||||
|
model.predict(row_index, column_index);
|
||||||
|
```
|
||||||
|
|
||||||
|
Get the latent factors (these approximate the training matrix)
|
||||||
|
|
||||||
|
```rust
|
||||||
|
model.p_factors();
|
||||||
|
model.q_factors();
|
||||||
|
```
|
||||||
|
|
||||||
|
Get the bias (average of all elements in the training matrix)
|
||||||
|
|
||||||
|
```rust
|
||||||
|
model.bias();
|
||||||
|
```
|
||||||
|
|
||||||
|
Save the model to a file
|
||||||
|
|
||||||
|
```rust
|
||||||
|
model.save("model.txt");
|
||||||
|
```
|
||||||
|
|
||||||
|
Load the model from a file
|
||||||
|
|
||||||
|
```rust
|
||||||
|
let model = libmf::Model::load("model.txt");
|
||||||
|
```
|
||||||
|
|
||||||
|
Pass a validation set
|
||||||
|
|
||||||
|
```rust
|
||||||
|
model.fit_eval(&train_set, &eval_set);
|
||||||
|
```
|
||||||
|
|
||||||
|
## Cross-Validation
|
||||||
|
|
||||||
|
Perform cross-validation
|
||||||
|
|
||||||
|
```rust
|
||||||
|
model.cv(&data, 5);
|
||||||
|
```
|
||||||
|
|
||||||
|
## Parameters
|
||||||
|
|
||||||
|
Set parameters - default values below
|
||||||
|
|
||||||
|
```rust
|
||||||
|
model.loss = 0; // loss function
|
||||||
|
model.factors = 8; // number of latent factors
|
||||||
|
model.threads = 12; // number of threads used
|
||||||
|
model.bins = 25; // number of bins
|
||||||
|
model.iterations = 20; // number of iterations
|
||||||
|
model.lambda_p1 = 0; // coefficient of L1-norm regularization on P
|
||||||
|
model.lambda_p2 = 0.1; // coefficient of L2-norm regularization on P
|
||||||
|
model.lambda_q1 = 0; // coefficient of L1-norm regularization on Q
|
||||||
|
model.lambda_q2 = 0.1; // coefficient of L2-norm regularization on Q
|
||||||
|
model.learning_rate = 0.1; // learning rate
|
||||||
|
model.alpha = 0.1; // importance of negative entries
|
||||||
|
model.c = 0.0001; // desired value of negative entries
|
||||||
|
model.nmf = false; // perform non-negative MF (NMF)
|
||||||
|
model.quiet = false; // no outputs to stdout
|
||||||
|
```
|
||||||
|
|
||||||
|
### Loss Functions
|
||||||
|
|
||||||
|
For real-valued matrix factorization
|
||||||
|
|
||||||
|
- 0 - squared error (L2-norm)
|
||||||
|
- 1 - absolute error (L1-norm)
|
||||||
|
- 2 - generalized KL-divergence
|
||||||
|
|
||||||
|
For binary matrix factorization
|
||||||
|
|
||||||
|
- 5 - logarithmic error
|
||||||
|
- 6 - squared hinge loss
|
||||||
|
- 7 - hinge loss
|
||||||
|
|
||||||
|
For one-class matrix factorization
|
||||||
|
|
||||||
|
- 10 - row-oriented pair-wise logarithmic loss
|
||||||
|
- 11 - column-oriented pair-wise logarithmic loss
|
||||||
|
- 12 - squared error (L2-norm)
|
||||||
|
|
||||||
|
## Reference
|
||||||
|
|
||||||
|
Specify the initial capacity for a matrix
|
||||||
|
|
||||||
|
```rust
|
||||||
|
let mut data = libmf::Matrix::with_capacity(3);
|
||||||
|
```
|
||||||
|
|
||||||
|
## Resources
|
||||||
|
|
||||||
|
- [LIBMF: A Library for Parallel Matrix Factorization in Shared-memory Systems](https://www.csie.ntu.edu.tw/~cjlin/papers/libmf/libmf_open_source.pdf)
|
||||||
|
|
||||||
|
## History
|
||||||
|
|
||||||
|
View the [changelog](https://github.com/ankane/libmf-rust/blob/master/CHANGELOG.md)
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Everyone is encouraged to help improve this project. Here are a few ways you can help:
|
||||||
|
|
||||||
|
- [Report bugs](https://github.com/ankane/libmf-rust/issues)
|
||||||
|
- Fix bugs and [submit pull requests](https://github.com/ankane/libmf-rust/pulls)
|
||||||
|
- Write, clarify, or fix documentation
|
||||||
|
- Suggest or add new features
|
||||||
|
|
||||||
|
To get started with development:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
git clone --recursive https://github.com/ankane/libmf-rust.git
|
||||||
|
cd libmf-rust
|
||||||
|
cargo test
|
||||||
|
```
|
||||||
10
build.rs
Normal file
10
build.rs
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
extern crate cc;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
cc::Build::new()
|
||||||
|
.cpp(true)
|
||||||
|
.flag("-std=c++11")
|
||||||
|
.flag("-Wno-unused-parameter")
|
||||||
|
.file("vendor/libmf/mf.cpp")
|
||||||
|
.compile("mf");
|
||||||
|
}
|
||||||
63
src/bindings.rs
Normal file
63
src/bindings.rs
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
extern crate libc;
|
||||||
|
|
||||||
|
use std::os::raw::{c_char, c_double, c_float, c_int, c_longlong};
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
pub struct MfNode
|
||||||
|
{
|
||||||
|
pub u: c_int,
|
||||||
|
pub v: c_int,
|
||||||
|
pub r: c_float
|
||||||
|
}
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
pub struct MfProblem
|
||||||
|
{
|
||||||
|
pub m: c_int,
|
||||||
|
pub n: c_int,
|
||||||
|
pub nnz: c_longlong,
|
||||||
|
pub r: *const MfNode
|
||||||
|
}
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
pub struct MfParameter {
|
||||||
|
pub fun: c_int,
|
||||||
|
pub k: c_int,
|
||||||
|
pub nr_threads: c_int,
|
||||||
|
pub nr_bins: c_int,
|
||||||
|
pub nr_iters: c_int,
|
||||||
|
pub lambda_p1: c_float,
|
||||||
|
pub lambda_p2: c_float,
|
||||||
|
pub lambda_q1: c_float,
|
||||||
|
pub lambda_q2: c_float,
|
||||||
|
pub eta: c_float,
|
||||||
|
pub alpha: c_float,
|
||||||
|
pub c: c_float,
|
||||||
|
pub do_nmf: bool,
|
||||||
|
pub quiet: bool,
|
||||||
|
pub copy_data: bool
|
||||||
|
}
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
pub struct MfModel
|
||||||
|
{
|
||||||
|
pub fun: c_int,
|
||||||
|
pub m: c_int,
|
||||||
|
pub n: c_int,
|
||||||
|
pub k: c_int,
|
||||||
|
pub b: c_float,
|
||||||
|
pub p: *const c_float,
|
||||||
|
pub q: *const c_float
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
pub fn mf_get_default_param() -> MfParameter;
|
||||||
|
pub fn mf_save_model(model: *const MfModel, path: *const c_char) -> c_int;
|
||||||
|
pub fn mf_load_model(path: *const c_char) -> *mut MfModel;
|
||||||
|
pub fn mf_train(prob: *const MfProblem, param: MfParameter) -> *mut MfModel;
|
||||||
|
pub fn mf_train_with_validation(tr: *const MfProblem, va: *const MfProblem, param: MfParameter) -> *mut MfModel;
|
||||||
|
pub fn mf_cross_validation(prob: *const MfProblem, nr_folds: c_int, param: MfParameter) -> c_double;
|
||||||
|
pub fn mf_predict(model: *const MfModel, u: c_int, v: c_int) -> c_float;
|
||||||
|
pub fn calc_rmse(prob: *const MfProblem, model: *const MfModel) -> c_double;
|
||||||
|
pub fn calc_mae(prob: *const MfProblem, model: *const MfModel) -> c_double;
|
||||||
|
}
|
||||||
10
src/lib.rs
Normal file
10
src/lib.rs
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
//! Large-scale sparse matrix factorization for Rust
|
||||||
|
//!
|
||||||
|
//! [View the docs](https://github.com/ankane/libmf-rust)
|
||||||
|
|
||||||
|
mod bindings;
|
||||||
|
mod matrix;
|
||||||
|
mod model;
|
||||||
|
|
||||||
|
pub use matrix::Matrix;
|
||||||
|
pub use model::Model;
|
||||||
53
src/matrix.rs
Normal file
53
src/matrix.rs
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
use crate::bindings::{MfNode, MfProblem};
|
||||||
|
|
||||||
|
pub struct Matrix {
|
||||||
|
data: Vec<MfNode>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Matrix {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
data: Vec::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_capacity(capacity: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
data: Vec::with_capacity(capacity)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn push(&mut self, row_index: i32, column_index: i32, value: f32) {
|
||||||
|
self.data.push(MfNode { u: row_index, v: column_index, r: value });
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn to_problem(&self) -> MfProblem {
|
||||||
|
let data = &self.data;
|
||||||
|
let m = data.iter().map(|x| x.u).max().unwrap_or(-1) + 1;
|
||||||
|
let n = data.iter().map(|x| x.v).max().unwrap_or(-1) + 1;
|
||||||
|
|
||||||
|
MfProblem {
|
||||||
|
m: m,
|
||||||
|
n: n,
|
||||||
|
nnz: data.len() as i64,
|
||||||
|
r: data.as_ptr()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::Matrix;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_new() {
|
||||||
|
let mut data = Matrix::new();
|
||||||
|
data.push(0, 0, 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_with_capacity() {
|
||||||
|
let mut data = Matrix::with_capacity(1);
|
||||||
|
data.push(0, 0, 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
262
src/model.rs
Normal file
262
src/model.rs
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
use crate::bindings::*;
|
||||||
|
use crate::Matrix;
|
||||||
|
use std::ffi::CString;
|
||||||
|
|
||||||
|
pub struct Model {
|
||||||
|
model: *const MfModel,
|
||||||
|
pub loss: i32,
|
||||||
|
pub factors: i32,
|
||||||
|
pub threads: i32,
|
||||||
|
pub bins: i32,
|
||||||
|
pub iterations: i32,
|
||||||
|
pub lambda_p1: f32,
|
||||||
|
pub lambda_p2: f32,
|
||||||
|
pub lambda_q1: f32,
|
||||||
|
pub lambda_q2: f32,
|
||||||
|
pub learning_rate: f32,
|
||||||
|
pub alpha: f32,
|
||||||
|
pub c: f32,
|
||||||
|
pub nmf: bool,
|
||||||
|
pub quiet: bool
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::with_model(std::ptr::null())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load(path: &str) -> Self {
|
||||||
|
let cpath = CString::new(path).expect("CString::new failed");
|
||||||
|
Self::with_model(unsafe { mf_load_model(cpath.as_ptr()) })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fit(&mut self, data: &Matrix) {
|
||||||
|
let prob = data.to_problem();
|
||||||
|
self.model = unsafe { mf_train(&prob, self.param()) };
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fit_eval(&mut self, train_set: &Matrix, eval_set: &Matrix) {
|
||||||
|
let tr = train_set.to_problem();
|
||||||
|
let va = eval_set.to_problem();
|
||||||
|
self.model = unsafe { mf_train_with_validation(&tr, &va, self.param()) };
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn cv(&mut self, data: &Matrix, folds: i32) {
|
||||||
|
let prob = data.to_problem();
|
||||||
|
unsafe { mf_cross_validation(&prob, folds, self.param()); }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn predict(&self, row_index: i32, column_index: i32) -> f32 {
|
||||||
|
assert!(self.is_fit());
|
||||||
|
unsafe { mf_predict(self.model, row_index, column_index) }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn save(&self, path: &str) {
|
||||||
|
assert!(self.is_fit());
|
||||||
|
let cpath = CString::new(path).expect("CString::new failed");
|
||||||
|
unsafe { mf_save_model(self.model, cpath.as_ptr()); }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rows(&self) -> i32 {
|
||||||
|
if self.is_fit() {
|
||||||
|
unsafe { (*self.model).m }
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn columns(&self) -> i32 {
|
||||||
|
if self.is_fit() {
|
||||||
|
unsafe { (*self.model).n }
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn factors(&self) -> i32 {
|
||||||
|
if self.is_fit() {
|
||||||
|
unsafe { (*self.model).k }
|
||||||
|
} else {
|
||||||
|
self.factors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn bias(&self) -> f32 {
|
||||||
|
if self.is_fit() {
|
||||||
|
unsafe { (*self.model).b }
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn p_factors(&self) -> &[f32] {
|
||||||
|
if self.is_fit() {
|
||||||
|
unsafe { std::slice::from_raw_parts((*self.model).p, (self.rows() * self.factors()) as usize) }
|
||||||
|
} else {
|
||||||
|
&[]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn q_factors(&self) -> &[f32] {
|
||||||
|
if self.is_fit() {
|
||||||
|
unsafe { std::slice::from_raw_parts((*self.model).q, (self.columns() * self.factors()) as usize) }
|
||||||
|
} else {
|
||||||
|
&[]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rmse(&self, data: &Matrix) -> f64 {
|
||||||
|
assert!(self.is_fit());
|
||||||
|
let prob = data.to_problem();
|
||||||
|
unsafe { calc_rmse(&prob, self.model) }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mae(&self, data: &Matrix) -> f64 {
|
||||||
|
assert!(self.is_fit());
|
||||||
|
let prob = data.to_problem();
|
||||||
|
unsafe { calc_mae(&prob, self.model) }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn with_model(model: *const MfModel) -> Self {
|
||||||
|
let param = unsafe { mf_get_default_param() };
|
||||||
|
Self {
|
||||||
|
model: model,
|
||||||
|
loss: param.fun,
|
||||||
|
factors: param.k,
|
||||||
|
threads: param.nr_threads,
|
||||||
|
bins: 25, // prevent warning
|
||||||
|
iterations: param.nr_iters,
|
||||||
|
lambda_p1: param.lambda_p1,
|
||||||
|
lambda_p2: param.lambda_p2,
|
||||||
|
lambda_q1: param.lambda_q1,
|
||||||
|
lambda_q2: param.lambda_q2,
|
||||||
|
learning_rate: param.eta,
|
||||||
|
alpha: param.alpha,
|
||||||
|
c: param.c,
|
||||||
|
nmf: param.do_nmf,
|
||||||
|
quiet: param.quiet
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn param(&self) -> MfParameter {
|
||||||
|
let mut param = unsafe { mf_get_default_param() };
|
||||||
|
param.fun = self.loss;
|
||||||
|
param.k = self.factors;
|
||||||
|
param.nr_threads = self.threads;
|
||||||
|
param.nr_bins = self.bins;
|
||||||
|
param.nr_iters = self.iterations;
|
||||||
|
param.lambda_p1 = self.lambda_p1;
|
||||||
|
param.lambda_p2 = self.lambda_p2;
|
||||||
|
param.lambda_q1 = self.lambda_q1;
|
||||||
|
param.lambda_q2 = self.lambda_q2;
|
||||||
|
param.eta = self.learning_rate;
|
||||||
|
param.alpha = self.alpha;
|
||||||
|
param.c = self.c;
|
||||||
|
param.do_nmf = self.nmf;
|
||||||
|
param.quiet = self.quiet;
|
||||||
|
param
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_fit(&self) -> bool {
|
||||||
|
!self.model.is_null()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::{Matrix, Model};
|
||||||
|
|
||||||
|
fn generate_data() -> Matrix {
|
||||||
|
let mut data = Matrix::new();
|
||||||
|
data.push(0, 0, 1.0);
|
||||||
|
data.push(1, 0, 2.0);
|
||||||
|
data.push(1, 1, 1.0);
|
||||||
|
data
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_fit() {
|
||||||
|
let data = generate_data();
|
||||||
|
let mut model = Model::new();
|
||||||
|
model.quiet = true;
|
||||||
|
model.fit(&data);
|
||||||
|
model.predict(0, 1);
|
||||||
|
|
||||||
|
model.p_factors();
|
||||||
|
model.q_factors();
|
||||||
|
model.bias();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_fit_eval() {
|
||||||
|
let data = generate_data();
|
||||||
|
let mut model = Model::new();
|
||||||
|
model.quiet = true;
|
||||||
|
model.fit_eval(&data, &data);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cv() {
|
||||||
|
let data = generate_data();
|
||||||
|
let mut model = Model::new();
|
||||||
|
model.quiet = true;
|
||||||
|
model.cv(&data, 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_save_load() {
|
||||||
|
let data = generate_data();
|
||||||
|
let mut model = Model::new();
|
||||||
|
model.quiet = true;
|
||||||
|
model.fit(&data);
|
||||||
|
|
||||||
|
model.save("/tmp/model.txt");
|
||||||
|
let model = Model::load("/tmp/model.txt");
|
||||||
|
|
||||||
|
model.p_factors();
|
||||||
|
model.q_factors();
|
||||||
|
model.bias();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_metrics() {
|
||||||
|
let data = generate_data();
|
||||||
|
let mut model = Model::new();
|
||||||
|
model.quiet = true;
|
||||||
|
model.fit(&data);
|
||||||
|
|
||||||
|
assert!(model.rmse(&data) < 0.15);
|
||||||
|
assert!(model.mae(&data) < 0.15);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_not_fit() {
|
||||||
|
let model = Model::new();
|
||||||
|
assert_eq!(0.0, model.bias());
|
||||||
|
assert!(model.p_factors().is_empty());
|
||||||
|
assert!(model.q_factors().is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "assertion failed: self.is_fit()")]
|
||||||
|
fn test_predict_not_fit() {
|
||||||
|
let model = Model::new();
|
||||||
|
model.predict(0, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "assertion failed: self.is_fit()")]
|
||||||
|
fn test_save_not_fit() {
|
||||||
|
let model = Model::new();
|
||||||
|
model.save("/tmp/model.txt");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_predict_out_of_range() {
|
||||||
|
let data = generate_data();
|
||||||
|
let mut model = Model::new();
|
||||||
|
model.quiet = true;
|
||||||
|
model.fit(&data);
|
||||||
|
assert_eq!(model.bias(), model.predict(1000, 1000));
|
||||||
|
}
|
||||||
|
}
|
||||||
1
vendor/libmf
vendored
Submodule
1
vendor/libmf
vendored
Submodule
Submodule vendor/libmf added at 6fa534dae7
Reference in New Issue
Block a user