Files
libmf-rust/src/bindings.rs
2024-10-05 10:17:43 -07:00

98 lines
2.7 KiB
Rust

use std::ffi::{c_char, c_double, c_float, c_int, c_longlong};
#[repr(C)]
pub struct MfNode {
pub u: c_int,
pub v: c_int,
pub r: c_float,
}
#[repr(C)]
pub struct MfProblem {
pub m: c_int,
pub n: c_int,
pub nnz: c_longlong,
pub r: *const MfNode,
}
#[repr(C)]
#[derive(Clone, Copy)]
pub struct MfParameter {
pub fun: Loss,
pub k: c_int,
pub nr_threads: c_int,
pub nr_bins: c_int,
pub nr_iters: c_int,
pub lambda_p1: c_float,
pub lambda_p2: c_float,
pub lambda_q1: c_float,
pub lambda_q2: c_float,
pub eta: c_float,
pub alpha: c_float,
pub c: c_float,
pub do_nmf: bool,
pub quiet: bool,
pub copy_data: bool,
}
#[repr(C)]
pub struct MfModel {
pub fun: Loss,
pub m: c_int,
pub n: c_int,
pub k: c_int,
pub b: c_float,
pub p: *const c_float,
pub q: *const c_float,
}
/// Loss functions.
#[repr(C)]
#[derive(Clone, Copy)]
pub enum Loss {
/// Squared error (L2-norm).
RealL2 = 0,
/// Absolute error (L1-norm).
RealL1 = 1,
/// Generalized KL-divergence.
RealKL = 2,
/// Logarithmic error.
BinaryLog = 5,
/// Squared hinge loss.
BinaryL2 = 6,
/// Hinge loss.
BinaryL1 = 7,
/// Row-oriented pair-wise logarithmic loss.
OneClassRow = 10,
/// Column-oriented pair-wise logarithmic loss.
OneClassCol = 11,
/// Squared error (L2-norm).
OneClassL2 = 12,
}
extern "C" {
pub fn mf_get_default_param() -> MfParameter;
pub fn mf_save_model(model: *const MfModel, path: *const c_char) -> c_int;
pub fn mf_load_model(path: *const c_char) -> *mut MfModel;
pub fn mf_destroy_model(model: *mut *mut MfModel);
pub fn mf_train(prob: *const MfProblem, param: MfParameter) -> *mut MfModel;
pub fn mf_train_with_validation(
tr: *const MfProblem,
va: *const MfProblem,
param: MfParameter,
) -> *mut MfModel;
pub fn mf_cross_validation(
prob: *const MfProblem,
nr_folds: c_int,
param: MfParameter,
) -> c_double;
pub fn mf_predict(model: *const MfModel, u: c_int, v: c_int) -> c_float;
pub fn calc_rmse(prob: *const MfProblem, model: *const MfModel) -> c_double;
pub fn calc_mae(prob: *const MfProblem, model: *const MfModel) -> c_double;
pub fn calc_gkl(prob: *const MfProblem, model: *const MfModel) -> c_double;
pub fn calc_logloss(prob: *const MfProblem, model: *const MfModel) -> c_double;
pub fn calc_accuracy(prob: *const MfProblem, model: *const MfModel) -> c_double;
pub fn calc_mpr(prob: *const MfProblem, model: *const MfModel, transpose: bool) -> c_double;
pub fn calc_auc(prob: *const MfProblem, model: *const MfModel, transpose: bool) -> c_double;
}