Use enum for errors

This commit is contained in:
Andrew Kane
2021-10-17 11:53:51 -07:00
parent 527e8e8b2e
commit 00a924922d
3 changed files with 18 additions and 14 deletions

View File

@@ -1,2 +1,6 @@
#[derive(Debug)] #[derive(Debug)]
pub struct Error(pub(crate) String); pub enum Error {
Io,
Parameter(String),
Unknown
}

View File

@@ -15,7 +15,7 @@ impl Model {
let cpath = CString::new(path).expect("CString::new failed"); let cpath = CString::new(path).expect("CString::new failed");
let model = unsafe { mf_load_model(cpath.as_ptr()) }; let model = unsafe { mf_load_model(cpath.as_ptr()) };
if model.is_null() { if model.is_null() {
return Err(Error("Cannot open model".to_string())); return Err(Error::Io);
} }
Ok(Model { model }) Ok(Model { model })
} }
@@ -28,7 +28,7 @@ impl Model {
let cpath = CString::new(path).expect("CString::new failed"); let cpath = CString::new(path).expect("CString::new failed");
let status = unsafe { mf_save_model(self.model, cpath.as_ptr()) }; let status = unsafe { mf_save_model(self.model, cpath.as_ptr()) };
if status != 0 { if status != 0 {
return Err(Error("Cannot save model".to_string())); return Err(Error::Io);
} }
Ok(()) Ok(())
} }

View File

@@ -89,7 +89,7 @@ impl Params {
let param = self.build_param()?; let param = self.build_param()?;
let model = unsafe { mf_train(&prob, param) }; let model = unsafe { mf_train(&prob, param) };
if model.is_null() { if model.is_null() {
return Err(Error("fit failed".to_string())); return Err(Error::Unknown);
} }
Ok(Model { model }) Ok(Model { model })
} }
@@ -100,7 +100,7 @@ impl Params {
let param = self.build_param()?; let param = self.build_param()?;
let model = unsafe { mf_train_with_validation(&tr, &va, param) }; let model = unsafe { mf_train_with_validation(&tr, &va, param) };
if model.is_null() { if model.is_null() {
return Err(Error("fit_eval failed".to_string())); return Err(Error::Unknown);
} }
Ok(Model { model }) Ok(Model { model })
} }
@@ -111,7 +111,7 @@ impl Params {
let avg_error = unsafe { mf_cross_validation(&prob, folds, param) }; let avg_error = unsafe { mf_cross_validation(&prob, folds, param) };
// TODO update fork to differentiate between bad parameters and zero error // TODO update fork to differentiate between bad parameters and zero error
if avg_error == 0.0 { if avg_error == 0.0 {
return Err(Error("cv failed".to_string())); return Err(Error::Unknown);
} }
Ok(avg_error) Ok(avg_error)
} }
@@ -121,35 +121,35 @@ impl Params {
let param = self.param; let param = self.param;
if param.k < 1 { if param.k < 1 {
return Err(Error("number of factors must be greater than zero".to_string())); return Err(Error::Parameter("number of factors must be greater than zero".to_string()));
} }
if param.nr_threads < 1 { if param.nr_threads < 1 {
return Err(Error("number of threads must be greater than zero".to_string())); return Err(Error::Parameter("number of threads must be greater than zero".to_string()));
} }
if param.nr_bins < 1 || param.nr_bins < param.nr_threads { if param.nr_bins < 1 || param.nr_bins < param.nr_threads {
return Err(Error("number of bins must be greater than number of threads".to_string())); return Err(Error::Parameter("number of bins must be greater than number of threads".to_string()));
} }
if param.nr_iters < 1 { if param.nr_iters < 1 {
return Err(Error("number of iterations must be greater than zero".to_string())); return Err(Error::Parameter("number of iterations must be greater than zero".to_string()));
} }
if param.lambda_p1 < 0.0 || param.lambda_p2 < 0.0 || param.lambda_q1 < 0.0 || param.lambda_q2 < 0.0 { if param.lambda_p1 < 0.0 || param.lambda_p2 < 0.0 || param.lambda_q1 < 0.0 || param.lambda_q2 < 0.0 {
return Err(Error("regularization coefficient must be non-negative".to_string())); return Err(Error::Parameter("regularization coefficient must be non-negative".to_string()));
} }
if param.eta <= 0.0 { if param.eta <= 0.0 {
return Err(Error("learning rate must be greater than zero".to_string())); return Err(Error::Parameter("learning rate must be greater than zero".to_string()));
} }
if matches!(param.fun, Loss::RealKL) && !param.do_nmf { if matches!(param.fun, Loss::RealKL) && !param.do_nmf {
return Err(Error("nmf must be set when using generalized KL-divergence".to_string())); return Err(Error::Parameter("nmf must be set when using generalized KL-divergence".to_string()));
} }
if param.alpha < 0.0 { if param.alpha < 0.0 {
return Err(Error("alpha must be a non-negative number".to_string())); return Err(Error::Parameter("alpha must be a non-negative number".to_string()));
} }
Ok(param) Ok(param)