Added tests
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
## 0.2.2 (unreleased)
|
## 0.2.2 (unreleased)
|
||||||
|
|
||||||
- Added `p_iter` and `q_iter`
|
- Added `p_iter` and `q_iter`
|
||||||
|
- Added `p_row` and `q_col`
|
||||||
|
|
||||||
## 0.2.1 (2021-11-15)
|
## 0.2.1 (2021-11-15)
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ model.p_factors();
|
|||||||
model.q_factors();
|
model.q_factors();
|
||||||
```
|
```
|
||||||
|
|
||||||
Get the latent factors of a specific row or column
|
Get the latent factors of a specific row or column [unreleased]
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
model.p_row(row_index);
|
model.p_row(row_index);
|
||||||
|
|||||||
14
src/model.rs
14
src/model.rs
@@ -75,17 +75,17 @@ impl Model {
|
|||||||
let factors = self.factors();
|
let factors = self.factors();
|
||||||
let start_index = factors as usize * row_index as usize;
|
let start_index = factors as usize * row_index as usize;
|
||||||
let end_index = factors as usize * (row_index as usize + 1);
|
let end_index = factors as usize * (row_index as usize + 1);
|
||||||
return Some(&self.p_factors()[start_index..end_index]);
|
return Some(&self.p_factors()[start_index..end_index]);
|
||||||
}
|
}
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn q_col(&self, column_index: i32) -> Option<&[f32]>{
|
pub fn q_col(&self, column_index: i32) -> Option<&[f32]>{
|
||||||
if column_index >= 0 && column_index < self.columns() {
|
if column_index >= 0 && column_index < self.columns() {
|
||||||
let factors = self.factors();
|
let factors = self.factors();
|
||||||
let start_index = factors as usize * column_index as usize;
|
let start_index = factors as usize * column_index as usize;
|
||||||
let end_index = factors as usize * (column_index as usize + 1);
|
let end_index = factors as usize * (column_index as usize + 1);
|
||||||
return Some(&self.q_factors()[start_index..end_index]);
|
return Some(&self.q_factors()[start_index..end_index]);
|
||||||
}
|
}
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
@@ -177,6 +177,14 @@ mod tests {
|
|||||||
for (i, factors) in model.q_iter().enumerate() {
|
for (i, factors) in model.q_iter().enumerate() {
|
||||||
assert_eq!(factors, q_vec[i]);
|
assert_eq!(factors, q_vec[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert_eq!(model.p_row(0), Some(p_vec[0]));
|
||||||
|
assert_eq!(model.p_row(1), Some(p_vec[1]));
|
||||||
|
assert_eq!(model.p_row(2), None);
|
||||||
|
|
||||||
|
assert_eq!(model.q_col(0), Some(q_vec[0]));
|
||||||
|
assert_eq!(model.q_col(1), Some(q_vec[1]));
|
||||||
|
assert_eq!(model.q_col(2), None);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
Reference in New Issue
Block a user