Added tests
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
## 0.2.2 (unreleased)
|
||||
|
||||
- Added `p_iter` and `q_iter`
|
||||
- Added `p_row` and `q_col`
|
||||
|
||||
## 0.2.1 (2021-11-15)
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ model.p_factors();
|
||||
model.q_factors();
|
||||
```
|
||||
|
||||
Get the latent factors of a specific row or column
|
||||
Get the latent factors of a specific row or column [unreleased]
|
||||
|
||||
```rust
|
||||
model.p_row(row_index);
|
||||
|
||||
14
src/model.rs
14
src/model.rs
@@ -75,17 +75,17 @@ impl Model {
|
||||
let factors = self.factors();
|
||||
let start_index = factors as usize * row_index as usize;
|
||||
let end_index = factors as usize * (row_index as usize + 1);
|
||||
return Some(&self.p_factors()[start_index..end_index]);
|
||||
return Some(&self.p_factors()[start_index..end_index]);
|
||||
}
|
||||
return None;
|
||||
}
|
||||
|
||||
|
||||
pub fn q_col(&self, column_index: i32) -> Option<&[f32]>{
|
||||
if column_index >= 0 && column_index < self.columns() {
|
||||
let factors = self.factors();
|
||||
let start_index = factors as usize * column_index as usize;
|
||||
let end_index = factors as usize * (column_index as usize + 1);
|
||||
return Some(&self.q_factors()[start_index..end_index]);
|
||||
return Some(&self.q_factors()[start_index..end_index]);
|
||||
}
|
||||
return None;
|
||||
}
|
||||
@@ -177,6 +177,14 @@ mod tests {
|
||||
for (i, factors) in model.q_iter().enumerate() {
|
||||
assert_eq!(factors, q_vec[i]);
|
||||
}
|
||||
|
||||
assert_eq!(model.p_row(0), Some(p_vec[0]));
|
||||
assert_eq!(model.p_row(1), Some(p_vec[1]));
|
||||
assert_eq!(model.p_row(2), None);
|
||||
|
||||
assert_eq!(model.q_col(0), Some(q_vec[0]));
|
||||
assert_eq!(model.q_col(1), Some(q_vec[1]));
|
||||
assert_eq!(model.q_col(2), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user