better test data generation
This commit is contained in:
@@ -4,9 +4,10 @@ use deadpool_postgres::{Config, Pool, Runtime};
|
|||||||
use dotenv::dotenv;
|
use dotenv::dotenv;
|
||||||
use log::{info, warn};
|
use log::{info, warn};
|
||||||
use rand::distributions::{Distribution, Uniform};
|
use rand::distributions::{Distribution, Uniform};
|
||||||
|
use rand::seq::SliceRandom;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use rand_distr::Normal;
|
use rand_distr::Normal;
|
||||||
use std::collections::HashMap;
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::env;
|
use std::env;
|
||||||
use tokio_postgres::NoTls;
|
use tokio_postgres::NoTls;
|
||||||
|
|
||||||
@@ -33,9 +34,9 @@ struct Args {
|
|||||||
#[arg(long, default_value = "50")]
|
#[arg(long, default_value = "50")]
|
||||||
avg_interactions: i32,
|
avg_interactions: i32,
|
||||||
|
|
||||||
/// Minimum interactions per item
|
/// Noise level (0.0 to 1.0) - fraction of interactions that are random
|
||||||
#[arg(long, default_value = "10")]
|
#[arg(long, default_value = "0.05")]
|
||||||
min_item_interactions: i32,
|
noise_level: f64,
|
||||||
|
|
||||||
/// Source table name for interactions
|
/// Source table name for interactions
|
||||||
#[arg(long, default_value = "user_interactions")]
|
#[arg(long, default_value = "user_interactions")]
|
||||||
@@ -62,25 +63,6 @@ async fn create_pool() -> Result<Pool> {
|
|||||||
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone)]
|
|
||||||
struct ClusterInfo {
|
|
||||||
center: f64,
|
|
||||||
std_dev: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn generate_cluster_info(num_clusters: i32) -> Vec<ClusterInfo> {
|
|
||||||
let mut rng = rand::thread_rng();
|
|
||||||
let center_dist = Uniform::new(-1.0, 1.0);
|
|
||||||
let std_dev_dist = Uniform::new(0.1, 0.3);
|
|
||||||
|
|
||||||
(0..num_clusters)
|
|
||||||
.map(|_| ClusterInfo {
|
|
||||||
center: center_dist.sample(&mut rng),
|
|
||||||
std_dev: std_dev_dist.sample(&mut rng),
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
dotenv().ok();
|
dotenv().ok();
|
||||||
@@ -131,74 +113,88 @@ async fn main() -> Result<()> {
|
|||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Generate cluster information
|
// Generate cluster centers that are well-separated
|
||||||
let user_clusters = generate_cluster_info(args.user_clusters);
|
|
||||||
let item_clusters = generate_cluster_info(args.item_clusters);
|
|
||||||
|
|
||||||
// Store item cluster assignments
|
|
||||||
let mut item_to_cluster = HashMap::new();
|
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::thread_rng();
|
||||||
let user_cluster_dist = Uniform::new(0, args.user_clusters);
|
let mut cluster_centers = Vec::new();
|
||||||
let item_cluster_dist = Uniform::new(0, args.item_clusters);
|
for i in 0..args.item_clusters {
|
||||||
|
// Place centers evenly around a circle in 2D space, then project to 1D
|
||||||
// Assign and store cluster information for each item
|
let angle = 2.0 * std::f64::consts::PI * (i as f64) / (args.item_clusters as f64);
|
||||||
for item_id in 0..args.num_items {
|
let center = angle.cos(); // Project to 1D by taking cosine
|
||||||
let cluster_id = item_cluster_dist.sample(&mut rng);
|
cluster_centers.push(center);
|
||||||
let cluster = item_clusters[cluster_id as usize];
|
|
||||||
item_to_cluster.insert(item_id, (cluster_id, cluster));
|
|
||||||
|
|
||||||
client
|
|
||||||
.execute(
|
|
||||||
"INSERT INTO item_clusters (item_id, cluster_id, cluster_center)
|
|
||||||
VALUES ($1, $2, $3)
|
|
||||||
ON CONFLICT (item_id) DO UPDATE
|
|
||||||
SET cluster_id = EXCLUDED.cluster_id,
|
|
||||||
cluster_center = EXCLUDED.cluster_center",
|
|
||||||
&[&item_id, &cluster_id, &cluster.center],
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut rng = rand::thread_rng();
|
// Create shuffled list of items
|
||||||
let user_cluster_dist = Uniform::new(0, args.user_clusters);
|
let mut all_items: Vec<i32> = (0..args.num_items).collect();
|
||||||
let item_cluster_dist = Uniform::new(0, args.item_clusters);
|
all_items.shuffle(&mut rng);
|
||||||
let num_interactions_dist = Normal::new(
|
|
||||||
args.avg_interactions as f64,
|
|
||||||
args.avg_interactions as f64 * 0.2,
|
|
||||||
)
|
|
||||||
.context("Failed to create normal distribution")?;
|
|
||||||
|
|
||||||
// Track item interaction counts
|
// Assign items to clusters in contiguous blocks
|
||||||
|
let mut items_per_cluster = vec![Vec::new(); args.item_clusters as usize];
|
||||||
|
let items_per_cluster_count = args.num_items / args.item_clusters;
|
||||||
|
for (i, &item_id) in all_items.iter().enumerate() {
|
||||||
|
let cluster_id = (i as i32) / items_per_cluster_count;
|
||||||
|
if cluster_id < args.item_clusters {
|
||||||
|
items_per_cluster[cluster_id as usize].push(item_id);
|
||||||
|
|
||||||
|
// Store cluster assignment with meaningful center
|
||||||
|
client
|
||||||
|
.execute(
|
||||||
|
"INSERT INTO item_clusters (item_id, cluster_id, cluster_center)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
ON CONFLICT (item_id) DO UPDATE
|
||||||
|
SET cluster_id = EXCLUDED.cluster_id,
|
||||||
|
cluster_center = EXCLUDED.cluster_center",
|
||||||
|
&[&item_id, &cluster_id, &cluster_centers[cluster_id as usize]],
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate interactions with strong cluster affinity
|
||||||
|
info!("Generating interactions...");
|
||||||
let mut item_interactions = HashMap::new();
|
let mut item_interactions = HashMap::new();
|
||||||
|
|
||||||
info!("Generating interactions...");
|
// For each user cluster
|
||||||
// First pass: generate normal user-item interactions
|
for user_cluster in 0..args.user_clusters {
|
||||||
for user_id in 0..args.num_users {
|
// Each user cluster strongly prefers 1-2 item clusters
|
||||||
// Assign user to a cluster
|
let preferred_item_clusters: HashSet<i32> = (0..args.item_clusters)
|
||||||
let user_cluster = user_clusters[user_cluster_dist.sample(&mut rng) as usize];
|
.filter(|&i| i % args.user_clusters == user_cluster)
|
||||||
|
.collect();
|
||||||
|
|
||||||
// Generate number of interactions for this user
|
// For each user in this cluster
|
||||||
let num_interactions = num_interactions_dist.sample(&mut rng).max(1.0) as i32;
|
let cluster_start = user_cluster * (args.num_users / args.user_clusters);
|
||||||
|
let cluster_end = if user_cluster == args.user_clusters - 1 {
|
||||||
|
args.num_users
|
||||||
|
} else {
|
||||||
|
(user_cluster + 1) * (args.num_users / args.user_clusters)
|
||||||
|
};
|
||||||
|
|
||||||
// Generate interactions
|
for user_id in cluster_start..cluster_end {
|
||||||
let mut interactions = Vec::new();
|
// Get all items from preferred clusters
|
||||||
while interactions.len() < num_interactions as usize {
|
let mut preferred_items = HashSet::new();
|
||||||
let item_id = rng.gen_range(0..args.num_items);
|
for &item_cluster in &preferred_item_clusters {
|
||||||
if !interactions.contains(&item_id) {
|
preferred_items.extend(items_per_cluster[item_cluster as usize].iter());
|
||||||
interactions.push(item_id);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Insert interactions
|
// Interact with most preferred items (1 - noise_level of items in preferred clusters)
|
||||||
for &item_id in &interactions {
|
let num_interactions =
|
||||||
let item_cluster = item_clusters[item_cluster_dist.sample(&mut rng) as usize];
|
(preferred_items.len() as f64 * (1.0 - args.noise_level)) as usize;
|
||||||
|
let selected_items: Vec<_> = preferred_items.into_iter().collect();
|
||||||
|
let mut interactions: HashSet<_> = selected_items
|
||||||
|
.choose_multiple(&mut rng, num_interactions)
|
||||||
|
.copied()
|
||||||
|
.collect();
|
||||||
|
|
||||||
// Higher probability of interaction if user and item clusters are similar
|
// Add random items from non-preferred clusters (noise)
|
||||||
let interaction_prob = (-(user_cluster.center - item_cluster.center).powi(2)
|
let num_noise = (args.avg_interactions as f64 * args.noise_level) as i32;
|
||||||
/ (user_cluster.std_dev.powi(2) + item_cluster.std_dev.powi(2)))
|
while interactions.len() < num_interactions + num_noise as usize {
|
||||||
.exp();
|
let item_id = rng.gen_range(0..args.num_items);
|
||||||
|
if !interactions.contains(&item_id) {
|
||||||
|
interactions.insert(item_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if rng.gen::<f64>() < interaction_prob {
|
// Insert interactions
|
||||||
|
for &item_id in &interactions {
|
||||||
client
|
client
|
||||||
.execute(
|
.execute(
|
||||||
&format!(
|
&format!(
|
||||||
@@ -213,51 +209,7 @@ async fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if user_id % 100 == 0 {
|
info!("Generated interactions for user cluster {}", user_cluster);
|
||||||
info!("Generated interactions for {} users", user_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Second pass: ensure minimum interactions per item
|
|
||||||
info!("Ensuring minimum interactions per item...");
|
|
||||||
for item_id in 0..args.num_items {
|
|
||||||
let current_interactions = item_interactions.get(&item_id).copied().unwrap_or(0);
|
|
||||||
if current_interactions < args.min_item_interactions {
|
|
||||||
let needed = args.min_item_interactions - current_interactions;
|
|
||||||
let item_cluster = item_clusters[item_cluster_dist.sample(&mut rng) as usize];
|
|
||||||
|
|
||||||
// Add interactions from random users
|
|
||||||
let mut added = 0;
|
|
||||||
while added < needed {
|
|
||||||
let user_id = rng.gen_range(0..args.num_users);
|
|
||||||
let user_cluster = user_clusters[user_cluster_dist.sample(&mut rng) as usize];
|
|
||||||
|
|
||||||
// Use higher base probability to ensure we get enough interactions
|
|
||||||
let interaction_prob = 0.5
|
|
||||||
+ 0.5
|
|
||||||
* (-(user_cluster.center - item_cluster.center).powi(2)
|
|
||||||
/ (user_cluster.std_dev.powi(2) + item_cluster.std_dev.powi(2)))
|
|
||||||
.exp();
|
|
||||||
|
|
||||||
if rng.gen::<f64>() < interaction_prob {
|
|
||||||
match client
|
|
||||||
.execute(
|
|
||||||
&format!(
|
|
||||||
"INSERT INTO {} (user_id, item_id) VALUES ($1, $2)
|
|
||||||
ON CONFLICT DO NOTHING",
|
|
||||||
args.interactions_table
|
|
||||||
),
|
|
||||||
&[&user_id, &item_id],
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(1) => added += 1, // Only increment if we actually inserted
|
|
||||||
Ok(_) => continue, // Row already existed
|
|
||||||
Err(e) => return Err(e.into()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get final statistics
|
// Get final statistics
|
||||||
@@ -266,21 +218,43 @@ async fn main() -> Result<()> {
|
|||||||
&format!(
|
&format!(
|
||||||
"SELECT COUNT(*) as total_interactions,
|
"SELECT COUNT(*) as total_interactions,
|
||||||
COUNT(DISTINCT user_id) as unique_users,
|
COUNT(DISTINCT user_id) as unique_users,
|
||||||
COUNT(DISTINCT item_id) as unique_items
|
COUNT(DISTINCT item_id) as unique_items,
|
||||||
|
COUNT(DISTINCT user_id)::float / {} as user_coverage,
|
||||||
|
COUNT(DISTINCT item_id)::float / {} as item_coverage
|
||||||
FROM {}",
|
FROM {}",
|
||||||
args.interactions_table
|
args.num_users, args.num_items, args.interactions_table
|
||||||
),
|
),
|
||||||
&[],
|
&[],
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Generated {} total interactions between {} users and {} items",
|
"Generated {} total interactions between {} users ({:.1}%) and {} items ({:.1}%)",
|
||||||
stats.get::<_, i64>(0),
|
stats.get::<_, i64>(0),
|
||||||
stats.get::<_, i64>(1),
|
stats.get::<_, i64>(1),
|
||||||
stats.get::<_, i64>(2)
|
stats.get::<_, f64>(3) * 100.0,
|
||||||
|
stats.get::<_, i64>(2),
|
||||||
|
stats.get::<_, f64>(4) * 100.0
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Print cluster sizes
|
||||||
|
let cluster_sizes = client
|
||||||
|
.query(
|
||||||
|
"SELECT cluster_id, COUNT(*) as size
|
||||||
|
FROM item_clusters
|
||||||
|
GROUP BY cluster_id
|
||||||
|
ORDER BY cluster_id",
|
||||||
|
&[],
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
info!("Cluster sizes:");
|
||||||
|
for row in cluster_sizes {
|
||||||
|
let cluster_id: i32 = row.get(0);
|
||||||
|
let size: i64 = row.get(1);
|
||||||
|
info!("Cluster {}: {} items", cluster_id, size);
|
||||||
|
}
|
||||||
|
|
||||||
info!("Done!");
|
info!("Done!");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,13 +46,15 @@ async fn analyze_cluster_cohesion(client: &tokio_postgres::Client, args: &Args)
|
|||||||
"WITH distances AS (
|
"WITH distances AS (
|
||||||
SELECT
|
SELECT
|
||||||
a.cluster_id = b.cluster_id as same_cluster,
|
a.cluster_id = b.cluster_id as same_cluster,
|
||||||
|
a.cluster_id as cluster1,
|
||||||
|
b.cluster_id as cluster2,
|
||||||
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance
|
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance
|
||||||
FROM {} a
|
FROM {} a
|
||||||
JOIN {} b ON a.item_id < b.item_id
|
JOIN {} b ON a.item_id < b.item_id
|
||||||
JOIN {} ie1 ON a.item_id = ie1.item_id
|
JOIN {} ie1 ON a.item_id = ie1.item_id
|
||||||
JOIN {} ie2 ON b.item_id = ie2.item_id,
|
JOIN {} ie2 ON b.item_id = ie2.item_id,
|
||||||
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
|
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
|
||||||
GROUP BY a.item_id, b.item_id, same_cluster
|
GROUP BY a.item_id, b.item_id, same_cluster, a.cluster_id, b.cluster_id
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
CASE WHEN same_cluster THEN 'Within Cluster' ELSE 'Between Clusters' END as comparison,
|
CASE WHEN same_cluster THEN 'Within Cluster' ELSE 'Between Clusters' END as comparison,
|
||||||
@@ -82,6 +84,87 @@ async fn analyze_cluster_cohesion(client: &tokio_postgres::Client, args: &Args)
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Calculate per-cluster statistics
|
||||||
|
let query = format!(
|
||||||
|
"WITH distances AS (
|
||||||
|
SELECT
|
||||||
|
a.cluster_id,
|
||||||
|
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance
|
||||||
|
FROM {} a
|
||||||
|
JOIN {} b ON a.item_id < b.item_id
|
||||||
|
JOIN {} ie1 ON a.item_id = ie1.item_id
|
||||||
|
JOIN {} ie2 ON b.item_id = ie2.item_id,
|
||||||
|
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
|
||||||
|
WHERE a.cluster_id = b.cluster_id
|
||||||
|
GROUP BY a.item_id, b.item_id, a.cluster_id
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
cluster_id,
|
||||||
|
AVG(distance)::float8 as avg_distance,
|
||||||
|
STDDEV(distance)::float8 as stddev_distance,
|
||||||
|
COUNT(*) as num_pairs
|
||||||
|
FROM distances
|
||||||
|
GROUP BY cluster_id
|
||||||
|
ORDER BY cluster_id",
|
||||||
|
args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table
|
||||||
|
);
|
||||||
|
|
||||||
|
info!("\nPer-cluster cohesion:");
|
||||||
|
let rows = client.query(&query, &[]).await?;
|
||||||
|
for row in rows {
|
||||||
|
let cluster: i32 = row.get(0);
|
||||||
|
let avg: f64 = row.get(1);
|
||||||
|
let stddev: f64 = row.get(2);
|
||||||
|
let count: i64 = row.get(3);
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Cluster {}: avg_distance={:.3}±{:.3} ({} pairs)",
|
||||||
|
cluster, avg, stddev, count
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate separation between specific cluster pairs
|
||||||
|
let query = format!(
|
||||||
|
"WITH distances AS (
|
||||||
|
SELECT
|
||||||
|
a.cluster_id as cluster1,
|
||||||
|
b.cluster_id as cluster2,
|
||||||
|
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance
|
||||||
|
FROM {} a
|
||||||
|
JOIN {} b ON a.item_id < b.item_id
|
||||||
|
JOIN {} ie1 ON a.item_id = ie1.item_id
|
||||||
|
JOIN {} ie2 ON b.item_id = ie2.item_id,
|
||||||
|
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
|
||||||
|
WHERE a.cluster_id < b.cluster_id
|
||||||
|
GROUP BY a.item_id, b.item_id, a.cluster_id, b.cluster_id
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
cluster1,
|
||||||
|
cluster2,
|
||||||
|
AVG(distance)::float8 as avg_distance,
|
||||||
|
STDDEV(distance)::float8 as stddev_distance,
|
||||||
|
COUNT(*) as num_pairs
|
||||||
|
FROM distances
|
||||||
|
GROUP BY cluster1, cluster2
|
||||||
|
ORDER BY cluster1, cluster2",
|
||||||
|
args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table
|
||||||
|
);
|
||||||
|
|
||||||
|
info!("\nBetween-cluster separation:");
|
||||||
|
let rows = client.query(&query, &[]).await?;
|
||||||
|
for row in rows {
|
||||||
|
let cluster1: i32 = row.get(0);
|
||||||
|
let cluster2: i32 = row.get(1);
|
||||||
|
let avg: f64 = row.get(2);
|
||||||
|
let stddev: f64 = row.get(3);
|
||||||
|
let count: i64 = row.get(4);
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Clusters {} <-> {}: avg_distance={:.3}±{:.3} ({} pairs)",
|
||||||
|
cluster1, cluster2, avg, stddev, count
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user