Ran cargo fmt [skip ci]

This commit is contained in:
Andrew Kane
2024-10-05 10:17:43 -07:00
parent bf30f1336c
commit c79091a2fd
3 changed files with 56 additions and 15 deletions

View File

@@ -76,8 +76,16 @@ extern "C" {
pub fn mf_load_model(path: *const c_char) -> *mut MfModel;
pub fn mf_destroy_model(model: *mut *mut MfModel);
pub fn mf_train(prob: *const MfProblem, param: MfParameter) -> *mut MfModel;
pub fn mf_train_with_validation(tr: *const MfProblem, va: *const MfProblem, param: MfParameter) -> *mut MfModel;
pub fn mf_cross_validation(prob: *const MfProblem, nr_folds: c_int, param: MfParameter) -> c_double;
pub fn mf_train_with_validation(
tr: *const MfProblem,
va: *const MfProblem,
param: MfParameter,
) -> *mut MfModel;
pub fn mf_cross_validation(
prob: *const MfProblem,
nr_folds: c_int,
param: MfParameter,
) -> c_double;
pub fn mf_predict(model: *const MfModel, u: c_int, v: c_int) -> c_float;
pub fn calc_rmse(prob: *const MfProblem, model: *const MfModel) -> c_double;
pub fn calc_mae(prob: *const MfProblem, model: *const MfModel) -> c_double;

View File

@@ -65,12 +65,16 @@ impl Model {
/// Returns the latent factors for rows.
pub fn p_factors(&self) -> &[f32] {
unsafe { std::slice::from_raw_parts((*self.model).p, (self.rows() * self.factors()) as usize) }
unsafe {
std::slice::from_raw_parts((*self.model).p, (self.rows() * self.factors()) as usize)
}
}
/// Returns the latent factors for columns.
pub fn q_factors(&self) -> &[f32] {
unsafe { std::slice::from_raw_parts((*self.model).q, (self.columns() * self.factors()) as usize) }
unsafe {
std::slice::from_raw_parts((*self.model).q, (self.columns() * self.factors()) as usize)
}
}
/// Returns the latent factors for a row.
@@ -226,14 +230,23 @@ mod tests {
#[test]
fn test_loss() {
let data = generate_data();
let model = Model::params().loss(Loss::OneClassL2).quiet(true).fit(&data).unwrap();
let model = Model::params()
.loss(Loss::OneClassL2)
.quiet(true)
.fit(&data)
.unwrap();
assert_eq!(model.bias(), 0.0);
}
#[test]
fn test_loss_real_kl() {
let data = generate_data();
assert!(Model::params().loss(Loss::RealKL).nmf(true).quiet(true).fit(&data).is_ok());
assert!(Model::params()
.loss(Loss::RealKL)
.nmf(true)
.quiet(true)
.fit(&data)
.is_ok());
}
#[test]

View File

@@ -137,35 +137,55 @@ impl Params {
let param = self.param;
if param.k < 1 {
return Err(Error::Parameter("number of factors must be greater than zero".to_string()));
return Err(Error::Parameter(
"number of factors must be greater than zero".to_string(),
));
}
if param.nr_threads < 1 {
return Err(Error::Parameter("number of threads must be greater than zero".to_string()));
return Err(Error::Parameter(
"number of threads must be greater than zero".to_string(),
));
}
if param.nr_bins < 1 || param.nr_bins < param.nr_threads {
return Err(Error::Parameter("number of bins must be greater than number of threads".to_string()));
return Err(Error::Parameter(
"number of bins must be greater than number of threads".to_string(),
));
}
if param.nr_iters < 1 {
return Err(Error::Parameter("number of iterations must be greater than zero".to_string()));
return Err(Error::Parameter(
"number of iterations must be greater than zero".to_string(),
));
}
if param.lambda_p1 < 0.0 || param.lambda_p2 < 0.0 || param.lambda_q1 < 0.0 || param.lambda_q2 < 0.0 {
return Err(Error::Parameter("regularization coefficient must be non-negative".to_string()));
if param.lambda_p1 < 0.0
|| param.lambda_p2 < 0.0
|| param.lambda_q1 < 0.0
|| param.lambda_q2 < 0.0
{
return Err(Error::Parameter(
"regularization coefficient must be non-negative".to_string(),
));
}
if param.eta <= 0.0 {
return Err(Error::Parameter("learning rate must be greater than zero".to_string()));
return Err(Error::Parameter(
"learning rate must be greater than zero".to_string(),
));
}
if matches!(param.fun, Loss::RealKL) && !param.do_nmf {
return Err(Error::Parameter("nmf must be set when using generalized KL-divergence".to_string()));
return Err(Error::Parameter(
"nmf must be set when using generalized KL-divergence".to_string(),
));
}
if param.alpha < 0.0 {
return Err(Error::Parameter("alpha must be a non-negative number".to_string()));
return Err(Error::Parameter(
"alpha must be a non-negative number".to_string(),
));
}
Ok(param)