From 00b30ac28580b2e21f9e29f39be7f37432f6fbd5 Mon Sep 17 00:00:00 2001 From: Dylan Knutson Date: Sat, 28 Dec 2024 01:46:48 +0000 Subject: [PATCH] cluster validation --- src/bin/generate_test_data.rs | 36 ++++++ src/bin/validate_embeddings.rs | 195 +++++++++++++++++++++++++++++++++ src/main.rs | 3 +- 3 files changed, 233 insertions(+), 1 deletion(-) create mode 100644 src/bin/validate_embeddings.rs diff --git a/src/bin/generate_test_data.rs b/src/bin/generate_test_data.rs index 3a1b6e5..86ad93e 100644 --- a/src/bin/generate_test_data.rs +++ b/src/bin/generate_test_data.rs @@ -119,10 +119,46 @@ async fn main() -> Result<()> { ) .await?; + // Create table for cluster assignments + client + .execute( + "CREATE TABLE IF NOT EXISTS item_clusters ( + item_id INTEGER PRIMARY KEY, + cluster_id INTEGER, + cluster_center FLOAT + )", + &[], + ) + .await?; + // Generate cluster information let user_clusters = generate_cluster_info(args.user_clusters); let item_clusters = generate_cluster_info(args.item_clusters); + // Store item cluster assignments + let mut item_to_cluster = HashMap::new(); + let mut rng = rand::thread_rng(); + let user_cluster_dist = Uniform::new(0, args.user_clusters); + let item_cluster_dist = Uniform::new(0, args.item_clusters); + + // Assign and store cluster information for each item + for item_id in 0..args.num_items { + let cluster_id = item_cluster_dist.sample(&mut rng); + let cluster = item_clusters[cluster_id as usize]; + item_to_cluster.insert(item_id, (cluster_id, cluster)); + + client + .execute( + "INSERT INTO item_clusters (item_id, cluster_id, cluster_center) + VALUES ($1, $2, $3) + ON CONFLICT (item_id) DO UPDATE + SET cluster_id = EXCLUDED.cluster_id, + cluster_center = EXCLUDED.cluster_center", + &[&item_id, &cluster_id, &cluster.center], + ) + .await?; + } + let mut rng = rand::thread_rng(); let user_cluster_dist = Uniform::new(0, args.user_clusters); let item_cluster_dist = Uniform::new(0, args.item_clusters); diff --git a/src/bin/validate_embeddings.rs b/src/bin/validate_embeddings.rs new file mode 100644 index 0000000..b43d8a8 --- /dev/null +++ b/src/bin/validate_embeddings.rs @@ -0,0 +1,195 @@ +use anyhow::{Context, Result}; +use clap::Parser; +use deadpool_postgres::{Config, Pool, Runtime}; +use dotenv::dotenv; +use log::info; +use std::env; +use tokio_postgres::NoTls; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Table containing user-item interactions + #[arg(long)] + interactions_table: String, + + /// Table containing item embeddings + #[arg(long)] + embeddings_table: String, + + /// Table containing item cluster information + #[arg(long)] + clusters_table: String, +} + +async fn create_pool() -> Result { + let mut config = Config::new(); + config.host = Some(env::var("POSTGRES_HOST").context("POSTGRES_HOST not set")?); + config.port = Some( + env::var("POSTGRES_PORT") + .context("POSTGRES_PORT not set")? + .parse() + .context("Invalid POSTGRES_PORT")?, + ); + config.dbname = Some(env::var("POSTGRES_DB").context("POSTGRES_DB not set")?); + config.user = Some(env::var("POSTGRES_USER").context("POSTGRES_USER not set")?); + config.password = Some(env::var("POSTGRES_PASSWORD").context("POSTGRES_PASSWORD not set")?); + + Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?) +} + +async fn analyze_cluster_cohesion(client: &tokio_postgres::Client, args: &Args) -> Result<()> { + info!("Analyzing cluster cohesion..."); + + // Calculate within-cluster and between-cluster distances + let query = format!( + "WITH distances AS ( + SELECT + a.cluster_id = b.cluster_id as same_cluster, + SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance + FROM {} a + JOIN {} b ON a.item_id < b.item_id + JOIN {} ie1 ON a.item_id = ie1.item_id + JOIN {} ie2 ON b.item_id = ie2.item_id, + UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2) + GROUP BY a.item_id, b.item_id, same_cluster + ) + SELECT + CASE WHEN same_cluster THEN 'Within Cluster' ELSE 'Between Clusters' END as comparison, + AVG(distance)::float8 as avg_distance, + STDDEV(distance)::float8 as stddev_distance, + MIN(distance)::float8 as min_distance, + MAX(distance)::float8 as max_distance, + COUNT(*) as num_pairs + FROM distances + GROUP BY same_cluster + ORDER BY same_cluster DESC", + args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table + ); + + let rows = client.query(&query, &[]).await?; + for row in rows { + let comparison: &str = row.get(0); + let avg: f64 = row.get(1); + let stddev: f64 = row.get(2); + let min: f64 = row.get(3); + let max: f64 = row.get(4); + let count: i64 = row.get(5); + + info!( + "{}: avg={:.3}, stddev={:.3}, min={:.3}, max={:.3}, pairs={}", + comparison, avg, stddev, min, max, count + ); + } + + Ok(()) +} + +async fn analyze_embedding_stats(client: &tokio_postgres::Client, args: &Args) -> Result<()> { + info!("Analyzing embedding statistics..."); + + // Calculate embedding norms and component statistics + let query = format!( + "WITH stats AS ( + SELECT + ie.item_id, + c.cluster_id, + SQRT(SUM(x * x))::float8 as norm, + AVG(x)::float8 as avg_component, + STDDEV(x)::float8 as stddev_component, + MIN(x)::float8 as min_component, + MAX(x)::float8 as max_component + FROM {} ie + JOIN {} c ON ie.item_id = c.item_id, + UNNEST(embedding) x + GROUP BY ie.item_id, c.cluster_id + ) + SELECT + cluster_id, + AVG(norm)::float8 as avg_norm, + STDDEV(norm)::float8 as stddev_norm, + AVG(avg_component)::float8 as avg_component, + AVG(stddev_component)::float8 as avg_component_spread, + COUNT(*) as num_items + FROM stats + GROUP BY cluster_id + ORDER BY cluster_id", + args.embeddings_table, args.clusters_table + ); + + let rows = client.query(&query, &[]).await?; + info!("Per-cluster embedding statistics:"); + for row in rows { + let cluster_id: i32 = row.get(0); + let avg_norm: f64 = row.get(1); + let stddev_norm: f64 = row.get(2); + let avg_component: f64 = row.get(3); + let avg_spread: f64 = row.get(4); + let count: i64 = row.get(5); + + info!( + "Cluster {}: {} items, norm={:.3}±{:.3}, components={:.3}±{:.3}", + cluster_id, count, avg_norm, stddev_norm, avg_component, avg_spread + ); + } + + Ok(()) +} + +async fn analyze_cluster_correlation(client: &tokio_postgres::Client, args: &Args) -> Result<()> { + info!("Analyzing correlation between cluster centers and embedding distances..."); + + // Calculate correlation between cluster center distances and embedding distances + let query = format!( + "WITH distances AS ( + SELECT + a.cluster_id as cluster1, + b.cluster_id as cluster2, + ABS(a.cluster_center - b.cluster_center)::float8 as center_distance, + SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as embedding_distance + FROM {} a + JOIN {} b ON a.item_id < b.item_id + JOIN {} ie1 ON a.item_id = ie1.item_id + JOIN {} ie2 ON b.item_id = ie2.item_id, + UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2) + GROUP BY a.item_id, b.item_id, a.cluster_id, b.cluster_id, a.cluster_center, b.cluster_center + ) + SELECT + CORR(center_distance, embedding_distance)::float8 as correlation, + COUNT(*) as num_pairs, + AVG(center_distance)::float8 as avg_center_dist, + AVG(embedding_distance)::float8 as avg_emb_dist + FROM distances", + args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table + ); + + let row = client.query_one(&query, &[]).await?; + info!( + "Correlation between cluster center distances and embedding distances: {:.3} (from {} pairs)", + row.get::<_, f64>(0), + row.get::<_, i64>(1) + ); + info!( + "Average distances: cluster centers={:.3}, embeddings={:.3}", + row.get::<_, f64>(2), + row.get::<_, f64>(3) + ); + + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<()> { + dotenv().ok(); + pretty_env_logger::init(); + let args = Args::parse(); + + let pool = create_pool().await?; + let client = pool.get().await?; + + analyze_cluster_cohesion(&client, &args).await?; + analyze_embedding_stats(&client, &args).await?; + analyze_cluster_correlation(&client, &args).await?; + + Ok(()) +} diff --git a/src/main.rs b/src/main.rs index 10b6bd8..916b25b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -188,7 +188,8 @@ async fn main() -> Result<()> { .lambda_q2(args.lambda) .learning_rate(0.01) .iterations(100) - .loss(libmf::Loss::BinaryL2) + .loss(libmf::Loss::OneClassL2) + .c(0.00001) .quiet(false) .fit(&matrix)?;