improve embedding visualization
This commit is contained in:
@@ -5,7 +5,11 @@ use dotenv::dotenv;
|
|||||||
use log::info;
|
use log::info;
|
||||||
use ndarray::{s, Array2, ArrayView1};
|
use ndarray::{s, Array2, ArrayView1};
|
||||||
use ndarray_linalg::SVD;
|
use ndarray_linalg::SVD;
|
||||||
use plotly::{Plot, Scatter};
|
use plotly::{
|
||||||
|
common::{Mode, Title},
|
||||||
|
layout::{Legend, Margin},
|
||||||
|
Layout, Plot, Scatter3D,
|
||||||
|
};
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
@@ -99,9 +103,9 @@ async fn main() -> Result<()> {
|
|||||||
data.row_mut(i).assign(&ArrayView1::from(&embedding));
|
data.row_mut(i).assign(&ArrayView1::from(&embedding));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform PCA
|
// Perform PCA with 3 components
|
||||||
info!("Performing PCA...");
|
info!("Performing PCA...");
|
||||||
let projected_data = perform_pca(&data, 2)?;
|
let projected_data = perform_pca(&data, 3)?;
|
||||||
|
|
||||||
// Create scatter plot for each cluster
|
// Create scatter plot for each cluster
|
||||||
let mut plot = Plot::new();
|
let mut plot = Plot::new();
|
||||||
@@ -122,33 +126,34 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
let x: Vec<_> = indices.iter().map(|&i| projected_data[[i, 0]]).collect();
|
let x: Vec<_> = indices.iter().map(|&i| projected_data[[i, 0]]).collect();
|
||||||
let y: Vec<_> = indices.iter().map(|&i| projected_data[[i, 1]]).collect();
|
let y: Vec<_> = indices.iter().map(|&i| projected_data[[i, 1]]).collect();
|
||||||
|
let z: Vec<_> = indices.iter().map(|&i| projected_data[[i, 2]]).collect();
|
||||||
let text: Vec<_> = indices
|
let text: Vec<_> = indices
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&i| format!("Item {}", item_ids[i]))
|
.map(|&i| format!("Item {}", item_ids[i]))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let trace = Scatter::new(x, y)
|
let trace = Scatter3D::new(x, y, z)
|
||||||
.name(&format!("Cluster {}", cluster_id))
|
.name(&format!("Cluster {}", cluster_id))
|
||||||
.mode(plotly::common::Mode::Markers)
|
.mode(Mode::Markers)
|
||||||
.text_array(text)
|
.text_array(text)
|
||||||
|
.marker(
|
||||||
|
plotly::common::Marker::new()
|
||||||
|
.size(8)
|
||||||
|
.symbol(plotly::common::MarkerSymbol::Circle),
|
||||||
|
)
|
||||||
.show_legend(true);
|
.show_legend(true);
|
||||||
|
|
||||||
plot.add_trace(trace);
|
plot.add_trace(trace);
|
||||||
}
|
}
|
||||||
|
|
||||||
plot.set_layout(
|
plot.set_layout(
|
||||||
plotly::Layout::new()
|
Layout::new()
|
||||||
.title(plotly::common::Title::new(
|
.title(Title::new("Item Embeddings Visualization (PCA)"))
|
||||||
"Item Embeddings Visualization (PCA)",
|
.show_legend(true)
|
||||||
))
|
.legend(Legend::new().x(1.0).y(0.5))
|
||||||
.x_axis(
|
.margin(Margin::new().left(100).right(100).top(100).bottom(100))
|
||||||
plotly::layout::Axis::new()
|
.height(900)
|
||||||
.title(plotly::common::Title::new("First Principal Component")),
|
.width(1600),
|
||||||
)
|
|
||||||
.y_axis(
|
|
||||||
plotly::layout::Axis::new()
|
|
||||||
.title(plotly::common::Title::new("Second Principal Component")),
|
|
||||||
),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
// Save plot to HTML file
|
// Save plot to HTML file
|
||||||
|
|||||||
Reference in New Issue
Block a user