diff --git a/src/bin/generate_test_data.rs b/src/bin/generate_test_data.rs index 86ad93e..4f1633a 100644 --- a/src/bin/generate_test_data.rs +++ b/src/bin/generate_test_data.rs @@ -4,9 +4,10 @@ use deadpool_postgres::{Config, Pool, Runtime}; use dotenv::dotenv; use log::{info, warn}; use rand::distributions::{Distribution, Uniform}; +use rand::seq::SliceRandom; use rand::Rng; use rand_distr::Normal; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::env; use tokio_postgres::NoTls; @@ -33,9 +34,9 @@ struct Args { #[arg(long, default_value = "50")] avg_interactions: i32, - /// Minimum interactions per item - #[arg(long, default_value = "10")] - min_item_interactions: i32, + /// Noise level (0.0 to 1.0) - fraction of interactions that are random + #[arg(long, default_value = "0.05")] + noise_level: f64, /// Source table name for interactions #[arg(long, default_value = "user_interactions")] @@ -62,25 +63,6 @@ async fn create_pool() -> Result { Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?) } -#[derive(Copy, Clone)] -struct ClusterInfo { - center: f64, - std_dev: f64, -} - -fn generate_cluster_info(num_clusters: i32) -> Vec { - let mut rng = rand::thread_rng(); - let center_dist = Uniform::new(-1.0, 1.0); - let std_dev_dist = Uniform::new(0.1, 0.3); - - (0..num_clusters) - .map(|_| ClusterInfo { - center: center_dist.sample(&mut rng), - std_dev: std_dev_dist.sample(&mut rng), - }) - .collect() -} - #[tokio::main] async fn main() -> Result<()> { dotenv().ok(); @@ -131,74 +113,88 @@ async fn main() -> Result<()> { ) .await?; - // Generate cluster information - let user_clusters = generate_cluster_info(args.user_clusters); - let item_clusters = generate_cluster_info(args.item_clusters); - - // Store item cluster assignments - let mut item_to_cluster = HashMap::new(); + // Generate cluster centers that are well-separated let mut rng = rand::thread_rng(); - let user_cluster_dist = Uniform::new(0, args.user_clusters); - let item_cluster_dist = Uniform::new(0, args.item_clusters); - - // Assign and store cluster information for each item - for item_id in 0..args.num_items { - let cluster_id = item_cluster_dist.sample(&mut rng); - let cluster = item_clusters[cluster_id as usize]; - item_to_cluster.insert(item_id, (cluster_id, cluster)); - - client - .execute( - "INSERT INTO item_clusters (item_id, cluster_id, cluster_center) - VALUES ($1, $2, $3) - ON CONFLICT (item_id) DO UPDATE - SET cluster_id = EXCLUDED.cluster_id, - cluster_center = EXCLUDED.cluster_center", - &[&item_id, &cluster_id, &cluster.center], - ) - .await?; + let mut cluster_centers = Vec::new(); + for i in 0..args.item_clusters { + // Place centers evenly around a circle in 2D space, then project to 1D + let angle = 2.0 * std::f64::consts::PI * (i as f64) / (args.item_clusters as f64); + let center = angle.cos(); // Project to 1D by taking cosine + cluster_centers.push(center); } - let mut rng = rand::thread_rng(); - let user_cluster_dist = Uniform::new(0, args.user_clusters); - let item_cluster_dist = Uniform::new(0, args.item_clusters); - let num_interactions_dist = Normal::new( - args.avg_interactions as f64, - args.avg_interactions as f64 * 0.2, - ) - .context("Failed to create normal distribution")?; + // Create shuffled list of items + let mut all_items: Vec = (0..args.num_items).collect(); + all_items.shuffle(&mut rng); - // Track item interaction counts + // Assign items to clusters in contiguous blocks + let mut items_per_cluster = vec![Vec::new(); args.item_clusters as usize]; + let items_per_cluster_count = args.num_items / args.item_clusters; + for (i, &item_id) in all_items.iter().enumerate() { + let cluster_id = (i as i32) / items_per_cluster_count; + if cluster_id < args.item_clusters { + items_per_cluster[cluster_id as usize].push(item_id); + + // Store cluster assignment with meaningful center + client + .execute( + "INSERT INTO item_clusters (item_id, cluster_id, cluster_center) + VALUES ($1, $2, $3) + ON CONFLICT (item_id) DO UPDATE + SET cluster_id = EXCLUDED.cluster_id, + cluster_center = EXCLUDED.cluster_center", + &[&item_id, &cluster_id, &cluster_centers[cluster_id as usize]], + ) + .await?; + } + } + + // Generate interactions with strong cluster affinity + info!("Generating interactions..."); let mut item_interactions = HashMap::new(); - info!("Generating interactions..."); - // First pass: generate normal user-item interactions - for user_id in 0..args.num_users { - // Assign user to a cluster - let user_cluster = user_clusters[user_cluster_dist.sample(&mut rng) as usize]; + // For each user cluster + for user_cluster in 0..args.user_clusters { + // Each user cluster strongly prefers 1-2 item clusters + let preferred_item_clusters: HashSet = (0..args.item_clusters) + .filter(|&i| i % args.user_clusters == user_cluster) + .collect(); - // Generate number of interactions for this user - let num_interactions = num_interactions_dist.sample(&mut rng).max(1.0) as i32; + // For each user in this cluster + let cluster_start = user_cluster * (args.num_users / args.user_clusters); + let cluster_end = if user_cluster == args.user_clusters - 1 { + args.num_users + } else { + (user_cluster + 1) * (args.num_users / args.user_clusters) + }; - // Generate interactions - let mut interactions = Vec::new(); - while interactions.len() < num_interactions as usize { - let item_id = rng.gen_range(0..args.num_items); - if !interactions.contains(&item_id) { - interactions.push(item_id); + for user_id in cluster_start..cluster_end { + // Get all items from preferred clusters + let mut preferred_items = HashSet::new(); + for &item_cluster in &preferred_item_clusters { + preferred_items.extend(items_per_cluster[item_cluster as usize].iter()); } - } - // Insert interactions - for &item_id in &interactions { - let item_cluster = item_clusters[item_cluster_dist.sample(&mut rng) as usize]; + // Interact with most preferred items (1 - noise_level of items in preferred clusters) + let num_interactions = + (preferred_items.len() as f64 * (1.0 - args.noise_level)) as usize; + let selected_items: Vec<_> = preferred_items.into_iter().collect(); + let mut interactions: HashSet<_> = selected_items + .choose_multiple(&mut rng, num_interactions) + .copied() + .collect(); - // Higher probability of interaction if user and item clusters are similar - let interaction_prob = (-(user_cluster.center - item_cluster.center).powi(2) - / (user_cluster.std_dev.powi(2) + item_cluster.std_dev.powi(2))) - .exp(); + // Add random items from non-preferred clusters (noise) + let num_noise = (args.avg_interactions as f64 * args.noise_level) as i32; + while interactions.len() < num_interactions + num_noise as usize { + let item_id = rng.gen_range(0..args.num_items); + if !interactions.contains(&item_id) { + interactions.insert(item_id); + } + } - if rng.gen::() < interaction_prob { + // Insert interactions + for &item_id in &interactions { client .execute( &format!( @@ -213,51 +209,7 @@ async fn main() -> Result<()> { } } - if user_id % 100 == 0 { - info!("Generated interactions for {} users", user_id); - } - } - - // Second pass: ensure minimum interactions per item - info!("Ensuring minimum interactions per item..."); - for item_id in 0..args.num_items { - let current_interactions = item_interactions.get(&item_id).copied().unwrap_or(0); - if current_interactions < args.min_item_interactions { - let needed = args.min_item_interactions - current_interactions; - let item_cluster = item_clusters[item_cluster_dist.sample(&mut rng) as usize]; - - // Add interactions from random users - let mut added = 0; - while added < needed { - let user_id = rng.gen_range(0..args.num_users); - let user_cluster = user_clusters[user_cluster_dist.sample(&mut rng) as usize]; - - // Use higher base probability to ensure we get enough interactions - let interaction_prob = 0.5 - + 0.5 - * (-(user_cluster.center - item_cluster.center).powi(2) - / (user_cluster.std_dev.powi(2) + item_cluster.std_dev.powi(2))) - .exp(); - - if rng.gen::() < interaction_prob { - match client - .execute( - &format!( - "INSERT INTO {} (user_id, item_id) VALUES ($1, $2) - ON CONFLICT DO NOTHING", - args.interactions_table - ), - &[&user_id, &item_id], - ) - .await - { - Ok(1) => added += 1, // Only increment if we actually inserted - Ok(_) => continue, // Row already existed - Err(e) => return Err(e.into()), - } - } - } - } + info!("Generated interactions for user cluster {}", user_cluster); } // Get final statistics @@ -266,21 +218,43 @@ async fn main() -> Result<()> { &format!( "SELECT COUNT(*) as total_interactions, COUNT(DISTINCT user_id) as unique_users, - COUNT(DISTINCT item_id) as unique_items + COUNT(DISTINCT item_id) as unique_items, + COUNT(DISTINCT user_id)::float / {} as user_coverage, + COUNT(DISTINCT item_id)::float / {} as item_coverage FROM {}", - args.interactions_table + args.num_users, args.num_items, args.interactions_table ), &[], ) .await?; info!( - "Generated {} total interactions between {} users and {} items", + "Generated {} total interactions between {} users ({:.1}%) and {} items ({:.1}%)", stats.get::<_, i64>(0), stats.get::<_, i64>(1), - stats.get::<_, i64>(2) + stats.get::<_, f64>(3) * 100.0, + stats.get::<_, i64>(2), + stats.get::<_, f64>(4) * 100.0 ); + // Print cluster sizes + let cluster_sizes = client + .query( + "SELECT cluster_id, COUNT(*) as size + FROM item_clusters + GROUP BY cluster_id + ORDER BY cluster_id", + &[], + ) + .await?; + + info!("Cluster sizes:"); + for row in cluster_sizes { + let cluster_id: i32 = row.get(0); + let size: i64 = row.get(1); + info!("Cluster {}: {} items", cluster_id, size); + } + info!("Done!"); Ok(()) } diff --git a/src/bin/validate_embeddings.rs b/src/bin/validate_embeddings.rs index b43d8a8..c37c84d 100644 --- a/src/bin/validate_embeddings.rs +++ b/src/bin/validate_embeddings.rs @@ -46,13 +46,15 @@ async fn analyze_cluster_cohesion(client: &tokio_postgres::Client, args: &Args) "WITH distances AS ( SELECT a.cluster_id = b.cluster_id as same_cluster, + a.cluster_id as cluster1, + b.cluster_id as cluster2, SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance FROM {} a JOIN {} b ON a.item_id < b.item_id JOIN {} ie1 ON a.item_id = ie1.item_id JOIN {} ie2 ON b.item_id = ie2.item_id, UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2) - GROUP BY a.item_id, b.item_id, same_cluster + GROUP BY a.item_id, b.item_id, same_cluster, a.cluster_id, b.cluster_id ) SELECT CASE WHEN same_cluster THEN 'Within Cluster' ELSE 'Between Clusters' END as comparison, @@ -82,6 +84,87 @@ async fn analyze_cluster_cohesion(client: &tokio_postgres::Client, args: &Args) ); } + // Calculate per-cluster statistics + let query = format!( + "WITH distances AS ( + SELECT + a.cluster_id, + SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance + FROM {} a + JOIN {} b ON a.item_id < b.item_id + JOIN {} ie1 ON a.item_id = ie1.item_id + JOIN {} ie2 ON b.item_id = ie2.item_id, + UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2) + WHERE a.cluster_id = b.cluster_id + GROUP BY a.item_id, b.item_id, a.cluster_id + ) + SELECT + cluster_id, + AVG(distance)::float8 as avg_distance, + STDDEV(distance)::float8 as stddev_distance, + COUNT(*) as num_pairs + FROM distances + GROUP BY cluster_id + ORDER BY cluster_id", + args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table + ); + + info!("\nPer-cluster cohesion:"); + let rows = client.query(&query, &[]).await?; + for row in rows { + let cluster: i32 = row.get(0); + let avg: f64 = row.get(1); + let stddev: f64 = row.get(2); + let count: i64 = row.get(3); + + info!( + "Cluster {}: avg_distance={:.3}±{:.3} ({} pairs)", + cluster, avg, stddev, count + ); + } + + // Calculate separation between specific cluster pairs + let query = format!( + "WITH distances AS ( + SELECT + a.cluster_id as cluster1, + b.cluster_id as cluster2, + SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance + FROM {} a + JOIN {} b ON a.item_id < b.item_id + JOIN {} ie1 ON a.item_id = ie1.item_id + JOIN {} ie2 ON b.item_id = ie2.item_id, + UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2) + WHERE a.cluster_id < b.cluster_id + GROUP BY a.item_id, b.item_id, a.cluster_id, b.cluster_id + ) + SELECT + cluster1, + cluster2, + AVG(distance)::float8 as avg_distance, + STDDEV(distance)::float8 as stddev_distance, + COUNT(*) as num_pairs + FROM distances + GROUP BY cluster1, cluster2 + ORDER BY cluster1, cluster2", + args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table + ); + + info!("\nBetween-cluster separation:"); + let rows = client.query(&query, &[]).await?; + for row in rows { + let cluster1: i32 = row.get(0); + let cluster2: i32 = row.get(1); + let avg: f64 = row.get(2); + let stddev: f64 = row.get(3); + let count: i64 = row.get(4); + + info!( + "Clusters {} <-> {}: avg_distance={:.3}±{:.3} ({} pairs)", + cluster1, cluster2, avg, stddev, count + ); + } + Ok(()) }