embeddings visualization

This commit is contained in:
Dylan Knutson
2024-12-28 01:59:11 +00:00
parent 61b9728fd8
commit e21541af46
4 changed files with 1672 additions and 64 deletions

2
.gitignore vendored
View File

@@ -9,3 +9,5 @@
.vscode/
*.swp
*.swo
embeddings_visualization.html

1550
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,17 +4,18 @@ version = "0.1.0"
edition = "2021"
[dependencies]
tokio = { version = "1.35", features = ["full"] }
tokio-postgres = "0.7"
clap = { version = "4.4", features = ["derive"] }
log = "0.4"
pretty_env_logger = "0.5"
libmf = "0.3"
anyhow = "1.0"
futures = "0.3"
clap = { version = "4.4", features = ["derive"] }
ctrlc = "3.4"
deadpool-postgres = "0.11"
tokio-util = "0.7"
dotenv = "0.15"
rand = { version = "0.8", features = ["std_rng"] }
libmf = "0.3"
log = "0.4"
ndarray = { version = "0.15", features = ["blas"] }
ndarray-linalg = { version = "0.16", features = ["openblas-system"] }
plotly = { version = "0.8", features = ["kaleido"] }
pretty_env_logger = "0.5"
rand = "0.8"
rand_distr = "0.4"
tokio = { version = "1.35", features = ["full"] }
tokio-postgres = "0.7"

View File

@@ -0,0 +1,165 @@
use anyhow::{Context, Result};
use clap::Parser;
use deadpool_postgres::{Config, Pool, Runtime};
use dotenv::dotenv;
use log::info;
use ndarray::{s, Array2, ArrayView1};
use ndarray_linalg::SVD;
use plotly::{Plot, Scatter};
use std::env;
use std::fs::File;
use std::io::Write;
use tokio_postgres::NoTls;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Table containing item embeddings
#[arg(long)]
embeddings_table: String,
/// Table containing item cluster information
#[arg(long)]
clusters_table: String,
/// Output HTML file path
#[arg(long, default_value = "embeddings_visualization.html")]
output_file: String,
}
async fn create_pool() -> Result<Pool> {
let mut config = Config::new();
config.host = Some(env::var("POSTGRES_HOST").context("POSTGRES_HOST not set")?);
config.port = Some(
env::var("POSTGRES_PORT")
.context("POSTGRES_PORT not set")?
.parse()
.context("Invalid POSTGRES_PORT")?,
);
config.dbname = Some(env::var("POSTGRES_DB").context("POSTGRES_DB not set")?);
config.user = Some(env::var("POSTGRES_USER").context("POSTGRES_USER not set")?);
config.password = Some(env::var("POSTGRES_PASSWORD").context("POSTGRES_PASSWORD not set")?);
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
}
fn perform_pca(data: &Array2<f64>, n_components: usize) -> Result<Array2<f64>> {
// Center the data
let means = data.mean_axis(ndarray::Axis(0)).unwrap();
let centered = data.clone() - &means.view().insert_axis(ndarray::Axis(0));
// Perform SVD
let svd = centered.svd(true, true)?;
let components = svd.2.unwrap();
let projection = centered.dot(&components.slice(s![..n_components, ..]).t());
Ok(projection)
}
#[tokio::main]
async fn main() -> Result<()> {
dotenv().ok();
pretty_env_logger::init();
let args = Args::parse();
let pool = create_pool().await?;
let client = pool.get().await?;
// Load embeddings and cluster information
info!("Loading embeddings and cluster information...");
let query = format!(
"SELECT e.item_id, e.embedding, c.cluster_id
FROM {} e
JOIN {} c ON e.item_id = c.item_id
ORDER BY e.item_id",
args.embeddings_table, args.clusters_table
);
let rows = client.query(&query, &[]).await?;
let n_items = rows.len();
let n_dims = if let Some(first_row) = rows.first() {
let embedding: Vec<f64> = first_row.get(1);
embedding.len()
} else {
return Ok(());
};
// Convert data to ndarray format
let mut data = Array2::zeros((n_items, n_dims));
let mut item_ids = Vec::with_capacity(n_items);
let mut cluster_ids = Vec::with_capacity(n_items);
for (i, row) in rows.iter().enumerate() {
let item_id: i32 = row.get(0);
let embedding: Vec<f64> = row.get(1);
let cluster_id: i32 = row.get(2);
item_ids.push(item_id);
cluster_ids.push(cluster_id);
data.row_mut(i).assign(&ArrayView1::from(&embedding));
}
// Perform PCA
info!("Performing PCA...");
let projected_data = perform_pca(&data, 2)?;
// Create scatter plot for each cluster
let mut plot = Plot::new();
let unique_clusters: Vec<_> = cluster_ids
.iter()
.copied()
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
for cluster_id in unique_clusters {
let indices: Vec<_> = cluster_ids
.iter()
.enumerate()
.filter(|(_, &c)| c == cluster_id)
.map(|(i, _)| i)
.collect();
let x: Vec<_> = indices.iter().map(|&i| projected_data[[i, 0]]).collect();
let y: Vec<_> = indices.iter().map(|&i| projected_data[[i, 1]]).collect();
let text: Vec<_> = indices
.iter()
.map(|&i| format!("Item {}", item_ids[i]))
.collect();
let trace = Scatter::new(x, y)
.name(&format!("Cluster {}", cluster_id))
.mode(plotly::common::Mode::Markers)
.text_array(text)
.show_legend(true);
plot.add_trace(trace);
}
plot.set_layout(
plotly::Layout::new()
.title(plotly::common::Title::new(
"Item Embeddings Visualization (PCA)",
))
.x_axis(
plotly::layout::Axis::new()
.title(plotly::common::Title::new("First Principal Component")),
)
.y_axis(
plotly::layout::Axis::new()
.title(plotly::common::Title::new("Second Principal Component")),
),
);
// Save plot to HTML file
info!("Saving visualization to {}...", args.output_file);
let html = plot.to_html();
let mut file = File::create(&args.output_file)?;
file.write_all(html.as_bytes())?;
info!(
"Done! Open {} in a web browser to view the visualization.",
args.output_file
);
Ok(())
}