diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c39a0d..1d5b45f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.2.2 (unreleased) + +- Added `p_iter` and `q_iter` + ## 0.2.1 (2021-11-15) - Added `Error` trait to errors diff --git a/README.md b/README.md index 3b1734e..8ee3046 100644 --- a/README.md +++ b/README.md @@ -40,8 +40,8 @@ model.predict(row_index, column_index); Get the latent factors (these approximate the training matrix) ```rust -model.p_factors(); -model.q_factors(); +model.p_iter().collect::>(); +model.q_iter().collect::>(); ``` Get the bias (average of all elements in the training matrix) diff --git a/src/model.rs b/src/model.rs index a08d212..02cd215 100644 --- a/src/model.rs +++ b/src/model.rs @@ -2,6 +2,7 @@ use crate::bindings::*; use crate::{Error, Matrix, Params}; use std::ffi::CString; use std::path::Path; +use std::slice::Chunks; #[derive(Debug)] pub struct Model { @@ -53,6 +54,14 @@ impl Model { unsafe { (*self.model).b } } + pub fn p_iter(&self) -> Chunks<'_, f32> { + self.p_factors().chunks(self.factors() as usize) + } + + pub fn q_iter(&self) -> Chunks<'_, f32> { + self.q_factors().chunks(self.factors() as usize) + } + pub fn p_factors(&self) -> &[f32] { unsafe { std::slice::from_raw_parts((*self.model).p, (self.rows() * self.factors()) as usize) } } @@ -123,9 +132,31 @@ mod tests { let model = Model::params().quiet(true).fit(&data).unwrap(); model.predict(0, 1); - model.p_factors(); - model.q_factors(); - model.bias(); + // TODO assert in delta + assert_eq!(4.0 / 3.0, model.bias()); + + let p_factors = model.p_factors(); + let q_factors = model.q_factors(); + + assert_eq!(model.p_iter().len(), 2); + assert_eq!(model.q_iter().len(), 2); + + let p_vec = model.p_iter().collect::>(); + let q_vec = model.q_iter().collect::>(); + + assert_eq!(p_vec[0], &p_factors[0..8]); + assert_eq!(p_vec[1], &p_factors[8..]); + + assert_eq!(q_vec[0], &q_factors[0..8]); + assert_eq!(q_vec[1], &q_factors[8..]); + + for (i, factors) in model.p_iter().enumerate() { + assert_eq!(factors, p_vec[i]); + } + + for (i, factors) in model.q_iter().enumerate() { + assert_eq!(factors, q_vec[i]); + } } #[test]