Updated style [skip ci]
This commit is contained in:
10
src/model.rs
10
src/model.rs
@@ -15,10 +15,9 @@ impl Model {
|
||||
let cpath = CString::new(path).expect("CString::new failed");
|
||||
let model = unsafe { mf_load_model(cpath.as_ptr()) };
|
||||
if model.is_null() {
|
||||
Err(Error("Cannot open model".to_string()))
|
||||
} else {
|
||||
Ok(Model { model })
|
||||
return Err(Error("Cannot open model".to_string()));
|
||||
}
|
||||
Ok(Model { model })
|
||||
}
|
||||
|
||||
pub fn predict(&self, row_index: i32, column_index: i32) -> f32 {
|
||||
@@ -29,10 +28,9 @@ impl Model {
|
||||
let cpath = CString::new(path).expect("CString::new failed");
|
||||
let status = unsafe { mf_save_model(self.model, cpath.as_ptr()) };
|
||||
if status != 0 {
|
||||
Err(Error("Cannot save model".to_string()))
|
||||
} else {
|
||||
Ok(())
|
||||
return Err(Error("Cannot save model".to_string()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn rows(&self) -> i32 {
|
||||
|
||||
@@ -88,10 +88,9 @@ impl Params {
|
||||
let prob = data.to_problem();
|
||||
let model = unsafe { mf_train(&prob, self.build_param()?) };
|
||||
if model.is_null() {
|
||||
Err(Error("fit failed".to_string()))
|
||||
} else {
|
||||
Ok(Model { model })
|
||||
return Err(Error("fit failed".to_string()));
|
||||
}
|
||||
Ok(Model { model })
|
||||
}
|
||||
|
||||
pub fn fit_eval(&mut self, train_set: &Matrix, eval_set: &Matrix) -> Result<Model, Error> {
|
||||
@@ -99,10 +98,9 @@ impl Params {
|
||||
let va = eval_set.to_problem();
|
||||
let model = unsafe { mf_train_with_validation(&tr, &va, self.build_param()?) };
|
||||
if model.is_null() {
|
||||
Err(Error("fit_eval failed".to_string()))
|
||||
} else {
|
||||
Ok(Model { model })
|
||||
return Err(Error("fit_eval failed".to_string()));
|
||||
}
|
||||
Ok(Model { model })
|
||||
}
|
||||
|
||||
pub fn cv(&mut self, data: &Matrix, folds: i32) -> Result<f64, Error> {
|
||||
@@ -110,10 +108,9 @@ impl Params {
|
||||
let avg_error = unsafe { mf_cross_validation(&prob, folds, self.build_param()?) };
|
||||
// TODO update fork to differentiate between bad parameters and zero error
|
||||
if avg_error == 0.0 {
|
||||
Err(Error("cv failed".to_string()))
|
||||
} else {
|
||||
Ok(avg_error)
|
||||
return Err(Error("cv failed".to_string()));
|
||||
}
|
||||
Ok(avg_error)
|
||||
}
|
||||
|
||||
// check parameters in Rust for better error message
|
||||
|
||||
Reference in New Issue
Block a user