Ran cargo fmt [skip ci]
This commit is contained in:
@@ -76,8 +76,16 @@ extern "C" {
|
|||||||
pub fn mf_load_model(path: *const c_char) -> *mut MfModel;
|
pub fn mf_load_model(path: *const c_char) -> *mut MfModel;
|
||||||
pub fn mf_destroy_model(model: *mut *mut MfModel);
|
pub fn mf_destroy_model(model: *mut *mut MfModel);
|
||||||
pub fn mf_train(prob: *const MfProblem, param: MfParameter) -> *mut MfModel;
|
pub fn mf_train(prob: *const MfProblem, param: MfParameter) -> *mut MfModel;
|
||||||
pub fn mf_train_with_validation(tr: *const MfProblem, va: *const MfProblem, param: MfParameter) -> *mut MfModel;
|
pub fn mf_train_with_validation(
|
||||||
pub fn mf_cross_validation(prob: *const MfProblem, nr_folds: c_int, param: MfParameter) -> c_double;
|
tr: *const MfProblem,
|
||||||
|
va: *const MfProblem,
|
||||||
|
param: MfParameter,
|
||||||
|
) -> *mut MfModel;
|
||||||
|
pub fn mf_cross_validation(
|
||||||
|
prob: *const MfProblem,
|
||||||
|
nr_folds: c_int,
|
||||||
|
param: MfParameter,
|
||||||
|
) -> c_double;
|
||||||
pub fn mf_predict(model: *const MfModel, u: c_int, v: c_int) -> c_float;
|
pub fn mf_predict(model: *const MfModel, u: c_int, v: c_int) -> c_float;
|
||||||
pub fn calc_rmse(prob: *const MfProblem, model: *const MfModel) -> c_double;
|
pub fn calc_rmse(prob: *const MfProblem, model: *const MfModel) -> c_double;
|
||||||
pub fn calc_mae(prob: *const MfProblem, model: *const MfModel) -> c_double;
|
pub fn calc_mae(prob: *const MfProblem, model: *const MfModel) -> c_double;
|
||||||
|
|||||||
21
src/model.rs
21
src/model.rs
@@ -65,12 +65,16 @@ impl Model {
|
|||||||
|
|
||||||
/// Returns the latent factors for rows.
|
/// Returns the latent factors for rows.
|
||||||
pub fn p_factors(&self) -> &[f32] {
|
pub fn p_factors(&self) -> &[f32] {
|
||||||
unsafe { std::slice::from_raw_parts((*self.model).p, (self.rows() * self.factors()) as usize) }
|
unsafe {
|
||||||
|
std::slice::from_raw_parts((*self.model).p, (self.rows() * self.factors()) as usize)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the latent factors for columns.
|
/// Returns the latent factors for columns.
|
||||||
pub fn q_factors(&self) -> &[f32] {
|
pub fn q_factors(&self) -> &[f32] {
|
||||||
unsafe { std::slice::from_raw_parts((*self.model).q, (self.columns() * self.factors()) as usize) }
|
unsafe {
|
||||||
|
std::slice::from_raw_parts((*self.model).q, (self.columns() * self.factors()) as usize)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the latent factors for a row.
|
/// Returns the latent factors for a row.
|
||||||
@@ -226,14 +230,23 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_loss() {
|
fn test_loss() {
|
||||||
let data = generate_data();
|
let data = generate_data();
|
||||||
let model = Model::params().loss(Loss::OneClassL2).quiet(true).fit(&data).unwrap();
|
let model = Model::params()
|
||||||
|
.loss(Loss::OneClassL2)
|
||||||
|
.quiet(true)
|
||||||
|
.fit(&data)
|
||||||
|
.unwrap();
|
||||||
assert_eq!(model.bias(), 0.0);
|
assert_eq!(model.bias(), 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_loss_real_kl() {
|
fn test_loss_real_kl() {
|
||||||
let data = generate_data();
|
let data = generate_data();
|
||||||
assert!(Model::params().loss(Loss::RealKL).nmf(true).quiet(true).fit(&data).is_ok());
|
assert!(Model::params()
|
||||||
|
.loss(Loss::RealKL)
|
||||||
|
.nmf(true)
|
||||||
|
.quiet(true)
|
||||||
|
.fit(&data)
|
||||||
|
.is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -137,35 +137,55 @@ impl Params {
|
|||||||
let param = self.param;
|
let param = self.param;
|
||||||
|
|
||||||
if param.k < 1 {
|
if param.k < 1 {
|
||||||
return Err(Error::Parameter("number of factors must be greater than zero".to_string()));
|
return Err(Error::Parameter(
|
||||||
|
"number of factors must be greater than zero".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if param.nr_threads < 1 {
|
if param.nr_threads < 1 {
|
||||||
return Err(Error::Parameter("number of threads must be greater than zero".to_string()));
|
return Err(Error::Parameter(
|
||||||
|
"number of threads must be greater than zero".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if param.nr_bins < 1 || param.nr_bins < param.nr_threads {
|
if param.nr_bins < 1 || param.nr_bins < param.nr_threads {
|
||||||
return Err(Error::Parameter("number of bins must be greater than number of threads".to_string()));
|
return Err(Error::Parameter(
|
||||||
|
"number of bins must be greater than number of threads".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if param.nr_iters < 1 {
|
if param.nr_iters < 1 {
|
||||||
return Err(Error::Parameter("number of iterations must be greater than zero".to_string()));
|
return Err(Error::Parameter(
|
||||||
|
"number of iterations must be greater than zero".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if param.lambda_p1 < 0.0 || param.lambda_p2 < 0.0 || param.lambda_q1 < 0.0 || param.lambda_q2 < 0.0 {
|
if param.lambda_p1 < 0.0
|
||||||
return Err(Error::Parameter("regularization coefficient must be non-negative".to_string()));
|
|| param.lambda_p2 < 0.0
|
||||||
|
|| param.lambda_q1 < 0.0
|
||||||
|
|| param.lambda_q2 < 0.0
|
||||||
|
{
|
||||||
|
return Err(Error::Parameter(
|
||||||
|
"regularization coefficient must be non-negative".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if param.eta <= 0.0 {
|
if param.eta <= 0.0 {
|
||||||
return Err(Error::Parameter("learning rate must be greater than zero".to_string()));
|
return Err(Error::Parameter(
|
||||||
|
"learning rate must be greater than zero".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if matches!(param.fun, Loss::RealKL) && !param.do_nmf {
|
if matches!(param.fun, Loss::RealKL) && !param.do_nmf {
|
||||||
return Err(Error::Parameter("nmf must be set when using generalized KL-divergence".to_string()));
|
return Err(Error::Parameter(
|
||||||
|
"nmf must be set when using generalized KL-divergence".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if param.alpha < 0.0 {
|
if param.alpha < 0.0 {
|
||||||
return Err(Error::Parameter("alpha must be a non-negative number".to_string()));
|
return Err(Error::Parameter(
|
||||||
|
"alpha must be a non-negative number".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(param)
|
Ok(param)
|
||||||
|
|||||||
Reference in New Issue
Block a user