Improved code
This commit is contained in:
@@ -86,7 +86,8 @@ impl Params {
|
|||||||
|
|
||||||
pub fn fit(&mut self, data: &Matrix) -> Result<Model, Error> {
|
pub fn fit(&mut self, data: &Matrix) -> Result<Model, Error> {
|
||||||
let prob = data.to_problem();
|
let prob = data.to_problem();
|
||||||
let model = unsafe { mf_train(&prob, self.build_param()?) };
|
let param = self.build_param()?;
|
||||||
|
let model = unsafe { mf_train(&prob, param) };
|
||||||
if model.is_null() {
|
if model.is_null() {
|
||||||
return Err(Error("fit failed".to_string()));
|
return Err(Error("fit failed".to_string()));
|
||||||
}
|
}
|
||||||
@@ -96,7 +97,8 @@ impl Params {
|
|||||||
pub fn fit_eval(&mut self, train_set: &Matrix, eval_set: &Matrix) -> Result<Model, Error> {
|
pub fn fit_eval(&mut self, train_set: &Matrix, eval_set: &Matrix) -> Result<Model, Error> {
|
||||||
let tr = train_set.to_problem();
|
let tr = train_set.to_problem();
|
||||||
let va = eval_set.to_problem();
|
let va = eval_set.to_problem();
|
||||||
let model = unsafe { mf_train_with_validation(&tr, &va, self.build_param()?) };
|
let param = self.build_param()?;
|
||||||
|
let model = unsafe { mf_train_with_validation(&tr, &va, param) };
|
||||||
if model.is_null() {
|
if model.is_null() {
|
||||||
return Err(Error("fit_eval failed".to_string()));
|
return Err(Error("fit_eval failed".to_string()));
|
||||||
}
|
}
|
||||||
@@ -105,7 +107,8 @@ impl Params {
|
|||||||
|
|
||||||
pub fn cv(&mut self, data: &Matrix, folds: i32) -> Result<f64, Error> {
|
pub fn cv(&mut self, data: &Matrix, folds: i32) -> Result<f64, Error> {
|
||||||
let prob = data.to_problem();
|
let prob = data.to_problem();
|
||||||
let avg_error = unsafe { mf_cross_validation(&prob, folds, self.build_param()?) };
|
let param = self.build_param()?;
|
||||||
|
let avg_error = unsafe { mf_cross_validation(&prob, folds, param) };
|
||||||
// TODO update fork to differentiate between bad parameters and zero error
|
// TODO update fork to differentiate between bad parameters and zero error
|
||||||
if avg_error == 0.0 {
|
if avg_error == 0.0 {
|
||||||
return Err(Error("cv failed".to_string()));
|
return Err(Error("cv failed".to_string()));
|
||||||
|
|||||||
Reference in New Issue
Block a user