better test data generation
This commit is contained in:
@@ -4,9 +4,10 @@ use deadpool_postgres::{Config, Pool, Runtime};
|
||||
use dotenv::dotenv;
|
||||
use log::{info, warn};
|
||||
use rand::distributions::{Distribution, Uniform};
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::Rng;
|
||||
use rand_distr::Normal;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::env;
|
||||
use tokio_postgres::NoTls;
|
||||
|
||||
@@ -33,9 +34,9 @@ struct Args {
|
||||
#[arg(long, default_value = "50")]
|
||||
avg_interactions: i32,
|
||||
|
||||
/// Minimum interactions per item
|
||||
#[arg(long, default_value = "10")]
|
||||
min_item_interactions: i32,
|
||||
/// Noise level (0.0 to 1.0) - fraction of interactions that are random
|
||||
#[arg(long, default_value = "0.05")]
|
||||
noise_level: f64,
|
||||
|
||||
/// Source table name for interactions
|
||||
#[arg(long, default_value = "user_interactions")]
|
||||
@@ -62,25 +63,6 @@ async fn create_pool() -> Result<Pool> {
|
||||
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
struct ClusterInfo {
|
||||
center: f64,
|
||||
std_dev: f64,
|
||||
}
|
||||
|
||||
fn generate_cluster_info(num_clusters: i32) -> Vec<ClusterInfo> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let center_dist = Uniform::new(-1.0, 1.0);
|
||||
let std_dev_dist = Uniform::new(0.1, 0.3);
|
||||
|
||||
(0..num_clusters)
|
||||
.map(|_| ClusterInfo {
|
||||
center: center_dist.sample(&mut rng),
|
||||
std_dev: std_dev_dist.sample(&mut rng),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
dotenv().ok();
|
||||
@@ -131,22 +113,29 @@ async fn main() -> Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Generate cluster information
|
||||
let user_clusters = generate_cluster_info(args.user_clusters);
|
||||
let item_clusters = generate_cluster_info(args.item_clusters);
|
||||
|
||||
// Store item cluster assignments
|
||||
let mut item_to_cluster = HashMap::new();
|
||||
// Generate cluster centers that are well-separated
|
||||
let mut rng = rand::thread_rng();
|
||||
let user_cluster_dist = Uniform::new(0, args.user_clusters);
|
||||
let item_cluster_dist = Uniform::new(0, args.item_clusters);
|
||||
let mut cluster_centers = Vec::new();
|
||||
for i in 0..args.item_clusters {
|
||||
// Place centers evenly around a circle in 2D space, then project to 1D
|
||||
let angle = 2.0 * std::f64::consts::PI * (i as f64) / (args.item_clusters as f64);
|
||||
let center = angle.cos(); // Project to 1D by taking cosine
|
||||
cluster_centers.push(center);
|
||||
}
|
||||
|
||||
// Assign and store cluster information for each item
|
||||
for item_id in 0..args.num_items {
|
||||
let cluster_id = item_cluster_dist.sample(&mut rng);
|
||||
let cluster = item_clusters[cluster_id as usize];
|
||||
item_to_cluster.insert(item_id, (cluster_id, cluster));
|
||||
// Create shuffled list of items
|
||||
let mut all_items: Vec<i32> = (0..args.num_items).collect();
|
||||
all_items.shuffle(&mut rng);
|
||||
|
||||
// Assign items to clusters in contiguous blocks
|
||||
let mut items_per_cluster = vec![Vec::new(); args.item_clusters as usize];
|
||||
let items_per_cluster_count = args.num_items / args.item_clusters;
|
||||
for (i, &item_id) in all_items.iter().enumerate() {
|
||||
let cluster_id = (i as i32) / items_per_cluster_count;
|
||||
if cluster_id < args.item_clusters {
|
||||
items_per_cluster[cluster_id as usize].push(item_id);
|
||||
|
||||
// Store cluster assignment with meaningful center
|
||||
client
|
||||
.execute(
|
||||
"INSERT INTO item_clusters (item_id, cluster_id, cluster_center)
|
||||
@@ -154,51 +143,58 @@ async fn main() -> Result<()> {
|
||||
ON CONFLICT (item_id) DO UPDATE
|
||||
SET cluster_id = EXCLUDED.cluster_id,
|
||||
cluster_center = EXCLUDED.cluster_center",
|
||||
&[&item_id, &cluster_id, &cluster.center],
|
||||
&[&item_id, &cluster_id, &cluster_centers[cluster_id as usize]],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let user_cluster_dist = Uniform::new(0, args.user_clusters);
|
||||
let item_cluster_dist = Uniform::new(0, args.item_clusters);
|
||||
let num_interactions_dist = Normal::new(
|
||||
args.avg_interactions as f64,
|
||||
args.avg_interactions as f64 * 0.2,
|
||||
)
|
||||
.context("Failed to create normal distribution")?;
|
||||
|
||||
// Track item interaction counts
|
||||
// Generate interactions with strong cluster affinity
|
||||
info!("Generating interactions...");
|
||||
let mut item_interactions = HashMap::new();
|
||||
|
||||
info!("Generating interactions...");
|
||||
// First pass: generate normal user-item interactions
|
||||
for user_id in 0..args.num_users {
|
||||
// Assign user to a cluster
|
||||
let user_cluster = user_clusters[user_cluster_dist.sample(&mut rng) as usize];
|
||||
// For each user cluster
|
||||
for user_cluster in 0..args.user_clusters {
|
||||
// Each user cluster strongly prefers 1-2 item clusters
|
||||
let preferred_item_clusters: HashSet<i32> = (0..args.item_clusters)
|
||||
.filter(|&i| i % args.user_clusters == user_cluster)
|
||||
.collect();
|
||||
|
||||
// Generate number of interactions for this user
|
||||
let num_interactions = num_interactions_dist.sample(&mut rng).max(1.0) as i32;
|
||||
// For each user in this cluster
|
||||
let cluster_start = user_cluster * (args.num_users / args.user_clusters);
|
||||
let cluster_end = if user_cluster == args.user_clusters - 1 {
|
||||
args.num_users
|
||||
} else {
|
||||
(user_cluster + 1) * (args.num_users / args.user_clusters)
|
||||
};
|
||||
|
||||
// Generate interactions
|
||||
let mut interactions = Vec::new();
|
||||
while interactions.len() < num_interactions as usize {
|
||||
for user_id in cluster_start..cluster_end {
|
||||
// Get all items from preferred clusters
|
||||
let mut preferred_items = HashSet::new();
|
||||
for &item_cluster in &preferred_item_clusters {
|
||||
preferred_items.extend(items_per_cluster[item_cluster as usize].iter());
|
||||
}
|
||||
|
||||
// Interact with most preferred items (1 - noise_level of items in preferred clusters)
|
||||
let num_interactions =
|
||||
(preferred_items.len() as f64 * (1.0 - args.noise_level)) as usize;
|
||||
let selected_items: Vec<_> = preferred_items.into_iter().collect();
|
||||
let mut interactions: HashSet<_> = selected_items
|
||||
.choose_multiple(&mut rng, num_interactions)
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
// Add random items from non-preferred clusters (noise)
|
||||
let num_noise = (args.avg_interactions as f64 * args.noise_level) as i32;
|
||||
while interactions.len() < num_interactions + num_noise as usize {
|
||||
let item_id = rng.gen_range(0..args.num_items);
|
||||
if !interactions.contains(&item_id) {
|
||||
interactions.push(item_id);
|
||||
interactions.insert(item_id);
|
||||
}
|
||||
}
|
||||
|
||||
// Insert interactions
|
||||
for &item_id in &interactions {
|
||||
let item_cluster = item_clusters[item_cluster_dist.sample(&mut rng) as usize];
|
||||
|
||||
// Higher probability of interaction if user and item clusters are similar
|
||||
let interaction_prob = (-(user_cluster.center - item_cluster.center).powi(2)
|
||||
/ (user_cluster.std_dev.powi(2) + item_cluster.std_dev.powi(2)))
|
||||
.exp();
|
||||
|
||||
if rng.gen::<f64>() < interaction_prob {
|
||||
client
|
||||
.execute(
|
||||
&format!(
|
||||
@@ -213,51 +209,7 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
if user_id % 100 == 0 {
|
||||
info!("Generated interactions for {} users", user_id);
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: ensure minimum interactions per item
|
||||
info!("Ensuring minimum interactions per item...");
|
||||
for item_id in 0..args.num_items {
|
||||
let current_interactions = item_interactions.get(&item_id).copied().unwrap_or(0);
|
||||
if current_interactions < args.min_item_interactions {
|
||||
let needed = args.min_item_interactions - current_interactions;
|
||||
let item_cluster = item_clusters[item_cluster_dist.sample(&mut rng) as usize];
|
||||
|
||||
// Add interactions from random users
|
||||
let mut added = 0;
|
||||
while added < needed {
|
||||
let user_id = rng.gen_range(0..args.num_users);
|
||||
let user_cluster = user_clusters[user_cluster_dist.sample(&mut rng) as usize];
|
||||
|
||||
// Use higher base probability to ensure we get enough interactions
|
||||
let interaction_prob = 0.5
|
||||
+ 0.5
|
||||
* (-(user_cluster.center - item_cluster.center).powi(2)
|
||||
/ (user_cluster.std_dev.powi(2) + item_cluster.std_dev.powi(2)))
|
||||
.exp();
|
||||
|
||||
if rng.gen::<f64>() < interaction_prob {
|
||||
match client
|
||||
.execute(
|
||||
&format!(
|
||||
"INSERT INTO {} (user_id, item_id) VALUES ($1, $2)
|
||||
ON CONFLICT DO NOTHING",
|
||||
args.interactions_table
|
||||
),
|
||||
&[&user_id, &item_id],
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(1) => added += 1, // Only increment if we actually inserted
|
||||
Ok(_) => continue, // Row already existed
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("Generated interactions for user cluster {}", user_cluster);
|
||||
}
|
||||
|
||||
// Get final statistics
|
||||
@@ -266,21 +218,43 @@ async fn main() -> Result<()> {
|
||||
&format!(
|
||||
"SELECT COUNT(*) as total_interactions,
|
||||
COUNT(DISTINCT user_id) as unique_users,
|
||||
COUNT(DISTINCT item_id) as unique_items
|
||||
COUNT(DISTINCT item_id) as unique_items,
|
||||
COUNT(DISTINCT user_id)::float / {} as user_coverage,
|
||||
COUNT(DISTINCT item_id)::float / {} as item_coverage
|
||||
FROM {}",
|
||||
args.interactions_table
|
||||
args.num_users, args.num_items, args.interactions_table
|
||||
),
|
||||
&[],
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!(
|
||||
"Generated {} total interactions between {} users and {} items",
|
||||
"Generated {} total interactions between {} users ({:.1}%) and {} items ({:.1}%)",
|
||||
stats.get::<_, i64>(0),
|
||||
stats.get::<_, i64>(1),
|
||||
stats.get::<_, i64>(2)
|
||||
stats.get::<_, f64>(3) * 100.0,
|
||||
stats.get::<_, i64>(2),
|
||||
stats.get::<_, f64>(4) * 100.0
|
||||
);
|
||||
|
||||
// Print cluster sizes
|
||||
let cluster_sizes = client
|
||||
.query(
|
||||
"SELECT cluster_id, COUNT(*) as size
|
||||
FROM item_clusters
|
||||
GROUP BY cluster_id
|
||||
ORDER BY cluster_id",
|
||||
&[],
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!("Cluster sizes:");
|
||||
for row in cluster_sizes {
|
||||
let cluster_id: i32 = row.get(0);
|
||||
let size: i64 = row.get(1);
|
||||
info!("Cluster {}: {} items", cluster_id, size);
|
||||
}
|
||||
|
||||
info!("Done!");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -46,13 +46,15 @@ async fn analyze_cluster_cohesion(client: &tokio_postgres::Client, args: &Args)
|
||||
"WITH distances AS (
|
||||
SELECT
|
||||
a.cluster_id = b.cluster_id as same_cluster,
|
||||
a.cluster_id as cluster1,
|
||||
b.cluster_id as cluster2,
|
||||
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance
|
||||
FROM {} a
|
||||
JOIN {} b ON a.item_id < b.item_id
|
||||
JOIN {} ie1 ON a.item_id = ie1.item_id
|
||||
JOIN {} ie2 ON b.item_id = ie2.item_id,
|
||||
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
|
||||
GROUP BY a.item_id, b.item_id, same_cluster
|
||||
GROUP BY a.item_id, b.item_id, same_cluster, a.cluster_id, b.cluster_id
|
||||
)
|
||||
SELECT
|
||||
CASE WHEN same_cluster THEN 'Within Cluster' ELSE 'Between Clusters' END as comparison,
|
||||
@@ -82,6 +84,87 @@ async fn analyze_cluster_cohesion(client: &tokio_postgres::Client, args: &Args)
|
||||
);
|
||||
}
|
||||
|
||||
// Calculate per-cluster statistics
|
||||
let query = format!(
|
||||
"WITH distances AS (
|
||||
SELECT
|
||||
a.cluster_id,
|
||||
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance
|
||||
FROM {} a
|
||||
JOIN {} b ON a.item_id < b.item_id
|
||||
JOIN {} ie1 ON a.item_id = ie1.item_id
|
||||
JOIN {} ie2 ON b.item_id = ie2.item_id,
|
||||
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
|
||||
WHERE a.cluster_id = b.cluster_id
|
||||
GROUP BY a.item_id, b.item_id, a.cluster_id
|
||||
)
|
||||
SELECT
|
||||
cluster_id,
|
||||
AVG(distance)::float8 as avg_distance,
|
||||
STDDEV(distance)::float8 as stddev_distance,
|
||||
COUNT(*) as num_pairs
|
||||
FROM distances
|
||||
GROUP BY cluster_id
|
||||
ORDER BY cluster_id",
|
||||
args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table
|
||||
);
|
||||
|
||||
info!("\nPer-cluster cohesion:");
|
||||
let rows = client.query(&query, &[]).await?;
|
||||
for row in rows {
|
||||
let cluster: i32 = row.get(0);
|
||||
let avg: f64 = row.get(1);
|
||||
let stddev: f64 = row.get(2);
|
||||
let count: i64 = row.get(3);
|
||||
|
||||
info!(
|
||||
"Cluster {}: avg_distance={:.3}±{:.3} ({} pairs)",
|
||||
cluster, avg, stddev, count
|
||||
);
|
||||
}
|
||||
|
||||
// Calculate separation between specific cluster pairs
|
||||
let query = format!(
|
||||
"WITH distances AS (
|
||||
SELECT
|
||||
a.cluster_id as cluster1,
|
||||
b.cluster_id as cluster2,
|
||||
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance
|
||||
FROM {} a
|
||||
JOIN {} b ON a.item_id < b.item_id
|
||||
JOIN {} ie1 ON a.item_id = ie1.item_id
|
||||
JOIN {} ie2 ON b.item_id = ie2.item_id,
|
||||
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
|
||||
WHERE a.cluster_id < b.cluster_id
|
||||
GROUP BY a.item_id, b.item_id, a.cluster_id, b.cluster_id
|
||||
)
|
||||
SELECT
|
||||
cluster1,
|
||||
cluster2,
|
||||
AVG(distance)::float8 as avg_distance,
|
||||
STDDEV(distance)::float8 as stddev_distance,
|
||||
COUNT(*) as num_pairs
|
||||
FROM distances
|
||||
GROUP BY cluster1, cluster2
|
||||
ORDER BY cluster1, cluster2",
|
||||
args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table
|
||||
);
|
||||
|
||||
info!("\nBetween-cluster separation:");
|
||||
let rows = client.query(&query, &[]).await?;
|
||||
for row in rows {
|
||||
let cluster1: i32 = row.get(0);
|
||||
let cluster2: i32 = row.get(1);
|
||||
let avg: f64 = row.get(2);
|
||||
let stddev: f64 = row.get(3);
|
||||
let count: i64 = row.get(4);
|
||||
|
||||
info!(
|
||||
"Clusters {} <-> {}: avg_distance={:.3}±{:.3} ({} pairs)",
|
||||
cluster1, cluster2, avg, stddev, count
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user