better visualization

This commit is contained in:
Dylan Knutson
2024-12-28 03:39:24 +00:00
parent ab5f379b94
commit 6ebbd6aaa9
3 changed files with 39 additions and 22 deletions

3
.gitignore vendored
View File

@@ -11,3 +11,6 @@
*.swo *.swo
embeddings_visualization.html embeddings_visualization.html
# Coredumps
/core

View File

@@ -88,11 +88,19 @@ async fn main() -> Result<()> {
return Ok(()); return Ok(());
}; };
// Get affinity dimension (should be number of items)
let affinity_dims = if let Some(first_row) = rows.first() {
let affinities: Vec<f64> = first_row.get(3);
affinities.len()
} else {
return Ok(());
};
// Convert data to ndarray format // Convert data to ndarray format
let mut data = Array2::zeros((n_items, n_dims)); let mut data = Array2::zeros((n_items, n_dims));
let mut item_ids = Vec::with_capacity(n_items); let mut item_ids = Vec::with_capacity(n_items);
let mut cluster_ids = Vec::with_capacity(n_items); let mut cluster_ids = Vec::with_capacity(n_items);
let mut affinity_data = Array2::zeros((n_items, n_dims)); // Changed from n_items to n_dims for affinity dimension let mut affinity_data = Array2::zeros((n_items, affinity_dims)); // Use full affinity dimension
for (i, row) in rows.iter().enumerate() { for (i, row) in rows.iter().enumerate() {
let item_id: i32 = row.get(0); let item_id: i32 = row.get(0);
@@ -103,15 +111,9 @@ async fn main() -> Result<()> {
item_ids.push(item_id); item_ids.push(item_id);
cluster_ids.push(cluster_id); cluster_ids.push(cluster_id);
data.row_mut(i).assign(&ArrayView1::from(&embedding)); data.row_mut(i).assign(&ArrayView1::from(&embedding));
// Ensure affinity vector has the right length by truncating or padding if necessary
let mut affinity_vec = vec![0.0; n_dims];
for (j, &val) in affinities.iter().take(n_dims).enumerate() {
affinity_vec[j] = val;
}
affinity_data affinity_data
.row_mut(i) .row_mut(i)
.assign(&ArrayView1::from(&affinity_vec)); .assign(&ArrayView1::from(&affinities));
} }
// Perform PCA on both embeddings and affinity vectors // Perform PCA on both embeddings and affinity vectors
@@ -160,15 +162,17 @@ async fn main() -> Result<()> {
.text_array(text) .text_array(text)
.marker( .marker(
plotly::common::Marker::new() plotly::common::Marker::new()
.size(8) .size(10)
.symbol(plotly::common::MarkerSymbol::Circle), .symbol(plotly::common::MarkerSymbol::Circle)
.opacity(0.9)
.line(plotly::common::Line::new().width(1.0).color("white")),
) )
.show_legend(true); .show_legend(true);
plot.add_trace(trace); plot.add_trace(trace);
} }
// Plot affinity vectors // Plot affinity vectors with distinct appearance
for cluster_id in &unique_clusters { for cluster_id in &unique_clusters {
let indices: Vec<_> = cluster_ids let indices: Vec<_> = cluster_ids
.iter() .iter()
@@ -191,7 +195,7 @@ async fn main() -> Result<()> {
.collect(); .collect();
let text: Vec<_> = indices let text: Vec<_> = indices
.iter() .iter()
.map(|&i| format!("Item {}", item_ids[i])) .map(|&i| format!("Item {} (Affinity)", item_ids[i]))
.collect(); .collect();
let trace = Scatter3D::new(x, y, z) let trace = Scatter3D::new(x, y, z)
@@ -200,8 +204,10 @@ async fn main() -> Result<()> {
.text_array(text) .text_array(text)
.marker( .marker(
plotly::common::Marker::new() plotly::common::Marker::new()
.size(8) .size(7)
.symbol(plotly::common::MarkerSymbol::Square), .symbol(plotly::common::MarkerSymbol::Diamond)
.opacity(0.7)
.line(plotly::common::Line::new().width(1.0).color("black")),
) )
.show_legend(true); .show_legend(true);
@@ -210,9 +216,7 @@ async fn main() -> Result<()> {
plot.set_layout( plot.set_layout(
Layout::new() Layout::new()
.title(Title::new( .title(Title::new("Item Embeddings (●) vs Affinity Vectors (◆)"))
"Item Embeddings and Affinities Visualization (PCA)",
))
.show_legend(true) .show_legend(true)
.legend(Legend::new().x(1.0).y(0.5)) .legend(Legend::new().x(1.0).y(0.5))
.margin(Margin::new().left(100).right(100).top(100).bottom(100)) .margin(Margin::new().left(100).right(100).top(100).bottom(100))

View File

@@ -33,13 +33,21 @@ struct Args {
#[arg(long, default_value = "10000")] #[arg(long, default_value = "10000")]
batch_size: i32, batch_size: i32,
/// Learning rate
#[arg(long, default_value = "0.01")]
learning_rate: f32,
/// Number of factors for matrix factorization /// Number of factors for matrix factorization
#[arg(long, default_value = "8")] #[arg(long, default_value = "8")]
factors: i32, factors: i32,
/// Lambda for regularization
#[arg(long, default_value = "0.0")]
lambda1: f32,
/// Lambda for regularization /// Lambda for regularization
#[arg(long, default_value = "0.1")] #[arg(long, default_value = "0.1")]
lambda: f32, lambda2: f32,
/// Number of threads for matrix factorization (defaults to number of CPU cores) /// Number of threads for matrix factorization (defaults to number of CPU cores)
#[arg(long, default_value_t = num_cpus::get() as i32)] #[arg(long, default_value_t = num_cpus::get() as i32)]
@@ -236,9 +244,11 @@ async fn main() -> Result<()> {
// Set up training parameters // Set up training parameters
let model = Model::params() let model = Model::params()
.factors(args.factors as i32) .factors(args.factors as i32)
.lambda_p2(args.lambda) .lambda_p1(args.lambda1)
.lambda_q2(args.lambda) .lambda_q1(args.lambda1)
.learning_rate(0.01) .lambda_p2(args.lambda2)
.lambda_q2(args.lambda2)
.learning_rate(args.learning_rate)
.iterations(100) .iterations(100)
.loss(Loss::OneClassL2) .loss(Loss::OneClassL2)
.c(0.00001) .c(0.00001)