Use enum for errors

This commit is contained in:
Andrew Kane
2021-10-17 11:53:51 -07:00
parent 527e8e8b2e
commit 00a924922d
3 changed files with 18 additions and 14 deletions

View File

@@ -1,2 +1,6 @@
#[derive(Debug)]
pub struct Error(pub(crate) String);
pub enum Error {
Io,
Parameter(String),
Unknown
}

View File

@@ -15,7 +15,7 @@ impl Model {
let cpath = CString::new(path).expect("CString::new failed");
let model = unsafe { mf_load_model(cpath.as_ptr()) };
if model.is_null() {
return Err(Error("Cannot open model".to_string()));
return Err(Error::Io);
}
Ok(Model { model })
}
@@ -28,7 +28,7 @@ impl Model {
let cpath = CString::new(path).expect("CString::new failed");
let status = unsafe { mf_save_model(self.model, cpath.as_ptr()) };
if status != 0 {
return Err(Error("Cannot save model".to_string()));
return Err(Error::Io);
}
Ok(())
}

View File

@@ -89,7 +89,7 @@ impl Params {
let param = self.build_param()?;
let model = unsafe { mf_train(&prob, param) };
if model.is_null() {
return Err(Error("fit failed".to_string()));
return Err(Error::Unknown);
}
Ok(Model { model })
}
@@ -100,7 +100,7 @@ impl Params {
let param = self.build_param()?;
let model = unsafe { mf_train_with_validation(&tr, &va, param) };
if model.is_null() {
return Err(Error("fit_eval failed".to_string()));
return Err(Error::Unknown);
}
Ok(Model { model })
}
@@ -111,7 +111,7 @@ impl Params {
let avg_error = unsafe { mf_cross_validation(&prob, folds, param) };
// TODO update fork to differentiate between bad parameters and zero error
if avg_error == 0.0 {
return Err(Error("cv failed".to_string()));
return Err(Error::Unknown);
}
Ok(avg_error)
}
@@ -121,35 +121,35 @@ impl Params {
let param = self.param;
if param.k < 1 {
return Err(Error("number of factors must be greater than zero".to_string()));
return Err(Error::Parameter("number of factors must be greater than zero".to_string()));
}
if param.nr_threads < 1 {
return Err(Error("number of threads must be greater than zero".to_string()));
return Err(Error::Parameter("number of threads must be greater than zero".to_string()));
}
if param.nr_bins < 1 || param.nr_bins < param.nr_threads {
return Err(Error("number of bins must be greater than number of threads".to_string()));
return Err(Error::Parameter("number of bins must be greater than number of threads".to_string()));
}
if param.nr_iters < 1 {
return Err(Error("number of iterations must be greater than zero".to_string()));
return Err(Error::Parameter("number of iterations must be greater than zero".to_string()));
}
if param.lambda_p1 < 0.0 || param.lambda_p2 < 0.0 || param.lambda_q1 < 0.0 || param.lambda_q2 < 0.0 {
return Err(Error("regularization coefficient must be non-negative".to_string()));
return Err(Error::Parameter("regularization coefficient must be non-negative".to_string()));
}
if param.eta <= 0.0 {
return Err(Error("learning rate must be greater than zero".to_string()));
return Err(Error::Parameter("learning rate must be greater than zero".to_string()));
}
if matches!(param.fun, Loss::RealKL) && !param.do_nmf {
return Err(Error("nmf must be set when using generalized KL-divergence".to_string()));
return Err(Error::Parameter("nmf must be set when using generalized KL-divergence".to_string()));
}
if param.alpha < 0.0 {
return Err(Error("alpha must be a non-negative number".to_string()));
return Err(Error::Parameter("alpha must be a non-negative number".to_string()));
}
Ok(param)