better test data generation

This commit is contained in:
Dylan Knutson
2024-12-28 01:51:33 +00:00
parent 00b30ac285
commit 61b9728fd8
2 changed files with 186 additions and 129 deletions

View File

@@ -4,9 +4,10 @@ use deadpool_postgres::{Config, Pool, Runtime};
use dotenv::dotenv; use dotenv::dotenv;
use log::{info, warn}; use log::{info, warn};
use rand::distributions::{Distribution, Uniform}; use rand::distributions::{Distribution, Uniform};
use rand::seq::SliceRandom;
use rand::Rng; use rand::Rng;
use rand_distr::Normal; use rand_distr::Normal;
use std::collections::HashMap; use std::collections::{HashMap, HashSet};
use std::env; use std::env;
use tokio_postgres::NoTls; use tokio_postgres::NoTls;
@@ -33,9 +34,9 @@ struct Args {
#[arg(long, default_value = "50")] #[arg(long, default_value = "50")]
avg_interactions: i32, avg_interactions: i32,
/// Minimum interactions per item /// Noise level (0.0 to 1.0) - fraction of interactions that are random
#[arg(long, default_value = "10")] #[arg(long, default_value = "0.05")]
min_item_interactions: i32, noise_level: f64,
/// Source table name for interactions /// Source table name for interactions
#[arg(long, default_value = "user_interactions")] #[arg(long, default_value = "user_interactions")]
@@ -62,25 +63,6 @@ async fn create_pool() -> Result<Pool> {
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?) Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
} }
#[derive(Copy, Clone)]
struct ClusterInfo {
center: f64,
std_dev: f64,
}
fn generate_cluster_info(num_clusters: i32) -> Vec<ClusterInfo> {
let mut rng = rand::thread_rng();
let center_dist = Uniform::new(-1.0, 1.0);
let std_dev_dist = Uniform::new(0.1, 0.3);
(0..num_clusters)
.map(|_| ClusterInfo {
center: center_dist.sample(&mut rng),
std_dev: std_dev_dist.sample(&mut rng),
})
.collect()
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
dotenv().ok(); dotenv().ok();
@@ -131,74 +113,88 @@ async fn main() -> Result<()> {
) )
.await?; .await?;
// Generate cluster information // Generate cluster centers that are well-separated
let user_clusters = generate_cluster_info(args.user_clusters);
let item_clusters = generate_cluster_info(args.item_clusters);
// Store item cluster assignments
let mut item_to_cluster = HashMap::new();
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let user_cluster_dist = Uniform::new(0, args.user_clusters); let mut cluster_centers = Vec::new();
let item_cluster_dist = Uniform::new(0, args.item_clusters); for i in 0..args.item_clusters {
// Place centers evenly around a circle in 2D space, then project to 1D
// Assign and store cluster information for each item let angle = 2.0 * std::f64::consts::PI * (i as f64) / (args.item_clusters as f64);
for item_id in 0..args.num_items { let center = angle.cos(); // Project to 1D by taking cosine
let cluster_id = item_cluster_dist.sample(&mut rng); cluster_centers.push(center);
let cluster = item_clusters[cluster_id as usize];
item_to_cluster.insert(item_id, (cluster_id, cluster));
client
.execute(
"INSERT INTO item_clusters (item_id, cluster_id, cluster_center)
VALUES ($1, $2, $3)
ON CONFLICT (item_id) DO UPDATE
SET cluster_id = EXCLUDED.cluster_id,
cluster_center = EXCLUDED.cluster_center",
&[&item_id, &cluster_id, &cluster.center],
)
.await?;
} }
let mut rng = rand::thread_rng(); // Create shuffled list of items
let user_cluster_dist = Uniform::new(0, args.user_clusters); let mut all_items: Vec<i32> = (0..args.num_items).collect();
let item_cluster_dist = Uniform::new(0, args.item_clusters); all_items.shuffle(&mut rng);
let num_interactions_dist = Normal::new(
args.avg_interactions as f64,
args.avg_interactions as f64 * 0.2,
)
.context("Failed to create normal distribution")?;
// Track item interaction counts // Assign items to clusters in contiguous blocks
let mut items_per_cluster = vec![Vec::new(); args.item_clusters as usize];
let items_per_cluster_count = args.num_items / args.item_clusters;
for (i, &item_id) in all_items.iter().enumerate() {
let cluster_id = (i as i32) / items_per_cluster_count;
if cluster_id < args.item_clusters {
items_per_cluster[cluster_id as usize].push(item_id);
// Store cluster assignment with meaningful center
client
.execute(
"INSERT INTO item_clusters (item_id, cluster_id, cluster_center)
VALUES ($1, $2, $3)
ON CONFLICT (item_id) DO UPDATE
SET cluster_id = EXCLUDED.cluster_id,
cluster_center = EXCLUDED.cluster_center",
&[&item_id, &cluster_id, &cluster_centers[cluster_id as usize]],
)
.await?;
}
}
// Generate interactions with strong cluster affinity
info!("Generating interactions...");
let mut item_interactions = HashMap::new(); let mut item_interactions = HashMap::new();
info!("Generating interactions..."); // For each user cluster
// First pass: generate normal user-item interactions for user_cluster in 0..args.user_clusters {
for user_id in 0..args.num_users { // Each user cluster strongly prefers 1-2 item clusters
// Assign user to a cluster let preferred_item_clusters: HashSet<i32> = (0..args.item_clusters)
let user_cluster = user_clusters[user_cluster_dist.sample(&mut rng) as usize]; .filter(|&i| i % args.user_clusters == user_cluster)
.collect();
// Generate number of interactions for this user // For each user in this cluster
let num_interactions = num_interactions_dist.sample(&mut rng).max(1.0) as i32; let cluster_start = user_cluster * (args.num_users / args.user_clusters);
let cluster_end = if user_cluster == args.user_clusters - 1 {
args.num_users
} else {
(user_cluster + 1) * (args.num_users / args.user_clusters)
};
// Generate interactions for user_id in cluster_start..cluster_end {
let mut interactions = Vec::new(); // Get all items from preferred clusters
while interactions.len() < num_interactions as usize { let mut preferred_items = HashSet::new();
let item_id = rng.gen_range(0..args.num_items); for &item_cluster in &preferred_item_clusters {
if !interactions.contains(&item_id) { preferred_items.extend(items_per_cluster[item_cluster as usize].iter());
interactions.push(item_id);
} }
}
// Insert interactions // Interact with most preferred items (1 - noise_level of items in preferred clusters)
for &item_id in &interactions { let num_interactions =
let item_cluster = item_clusters[item_cluster_dist.sample(&mut rng) as usize]; (preferred_items.len() as f64 * (1.0 - args.noise_level)) as usize;
let selected_items: Vec<_> = preferred_items.into_iter().collect();
let mut interactions: HashSet<_> = selected_items
.choose_multiple(&mut rng, num_interactions)
.copied()
.collect();
// Higher probability of interaction if user and item clusters are similar // Add random items from non-preferred clusters (noise)
let interaction_prob = (-(user_cluster.center - item_cluster.center).powi(2) let num_noise = (args.avg_interactions as f64 * args.noise_level) as i32;
/ (user_cluster.std_dev.powi(2) + item_cluster.std_dev.powi(2))) while interactions.len() < num_interactions + num_noise as usize {
.exp(); let item_id = rng.gen_range(0..args.num_items);
if !interactions.contains(&item_id) {
interactions.insert(item_id);
}
}
if rng.gen::<f64>() < interaction_prob { // Insert interactions
for &item_id in &interactions {
client client
.execute( .execute(
&format!( &format!(
@@ -213,51 +209,7 @@ async fn main() -> Result<()> {
} }
} }
if user_id % 100 == 0 { info!("Generated interactions for user cluster {}", user_cluster);
info!("Generated interactions for {} users", user_id);
}
}
// Second pass: ensure minimum interactions per item
info!("Ensuring minimum interactions per item...");
for item_id in 0..args.num_items {
let current_interactions = item_interactions.get(&item_id).copied().unwrap_or(0);
if current_interactions < args.min_item_interactions {
let needed = args.min_item_interactions - current_interactions;
let item_cluster = item_clusters[item_cluster_dist.sample(&mut rng) as usize];
// Add interactions from random users
let mut added = 0;
while added < needed {
let user_id = rng.gen_range(0..args.num_users);
let user_cluster = user_clusters[user_cluster_dist.sample(&mut rng) as usize];
// Use higher base probability to ensure we get enough interactions
let interaction_prob = 0.5
+ 0.5
* (-(user_cluster.center - item_cluster.center).powi(2)
/ (user_cluster.std_dev.powi(2) + item_cluster.std_dev.powi(2)))
.exp();
if rng.gen::<f64>() < interaction_prob {
match client
.execute(
&format!(
"INSERT INTO {} (user_id, item_id) VALUES ($1, $2)
ON CONFLICT DO NOTHING",
args.interactions_table
),
&[&user_id, &item_id],
)
.await
{
Ok(1) => added += 1, // Only increment if we actually inserted
Ok(_) => continue, // Row already existed
Err(e) => return Err(e.into()),
}
}
}
}
} }
// Get final statistics // Get final statistics
@@ -266,21 +218,43 @@ async fn main() -> Result<()> {
&format!( &format!(
"SELECT COUNT(*) as total_interactions, "SELECT COUNT(*) as total_interactions,
COUNT(DISTINCT user_id) as unique_users, COUNT(DISTINCT user_id) as unique_users,
COUNT(DISTINCT item_id) as unique_items COUNT(DISTINCT item_id) as unique_items,
COUNT(DISTINCT user_id)::float / {} as user_coverage,
COUNT(DISTINCT item_id)::float / {} as item_coverage
FROM {}", FROM {}",
args.interactions_table args.num_users, args.num_items, args.interactions_table
), ),
&[], &[],
) )
.await?; .await?;
info!( info!(
"Generated {} total interactions between {} users and {} items", "Generated {} total interactions between {} users ({:.1}%) and {} items ({:.1}%)",
stats.get::<_, i64>(0), stats.get::<_, i64>(0),
stats.get::<_, i64>(1), stats.get::<_, i64>(1),
stats.get::<_, i64>(2) stats.get::<_, f64>(3) * 100.0,
stats.get::<_, i64>(2),
stats.get::<_, f64>(4) * 100.0
); );
// Print cluster sizes
let cluster_sizes = client
.query(
"SELECT cluster_id, COUNT(*) as size
FROM item_clusters
GROUP BY cluster_id
ORDER BY cluster_id",
&[],
)
.await?;
info!("Cluster sizes:");
for row in cluster_sizes {
let cluster_id: i32 = row.get(0);
let size: i64 = row.get(1);
info!("Cluster {}: {} items", cluster_id, size);
}
info!("Done!"); info!("Done!");
Ok(()) Ok(())
} }

