Added p_iter and q_iter
This commit is contained in:
@@ -1,3 +1,7 @@
|
||||
## 0.2.2 (unreleased)
|
||||
|
||||
- Added `p_iter` and `q_iter`
|
||||
|
||||
## 0.2.1 (2021-11-15)
|
||||
|
||||
- Added `Error` trait to errors
|
||||
|
||||
@@ -40,8 +40,8 @@ model.predict(row_index, column_index);
|
||||
Get the latent factors (these approximate the training matrix)
|
||||
|
||||
```rust
|
||||
model.p_factors();
|
||||
model.q_factors();
|
||||
model.p_iter().collect::<Vec<&[f32]>>();
|
||||
model.q_iter().collect::<Vec<&[f32]>>();
|
||||
```
|
||||
|
||||
Get the bias (average of all elements in the training matrix)
|
||||
|
||||
37
src/model.rs
37
src/model.rs
@@ -2,6 +2,7 @@ use crate::bindings::*;
|
||||
use crate::{Error, Matrix, Params};
|
||||
use std::ffi::CString;
|
||||
use std::path::Path;
|
||||
use std::slice::Chunks;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Model {
|
||||
@@ -53,6 +54,14 @@ impl Model {
|
||||
unsafe { (*self.model).b }
|
||||
}
|
||||
|
||||
pub fn p_iter(&self) -> Chunks<'_, f32> {
|
||||
self.p_factors().chunks(self.factors() as usize)
|
||||
}
|
||||
|
||||
pub fn q_iter(&self) -> Chunks<'_, f32> {
|
||||
self.q_factors().chunks(self.factors() as usize)
|
||||
}
|
||||
|
||||
pub fn p_factors(&self) -> &[f32] {
|
||||
unsafe { std::slice::from_raw_parts((*self.model).p, (self.rows() * self.factors()) as usize) }
|
||||
}
|
||||
@@ -123,9 +132,31 @@ mod tests {
|
||||
let model = Model::params().quiet(true).fit(&data).unwrap();
|
||||
model.predict(0, 1);
|
||||
|
||||
model.p_factors();
|
||||
model.q_factors();
|
||||
model.bias();
|
||||
// TODO assert in delta
|
||||
assert_eq!(4.0 / 3.0, model.bias());
|
||||
|
||||
let p_factors = model.p_factors();
|
||||
let q_factors = model.q_factors();
|
||||
|
||||
assert_eq!(model.p_iter().len(), 2);
|
||||
assert_eq!(model.q_iter().len(), 2);
|
||||
|
||||
let p_vec = model.p_iter().collect::<Vec<&[f32]>>();
|
||||
let q_vec = model.q_iter().collect::<Vec<&[f32]>>();
|
||||
|
||||
assert_eq!(p_vec[0], &p_factors[0..8]);
|
||||
assert_eq!(p_vec[1], &p_factors[8..]);
|
||||
|
||||
assert_eq!(q_vec[0], &q_factors[0..8]);
|
||||
assert_eq!(q_vec[1], &q_factors[8..]);
|
||||
|
||||
for (i, factors) in model.p_iter().enumerate() {
|
||||
assert_eq!(factors, p_vec[i]);
|
||||
}
|
||||
|
||||
for (i, factors) in model.q_iter().enumerate() {
|
||||
assert_eq!(factors, q_vec[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user