diff --git a/src/bin/visualize_embeddings.rs b/src/bin/visualize_embeddings.rs index f2eb31a..2ee0463 100644 --- a/src/bin/visualize_embeddings.rs +++ b/src/bin/visualize_embeddings.rs @@ -5,7 +5,11 @@ use dotenv::dotenv; use log::info; use ndarray::{s, Array2, ArrayView1}; use ndarray_linalg::SVD; -use plotly::{Plot, Scatter}; +use plotly::{ + common::{Mode, Title}, + layout::{Legend, Margin}, + Layout, Plot, Scatter3D, +}; use std::env; use std::fs::File; use std::io::Write; @@ -99,9 +103,9 @@ async fn main() -> Result<()> { data.row_mut(i).assign(&ArrayView1::from(&embedding)); } - // Perform PCA + // Perform PCA with 3 components info!("Performing PCA..."); - let projected_data = perform_pca(&data, 2)?; + let projected_data = perform_pca(&data, 3)?; // Create scatter plot for each cluster let mut plot = Plot::new(); @@ -122,33 +126,34 @@ async fn main() -> Result<()> { let x: Vec<_> = indices.iter().map(|&i| projected_data[[i, 0]]).collect(); let y: Vec<_> = indices.iter().map(|&i| projected_data[[i, 1]]).collect(); + let z: Vec<_> = indices.iter().map(|&i| projected_data[[i, 2]]).collect(); let text: Vec<_> = indices .iter() .map(|&i| format!("Item {}", item_ids[i])) .collect(); - let trace = Scatter::new(x, y) + let trace = Scatter3D::new(x, y, z) .name(&format!("Cluster {}", cluster_id)) - .mode(plotly::common::Mode::Markers) + .mode(Mode::Markers) .text_array(text) + .marker( + plotly::common::Marker::new() + .size(8) + .symbol(plotly::common::MarkerSymbol::Circle), + ) .show_legend(true); plot.add_trace(trace); } plot.set_layout( - plotly::Layout::new() - .title(plotly::common::Title::new( - "Item Embeddings Visualization (PCA)", - )) - .x_axis( - plotly::layout::Axis::new() - .title(plotly::common::Title::new("First Principal Component")), - ) - .y_axis( - plotly::layout::Axis::new() - .title(plotly::common::Title::new("Second Principal Component")), - ), + Layout::new() + .title(Title::new("Item Embeddings Visualization (PCA)")) + .show_legend(true) + .legend(Legend::new().x(1.0).y(0.5)) + .margin(Margin::new().left(100).right(100).top(100).bottom(100)) + .height(900) + .width(1600), ); // Save plot to HTML file