improve embedding visualization
This commit is contained in:
@@ -5,7 +5,11 @@ use dotenv::dotenv;
|
||||
use log::info;
|
||||
use ndarray::{s, Array2, ArrayView1};
|
||||
use ndarray_linalg::SVD;
|
||||
use plotly::{Plot, Scatter};
|
||||
use plotly::{
|
||||
common::{Mode, Title},
|
||||
layout::{Legend, Margin},
|
||||
Layout, Plot, Scatter3D,
|
||||
};
|
||||
use std::env;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
@@ -99,9 +103,9 @@ async fn main() -> Result<()> {
|
||||
data.row_mut(i).assign(&ArrayView1::from(&embedding));
|
||||
}
|
||||
|
||||
// Perform PCA
|
||||
// Perform PCA with 3 components
|
||||
info!("Performing PCA...");
|
||||
let projected_data = perform_pca(&data, 2)?;
|
||||
let projected_data = perform_pca(&data, 3)?;
|
||||
|
||||
// Create scatter plot for each cluster
|
||||
let mut plot = Plot::new();
|
||||
@@ -122,33 +126,34 @@ async fn main() -> Result<()> {
|
||||
|
||||
let x: Vec<_> = indices.iter().map(|&i| projected_data[[i, 0]]).collect();
|
||||
let y: Vec<_> = indices.iter().map(|&i| projected_data[[i, 1]]).collect();
|
||||
let z: Vec<_> = indices.iter().map(|&i| projected_data[[i, 2]]).collect();
|
||||
let text: Vec<_> = indices
|
||||
.iter()
|
||||
.map(|&i| format!("Item {}", item_ids[i]))
|
||||
.collect();
|
||||
|
||||
let trace = Scatter::new(x, y)
|
||||
let trace = Scatter3D::new(x, y, z)
|
||||
.name(&format!("Cluster {}", cluster_id))
|
||||
.mode(plotly::common::Mode::Markers)
|
||||
.mode(Mode::Markers)
|
||||
.text_array(text)
|
||||
.marker(
|
||||
plotly::common::Marker::new()
|
||||
.size(8)
|
||||
.symbol(plotly::common::MarkerSymbol::Circle),
|
||||
)
|
||||
.show_legend(true);
|
||||
|
||||
plot.add_trace(trace);
|
||||
}
|
||||
|
||||
plot.set_layout(
|
||||
plotly::Layout::new()
|
||||
.title(plotly::common::Title::new(
|
||||
"Item Embeddings Visualization (PCA)",
|
||||
))
|
||||
.x_axis(
|
||||
plotly::layout::Axis::new()
|
||||
.title(plotly::common::Title::new("First Principal Component")),
|
||||
)
|
||||
.y_axis(
|
||||
plotly::layout::Axis::new()
|
||||
.title(plotly::common::Title::new("Second Principal Component")),
|
||||
),
|
||||
Layout::new()
|
||||
.title(Title::new("Item Embeddings Visualization (PCA)"))
|
||||
.show_legend(true)
|
||||
.legend(Legend::new().x(1.0).y(0.5))
|
||||
.margin(Margin::new().left(100).right(100).top(100).bottom(100))
|
||||
.height(900)
|
||||
.width(1600),
|
||||
);
|
||||
|
||||
// Save plot to HTML file
|
||||
|
||||
Reference in New Issue
Block a user