Added p_iter and q_iter

This commit is contained in:
Andrew Kane
2022-09-23 12:28:44 -07:00
parent 9608ab9172
commit e276db1cec
3 changed files with 40 additions and 5 deletions

View File

@@ -1,3 +1,7 @@
## 0.2.2 (unreleased)
- Added `p_iter` and `q_iter`
## 0.2.1 (2021-11-15) ## 0.2.1 (2021-11-15)
- Added `Error` trait to errors - Added `Error` trait to errors

View File

@@ -40,8 +40,8 @@ model.predict(row_index, column_index);
Get the latent factors (these approximate the training matrix) Get the latent factors (these approximate the training matrix)
```rust ```rust
model.p_factors(); model.p_iter().collect::<Vec<&[f32]>>();
model.q_factors(); model.q_iter().collect::<Vec<&[f32]>>();
``` ```
Get the bias (average of all elements in the training matrix) Get the bias (average of all elements in the training matrix)

View File

@@ -2,6 +2,7 @@ use crate::bindings::*;
use crate::{Error, Matrix, Params}; use crate::{Error, Matrix, Params};
use std::ffi::CString; use std::ffi::CString;
use std::path::Path; use std::path::Path;
use std::slice::Chunks;
#[derive(Debug)] #[derive(Debug)]
pub struct Model { pub struct Model {
@@ -53,6 +54,14 @@ impl Model {
unsafe { (*self.model).b } unsafe { (*self.model).b }
} }
pub fn p_iter(&self) -> Chunks<'_, f32> {
self.p_factors().chunks(self.factors() as usize)
}
pub fn q_iter(&self) -> Chunks<'_, f32> {
self.q_factors().chunks(self.factors() as usize)
}
pub fn p_factors(&self) -> &[f32] { pub fn p_factors(&self) -> &[f32] {
unsafe { std::slice::from_raw_parts((*self.model).p, (self.rows() * self.factors()) as usize) } unsafe { std::slice::from_raw_parts((*self.model).p, (self.rows() * self.factors()) as usize) }
} }
@@ -123,9 +132,31 @@ mod tests {
let model = Model::params().quiet(true).fit(&data).unwrap(); let model = Model::params().quiet(true).fit(&data).unwrap();
model.predict(0, 1); model.predict(0, 1);
model.p_factors(); // TODO assert in delta
model.q_factors(); assert_eq!(4.0 / 3.0, model.bias());
model.bias();
let p_factors = model.p_factors();
let q_factors = model.q_factors();
assert_eq!(model.p_iter().len(), 2);
assert_eq!(model.q_iter().len(), 2);
let p_vec = model.p_iter().collect::<Vec<&[f32]>>();
let q_vec = model.q_iter().collect::<Vec<&[f32]>>();
assert_eq!(p_vec[0], &p_factors[0..8]);
assert_eq!(p_vec[1], &p_factors[8..]);
assert_eq!(q_vec[0], &q_factors[0..8]);
assert_eq!(q_vec[1], &q_factors[8..]);
for (i, factors) in model.p_iter().enumerate() {
assert_eq!(factors, p_vec[i]);
}
for (i, factors) in model.q_iter().enumerate() {
assert_eq!(factors, q_vec[i]);
}
} }
#[test] #[test]