285 lines
9.2 KiB
Rust
285 lines
9.2 KiB
Rust
use anyhow::{Context, Result};
|
|
use clap::Parser;
|
|
use deadpool_postgres::{Config, Pool, Runtime};
|
|
use dotenv::dotenv;
|
|
use log::info;
|
|
use rand::seq::SliceRandom;
|
|
use rand::Rng;
|
|
use std::collections::{HashMap, HashSet};
|
|
use std::env;
|
|
use tokio_postgres::NoTls;
|
|
|
|
#[derive(Parser, Debug)]
|
|
#[command(author, version, about, long_about = None)]
|
|
struct Args {
|
|
/// Number of users to generate
|
|
#[arg(long, default_value = "1000")]
|
|
num_users: i32,
|
|
|
|
/// Number of items to generate
|
|
#[arg(long, default_value = "500")]
|
|
num_items: i32,
|
|
|
|
/// Number of user clusters
|
|
#[arg(long, default_value = "5")]
|
|
user_clusters: i32,
|
|
|
|
/// Number of item clusters
|
|
#[arg(long, default_value = "10")]
|
|
item_clusters: i32,
|
|
|
|
/// Average number of interactions per user
|
|
#[arg(long, default_value = "50")]
|
|
avg_interactions: i32,
|
|
|
|
/// Noise level (0.0 to 1.0) - fraction of interactions that are random
|
|
#[arg(long, default_value = "0.05")]
|
|
noise_level: f64,
|
|
|
|
/// Source table name for interactions
|
|
#[arg(long, default_value = "user_interactions")]
|
|
interactions_table: String,
|
|
|
|
/// Target table name for embeddings
|
|
#[arg(long, default_value = "item_embeddings")]
|
|
embeddings_table: String,
|
|
}
|
|
|
|
async fn create_pool() -> Result<Pool> {
|
|
let mut config = Config::new();
|
|
config.host = Some(env::var("POSTGRES_HOST").context("POSTGRES_HOST not set")?);
|
|
config.port = Some(
|
|
env::var("POSTGRES_PORT")
|
|
.context("POSTGRES_PORT not set")?
|
|
.parse()
|
|
.context("Invalid POSTGRES_PORT")?,
|
|
);
|
|
config.dbname = Some(env::var("POSTGRES_DB").context("POSTGRES_DB not set")?);
|
|
config.user = Some(env::var("POSTGRES_USER").context("POSTGRES_USER not set")?);
|
|
config.password = Some(env::var("POSTGRES_PASSWORD").context("POSTGRES_PASSWORD not set")?);
|
|
|
|
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<()> {
|
|
dotenv().ok();
|
|
pretty_env_logger::init();
|
|
let args = Args::parse();
|
|
|
|
let pool = create_pool().await?;
|
|
let client = pool.get().await?;
|
|
|
|
// Create tables
|
|
info!("Creating tables...");
|
|
client
|
|
.execute(
|
|
&format!(
|
|
"CREATE TABLE IF NOT EXISTS {} (
|
|
user_id INTEGER,
|
|
item_id INTEGER,
|
|
PRIMARY KEY (user_id, item_id)
|
|
)",
|
|
args.interactions_table
|
|
),
|
|
&[],
|
|
)
|
|
.await?;
|
|
|
|
client
|
|
.execute(
|
|
&format!(
|
|
"CREATE TABLE IF NOT EXISTS {} (
|
|
item_id INTEGER PRIMARY KEY,
|
|
embedding FLOAT[]
|
|
)",
|
|
args.embeddings_table
|
|
),
|
|
&[],
|
|
)
|
|
.await?;
|
|
|
|
// Generate cluster centers that are well-separated in 3D space
|
|
let mut rng = rand::thread_rng();
|
|
let mut cluster_centers = Vec::new();
|
|
|
|
// Base vertices of a regular octahedron
|
|
let base_centers = vec![
|
|
vec![1.0, 0.0, 0.0], // +x
|
|
vec![-1.0, 0.0, 0.0], // -x
|
|
vec![0.0, 1.0, 0.0], // +y
|
|
vec![0.0, -1.0, 0.0], // -y
|
|
vec![0.0, 0.0, 1.0], // +z
|
|
vec![0.0, 0.0, -1.0], // -z
|
|
];
|
|
|
|
// Scale factor to control separation
|
|
let scale = 2.0;
|
|
|
|
// Add jittered versions of the base centers
|
|
for i in 0..args.item_clusters as usize {
|
|
let base = &base_centers[i % base_centers.len()];
|
|
// Add controlled random jitter (up to 10% of the scale)
|
|
let jitter = 0.1;
|
|
let x = base[0] * scale + (rng.gen::<f64>() - 0.5) * jitter * scale;
|
|
let y = base[1] * scale + (rng.gen::<f64>() - 0.5) * jitter * scale;
|
|
let z = base[2] * scale + (rng.gen::<f64>() - 0.5) * jitter * scale;
|
|
|
|
cluster_centers.push(vec![x, y, z]);
|
|
}
|
|
|
|
// Create shuffled list of items
|
|
let mut all_items: Vec<i32> = (0..args.num_items).collect();
|
|
all_items.shuffle(&mut rng);
|
|
|
|
// Assign items to clusters in contiguous blocks
|
|
let mut items_per_cluster = vec![Vec::new(); args.item_clusters as usize];
|
|
let items_per_cluster_count = args.num_items / args.item_clusters;
|
|
|
|
// Drop existing item_clusters table and recreate with 3D centers
|
|
client
|
|
.execute("DROP TABLE IF EXISTS item_clusters", &[])
|
|
.await?;
|
|
client
|
|
.execute(
|
|
"CREATE TABLE item_clusters (
|
|
item_id INTEGER PRIMARY KEY,
|
|
cluster_id INTEGER,
|
|
center_x FLOAT,
|
|
center_y FLOAT,
|
|
center_z FLOAT
|
|
)",
|
|
&[],
|
|
)
|
|
.await?;
|
|
|
|
// Clear existing interactions
|
|
client
|
|
.execute(&format!("TRUNCATE TABLE {}", args.interactions_table), &[])
|
|
.await?;
|
|
|
|
for (i, &item_id) in all_items.iter().enumerate() {
|
|
let cluster_id = (i as i32) / items_per_cluster_count;
|
|
if cluster_id < args.item_clusters {
|
|
items_per_cluster[cluster_id as usize].push(item_id);
|
|
|
|
// Store cluster assignment with 3D center
|
|
let center = &cluster_centers[cluster_id as usize];
|
|
client
|
|
.execute(
|
|
"INSERT INTO item_clusters (item_id, cluster_id, center_x, center_y, center_z)
|
|
VALUES ($1, $2, $3, $4, $5)",
|
|
&[&item_id, &cluster_id, ¢er[0], ¢er[1], ¢er[2]],
|
|
)
|
|
.await?;
|
|
}
|
|
}
|
|
|
|
// Generate interactions with very strong cluster affinity
|
|
info!("Generating interactions...");
|
|
let mut item_interactions = HashMap::new();
|
|
|
|
// For each user cluster
|
|
for user_cluster in 0..args.user_clusters {
|
|
// Each user cluster strongly prefers exactly one item cluster with wraparound
|
|
let preferred_item_cluster = user_cluster % args.item_clusters;
|
|
|
|
// For each user in this cluster
|
|
let cluster_start = user_cluster * (args.num_users / args.user_clusters);
|
|
let cluster_end = if user_cluster == args.user_clusters - 1 {
|
|
args.num_users
|
|
} else {
|
|
(user_cluster + 1) * (args.num_users / args.user_clusters)
|
|
};
|
|
|
|
for user_id in cluster_start..cluster_end {
|
|
// Get all items from preferred cluster
|
|
let preferred_items: HashSet<_> = items_per_cluster[preferred_item_cluster as usize]
|
|
.iter()
|
|
.copied()
|
|
.collect();
|
|
|
|
// Interact with most preferred items (1 - noise_level of items in preferred cluster)
|
|
let num_interactions =
|
|
(preferred_items.len() as f64 * (1.0 - args.noise_level)) as usize;
|
|
let selected_items: Vec<_> = preferred_items.into_iter().collect();
|
|
let mut interactions: HashSet<_> = selected_items
|
|
.choose_multiple(&mut rng, num_interactions)
|
|
.copied()
|
|
.collect();
|
|
|
|
// Add very few random items from non-preferred clusters (noise)
|
|
let num_noise = (args.avg_interactions as f64 * args.noise_level) as i32;
|
|
while interactions.len() < num_interactions + num_noise as usize {
|
|
let item_id = rng.gen_range(0..args.num_items);
|
|
if !interactions.contains(&item_id) {
|
|
interactions.insert(item_id);
|
|
}
|
|
}
|
|
|
|
// Insert interactions
|
|
for &item_id in &interactions {
|
|
client
|
|
.execute(
|
|
&format!(
|
|
"INSERT INTO {} (user_id, item_id) VALUES ($1, $2)
|
|
ON CONFLICT DO NOTHING",
|
|
args.interactions_table
|
|
),
|
|
&[&user_id, &item_id],
|
|
)
|
|
.await?;
|
|
*item_interactions.entry(item_id).or_insert(0) += 1;
|
|
}
|
|
}
|
|
|
|
info!("Generated interactions for user cluster {}", user_cluster);
|
|
}
|
|
|
|
// Get final statistics
|
|
let stats = client
|
|
.query_one(
|
|
&format!(
|
|
"SELECT COUNT(*) as total_interactions,
|
|
COUNT(DISTINCT user_id) as unique_users,
|
|
COUNT(DISTINCT item_id) as unique_items,
|
|
COUNT(DISTINCT user_id)::float / {} as user_coverage,
|
|
COUNT(DISTINCT item_id)::float / {} as item_coverage
|
|
FROM {}",
|
|
args.num_users, args.num_items, args.interactions_table
|
|
),
|
|
&[],
|
|
)
|
|
.await?;
|
|
|
|
info!(
|
|
"Generated {} total interactions between {} users ({:.1}%) and {} items ({:.1}%)",
|
|
stats.get::<_, i64>(0),
|
|
stats.get::<_, i64>(1),
|
|
stats.get::<_, f64>(3) * 100.0,
|
|
stats.get::<_, i64>(2),
|
|
stats.get::<_, f64>(4) * 100.0
|
|
);
|
|
|
|
// Print cluster sizes
|
|
let cluster_sizes = client
|
|
.query(
|
|
"SELECT cluster_id, COUNT(*) as size
|
|
FROM item_clusters
|
|
GROUP BY cluster_id
|
|
ORDER BY cluster_id",
|
|
&[],
|
|
)
|
|
.await?;
|
|
|
|
info!("Cluster sizes:");
|
|
for row in cluster_sizes {
|
|
let cluster_id: i32 = row.get(0);
|
|
let size: i64 = row.get(1);
|
|
info!("Cluster {}: {} items", cluster_id, size);
|
|
}
|
|
|
|
info!("Done!");
|
|
Ok(())
|
|
}
|