From 00a924922df0aa84ffff91d5a1dc568203b71cd8 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sun, 17 Oct 2021 11:53:51 -0700 Subject: [PATCH] Use enum for errors --- src/error.rs | 6 +++++- src/model.rs | 4 ++-- src/params.rs | 22 +++++++++++----------- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/error.rs b/src/error.rs index acd52f0..5a795bc 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,2 +1,6 @@ #[derive(Debug)] -pub struct Error(pub(crate) String); +pub enum Error { + Io, + Parameter(String), + Unknown +} diff --git a/src/model.rs b/src/model.rs index 6f4349a..f8f3ead 100644 --- a/src/model.rs +++ b/src/model.rs @@ -15,7 +15,7 @@ impl Model { let cpath = CString::new(path).expect("CString::new failed"); let model = unsafe { mf_load_model(cpath.as_ptr()) }; if model.is_null() { - return Err(Error("Cannot open model".to_string())); + return Err(Error::Io); } Ok(Model { model }) } @@ -28,7 +28,7 @@ impl Model { let cpath = CString::new(path).expect("CString::new failed"); let status = unsafe { mf_save_model(self.model, cpath.as_ptr()) }; if status != 0 { - return Err(Error("Cannot save model".to_string())); + return Err(Error::Io); } Ok(()) } diff --git a/src/params.rs b/src/params.rs index 8325ddc..809fa34 100644 --- a/src/params.rs +++ b/src/params.rs @@ -89,7 +89,7 @@ impl Params { let param = self.build_param()?; let model = unsafe { mf_train(&prob, param) }; if model.is_null() { - return Err(Error("fit failed".to_string())); + return Err(Error::Unknown); } Ok(Model { model }) } @@ -100,7 +100,7 @@ impl Params { let param = self.build_param()?; let model = unsafe { mf_train_with_validation(&tr, &va, param) }; if model.is_null() { - return Err(Error("fit_eval failed".to_string())); + return Err(Error::Unknown); } Ok(Model { model }) } @@ -111,7 +111,7 @@ impl Params { let avg_error = unsafe { mf_cross_validation(&prob, folds, param) }; // TODO update fork to differentiate between bad parameters and zero error if avg_error == 0.0 { - return Err(Error("cv failed".to_string())); + return Err(Error::Unknown); } Ok(avg_error) } @@ -121,35 +121,35 @@ impl Params { let param = self.param; if param.k < 1 { - return Err(Error("number of factors must be greater than zero".to_string())); + return Err(Error::Parameter("number of factors must be greater than zero".to_string())); } if param.nr_threads < 1 { - return Err(Error("number of threads must be greater than zero".to_string())); + return Err(Error::Parameter("number of threads must be greater than zero".to_string())); } if param.nr_bins < 1 || param.nr_bins < param.nr_threads { - return Err(Error("number of bins must be greater than number of threads".to_string())); + return Err(Error::Parameter("number of bins must be greater than number of threads".to_string())); } if param.nr_iters < 1 { - return Err(Error("number of iterations must be greater than zero".to_string())); + return Err(Error::Parameter("number of iterations must be greater than zero".to_string())); } if param.lambda_p1 < 0.0 || param.lambda_p2 < 0.0 || param.lambda_q1 < 0.0 || param.lambda_q2 < 0.0 { - return Err(Error("regularization coefficient must be non-negative".to_string())); + return Err(Error::Parameter("regularization coefficient must be non-negative".to_string())); } if param.eta <= 0.0 { - return Err(Error("learning rate must be greater than zero".to_string())); + return Err(Error::Parameter("learning rate must be greater than zero".to_string())); } if matches!(param.fun, Loss::RealKL) && !param.do_nmf { - return Err(Error("nmf must be set when using generalized KL-divergence".to_string())); + return Err(Error::Parameter("nmf must be set when using generalized KL-divergence".to_string())); } if param.alpha < 0.0 { - return Err(Error("alpha must be a non-negative number".to_string())); + return Err(Error::Parameter("alpha must be a non-negative number".to_string())); } Ok(param)