better visualization
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -11,3 +11,6 @@
|
|||||||
*.swo
|
*.swo
|
||||||
|
|
||||||
embeddings_visualization.html
|
embeddings_visualization.html
|
||||||
|
|
||||||
|
# Coredumps
|
||||||
|
/core
|
||||||
|
|||||||
@@ -88,11 +88,19 @@ async fn main() -> Result<()> {
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Get affinity dimension (should be number of items)
|
||||||
|
let affinity_dims = if let Some(first_row) = rows.first() {
|
||||||
|
let affinities: Vec<f64> = first_row.get(3);
|
||||||
|
affinities.len()
|
||||||
|
} else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
// Convert data to ndarray format
|
// Convert data to ndarray format
|
||||||
let mut data = Array2::zeros((n_items, n_dims));
|
let mut data = Array2::zeros((n_items, n_dims));
|
||||||
let mut item_ids = Vec::with_capacity(n_items);
|
let mut item_ids = Vec::with_capacity(n_items);
|
||||||
let mut cluster_ids = Vec::with_capacity(n_items);
|
let mut cluster_ids = Vec::with_capacity(n_items);
|
||||||
let mut affinity_data = Array2::zeros((n_items, n_dims)); // Changed from n_items to n_dims for affinity dimension
|
let mut affinity_data = Array2::zeros((n_items, affinity_dims)); // Use full affinity dimension
|
||||||
|
|
||||||
for (i, row) in rows.iter().enumerate() {
|
for (i, row) in rows.iter().enumerate() {
|
||||||
let item_id: i32 = row.get(0);
|
let item_id: i32 = row.get(0);
|
||||||
@@ -103,15 +111,9 @@ async fn main() -> Result<()> {
|
|||||||
item_ids.push(item_id);
|
item_ids.push(item_id);
|
||||||
cluster_ids.push(cluster_id);
|
cluster_ids.push(cluster_id);
|
||||||
data.row_mut(i).assign(&ArrayView1::from(&embedding));
|
data.row_mut(i).assign(&ArrayView1::from(&embedding));
|
||||||
|
|
||||||
// Ensure affinity vector has the right length by truncating or padding if necessary
|
|
||||||
let mut affinity_vec = vec![0.0; n_dims];
|
|
||||||
for (j, &val) in affinities.iter().take(n_dims).enumerate() {
|
|
||||||
affinity_vec[j] = val;
|
|
||||||
}
|
|
||||||
affinity_data
|
affinity_data
|
||||||
.row_mut(i)
|
.row_mut(i)
|
||||||
.assign(&ArrayView1::from(&affinity_vec));
|
.assign(&ArrayView1::from(&affinities));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform PCA on both embeddings and affinity vectors
|
// Perform PCA on both embeddings and affinity vectors
|
||||||
@@ -160,15 +162,17 @@ async fn main() -> Result<()> {
|
|||||||
.text_array(text)
|
.text_array(text)
|
||||||
.marker(
|
.marker(
|
||||||
plotly::common::Marker::new()
|
plotly::common::Marker::new()
|
||||||
.size(8)
|
.size(10)
|
||||||
.symbol(plotly::common::MarkerSymbol::Circle),
|
.symbol(plotly::common::MarkerSymbol::Circle)
|
||||||
|
.opacity(0.9)
|
||||||
|
.line(plotly::common::Line::new().width(1.0).color("white")),
|
||||||
)
|
)
|
||||||
.show_legend(true);
|
.show_legend(true);
|
||||||
|
|
||||||
plot.add_trace(trace);
|
plot.add_trace(trace);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Plot affinity vectors
|
// Plot affinity vectors with distinct appearance
|
||||||
for cluster_id in &unique_clusters {
|
for cluster_id in &unique_clusters {
|
||||||
let indices: Vec<_> = cluster_ids
|
let indices: Vec<_> = cluster_ids
|
||||||
.iter()
|
.iter()
|
||||||
@@ -191,7 +195,7 @@ async fn main() -> Result<()> {
|
|||||||
.collect();
|
.collect();
|
||||||
let text: Vec<_> = indices
|
let text: Vec<_> = indices
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&i| format!("Item {}", item_ids[i]))
|
.map(|&i| format!("Item {} (Affinity)", item_ids[i]))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let trace = Scatter3D::new(x, y, z)
|
let trace = Scatter3D::new(x, y, z)
|
||||||
@@ -200,8 +204,10 @@ async fn main() -> Result<()> {
|
|||||||
.text_array(text)
|
.text_array(text)
|
||||||
.marker(
|
.marker(
|
||||||
plotly::common::Marker::new()
|
plotly::common::Marker::new()
|
||||||
.size(8)
|
.size(7)
|
||||||
.symbol(plotly::common::MarkerSymbol::Square),
|
.symbol(plotly::common::MarkerSymbol::Diamond)
|
||||||
|
.opacity(0.7)
|
||||||
|
.line(plotly::common::Line::new().width(1.0).color("black")),
|
||||||
)
|
)
|
||||||
.show_legend(true);
|
.show_legend(true);
|
||||||
|
|
||||||
@@ -210,9 +216,7 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
plot.set_layout(
|
plot.set_layout(
|
||||||
Layout::new()
|
Layout::new()
|
||||||
.title(Title::new(
|
.title(Title::new("Item Embeddings (●) vs Affinity Vectors (◆)"))
|
||||||
"Item Embeddings and Affinities Visualization (PCA)",
|
|
||||||
))
|
|
||||||
.show_legend(true)
|
.show_legend(true)
|
||||||
.legend(Legend::new().x(1.0).y(0.5))
|
.legend(Legend::new().x(1.0).y(0.5))
|
||||||
.margin(Margin::new().left(100).right(100).top(100).bottom(100))
|
.margin(Margin::new().left(100).right(100).top(100).bottom(100))
|
||||||
|
|||||||
18
src/main.rs
18
src/main.rs
@@ -33,13 +33,21 @@ struct Args {
|
|||||||
#[arg(long, default_value = "10000")]
|
#[arg(long, default_value = "10000")]
|
||||||
batch_size: i32,
|
batch_size: i32,
|
||||||
|
|
||||||
|
/// Learning rate
|
||||||
|
#[arg(long, default_value = "0.01")]
|
||||||
|
learning_rate: f32,
|
||||||
|
|
||||||
/// Number of factors for matrix factorization
|
/// Number of factors for matrix factorization
|
||||||
#[arg(long, default_value = "8")]
|
#[arg(long, default_value = "8")]
|
||||||
factors: i32,
|
factors: i32,
|
||||||
|
|
||||||
|
/// Lambda for regularization
|
||||||
|
#[arg(long, default_value = "0.0")]
|
||||||
|
lambda1: f32,
|
||||||
|
|
||||||
/// Lambda for regularization
|
/// Lambda for regularization
|
||||||
#[arg(long, default_value = "0.1")]
|
#[arg(long, default_value = "0.1")]
|
||||||
lambda: f32,
|
lambda2: f32,
|
||||||
|
|
||||||
/// Number of threads for matrix factorization (defaults to number of CPU cores)
|
/// Number of threads for matrix factorization (defaults to number of CPU cores)
|
||||||
#[arg(long, default_value_t = num_cpus::get() as i32)]
|
#[arg(long, default_value_t = num_cpus::get() as i32)]
|
||||||
@@ -236,9 +244,11 @@ async fn main() -> Result<()> {
|
|||||||
// Set up training parameters
|
// Set up training parameters
|
||||||
let model = Model::params()
|
let model = Model::params()
|
||||||
.factors(args.factors as i32)
|
.factors(args.factors as i32)
|
||||||
.lambda_p2(args.lambda)
|
.lambda_p1(args.lambda1)
|
||||||
.lambda_q2(args.lambda)
|
.lambda_q1(args.lambda1)
|
||||||
.learning_rate(0.01)
|
.lambda_p2(args.lambda2)
|
||||||
|
.lambda_q2(args.lambda2)
|
||||||
|
.learning_rate(args.learning_rate)
|
||||||
.iterations(100)
|
.iterations(100)
|
||||||
.loss(Loss::OneClassL2)
|
.loss(Loss::OneClassL2)
|
||||||
.c(0.00001)
|
.c(0.00001)
|
||||||
|
|||||||
Reference in New Issue
Block a user