Ran cargo fmt [skip ci]
This commit is contained in:
@@ -76,8 +76,16 @@ extern "C" {
|
||||
pub fn mf_load_model(path: *const c_char) -> *mut MfModel;
|
||||
pub fn mf_destroy_model(model: *mut *mut MfModel);
|
||||
pub fn mf_train(prob: *const MfProblem, param: MfParameter) -> *mut MfModel;
|
||||
pub fn mf_train_with_validation(tr: *const MfProblem, va: *const MfProblem, param: MfParameter) -> *mut MfModel;
|
||||
pub fn mf_cross_validation(prob: *const MfProblem, nr_folds: c_int, param: MfParameter) -> c_double;
|
||||
pub fn mf_train_with_validation(
|
||||
tr: *const MfProblem,
|
||||
va: *const MfProblem,
|
||||
param: MfParameter,
|
||||
) -> *mut MfModel;
|
||||
pub fn mf_cross_validation(
|
||||
prob: *const MfProblem,
|
||||
nr_folds: c_int,
|
||||
param: MfParameter,
|
||||
) -> c_double;
|
||||
pub fn mf_predict(model: *const MfModel, u: c_int, v: c_int) -> c_float;
|
||||
pub fn calc_rmse(prob: *const MfProblem, model: *const MfModel) -> c_double;
|
||||
pub fn calc_mae(prob: *const MfProblem, model: *const MfModel) -> c_double;
|
||||
|
||||
21
src/model.rs
21
src/model.rs
@@ -65,12 +65,16 @@ impl Model {
|
||||
|
||||
/// Returns the latent factors for rows.
|
||||
pub fn p_factors(&self) -> &[f32] {
|
||||
unsafe { std::slice::from_raw_parts((*self.model).p, (self.rows() * self.factors()) as usize) }
|
||||
unsafe {
|
||||
std::slice::from_raw_parts((*self.model).p, (self.rows() * self.factors()) as usize)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the latent factors for columns.
|
||||
pub fn q_factors(&self) -> &[f32] {
|
||||
unsafe { std::slice::from_raw_parts((*self.model).q, (self.columns() * self.factors()) as usize) }
|
||||
unsafe {
|
||||
std::slice::from_raw_parts((*self.model).q, (self.columns() * self.factors()) as usize)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the latent factors for a row.
|
||||
@@ -226,14 +230,23 @@ mod tests {
|
||||
#[test]
|
||||
fn test_loss() {
|
||||
let data = generate_data();
|
||||
let model = Model::params().loss(Loss::OneClassL2).quiet(true).fit(&data).unwrap();
|
||||
let model = Model::params()
|
||||
.loss(Loss::OneClassL2)
|
||||
.quiet(true)
|
||||
.fit(&data)
|
||||
.unwrap();
|
||||
assert_eq!(model.bias(), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_loss_real_kl() {
|
||||
let data = generate_data();
|
||||
assert!(Model::params().loss(Loss::RealKL).nmf(true).quiet(true).fit(&data).is_ok());
|
||||
assert!(Model::params()
|
||||
.loss(Loss::RealKL)
|
||||
.nmf(true)
|
||||
.quiet(true)
|
||||
.fit(&data)
|
||||
.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -137,35 +137,55 @@ impl Params {
|
||||
let param = self.param;
|
||||
|
||||
if param.k < 1 {
|
||||
return Err(Error::Parameter("number of factors must be greater than zero".to_string()));
|
||||
return Err(Error::Parameter(
|
||||
"number of factors must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if param.nr_threads < 1 {
|
||||
return Err(Error::Parameter("number of threads must be greater than zero".to_string()));
|
||||
return Err(Error::Parameter(
|
||||
"number of threads must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if param.nr_bins < 1 || param.nr_bins < param.nr_threads {
|
||||
return Err(Error::Parameter("number of bins must be greater than number of threads".to_string()));
|
||||
return Err(Error::Parameter(
|
||||
"number of bins must be greater than number of threads".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if param.nr_iters < 1 {
|
||||
return Err(Error::Parameter("number of iterations must be greater than zero".to_string()));
|
||||
return Err(Error::Parameter(
|
||||
"number of iterations must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if param.lambda_p1 < 0.0 || param.lambda_p2 < 0.0 || param.lambda_q1 < 0.0 || param.lambda_q2 < 0.0 {
|
||||
return Err(Error::Parameter("regularization coefficient must be non-negative".to_string()));
|
||||
if param.lambda_p1 < 0.0
|
||||
|| param.lambda_p2 < 0.0
|
||||
|| param.lambda_q1 < 0.0
|
||||
|| param.lambda_q2 < 0.0
|
||||
{
|
||||
return Err(Error::Parameter(
|
||||
"regularization coefficient must be non-negative".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if param.eta <= 0.0 {
|
||||
return Err(Error::Parameter("learning rate must be greater than zero".to_string()));
|
||||
return Err(Error::Parameter(
|
||||
"learning rate must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if matches!(param.fun, Loss::RealKL) && !param.do_nmf {
|
||||
return Err(Error::Parameter("nmf must be set when using generalized KL-divergence".to_string()));
|
||||
return Err(Error::Parameter(
|
||||
"nmf must be set when using generalized KL-divergence".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if param.alpha < 0.0 {
|
||||
return Err(Error::Parameter("alpha must be a non-negative number".to_string()));
|
||||
return Err(Error::Parameter(
|
||||
"alpha must be a non-negative number".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(param)
|
||||
|
||||
Reference in New Issue
Block a user