Added p_iter and q_iter
This commit is contained in:
@@ -1,3 +1,7 @@
|
|||||||
|
## 0.2.2 (unreleased)
|
||||||
|
|
||||||
|
- Added `p_iter` and `q_iter`
|
||||||
|
|
||||||
## 0.2.1 (2021-11-15)
|
## 0.2.1 (2021-11-15)
|
||||||
|
|
||||||
- Added `Error` trait to errors
|
- Added `Error` trait to errors
|
||||||
|
|||||||
@@ -40,8 +40,8 @@ model.predict(row_index, column_index);
|
|||||||
Get the latent factors (these approximate the training matrix)
|
Get the latent factors (these approximate the training matrix)
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
model.p_factors();
|
model.p_iter().collect::<Vec<&[f32]>>();
|
||||||
model.q_factors();
|
model.q_iter().collect::<Vec<&[f32]>>();
|
||||||
```
|
```
|
||||||
|
|
||||||
Get the bias (average of all elements in the training matrix)
|
Get the bias (average of all elements in the training matrix)
|
||||||
|
|||||||
37
src/model.rs
37
src/model.rs
@@ -2,6 +2,7 @@ use crate::bindings::*;
|
|||||||
use crate::{Error, Matrix, Params};
|
use crate::{Error, Matrix, Params};
|
||||||
use std::ffi::CString;
|
use std::ffi::CString;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use std::slice::Chunks;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Model {
|
pub struct Model {
|
||||||
@@ -53,6 +54,14 @@ impl Model {
|
|||||||
unsafe { (*self.model).b }
|
unsafe { (*self.model).b }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn p_iter(&self) -> Chunks<'_, f32> {
|
||||||
|
self.p_factors().chunks(self.factors() as usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn q_iter(&self) -> Chunks<'_, f32> {
|
||||||
|
self.q_factors().chunks(self.factors() as usize)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn p_factors(&self) -> &[f32] {
|
pub fn p_factors(&self) -> &[f32] {
|
||||||
unsafe { std::slice::from_raw_parts((*self.model).p, (self.rows() * self.factors()) as usize) }
|
unsafe { std::slice::from_raw_parts((*self.model).p, (self.rows() * self.factors()) as usize) }
|
||||||
}
|
}
|
||||||
@@ -123,9 +132,31 @@ mod tests {
|
|||||||
let model = Model::params().quiet(true).fit(&data).unwrap();
|
let model = Model::params().quiet(true).fit(&data).unwrap();
|
||||||
model.predict(0, 1);
|
model.predict(0, 1);
|
||||||
|
|
||||||
model.p_factors();
|
// TODO assert in delta
|
||||||
model.q_factors();
|
assert_eq!(4.0 / 3.0, model.bias());
|
||||||
model.bias();
|
|
||||||
|
let p_factors = model.p_factors();
|
||||||
|
let q_factors = model.q_factors();
|
||||||
|
|
||||||
|
assert_eq!(model.p_iter().len(), 2);
|
||||||
|
assert_eq!(model.q_iter().len(), 2);
|
||||||
|
|
||||||
|
let p_vec = model.p_iter().collect::<Vec<&[f32]>>();
|
||||||
|
let q_vec = model.q_iter().collect::<Vec<&[f32]>>();
|
||||||
|
|
||||||
|
assert_eq!(p_vec[0], &p_factors[0..8]);
|
||||||
|
assert_eq!(p_vec[1], &p_factors[8..]);
|
||||||
|
|
||||||
|
assert_eq!(q_vec[0], &q_factors[0..8]);
|
||||||
|
assert_eq!(q_vec[1], &q_factors[8..]);
|
||||||
|
|
||||||
|
for (i, factors) in model.p_iter().enumerate() {
|
||||||
|
assert_eq!(factors, p_vec[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i, factors) in model.q_iter().enumerate() {
|
||||||
|
assert_eq!(factors, q_vec[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
Reference in New Issue
Block a user