diff --git a/src/params.rs b/src/params.rs index cf89a53..8325ddc 100644 --- a/src/params.rs +++ b/src/params.rs @@ -86,7 +86,8 @@ impl Params { pub fn fit(&mut self, data: &Matrix) -> Result { let prob = data.to_problem(); - let model = unsafe { mf_train(&prob, self.build_param()?) }; + let param = self.build_param()?; + let model = unsafe { mf_train(&prob, param) }; if model.is_null() { return Err(Error("fit failed".to_string())); } @@ -96,7 +97,8 @@ impl Params { pub fn fit_eval(&mut self, train_set: &Matrix, eval_set: &Matrix) -> Result { let tr = train_set.to_problem(); let va = eval_set.to_problem(); - let model = unsafe { mf_train_with_validation(&tr, &va, self.build_param()?) }; + let param = self.build_param()?; + let model = unsafe { mf_train_with_validation(&tr, &va, param) }; if model.is_null() { return Err(Error("fit_eval failed".to_string())); } @@ -105,7 +107,8 @@ impl Params { pub fn cv(&mut self, data: &Matrix, folds: i32) -> Result { let prob = data.to_problem(); - let avg_error = unsafe { mf_cross_validation(&prob, folds, self.build_param()?) }; + let param = self.build_param()?; + let avg_error = unsafe { mf_cross_validation(&prob, folds, param) }; // TODO update fork to differentiate between bad parameters and zero error if avg_error == 0.0 { return Err(Error("cv failed".to_string()));