Use enum for errors
This commit is contained in:
@@ -1,2 +1,6 @@
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Error(pub(crate) String);
|
pub enum Error {
|
||||||
|
Io,
|
||||||
|
Parameter(String),
|
||||||
|
Unknown
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ impl Model {
|
|||||||
let cpath = CString::new(path).expect("CString::new failed");
|
let cpath = CString::new(path).expect("CString::new failed");
|
||||||
let model = unsafe { mf_load_model(cpath.as_ptr()) };
|
let model = unsafe { mf_load_model(cpath.as_ptr()) };
|
||||||
if model.is_null() {
|
if model.is_null() {
|
||||||
return Err(Error("Cannot open model".to_string()));
|
return Err(Error::Io);
|
||||||
}
|
}
|
||||||
Ok(Model { model })
|
Ok(Model { model })
|
||||||
}
|
}
|
||||||
@@ -28,7 +28,7 @@ impl Model {
|
|||||||
let cpath = CString::new(path).expect("CString::new failed");
|
let cpath = CString::new(path).expect("CString::new failed");
|
||||||
let status = unsafe { mf_save_model(self.model, cpath.as_ptr()) };
|
let status = unsafe { mf_save_model(self.model, cpath.as_ptr()) };
|
||||||
if status != 0 {
|
if status != 0 {
|
||||||
return Err(Error("Cannot save model".to_string()));
|
return Err(Error::Io);
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ impl Params {
|
|||||||
let param = self.build_param()?;
|
let param = self.build_param()?;
|
||||||
let model = unsafe { mf_train(&prob, param) };
|
let model = unsafe { mf_train(&prob, param) };
|
||||||
if model.is_null() {
|
if model.is_null() {
|
||||||
return Err(Error("fit failed".to_string()));
|
return Err(Error::Unknown);
|
||||||
}
|
}
|
||||||
Ok(Model { model })
|
Ok(Model { model })
|
||||||
}
|
}
|
||||||
@@ -100,7 +100,7 @@ impl Params {
|
|||||||
let param = self.build_param()?;
|
let param = self.build_param()?;
|
||||||
let model = unsafe { mf_train_with_validation(&tr, &va, param) };
|
let model = unsafe { mf_train_with_validation(&tr, &va, param) };
|
||||||
if model.is_null() {
|
if model.is_null() {
|
||||||
return Err(Error("fit_eval failed".to_string()));
|
return Err(Error::Unknown);
|
||||||
}
|
}
|
||||||
Ok(Model { model })
|
Ok(Model { model })
|
||||||
}
|
}
|
||||||
@@ -111,7 +111,7 @@ impl Params {
|
|||||||
let avg_error = unsafe { mf_cross_validation(&prob, folds, param) };
|
let avg_error = unsafe { mf_cross_validation(&prob, folds, param) };
|
||||||
// TODO update fork to differentiate between bad parameters and zero error
|
// TODO update fork to differentiate between bad parameters and zero error
|
||||||
if avg_error == 0.0 {
|
if avg_error == 0.0 {
|
||||||
return Err(Error("cv failed".to_string()));
|
return Err(Error::Unknown);
|
||||||
}
|
}
|
||||||
Ok(avg_error)
|
Ok(avg_error)
|
||||||
}
|
}
|
||||||
@@ -121,35 +121,35 @@ impl Params {
|
|||||||
let param = self.param;
|
let param = self.param;
|
||||||
|
|
||||||
if param.k < 1 {
|
if param.k < 1 {
|
||||||
return Err(Error("number of factors must be greater than zero".to_string()));
|
return Err(Error::Parameter("number of factors must be greater than zero".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if param.nr_threads < 1 {
|
if param.nr_threads < 1 {
|
||||||
return Err(Error("number of threads must be greater than zero".to_string()));
|
return Err(Error::Parameter("number of threads must be greater than zero".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if param.nr_bins < 1 || param.nr_bins < param.nr_threads {
|
if param.nr_bins < 1 || param.nr_bins < param.nr_threads {
|
||||||
return Err(Error("number of bins must be greater than number of threads".to_string()));
|
return Err(Error::Parameter("number of bins must be greater than number of threads".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if param.nr_iters < 1 {
|
if param.nr_iters < 1 {
|
||||||
return Err(Error("number of iterations must be greater than zero".to_string()));
|
return Err(Error::Parameter("number of iterations must be greater than zero".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if param.lambda_p1 < 0.0 || param.lambda_p2 < 0.0 || param.lambda_q1 < 0.0 || param.lambda_q2 < 0.0 {
|
if param.lambda_p1 < 0.0 || param.lambda_p2 < 0.0 || param.lambda_q1 < 0.0 || param.lambda_q2 < 0.0 {
|
||||||
return Err(Error("regularization coefficient must be non-negative".to_string()));
|
return Err(Error::Parameter("regularization coefficient must be non-negative".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if param.eta <= 0.0 {
|
if param.eta <= 0.0 {
|
||||||
return Err(Error("learning rate must be greater than zero".to_string()));
|
return Err(Error::Parameter("learning rate must be greater than zero".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if matches!(param.fun, Loss::RealKL) && !param.do_nmf {
|
if matches!(param.fun, Loss::RealKL) && !param.do_nmf {
|
||||||
return Err(Error("nmf must be set when using generalized KL-divergence".to_string()));
|
return Err(Error::Parameter("nmf must be set when using generalized KL-divergence".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if param.alpha < 0.0 {
|
if param.alpha < 0.0 {
|
||||||
return Err(Error("alpha must be a non-negative number".to_string()));
|
return Err(Error::Parameter("alpha must be a non-negative number".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(param)
|
Ok(param)
|
||||||
|
|||||||
Reference in New Issue
Block a user