Added p_iter and q_iter

This commit is contained in:
Andrew Kane
2022-09-23 12:28:44 -07:00
parent 9608ab9172
commit e276db1cec
3 changed files with 40 additions and 5 deletions

View File

@@ -2,6 +2,7 @@ use crate::bindings::*;
use crate::{Error, Matrix, Params};
use std::ffi::CString;
use std::path::Path;
use std::slice::Chunks;
#[derive(Debug)]
pub struct Model {
@@ -53,6 +54,14 @@ impl Model {
unsafe { (*self.model).b }
}
pub fn p_iter(&self) -> Chunks<'_, f32> {
self.p_factors().chunks(self.factors() as usize)
}
pub fn q_iter(&self) -> Chunks<'_, f32> {
self.q_factors().chunks(self.factors() as usize)
}
pub fn p_factors(&self) -> &[f32] {
unsafe { std::slice::from_raw_parts((*self.model).p, (self.rows() * self.factors()) as usize) }
}
@@ -123,9 +132,31 @@ mod tests {
let model = Model::params().quiet(true).fit(&data).unwrap();
model.predict(0, 1);
model.p_factors();
model.q_factors();
model.bias();
// TODO assert in delta
assert_eq!(4.0 / 3.0, model.bias());
let p_factors = model.p_factors();
let q_factors = model.q_factors();
assert_eq!(model.p_iter().len(), 2);
assert_eq!(model.q_iter().len(), 2);
let p_vec = model.p_iter().collect::<Vec<&[f32]>>();
let q_vec = model.q_iter().collect::<Vec<&[f32]>>();
assert_eq!(p_vec[0], &p_factors[0..8]);
assert_eq!(p_vec[1], &p_factors[8..]);
assert_eq!(q_vec[0], &q_factors[0..8]);
assert_eq!(q_vec[1], &q_factors[8..]);
for (i, factors) in model.p_iter().enumerate() {
assert_eq!(factors, p_vec[i]);
}
for (i, factors) in model.q_iter().enumerate() {
assert_eq!(factors, q_vec[i]);
}
}
#[test]