better visualization
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -11,3 +11,6 @@
|
||||
*.swo
|
||||
|
||||
embeddings_visualization.html
|
||||
|
||||
# Coredumps
|
||||
/core
|
||||
|
||||
@@ -88,11 +88,19 @@ async fn main() -> Result<()> {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
// Get affinity dimension (should be number of items)
|
||||
let affinity_dims = if let Some(first_row) = rows.first() {
|
||||
let affinities: Vec<f64> = first_row.get(3);
|
||||
affinities.len()
|
||||
} else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
// Convert data to ndarray format
|
||||
let mut data = Array2::zeros((n_items, n_dims));
|
||||
let mut item_ids = Vec::with_capacity(n_items);
|
||||
let mut cluster_ids = Vec::with_capacity(n_items);
|
||||
let mut affinity_data = Array2::zeros((n_items, n_dims)); // Changed from n_items to n_dims for affinity dimension
|
||||
let mut affinity_data = Array2::zeros((n_items, affinity_dims)); // Use full affinity dimension
|
||||
|
||||
for (i, row) in rows.iter().enumerate() {
|
||||
let item_id: i32 = row.get(0);
|
||||
@@ -103,15 +111,9 @@ async fn main() -> Result<()> {
|
||||
item_ids.push(item_id);
|
||||
cluster_ids.push(cluster_id);
|
||||
data.row_mut(i).assign(&ArrayView1::from(&embedding));
|
||||
|
||||
// Ensure affinity vector has the right length by truncating or padding if necessary
|
||||
let mut affinity_vec = vec![0.0; n_dims];
|
||||
for (j, &val) in affinities.iter().take(n_dims).enumerate() {
|
||||
affinity_vec[j] = val;
|
||||
}
|
||||
affinity_data
|
||||
.row_mut(i)
|
||||
.assign(&ArrayView1::from(&affinity_vec));
|
||||
.assign(&ArrayView1::from(&affinities));
|
||||
}
|
||||
|
||||
// Perform PCA on both embeddings and affinity vectors
|
||||
@@ -160,15 +162,17 @@ async fn main() -> Result<()> {
|
||||
.text_array(text)
|
||||
.marker(
|
||||
plotly::common::Marker::new()
|
||||
.size(8)
|
||||
.symbol(plotly::common::MarkerSymbol::Circle),
|
||||
.size(10)
|
||||
.symbol(plotly::common::MarkerSymbol::Circle)
|
||||
.opacity(0.9)
|
||||
.line(plotly::common::Line::new().width(1.0).color("white")),
|
||||
)
|
||||
.show_legend(true);
|
||||
|
||||
plot.add_trace(trace);
|
||||
}
|
||||
|
||||
// Plot affinity vectors
|
||||
// Plot affinity vectors with distinct appearance
|
||||
for cluster_id in &unique_clusters {
|
||||
let indices: Vec<_> = cluster_ids
|
||||
.iter()
|
||||
@@ -191,7 +195,7 @@ async fn main() -> Result<()> {
|
||||
.collect();
|
||||
let text: Vec<_> = indices
|
||||
.iter()
|
||||
.map(|&i| format!("Item {}", item_ids[i]))
|
||||
.map(|&i| format!("Item {} (Affinity)", item_ids[i]))
|
||||
.collect();
|
||||
|
||||
let trace = Scatter3D::new(x, y, z)
|
||||
@@ -200,8 +204,10 @@ async fn main() -> Result<()> {
|
||||
.text_array(text)
|
||||
.marker(
|
||||
plotly::common::Marker::new()
|
||||
.size(8)
|
||||
.symbol(plotly::common::MarkerSymbol::Square),
|
||||
.size(7)
|
||||
.symbol(plotly::common::MarkerSymbol::Diamond)
|
||||
.opacity(0.7)
|
||||
.line(plotly::common::Line::new().width(1.0).color("black")),
|
||||
)
|
||||
.show_legend(true);
|
||||
|
||||
@@ -210,9 +216,7 @@ async fn main() -> Result<()> {
|
||||
|
||||
plot.set_layout(
|
||||
Layout::new()
|
||||
.title(Title::new(
|
||||
"Item Embeddings and Affinities Visualization (PCA)",
|
||||
))
|
||||
.title(Title::new("Item Embeddings (●) vs Affinity Vectors (◆)"))
|
||||
.show_legend(true)
|
||||
.legend(Legend::new().x(1.0).y(0.5))
|
||||
.margin(Margin::new().left(100).right(100).top(100).bottom(100))
|
||||
|
||||
18
src/main.rs
18
src/main.rs
@@ -33,13 +33,21 @@ struct Args {
|
||||
#[arg(long, default_value = "10000")]
|
||||
batch_size: i32,
|
||||
|
||||
/// Learning rate
|
||||
#[arg(long, default_value = "0.01")]
|
||||
learning_rate: f32,
|
||||
|
||||
/// Number of factors for matrix factorization
|
||||
#[arg(long, default_value = "8")]
|
||||
factors: i32,
|
||||
|
||||
/// Lambda for regularization
|
||||
#[arg(long, default_value = "0.0")]
|
||||
lambda1: f32,
|
||||
|
||||
/// Lambda for regularization
|
||||
#[arg(long, default_value = "0.1")]
|
||||
lambda: f32,
|
||||
lambda2: f32,
|
||||
|
||||
/// Number of threads for matrix factorization (defaults to number of CPU cores)
|
||||
#[arg(long, default_value_t = num_cpus::get() as i32)]
|
||||
@@ -236,9 +244,11 @@ async fn main() -> Result<()> {
|
||||
// Set up training parameters
|
||||
let model = Model::params()
|
||||
.factors(args.factors as i32)
|
||||
.lambda_p2(args.lambda)
|
||||
.lambda_q2(args.lambda)
|
||||
.learning_rate(0.01)
|
||||
.lambda_p1(args.lambda1)
|
||||
.lambda_q1(args.lambda1)
|
||||
.lambda_p2(args.lambda2)
|
||||
.lambda_q2(args.lambda2)
|
||||
.learning_rate(args.learning_rate)
|
||||
.iterations(100)
|
||||
.loss(Loss::OneClassL2)
|
||||
.c(0.00001)
|
||||
|
||||
Reference in New Issue
Block a user