From 6ebbd6aaa93ef1bb9322ce0d1886b64f16dc7f84 Mon Sep 17 00:00:00 2001 From: Dylan Knutson Date: Sat, 28 Dec 2024 03:39:24 +0000 Subject: [PATCH] better visualization --- .gitignore | 5 ++++- src/bin/visualize_embeddings.rs | 38 ++++++++++++++++++--------------- src/main.rs | 18 ++++++++++++---- 3 files changed, 39 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index 8b707c9..6545f05 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,7 @@ *.swp *.swo -embeddings_visualization.html \ No newline at end of file +embeddings_visualization.html + +# Coredumps +/core diff --git a/src/bin/visualize_embeddings.rs b/src/bin/visualize_embeddings.rs index 8bfab3b..14e39a4 100644 --- a/src/bin/visualize_embeddings.rs +++ b/src/bin/visualize_embeddings.rs @@ -88,11 +88,19 @@ async fn main() -> Result<()> { return Ok(()); }; + // Get affinity dimension (should be number of items) + let affinity_dims = if let Some(first_row) = rows.first() { + let affinities: Vec = first_row.get(3); + affinities.len() + } else { + return Ok(()); + }; + // Convert data to ndarray format let mut data = Array2::zeros((n_items, n_dims)); let mut item_ids = Vec::with_capacity(n_items); let mut cluster_ids = Vec::with_capacity(n_items); - let mut affinity_data = Array2::zeros((n_items, n_dims)); // Changed from n_items to n_dims for affinity dimension + let mut affinity_data = Array2::zeros((n_items, affinity_dims)); // Use full affinity dimension for (i, row) in rows.iter().enumerate() { let item_id: i32 = row.get(0); @@ -103,15 +111,9 @@ async fn main() -> Result<()> { item_ids.push(item_id); cluster_ids.push(cluster_id); data.row_mut(i).assign(&ArrayView1::from(&embedding)); - - // Ensure affinity vector has the right length by truncating or padding if necessary - let mut affinity_vec = vec![0.0; n_dims]; - for (j, &val) in affinities.iter().take(n_dims).enumerate() { - affinity_vec[j] = val; - } affinity_data .row_mut(i) - .assign(&ArrayView1::from(&affinity_vec)); + .assign(&ArrayView1::from(&affinities)); } // Perform PCA on both embeddings and affinity vectors @@ -160,15 +162,17 @@ async fn main() -> Result<()> { .text_array(text) .marker( plotly::common::Marker::new() - .size(8) - .symbol(plotly::common::MarkerSymbol::Circle), + .size(10) + .symbol(plotly::common::MarkerSymbol::Circle) + .opacity(0.9) + .line(plotly::common::Line::new().width(1.0).color("white")), ) .show_legend(true); plot.add_trace(trace); } - // Plot affinity vectors + // Plot affinity vectors with distinct appearance for cluster_id in &unique_clusters { let indices: Vec<_> = cluster_ids .iter() @@ -191,7 +195,7 @@ async fn main() -> Result<()> { .collect(); let text: Vec<_> = indices .iter() - .map(|&i| format!("Item {}", item_ids[i])) + .map(|&i| format!("Item {} (Affinity)", item_ids[i])) .collect(); let trace = Scatter3D::new(x, y, z) @@ -200,8 +204,10 @@ async fn main() -> Result<()> { .text_array(text) .marker( plotly::common::Marker::new() - .size(8) - .symbol(plotly::common::MarkerSymbol::Square), + .size(7) + .symbol(plotly::common::MarkerSymbol::Diamond) + .opacity(0.7) + .line(plotly::common::Line::new().width(1.0).color("black")), ) .show_legend(true); @@ -210,9 +216,7 @@ async fn main() -> Result<()> { plot.set_layout( Layout::new() - .title(Title::new( - "Item Embeddings and Affinities Visualization (PCA)", - )) + .title(Title::new("Item Embeddings (●) vs Affinity Vectors (◆)")) .show_legend(true) .legend(Legend::new().x(1.0).y(0.5)) .margin(Margin::new().left(100).right(100).top(100).bottom(100)) diff --git a/src/main.rs b/src/main.rs index 05a06dd..39f169e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,13 +33,21 @@ struct Args { #[arg(long, default_value = "10000")] batch_size: i32, + /// Learning rate + #[arg(long, default_value = "0.01")] + learning_rate: f32, + /// Number of factors for matrix factorization #[arg(long, default_value = "8")] factors: i32, + /// Lambda for regularization + #[arg(long, default_value = "0.0")] + lambda1: f32, + /// Lambda for regularization #[arg(long, default_value = "0.1")] - lambda: f32, + lambda2: f32, /// Number of threads for matrix factorization (defaults to number of CPU cores) #[arg(long, default_value_t = num_cpus::get() as i32)] @@ -236,9 +244,11 @@ async fn main() -> Result<()> { // Set up training parameters let model = Model::params() .factors(args.factors as i32) - .lambda_p2(args.lambda) - .lambda_q2(args.lambda) - .learning_rate(0.01) + .lambda_p1(args.lambda1) + .lambda_q1(args.lambda1) + .lambda_p2(args.lambda2) + .lambda_q2(args.lambda2) + .learning_rate(args.learning_rate) .iterations(100) .loss(Loss::OneClassL2) .c(0.00001)