From 32a7292481d90e825bac7e94779efd71e15924e9 Mon Sep 17 00:00:00 2001 From: Dylan Knutson Date: Sat, 28 Dec 2024 03:04:50 +0000 Subject: [PATCH] more fixes --- Cargo.lock | 1 + Cargo.toml | 1 + src/bin/generate_test_data.rs | 104 +++++++++------- src/bin/validate_embeddings.rs | 210 +++++++++++++++++---------------- src/main.rs | 158 ++++++++++++++++--------- 5 files changed, 282 insertions(+), 192 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f60708d..63e0ae4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1183,6 +1183,7 @@ dependencies = [ "log", "ndarray", "ndarray-linalg", + "num_cpus", "plotly", "pretty_env_logger", "rand", diff --git a/Cargo.toml b/Cargo.toml index 7ce8ab2..3113f76 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ libmf = "0.3" log = "0.4" ndarray = { version = "0.15", features = ["blas"] } ndarray-linalg = { version = "0.16", features = ["openblas-system"] } +num_cpus = "1.16" plotly = { version = "0.8", features = ["kaleido"] } pretty_env_logger = "0.5" rand = "0.8" diff --git a/src/bin/generate_test_data.rs b/src/bin/generate_test_data.rs index 4f1633a..81e0509 100644 --- a/src/bin/generate_test_data.rs +++ b/src/bin/generate_test_data.rs @@ -2,11 +2,9 @@ use anyhow::{Context, Result}; use clap::Parser; use deadpool_postgres::{Config, Pool, Runtime}; use dotenv::dotenv; -use log::{info, warn}; -use rand::distributions::{Distribution, Uniform}; +use log::info; use rand::seq::SliceRandom; use rand::Rng; -use rand_distr::Normal; use std::collections::{HashMap, HashSet}; use std::env; use tokio_postgres::NoTls; @@ -101,26 +99,33 @@ async fn main() -> Result<()> { ) .await?; - // Create table for cluster assignments - client - .execute( - "CREATE TABLE IF NOT EXISTS item_clusters ( - item_id INTEGER PRIMARY KEY, - cluster_id INTEGER, - cluster_center FLOAT - )", - &[], - ) - .await?; - - // Generate cluster centers that are well-separated + // Generate cluster centers that are well-separated in 3D space let mut rng = rand::thread_rng(); let mut cluster_centers = Vec::new(); - for i in 0..args.item_clusters { - // Place centers evenly around a circle in 2D space, then project to 1D - let angle = 2.0 * std::f64::consts::PI * (i as f64) / (args.item_clusters as f64); - let center = angle.cos(); // Project to 1D by taking cosine - cluster_centers.push(center); + + // Base vertices of a regular octahedron + let base_centers = vec![ + vec![1.0, 0.0, 0.0], // +x + vec![-1.0, 0.0, 0.0], // -x + vec![0.0, 1.0, 0.0], // +y + vec![0.0, -1.0, 0.0], // -y + vec![0.0, 0.0, 1.0], // +z + vec![0.0, 0.0, -1.0], // -z + ]; + + // Scale factor to control separation + let scale = 2.0; + + // Add jittered versions of the base centers + for i in 0..args.item_clusters as usize { + let base = &base_centers[i % base_centers.len()]; + // Add controlled random jitter (up to 10% of the scale) + let jitter = 0.1; + let x = base[0] * scale + (rng.gen::() - 0.5) * jitter * scale; + let y = base[1] * scale + (rng.gen::() - 0.5) * jitter * scale; + let z = base[2] * scale + (rng.gen::() - 0.5) * jitter * scale; + + cluster_centers.push(vec![x, y, z]); } // Create shuffled list of items @@ -130,35 +135,54 @@ async fn main() -> Result<()> { // Assign items to clusters in contiguous blocks let mut items_per_cluster = vec![Vec::new(); args.item_clusters as usize]; let items_per_cluster_count = args.num_items / args.item_clusters; + + // Drop existing item_clusters table and recreate with 3D centers + client + .execute("DROP TABLE IF EXISTS item_clusters", &[]) + .await?; + client + .execute( + "CREATE TABLE item_clusters ( + item_id INTEGER PRIMARY KEY, + cluster_id INTEGER, + center_x FLOAT, + center_y FLOAT, + center_z FLOAT + )", + &[], + ) + .await?; + + // Clear existing interactions + client + .execute(&format!("TRUNCATE TABLE {}", args.interactions_table), &[]) + .await?; + for (i, &item_id) in all_items.iter().enumerate() { let cluster_id = (i as i32) / items_per_cluster_count; if cluster_id < args.item_clusters { items_per_cluster[cluster_id as usize].push(item_id); - // Store cluster assignment with meaningful center + // Store cluster assignment with 3D center + let center = &cluster_centers[cluster_id as usize]; client .execute( - "INSERT INTO item_clusters (item_id, cluster_id, cluster_center) - VALUES ($1, $2, $3) - ON CONFLICT (item_id) DO UPDATE - SET cluster_id = EXCLUDED.cluster_id, - cluster_center = EXCLUDED.cluster_center", - &[&item_id, &cluster_id, &cluster_centers[cluster_id as usize]], + "INSERT INTO item_clusters (item_id, cluster_id, center_x, center_y, center_z) + VALUES ($1, $2, $3, $4, $5)", + &[&item_id, &cluster_id, ¢er[0], ¢er[1], ¢er[2]], ) .await?; } } - // Generate interactions with strong cluster affinity + // Generate interactions with very strong cluster affinity info!("Generating interactions..."); let mut item_interactions = HashMap::new(); // For each user cluster for user_cluster in 0..args.user_clusters { - // Each user cluster strongly prefers 1-2 item clusters - let preferred_item_clusters: HashSet = (0..args.item_clusters) - .filter(|&i| i % args.user_clusters == user_cluster) - .collect(); + // Each user cluster strongly prefers exactly one item cluster with wraparound + let preferred_item_cluster = user_cluster % args.item_clusters; // For each user in this cluster let cluster_start = user_cluster * (args.num_users / args.user_clusters); @@ -169,13 +193,13 @@ async fn main() -> Result<()> { }; for user_id in cluster_start..cluster_end { - // Get all items from preferred clusters - let mut preferred_items = HashSet::new(); - for &item_cluster in &preferred_item_clusters { - preferred_items.extend(items_per_cluster[item_cluster as usize].iter()); - } + // Get all items from preferred cluster + let preferred_items: HashSet<_> = items_per_cluster[preferred_item_cluster as usize] + .iter() + .copied() + .collect(); - // Interact with most preferred items (1 - noise_level of items in preferred clusters) + // Interact with most preferred items (1 - noise_level of items in preferred cluster) let num_interactions = (preferred_items.len() as f64 * (1.0 - args.noise_level)) as usize; let selected_items: Vec<_> = preferred_items.into_iter().collect(); @@ -184,7 +208,7 @@ async fn main() -> Result<()> { .copied() .collect(); - // Add random items from non-preferred clusters (noise) + // Add very few random items from non-preferred clusters (noise) let num_noise = (args.avg_interactions as f64 * args.noise_level) as i32; while interactions.len() < num_interactions + num_noise as usize { let item_id = rng.gen_range(0..args.num_items); diff --git a/src/bin/validate_embeddings.rs b/src/bin/validate_embeddings.rs index c37c84d..e964e54 100644 --- a/src/bin/validate_embeddings.rs +++ b/src/bin/validate_embeddings.rs @@ -41,85 +41,96 @@ async fn create_pool() -> Result { async fn analyze_cluster_cohesion(client: &tokio_postgres::Client, args: &Args) -> Result<()> { info!("Analyzing cluster cohesion..."); - // Calculate within-cluster and between-cluster distances - let query = format!( - "WITH distances AS ( + // Analyze cluster cohesion + info!("Analyzing cluster cohesion..."); + let cohesion_stats = client + .query_one( + "WITH embedding_distances AS ( + SELECT + a.item_id as item1, + b.item_id as item2, + a.cluster_id as cluster1, + b.cluster_id as cluster2, + SQRT( + POW(a.center_x - b.center_x, 2) + + POW(a.center_y - b.center_y, 2) + + POW(a.center_z - b.center_z, 2) + ) as distance, + CASE WHEN a.cluster_id = b.cluster_id THEN 'within' ELSE 'between' END as distance_type + FROM item_clusters a + CROSS JOIN item_clusters b + WHERE a.item_id < b.item_id + ) SELECT - a.cluster_id = b.cluster_id as same_cluster, - a.cluster_id as cluster1, - b.cluster_id as cluster2, - SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance - FROM {} a - JOIN {} b ON a.item_id < b.item_id - JOIN {} ie1 ON a.item_id = ie1.item_id - JOIN {} ie2 ON b.item_id = ie2.item_id, - UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2) - GROUP BY a.item_id, b.item_id, same_cluster, a.cluster_id, b.cluster_id + AVG(CASE WHEN distance_type = 'within' THEN distance END) as avg_within, + STDDEV(CASE WHEN distance_type = 'within' THEN distance END) as stddev_within, + MIN(CASE WHEN distance_type = 'within' THEN distance END) as min_within, + MAX(CASE WHEN distance_type = 'within' THEN distance END) as max_within, + COUNT(CASE WHEN distance_type = 'within' THEN 1 END) as count_within, + AVG(CASE WHEN distance_type = 'between' THEN distance END) as avg_between, + STDDEV(CASE WHEN distance_type = 'between' THEN distance END) as stddev_between, + MIN(CASE WHEN distance_type = 'between' THEN distance END) as min_between, + MAX(CASE WHEN distance_type = 'between' THEN distance END) as max_between, + COUNT(CASE WHEN distance_type = 'between' THEN 1 END) as count_between + FROM embedding_distances", + &[], ) - SELECT - CASE WHEN same_cluster THEN 'Within Cluster' ELSE 'Between Clusters' END as comparison, - AVG(distance)::float8 as avg_distance, - STDDEV(distance)::float8 as stddev_distance, - MIN(distance)::float8 as min_distance, - MAX(distance)::float8 as max_distance, - COUNT(*) as num_pairs - FROM distances - GROUP BY same_cluster - ORDER BY same_cluster DESC", - args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table + .await?; + + // Print cohesion statistics + info!( + "Within Cluster: avg={:.3}, stddev={:.3}, min={:.3}, max={:.3}, pairs={}", + cohesion_stats.get::<_, f64>("avg_within"), + cohesion_stats.get::<_, f64>("stddev_within"), + cohesion_stats.get::<_, f64>("min_within"), + cohesion_stats.get::<_, f64>("max_within"), + cohesion_stats.get::<_, i64>("count_within") ); - let rows = client.query(&query, &[]).await?; - for row in rows { - let comparison: &str = row.get(0); - let avg: f64 = row.get(1); - let stddev: f64 = row.get(2); - let min: f64 = row.get(3); - let max: f64 = row.get(4); - let count: i64 = row.get(5); - - info!( - "{}: avg={:.3}, stddev={:.3}, min={:.3}, max={:.3}, pairs={}", - comparison, avg, stddev, min, max, count - ); - } - - // Calculate per-cluster statistics - let query = format!( - "WITH distances AS ( - SELECT - a.cluster_id, - SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance - FROM {} a - JOIN {} b ON a.item_id < b.item_id - JOIN {} ie1 ON a.item_id = ie1.item_id - JOIN {} ie2 ON b.item_id = ie2.item_id, - UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2) - WHERE a.cluster_id = b.cluster_id - GROUP BY a.item_id, b.item_id, a.cluster_id - ) - SELECT - cluster_id, - AVG(distance)::float8 as avg_distance, - STDDEV(distance)::float8 as stddev_distance, - COUNT(*) as num_pairs - FROM distances - GROUP BY cluster_id - ORDER BY cluster_id", - args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table + info!( + "Between Clusters: avg={:.3}, stddev={:.3}, min={:.3}, max={:.3}, pairs={}", + cohesion_stats.get::<_, f64>("avg_between"), + cohesion_stats.get::<_, f64>("stddev_between"), + cohesion_stats.get::<_, f64>("min_between"), + cohesion_stats.get::<_, f64>("max_between"), + cohesion_stats.get::<_, i64>("count_between") ); + // Print per-cluster statistics info!("\nPer-cluster cohesion:"); - let rows = client.query(&query, &[]).await?; - for row in rows { - let cluster: i32 = row.get(0); - let avg: f64 = row.get(1); - let stddev: f64 = row.get(2); - let count: i64 = row.get(3); + let cluster_stats = client + .query( + "WITH cluster_distances AS ( + SELECT + a.cluster_id, + SQRT( + POW(a.center_x - b.center_x, 2) + + POW(a.center_y - b.center_y, 2) + + POW(a.center_z - b.center_z, 2) + ) as distance + FROM item_clusters a + JOIN item_clusters b ON a.cluster_id = b.cluster_id AND a.item_id < b.item_id + ) + SELECT + cluster_id, + AVG(distance) as avg_distance, + STDDEV(distance) as stddev_distance, + COUNT(*) as num_pairs + FROM cluster_distances + GROUP BY cluster_id + ORDER BY cluster_id", + &[], + ) + .await?; + for row in cluster_stats { + let cluster_id: i32 = row.get("cluster_id"); + let avg_distance: f64 = row.get("avg_distance"); + let stddev_distance: f64 = row.get("stddev_distance"); + let num_pairs: i64 = row.get("num_pairs"); info!( - "Cluster {}: avg_distance={:.3}±{:.3} ({} pairs)", - cluster, avg, stddev, count + "Cluster {}: avg={:.3}, stddev={:.3}, pairs={}", + cluster_id, avg_distance, stddev_distance, num_pairs ); } @@ -223,39 +234,40 @@ async fn analyze_cluster_correlation(client: &tokio_postgres::Client, args: &Arg info!("Analyzing correlation between cluster centers and embedding distances..."); // Calculate correlation between cluster center distances and embedding distances - let query = format!( - "WITH distances AS ( + let correlation = client + .query_one( + "WITH distances AS ( + SELECT + a.cluster_id as cluster1, + b.cluster_id as cluster2, + SQRT( + POW(a.center_x - b.center_x, 2) + + POW(a.center_y - b.center_y, 2) + + POW(a.center_z - b.center_z, 2) + ) as center_distance, + SQRT( + POW(ae.embedding[1] - be.embedding[1], 2) + + POW(ae.embedding[2] - be.embedding[2], 2) + + POW(ae.embedding[3] - be.embedding[3], 2) + ) as embedding_distance + FROM item_clusters a + JOIN item_clusters b ON a.cluster_id < b.cluster_id + JOIN item_embeddings ae ON a.item_id = ae.item_id + JOIN item_embeddings be ON b.item_id = be.item_id + ) SELECT - a.cluster_id as cluster1, - b.cluster_id as cluster2, - ABS(a.cluster_center - b.cluster_center)::float8 as center_distance, - SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as embedding_distance - FROM {} a - JOIN {} b ON a.item_id < b.item_id - JOIN {} ie1 ON a.item_id = ie1.item_id - JOIN {} ie2 ON b.item_id = ie2.item_id, - UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2) - GROUP BY a.item_id, b.item_id, a.cluster_id, b.cluster_id, a.cluster_center, b.cluster_center + corr(center_distance, embedding_distance) as correlation, + COUNT(*) as num_pairs + FROM distances", + &[], ) - SELECT - CORR(center_distance, embedding_distance)::float8 as correlation, - COUNT(*) as num_pairs, - AVG(center_distance)::float8 as avg_center_dist, - AVG(embedding_distance)::float8 as avg_emb_dist - FROM distances", - args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table - ); + .await?; - let row = client.query_one(&query, &[]).await?; + let correlation_value: f64 = correlation.get("correlation"); + let num_pairs: i64 = correlation.get("num_pairs"); info!( - "Correlation between cluster center distances and embedding distances: {:.3} (from {} pairs)", - row.get::<_, f64>(0), - row.get::<_, i64>(1) - ); - info!( - "Average distances: cluster centers={:.3}, embeddings={:.3}", - row.get::<_, f64>(2), - row.get::<_, f64>(3) + "Correlation between cluster center distances and embedding distances: {:.3} ({} pairs)", + correlation_value, num_pairs ); Ok(()) diff --git a/src/main.rs b/src/main.rs index 916b25b..05a06dd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,8 +2,9 @@ use anyhow::{Context, Result}; use clap::Parser; use deadpool_postgres::{Config, Pool, Runtime}; use dotenv::dotenv; -use libmf::Matrix; -use log::{info, warn}; +use libmf::{Loss, Matrix, Model}; +use log::info; +use num_cpus; use std::env; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -39,6 +40,10 @@ struct Args { /// Lambda for regularization #[arg(long, default_value = "0.1")] lambda: f32, + + /// Number of threads for matrix factorization (defaults to number of CPU cores) + #[arg(long, default_value_t = num_cpus::get() as i32)] + threads: i32, } async fn create_pool() -> Result { @@ -57,33 +62,51 @@ async fn create_pool() -> Result { Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?) } -async fn load_data_batch(pool: &Pool, args: &Args, offset: i64) -> Result> { - let client = pool.get().await?; - let query = format!( - "SELECT {}, {} FROM {} OFFSET $1 LIMIT $2", - args.user_id_column, args.item_id_column, args.source_table - ); +async fn load_data_batch( + client: &deadpool_postgres::Client, + source_table: &str, + user_id_column: &str, + item_id_column: &str, + batch_size: usize, + last_user_id: Option, + last_item_id: Option, +) -> Result> { + let rows = if let (Some(last_user), Some(last_item)) = (last_user_id, last_item_id) { + let query = format!( + "SELECT {user}, {item} FROM {table} \ + WHERE ({user}, {item}) > ($1, $2) \ + ORDER BY {user}, {item} \ + LIMIT $3", + user = user_id_column, + item = item_id_column, + table = source_table, + ); + client + .query(&query, &[&last_user, &last_item, &(batch_size as i64)]) + .await? + } else { + let query = format!( + "SELECT {user}, {item} FROM {table} \ + ORDER BY {user}, {item} \ + LIMIT $1", + user = user_id_column, + item = item_id_column, + table = source_table, + ); + client.query(&query, &[&(batch_size as i64)]).await? + }; - let rows = client - .query(&query, &[&offset, &(args.batch_size as i64)]) - .await?; + let mut batch = Vec::with_capacity(rows.len()); + for row in rows { + let user_id: i32 = row.get(0); + let item_id: i32 = row.get(1); + batch.push((user_id, item_id)); + } - Ok(rows - .into_iter() - .map(|row| { - let user_id: i32 = row.get(0); - let item_id: i32 = row.get(1); - (user_id, item_id, 1.0) // Using 1.0 as interaction strength - }) - .collect()) + Ok(batch) } -async fn save_embeddings( - pool: &Pool, - args: &Args, - model: &libmf::Model, - item_ids: &[i32], -) -> Result<()> { +async fn save_embeddings(pool: &Pool, args: &Args, model: &Model, item_ids: &[i32]) -> Result<()> { let client = pool.get().await?; // Create the target table if it doesn't exist @@ -128,6 +151,19 @@ async fn save_embeddings( Ok(()) } +async fn get_unique_item_ids( + client: &deadpool_postgres::Client, + source_table: &str, + item_id_column: &str, +) -> Result> { + let query = format!( + "SELECT DISTINCT {} FROM {} ORDER BY {}", + item_id_column, source_table, item_id_column + ); + let rows = client.query(&query, &[]).await?; + Ok(rows.iter().map(|row| row.get(0)).collect()) +} + #[tokio::main] async fn main() -> Result<()> { dotenv().ok(); @@ -143,59 +179,75 @@ async fn main() -> Result<()> { })?; let pool = create_pool().await?; - let mut offset = 0i64; - let mut all_data = Vec::new(); - let mut unique_item_ids = Vec::new(); + let mut matrix = Matrix::new(); + let mut last_user_id = None; + let mut last_item_id = None; + let mut total_rows = 0; info!("Starting data loading..."); while running.load(Ordering::SeqCst) { - let batch = load_data_batch(&pool, &args, offset).await?; + let batch = load_data_batch( + &pool.get().await?, + &args.source_table, + &args.user_id_column, + &args.item_id_column, + args.batch_size as usize, + last_user_id, + last_item_id, + ) + .await?; + if batch.is_empty() { break; } - // Track unique item IDs - for &(_, item_id, _) in &batch { - if !unique_item_ids.contains(&item_id) { - unique_item_ids.push(item_id); - } + total_rows += batch.len(); + info!( + "Loaded batch of {} rows (total: {})", + batch.len(), + total_rows + ); + + // Update last seen IDs for next batch + if let Some((user_id, item_id)) = batch.last() { + last_user_id = Some(*user_id); + last_item_id = Some(*item_id); } - all_data.extend(batch); - offset += args.batch_size as i64; - info!("Loaded {} rows so far", all_data.len()); + // Process batch + for (user_id, item_id) in batch { + matrix.push(user_id as i32, item_id as i32, 1.0f32); + } } - if all_data.is_empty() { - warn!("No data was loaded. Exiting."); + info!("Loaded {} total rows", total_rows); + + if total_rows == 0 { + info!("No data found in source table"); return Ok(()); } - info!("Creating libmf problem..."); - let mut matrix = Matrix::new(); - for (user_id, item_id, value) in all_data { - matrix.push( - user_id.try_into().unwrap(), - item_id.try_into().unwrap(), - value, - ); - } + // Get unique item IDs from database + let client = pool.get().await?; + let unique_item_ids = + get_unique_item_ids(&client, &args.source_table, &args.item_id_column).await?; + info!("Found {} unique items", unique_item_ids.len()); - info!("Training model with {} factors...", args.factors); - let model = libmf::Model::params() - .factors(args.factors.try_into().unwrap()) + // Set up training parameters + let model = Model::params() + .factors(args.factors as i32) .lambda_p2(args.lambda) .lambda_q2(args.lambda) .learning_rate(0.01) .iterations(100) - .loss(libmf::Loss::OneClassL2) + .loss(Loss::OneClassL2) .c(0.00001) .quiet(false) + .threads(args.threads) .fit(&matrix)?; info!("Saving embeddings..."); save_embeddings(&pool, &args, &model, &unique_item_ids).await?; - info!("Done!"); Ok(()) }