embeddings visualization
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -9,3 +9,5 @@
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
embeddings_visualization.html
|
||||
1550
Cargo.lock
generated
1550
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
19
Cargo.toml
19
Cargo.toml
@@ -4,17 +4,18 @@ version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.35", features = ["full"] }
|
||||
tokio-postgres = "0.7"
|
||||
clap = { version = "4.4", features = ["derive"] }
|
||||
log = "0.4"
|
||||
pretty_env_logger = "0.5"
|
||||
libmf = "0.3"
|
||||
anyhow = "1.0"
|
||||
futures = "0.3"
|
||||
clap = { version = "4.4", features = ["derive"] }
|
||||
ctrlc = "3.4"
|
||||
deadpool-postgres = "0.11"
|
||||
tokio-util = "0.7"
|
||||
dotenv = "0.15"
|
||||
rand = { version = "0.8", features = ["std_rng"] }
|
||||
libmf = "0.3"
|
||||
log = "0.4"
|
||||
ndarray = { version = "0.15", features = ["blas"] }
|
||||
ndarray-linalg = { version = "0.16", features = ["openblas-system"] }
|
||||
plotly = { version = "0.8", features = ["kaleido"] }
|
||||
pretty_env_logger = "0.5"
|
||||
rand = "0.8"
|
||||
rand_distr = "0.4"
|
||||
tokio = { version = "1.35", features = ["full"] }
|
||||
tokio-postgres = "0.7"
|
||||
|
||||
165
src/bin/visualize_embeddings.rs
Normal file
165
src/bin/visualize_embeddings.rs
Normal file
@@ -0,0 +1,165 @@
|
||||
use anyhow::{Context, Result};
|
||||
use clap::Parser;
|
||||
use deadpool_postgres::{Config, Pool, Runtime};
|
||||
use dotenv::dotenv;
|
||||
use log::info;
|
||||
use ndarray::{s, Array2, ArrayView1};
|
||||
use ndarray_linalg::SVD;
|
||||
use plotly::{Plot, Scatter};
|
||||
use std::env;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use tokio_postgres::NoTls;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Table containing item embeddings
|
||||
#[arg(long)]
|
||||
embeddings_table: String,
|
||||
|
||||
/// Table containing item cluster information
|
||||
#[arg(long)]
|
||||
clusters_table: String,
|
||||
|
||||
/// Output HTML file path
|
||||
#[arg(long, default_value = "embeddings_visualization.html")]
|
||||
output_file: String,
|
||||
}
|
||||
|
||||
async fn create_pool() -> Result<Pool> {
|
||||
let mut config = Config::new();
|
||||
config.host = Some(env::var("POSTGRES_HOST").context("POSTGRES_HOST not set")?);
|
||||
config.port = Some(
|
||||
env::var("POSTGRES_PORT")
|
||||
.context("POSTGRES_PORT not set")?
|
||||
.parse()
|
||||
.context("Invalid POSTGRES_PORT")?,
|
||||
);
|
||||
config.dbname = Some(env::var("POSTGRES_DB").context("POSTGRES_DB not set")?);
|
||||
config.user = Some(env::var("POSTGRES_USER").context("POSTGRES_USER not set")?);
|
||||
config.password = Some(env::var("POSTGRES_PASSWORD").context("POSTGRES_PASSWORD not set")?);
|
||||
|
||||
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
||||
}
|
||||
|
||||
fn perform_pca(data: &Array2<f64>, n_components: usize) -> Result<Array2<f64>> {
|
||||
// Center the data
|
||||
let means = data.mean_axis(ndarray::Axis(0)).unwrap();
|
||||
let centered = data.clone() - &means.view().insert_axis(ndarray::Axis(0));
|
||||
|
||||
// Perform SVD
|
||||
let svd = centered.svd(true, true)?;
|
||||
let components = svd.2.unwrap();
|
||||
let projection = centered.dot(&components.slice(s![..n_components, ..]).t());
|
||||
|
||||
Ok(projection)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
dotenv().ok();
|
||||
pretty_env_logger::init();
|
||||
let args = Args::parse();
|
||||
|
||||
let pool = create_pool().await?;
|
||||
let client = pool.get().await?;
|
||||
|
||||
// Load embeddings and cluster information
|
||||
info!("Loading embeddings and cluster information...");
|
||||
let query = format!(
|
||||
"SELECT e.item_id, e.embedding, c.cluster_id
|
||||
FROM {} e
|
||||
JOIN {} c ON e.item_id = c.item_id
|
||||
ORDER BY e.item_id",
|
||||
args.embeddings_table, args.clusters_table
|
||||
);
|
||||
|
||||
let rows = client.query(&query, &[]).await?;
|
||||
let n_items = rows.len();
|
||||
let n_dims = if let Some(first_row) = rows.first() {
|
||||
let embedding: Vec<f64> = first_row.get(1);
|
||||
embedding.len()
|
||||
} else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
// Convert data to ndarray format
|
||||
let mut data = Array2::zeros((n_items, n_dims));
|
||||
let mut item_ids = Vec::with_capacity(n_items);
|
||||
let mut cluster_ids = Vec::with_capacity(n_items);
|
||||
|
||||
for (i, row) in rows.iter().enumerate() {
|
||||
let item_id: i32 = row.get(0);
|
||||
let embedding: Vec<f64> = row.get(1);
|
||||
let cluster_id: i32 = row.get(2);
|
||||
|
||||
item_ids.push(item_id);
|
||||
cluster_ids.push(cluster_id);
|
||||
data.row_mut(i).assign(&ArrayView1::from(&embedding));
|
||||
}
|
||||
|
||||
// Perform PCA
|
||||
info!("Performing PCA...");
|
||||
let projected_data = perform_pca(&data, 2)?;
|
||||
|
||||
// Create scatter plot for each cluster
|
||||
let mut plot = Plot::new();
|
||||
let unique_clusters: Vec<_> = cluster_ids
|
||||
.iter()
|
||||
.copied()
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
for cluster_id in unique_clusters {
|
||||
let indices: Vec<_> = cluster_ids
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, &c)| c == cluster_id)
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
let x: Vec<_> = indices.iter().map(|&i| projected_data[[i, 0]]).collect();
|
||||
let y: Vec<_> = indices.iter().map(|&i| projected_data[[i, 1]]).collect();
|
||||
let text: Vec<_> = indices
|
||||
.iter()
|
||||
.map(|&i| format!("Item {}", item_ids[i]))
|
||||
.collect();
|
||||
|
||||
let trace = Scatter::new(x, y)
|
||||
.name(&format!("Cluster {}", cluster_id))
|
||||
.mode(plotly::common::Mode::Markers)
|
||||
.text_array(text)
|
||||
.show_legend(true);
|
||||
|
||||
plot.add_trace(trace);
|
||||
}
|
||||
|
||||
plot.set_layout(
|
||||
plotly::Layout::new()
|
||||
.title(plotly::common::Title::new(
|
||||
"Item Embeddings Visualization (PCA)",
|
||||
))
|
||||
.x_axis(
|
||||
plotly::layout::Axis::new()
|
||||
.title(plotly::common::Title::new("First Principal Component")),
|
||||
)
|
||||
.y_axis(
|
||||
plotly::layout::Axis::new()
|
||||
.title(plotly::common::Title::new("Second Principal Component")),
|
||||
),
|
||||
);
|
||||
|
||||
// Save plot to HTML file
|
||||
info!("Saving visualization to {}...", args.output_file);
|
||||
let html = plot.to_html();
|
||||
let mut file = File::create(&args.output_file)?;
|
||||
file.write_all(html.as_bytes())?;
|
||||
|
||||
info!(
|
||||
"Done! Open {} in a web browser to view the visualization.",
|
||||
args.output_file
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user