View File

@@ -46,13 +46,15 @@ async fn analyze_cluster_cohesion(client: &tokio_postgres::Client, args: &Args)
"WITH distances AS ( "WITH distances AS (
SELECT SELECT
a.cluster_id = b.cluster_id as same_cluster, a.cluster_id = b.cluster_id as same_cluster,
a.cluster_id as cluster1,
b.cluster_id as cluster2,
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance
FROM {} a FROM {} a
JOIN {} b ON a.item_id < b.item_id JOIN {} b ON a.item_id < b.item_id
JOIN {} ie1 ON a.item_id = ie1.item_id JOIN {} ie1 ON a.item_id = ie1.item_id
JOIN {} ie2 ON b.item_id = ie2.item_id, JOIN {} ie2 ON b.item_id = ie2.item_id,
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2) UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
GROUP BY a.item_id, b.item_id, same_cluster GROUP BY a.item_id, b.item_id, same_cluster, a.cluster_id, b.cluster_id
) )
SELECT SELECT
CASE WHEN same_cluster THEN 'Within Cluster' ELSE 'Between Clusters' END as comparison, CASE WHEN same_cluster THEN 'Within Cluster' ELSE 'Between Clusters' END as comparison,
@@ -82,6 +84,87 @@ async fn analyze_cluster_cohesion(client: &tokio_postgres::Client, args: &Args)
); );
} }
// Calculate per-cluster statistics
let query = format!(
"WITH distances AS (
SELECT
a.cluster_id,
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance
FROM {} a
JOIN {} b ON a.item_id < b.item_id
JOIN {} ie1 ON a.item_id = ie1.item_id
JOIN {} ie2 ON b.item_id = ie2.item_id,
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
WHERE a.cluster_id = b.cluster_id
GROUP BY a.item_id, b.item_id, a.cluster_id
)
SELECT
cluster_id,
AVG(distance)::float8 as avg_distance,
STDDEV(distance)::float8 as stddev_distance,
COUNT(*) as num_pairs
FROM distances
GROUP BY cluster_id
ORDER BY cluster_id",
args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table
);
info!("\nPer-cluster cohesion:");
let rows = client.query(&query, &[]).await?;
for row in rows {
let cluster: i32 = row.get(0);
let avg: f64 = row.get(1);
let stddev: f64 = row.get(2);
let count: i64 = row.get(3);
info!(
"Cluster {}: avg_distance={:.3}±{:.3} ({} pairs)",
cluster, avg, stddev, count
);
}
// Calculate separation between specific cluster pairs
let query = format!(
"WITH distances AS (
SELECT
a.cluster_id as cluster1,
b.cluster_id as cluster2,
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance
FROM {} a
JOIN {} b ON a.item_id < b.item_id
JOIN {} ie1 ON a.item_id = ie1.item_id
JOIN {} ie2 ON b.item_id = ie2.item_id,
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
WHERE a.cluster_id < b.cluster_id
GROUP BY a.item_id, b.item_id, a.cluster_id, b.cluster_id
)
SELECT
cluster1,
cluster2,
AVG(distance)::float8 as avg_distance,
STDDEV(distance)::float8 as stddev_distance,
COUNT(*) as num_pairs
FROM distances
GROUP BY cluster1, cluster2
ORDER BY cluster1, cluster2",
args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table
);
info!("\nBetween-cluster separation:");
let rows = client.query(&query, &[]).await?;
for row in rows {
let cluster1: i32 = row.get(0);
let cluster2: i32 = row.get(1);
let avg: f64 = row.get(2);
let stddev: f64 = row.get(3);
let count: i64 = row.get(4);
info!(
"Clusters {} <-> {}: avg_distance={:.3}±{:.3} ({} pairs)",
cluster1, cluster2, avg, stddev, count
);
}
Ok(()) Ok(())
} }