From fc1366cb9e317abaf36bf234f01209260d640167 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Fri, 23 Sep 2022 12:42:23 -0700 Subject: [PATCH] Added tests --- CHANGELOG.md | 1 + README.md | 2 +- src/model.rs | 14 +++++++++++--- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d5b45f..12396ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## 0.2.2 (unreleased) - Added `p_iter` and `q_iter` +- Added `p_row` and `q_col` ## 0.2.1 (2021-11-15) diff --git a/README.md b/README.md index 633f5a1..79f0ce0 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ model.p_factors(); model.q_factors(); ``` -Get the latent factors of a specific row or column +Get the latent factors of a specific row or column [unreleased] ```rust model.p_row(row_index); diff --git a/src/model.rs b/src/model.rs index f0a4e82..dffcfe7 100644 --- a/src/model.rs +++ b/src/model.rs @@ -75,17 +75,17 @@ impl Model { let factors = self.factors(); let start_index = factors as usize * row_index as usize; let end_index = factors as usize * (row_index as usize + 1); - return Some(&self.p_factors()[start_index..end_index]); + return Some(&self.p_factors()[start_index..end_index]); } return None; } - + pub fn q_col(&self, column_index: i32) -> Option<&[f32]>{ if column_index >= 0 && column_index < self.columns() { let factors = self.factors(); let start_index = factors as usize * column_index as usize; let end_index = factors as usize * (column_index as usize + 1); - return Some(&self.q_factors()[start_index..end_index]); + return Some(&self.q_factors()[start_index..end_index]); } return None; } @@ -177,6 +177,14 @@ mod tests { for (i, factors) in model.q_iter().enumerate() { assert_eq!(factors, q_vec[i]); } + + assert_eq!(model.p_row(0), Some(p_vec[0])); + assert_eq!(model.p_row(1), Some(p_vec[1])); + assert_eq!(model.p_row(2), None); + + assert_eq!(model.q_col(0), Some(q_vec[0])); + assert_eq!(model.q_col(1), Some(q_vec[1])); + assert_eq!(model.q_col(2), None); } #[test]