Files
mf-fitter/src/bin/generate_test_data.rs
2024-12-28 03:11:37 +00:00

304 lines
10 KiB
Rust

use anyhow::{Context, Result};
use clap::Parser;
use deadpool_postgres::{Config, Pool, Runtime};
use dotenv::dotenv;
use log::info;
use rand::seq::SliceRandom;
use rand::Rng;
use std::collections::{HashMap, HashSet};
use std::env;
use tokio_postgres::NoTls;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Number of users to generate
#[arg(long, default_value = "1000")]
num_users: i32,
/// Number of items to generate
#[arg(long, default_value = "500")]
num_items: i32,
/// Number of user clusters
#[arg(long, default_value = "5")]
user_clusters: i32,
/// Number of item clusters
#[arg(long, default_value = "10")]
item_clusters: i32,
/// Average number of interactions per user
#[arg(long, default_value = "50")]
avg_interactions: i32,
/// Noise level (0.0 to 1.0) - fraction of interactions that are random
#[arg(long, default_value = "0.05")]
noise_level: f64,
/// Source table name for interactions
#[arg(long, default_value = "user_interactions")]
interactions_table: String,
/// Target table name for embeddings
#[arg(long, default_value = "item_embeddings")]
embeddings_table: String,
}
async fn create_pool() -> Result<Pool> {
let mut config = Config::new();
config.host = Some(env::var("POSTGRES_HOST").context("POSTGRES_HOST not set")?);
config.port = Some(
env::var("POSTGRES_PORT")
.context("POSTGRES_PORT not set")?
.parse()
.context("Invalid POSTGRES_PORT")?,
);
config.dbname = Some(env::var("POSTGRES_DB").context("POSTGRES_DB not set")?);
config.user = Some(env::var("POSTGRES_USER").context("POSTGRES_USER not set")?);
config.password = Some(env::var("POSTGRES_PASSWORD").context("POSTGRES_PASSWORD not set")?);
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
}
#[tokio::main]
async fn main() -> Result<()> {
dotenv().ok();
pretty_env_logger::init();
let args = Args::parse();
let pool = create_pool().await?;
let client = pool.get().await?;
// Create tables
info!("Creating tables...");
client
.execute(
&format!(
"CREATE TABLE IF NOT EXISTS {} (
user_id INTEGER,
item_id INTEGER,
PRIMARY KEY (user_id, item_id)
)",
args.interactions_table
),
&[],
)
.await?;
client
.execute(
&format!(
"CREATE TABLE IF NOT EXISTS {} (
item_id INTEGER PRIMARY KEY,
embedding FLOAT[]
)",
args.embeddings_table
),
&[],
)
.await?;
// Generate cluster centers that are well-separated in 3D space
let mut rng = rand::thread_rng();
let mut cluster_centers = Vec::new();
// Base vertices of a regular octahedron
// These vertices form three orthogonal golden rectangles
let phi = (1.0 + 5.0_f64.sqrt()) / 2.0; // golden ratio
let scale = 2.0;
let base_centers = vec![
vec![1.0, phi, 0.0], // front top
vec![-1.0, phi, 0.0], // front bottom
vec![1.0, -phi, 0.0], // back top
vec![-1.0, -phi, 0.0], // back bottom
vec![phi, 0.0, 1.0], // right top
vec![-phi, 0.0, 1.0], // left top
vec![phi, 0.0, -1.0], // right bottom
vec![-phi, 0.0, -1.0], // left bottom
vec![0.0, 1.0, phi], // front right
vec![0.0, -1.0, phi], // back right
vec![0.0, 1.0, -phi], // front left
vec![0.0, -1.0, -phi], // back left
];
// Normalize and scale the vectors to ensure equal distances
let base_centers: Vec<Vec<f64>> = base_centers
.into_iter()
.map(|v| {
let norm = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
vec![
v[0] / norm * scale,
v[1] / norm * scale,
v[2] / norm * scale,
]
})
.collect();
// Take the first n centers needed
for i in 0..args.item_clusters as usize {
let base = &base_centers[i % base_centers.len()];
// Add very small jitter (1% of scale) to make it more natural
let jitter = 0.01;
let x = base[0] + (rng.gen::<f64>() - 0.5) * jitter * scale;
let y = base[1] + (rng.gen::<f64>() - 0.5) * jitter * scale;
let z = base[2] + (rng.gen::<f64>() - 0.5) * jitter * scale;
cluster_centers.push(vec![x, y, z]);
}
// Create shuffled list of items
let mut all_items: Vec<i32> = (0..args.num_items).collect();
all_items.shuffle(&mut rng);
// Assign items to clusters in contiguous blocks
let mut items_per_cluster = vec![Vec::new(); args.item_clusters as usize];
let items_per_cluster_count = args.num_items / args.item_clusters;
// Drop existing item_clusters table and recreate with 3D centers
client
.execute("DROP TABLE IF EXISTS item_clusters", &[])
.await?;
client
.execute(
"CREATE TABLE item_clusters (
item_id INTEGER PRIMARY KEY,
cluster_id INTEGER,
center_x FLOAT,
center_y FLOAT,
center_z FLOAT
)",
&[],
)
.await?;
// Clear existing interactions
client
.execute(&format!("TRUNCATE TABLE {}", args.interactions_table), &[])
.await?;
for (i, &item_id) in all_items.iter().enumerate() {
let cluster_id = (i as i32) / items_per_cluster_count;
if cluster_id < args.item_clusters {
items_per_cluster[cluster_id as usize].push(item_id);
// Store cluster assignment with 3D center
let center = &cluster_centers[cluster_id as usize];
client
.execute(
"INSERT INTO item_clusters (item_id, cluster_id, center_x, center_y, center_z)
VALUES ($1, $2, $3, $4, $5)",
&[&item_id, &cluster_id, &center[0], &center[1], &center[2]],
)
.await?;
}
}
// Generate interactions with very strong cluster affinity
info!("Generating interactions...");
let mut item_interactions = HashMap::new();
// For each user cluster
for user_cluster in 0..args.user_clusters {
// Each user cluster strongly prefers exactly one item cluster with wraparound
let preferred_item_cluster = user_cluster % args.item_clusters;
// For each user in this cluster
let cluster_start = user_cluster * (args.num_users / args.user_clusters);
let cluster_end = if user_cluster == args.user_clusters - 1 {
args.num_users
} else {
(user_cluster + 1) * (args.num_users / args.user_clusters)
};
for user_id in cluster_start..cluster_end {
// Get all items from preferred cluster
let preferred_items: HashSet<_> = items_per_cluster[preferred_item_cluster as usize]
.iter()
.copied()
.collect();
// Interact with most preferred items (1 - noise_level of items in preferred cluster)
let num_interactions =
(preferred_items.len() as f64 * (1.0 - args.noise_level)) as usize;
let selected_items: Vec<_> = preferred_items.into_iter().collect();
let mut interactions: HashSet<_> = selected_items
.choose_multiple(&mut rng, num_interactions)
.copied()
.collect();
// Add very few random items from non-preferred clusters (noise)
let num_noise = (args.avg_interactions as f64 * args.noise_level) as i32;
while interactions.len() < num_interactions + num_noise as usize {
let item_id = rng.gen_range(0..args.num_items);
if !interactions.contains(&item_id) {
interactions.insert(item_id);
}
}
// Insert interactions
for &item_id in &interactions {
client
.execute(
&format!(
"INSERT INTO {} (user_id, item_id) VALUES ($1, $2)
ON CONFLICT DO NOTHING",
args.interactions_table
),
&[&user_id, &item_id],
)
.await?;
*item_interactions.entry(item_id).or_insert(0) += 1;
}
}
info!("Generated interactions for user cluster {}", user_cluster);
}
// Get final statistics
let stats = client
.query_one(
&format!(
"SELECT COUNT(*) as total_interactions,
COUNT(DISTINCT user_id) as unique_users,
COUNT(DISTINCT item_id) as unique_items,
COUNT(DISTINCT user_id)::float / {} as user_coverage,
COUNT(DISTINCT item_id)::float / {} as item_coverage
FROM {}",
args.num_users, args.num_items, args.interactions_table
),
&[],
)
.await?;
info!(
"Generated {} total interactions between {} users ({:.1}%) and {} items ({:.1}%)",
stats.get::<_, i64>(0),
stats.get::<_, i64>(1),
stats.get::<_, f64>(3) * 100.0,
stats.get::<_, i64>(2),
stats.get::<_, f64>(4) * 100.0
);
// Print cluster sizes
let cluster_sizes = client
.query(
"SELECT cluster_id, COUNT(*) as size
FROM item_clusters
GROUP BY cluster_id
ORDER BY cluster_id",
&[],
)
.await?;
info!("Cluster sizes:");
for row in cluster_sizes {
let cluster_id: i32 = row.get(0);
let size: i64 = row.get(1);
info!("Cluster {}: {} items", cluster_id, size);
}
info!("Done!");
Ok(())
}