diff --git a/CHANGELOG.md b/CHANGELOG.md index 2fc45da..e3c6b8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## 0.2.1 (unreleased) - Added `Error` trait to errors +- Added support for paths to `save` and `load` ## 0.2.0 (2021-10-17) diff --git a/src/model.rs b/src/model.rs index 16ecd75..a08d212 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,6 +1,7 @@ use crate::bindings::*; use crate::{Error, Matrix, Params}; use std::ffi::CString; +use std::path::Path; #[derive(Debug)] pub struct Model { @@ -12,8 +13,9 @@ impl Model { Params::new() } - pub fn load(path: &str) -> Result { - let cpath = CString::new(path)?; + pub fn load>(path: P) -> Result { + // TODO better conversion + let cpath = CString::new(path.as_ref().to_str().unwrap())?; let model = unsafe { mf_load_model(cpath.as_ptr()) }; if model.is_null() { return Err(Error::Io); @@ -25,8 +27,9 @@ impl Model { unsafe { mf_predict(self.model, row_index, column_index) } } - pub fn save(&self, path: &str) -> Result<(), Error> { - let cpath = CString::new(path)?; + pub fn save>(&self, path: P) -> Result<(), Error> { + // TODO better conversion + let cpath = CString::new(path.as_ref().to_str().unwrap())?; let status = unsafe { mf_save_model(self.model, cpath.as_ptr()) }; if status != 0 { return Err(Error::Io);