Added support for paths to save and load

This commit is contained in:
Andrew Kane
2021-11-14 23:01:41 -08:00
parent c155fbc2e7
commit 3638acc078
2 changed files with 8 additions and 4 deletions

View File

@@ -1,6 +1,7 @@
## 0.2.1 (unreleased) ## 0.2.1 (unreleased)
- Added `Error` trait to errors - Added `Error` trait to errors
- Added support for paths to `save` and `load`
## 0.2.0 (2021-10-17) ## 0.2.0 (2021-10-17)

View File

@@ -1,6 +1,7 @@
use crate::bindings::*; use crate::bindings::*;
use crate::{Error, Matrix, Params}; use crate::{Error, Matrix, Params};
use std::ffi::CString; use std::ffi::CString;
use std::path::Path;
#[derive(Debug)] #[derive(Debug)]
pub struct Model { pub struct Model {
@@ -12,8 +13,9 @@ impl Model {
Params::new() Params::new()
} }
pub fn load(path: &str) -> Result<Self, Error> { pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
let cpath = CString::new(path)?; // TODO better conversion
let cpath = CString::new(path.as_ref().to_str().unwrap())?;
let model = unsafe { mf_load_model(cpath.as_ptr()) }; let model = unsafe { mf_load_model(cpath.as_ptr()) };
if model.is_null() { if model.is_null() {
return Err(Error::Io); return Err(Error::Io);
@@ -25,8 +27,9 @@ impl Model {
unsafe { mf_predict(self.model, row_index, column_index) } unsafe { mf_predict(self.model, row_index, column_index) }
} }
pub fn save(&self, path: &str) -> Result<(), Error> { pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Error> {
let cpath = CString::new(path)?; // TODO better conversion
let cpath = CString::new(path.as_ref().to_str().unwrap())?;
let status = unsafe { mf_save_model(self.model, cpath.as_ptr()) }; let status = unsafe { mf_save_model(self.model, cpath.as_ptr()) };
if status != 0 { if status != 0 {
return Err(Error::Io); return Err(Error::Io);