98 lines
2.7 KiB
Rust
98 lines
2.7 KiB
Rust
use std::ffi::{c_char, c_double, c_float, c_int, c_longlong};
|
|
|
|
#[repr(C)]
|
|
pub struct MfNode {
|
|
pub u: c_int,
|
|
pub v: c_int,
|
|
pub r: c_float,
|
|
}
|
|
|
|
#[repr(C)]
|
|
pub struct MfProblem {
|
|
pub m: c_int,
|
|
pub n: c_int,
|
|
pub nnz: c_longlong,
|
|
pub r: *const MfNode,
|
|
}
|
|
|
|
#[repr(C)]
|
|
#[derive(Clone, Copy)]
|
|
pub struct MfParameter {
|
|
pub fun: Loss,
|
|
pub k: c_int,
|
|
pub nr_threads: c_int,
|
|
pub nr_bins: c_int,
|
|
pub nr_iters: c_int,
|
|
pub lambda_p1: c_float,
|
|
pub lambda_p2: c_float,
|
|
pub lambda_q1: c_float,
|
|
pub lambda_q2: c_float,
|
|
pub eta: c_float,
|
|
pub alpha: c_float,
|
|
pub c: c_float,
|
|
pub do_nmf: bool,
|
|
pub quiet: bool,
|
|
pub copy_data: bool,
|
|
}
|
|
|
|
#[repr(C)]
|
|
pub struct MfModel {
|
|
pub fun: Loss,
|
|
pub m: c_int,
|
|
pub n: c_int,
|
|
pub k: c_int,
|
|
pub b: c_float,
|
|
pub p: *const c_float,
|
|
pub q: *const c_float,
|
|
}
|
|
|
|
/// Loss functions.
|
|
#[repr(C)]
|
|
#[derive(Clone, Copy)]
|
|
pub enum Loss {
|
|
/// Squared error (L2-norm).
|
|
RealL2 = 0,
|
|
/// Absolute error (L1-norm).
|
|
RealL1 = 1,
|
|
/// Generalized KL-divergence.
|
|
RealKL = 2,
|
|
/// Logarithmic error.
|
|
BinaryLog = 5,
|
|
/// Squared hinge loss.
|
|
BinaryL2 = 6,
|
|
/// Hinge loss.
|
|
BinaryL1 = 7,
|
|
/// Row-oriented pair-wise logarithmic loss.
|
|
OneClassRow = 10,
|
|
/// Column-oriented pair-wise logarithmic loss.
|
|
OneClassCol = 11,
|
|
/// Squared error (L2-norm).
|
|
OneClassL2 = 12,
|
|
}
|
|
|
|
extern "C" {
|
|
pub fn mf_get_default_param() -> MfParameter;
|
|
pub fn mf_save_model(model: *const MfModel, path: *const c_char) -> c_int;
|
|
pub fn mf_load_model(path: *const c_char) -> *mut MfModel;
|
|
pub fn mf_destroy_model(model: *mut *mut MfModel);
|
|
pub fn mf_train(prob: *const MfProblem, param: MfParameter) -> *mut MfModel;
|
|
pub fn mf_train_with_validation(
|
|
tr: *const MfProblem,
|
|
va: *const MfProblem,
|
|
param: MfParameter,
|
|
) -> *mut MfModel;
|
|
pub fn mf_cross_validation(
|
|
prob: *const MfProblem,
|
|
nr_folds: c_int,
|
|
param: MfParameter,
|
|
) -> c_double;
|
|
pub fn mf_predict(model: *const MfModel, u: c_int, v: c_int) -> c_float;
|
|
pub fn calc_rmse(prob: *const MfProblem, model: *const MfModel) -> c_double;
|
|
pub fn calc_mae(prob: *const MfProblem, model: *const MfModel) -> c_double;
|
|
pub fn calc_gkl(prob: *const MfProblem, model: *const MfModel) -> c_double;
|
|
pub fn calc_logloss(prob: *const MfProblem, model: *const MfModel) -> c_double;
|
|
pub fn calc_accuracy(prob: *const MfProblem, model: *const MfModel) -> c_double;
|
|
pub fn calc_mpr(prob: *const MfProblem, model: *const MfModel, transpose: bool) -> c_double;
|
|
pub fn calc_auc(prob: *const MfProblem, model: *const MfModel, transpose: bool) -> c_double;
|
|
}
|