improve embedding visualization

This commit is contained in:
Dylan Knutson
2024-12-28 02:09:32 +00:00
parent e21541af46
commit 56b6604142

View File

@@ -5,7 +5,11 @@ use dotenv::dotenv;
use log::info; use log::info;
use ndarray::{s, Array2, ArrayView1}; use ndarray::{s, Array2, ArrayView1};
use ndarray_linalg::SVD; use ndarray_linalg::SVD;
use plotly::{Plot, Scatter}; use plotly::{
common::{Mode, Title},
layout::{Legend, Margin},
Layout, Plot, Scatter3D,
};
use std::env; use std::env;
use std::fs::File; use std::fs::File;
use std::io::Write; use std::io::Write;
@@ -99,9 +103,9 @@ async fn main() -> Result<()> {
data.row_mut(i).assign(&ArrayView1::from(&embedding)); data.row_mut(i).assign(&ArrayView1::from(&embedding));
} }
// Perform PCA // Perform PCA with 3 components
info!("Performing PCA..."); info!("Performing PCA...");
let projected_data = perform_pca(&data, 2)?; let projected_data = perform_pca(&data, 3)?;
// Create scatter plot for each cluster // Create scatter plot for each cluster
let mut plot = Plot::new(); let mut plot = Plot::new();
@@ -122,33 +126,34 @@ async fn main() -> Result<()> {
let x: Vec<_> = indices.iter().map(|&i| projected_data[[i, 0]]).collect(); let x: Vec<_> = indices.iter().map(|&i| projected_data[[i, 0]]).collect();
let y: Vec<_> = indices.iter().map(|&i| projected_data[[i, 1]]).collect(); let y: Vec<_> = indices.iter().map(|&i| projected_data[[i, 1]]).collect();
let z: Vec<_> = indices.iter().map(|&i| projected_data[[i, 2]]).collect();
let text: Vec<_> = indices let text: Vec<_> = indices
.iter() .iter()
.map(|&i| format!("Item {}", item_ids[i])) .map(|&i| format!("Item {}", item_ids[i]))
.collect(); .collect();
let trace = Scatter::new(x, y) let trace = Scatter3D::new(x, y, z)
.name(&format!("Cluster {}", cluster_id)) .name(&format!("Cluster {}", cluster_id))
.mode(plotly::common::Mode::Markers) .mode(Mode::Markers)
.text_array(text) .text_array(text)
.marker(
plotly::common::Marker::new()
.size(8)
.symbol(plotly::common::MarkerSymbol::Circle),
)
.show_legend(true); .show_legend(true);
plot.add_trace(trace); plot.add_trace(trace);
} }
plot.set_layout( plot.set_layout(
plotly::Layout::new() Layout::new()
.title(plotly::common::Title::new( .title(Title::new("Item Embeddings Visualization (PCA)"))
"Item Embeddings Visualization (PCA)", .show_legend(true)
)) .legend(Legend::new().x(1.0).y(0.5))
.x_axis( .margin(Margin::new().left(100).right(100).top(100).bottom(100))
plotly::layout::Axis::new() .height(900)
.title(plotly::common::Title::new("First Principal Component")), .width(1600),
)
.y_axis(
plotly::layout::Axis::new()
.title(plotly::common::Title::new("Second Principal Component")),
),
); );
// Save plot to HTML file // Save plot to HTML file