From 66165a7eee99244a295e653fcdd33aff93244f44 Mon Sep 17 00:00:00 2001 From: Dylan Knutson Date: Sat, 28 Dec 2024 04:40:09 +0000 Subject: [PATCH] batch loading for computed rows --- .cargo/config.toml | 2 - src/bin/generate_test_data.rs | 28 ++++++------ src/bin/visualize_embeddings.rs | 10 ++-- src/main.rs | 81 +++++++++++++++++++-------------- 4 files changed, 66 insertions(+), 55 deletions(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index a883339..4b4d39b 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -3,8 +3,6 @@ rustflags = [ "-C", "link-arg=-fuse-ld=lld", "-C", - "link-arg=-L/usr/lib/x86_64-linux-gnu", - "-C", "link-arg=-lgomp", "-C", "link-arg=-fopenmp", diff --git a/src/bin/generate_test_data.rs b/src/bin/generate_test_data.rs index 49e033c..b891056 100644 --- a/src/bin/generate_test_data.rs +++ b/src/bin/generate_test_data.rs @@ -34,7 +34,7 @@ struct Args { /// Noise level (0.0 to 1.0) - fraction of interactions that are random #[arg(long, default_value = "0.05")] - noise_level: f64, + noise_level: f32, /// Source table name for interactions #[arg(long, default_value = "user_interactions")] @@ -91,7 +91,7 @@ async fn main() -> Result<()> { &format!( "CREATE TABLE IF NOT EXISTS {} ( item_id INTEGER PRIMARY KEY, - embedding FLOAT[] + embedding FLOAT4[] )", args.embeddings_table ), @@ -105,7 +105,7 @@ async fn main() -> Result<()> { // Base vertices of a regular octahedron // These vertices form three orthogonal golden rectangles - let phi = (1.0 + 5.0_f64.sqrt()) / 2.0; // golden ratio + let phi = (1.0 + 5.0_f32.sqrt()) / 2.0; // golden ratio let scale = 2.0; let base_centers = vec![ vec![1.0, phi, 0.0], // front top @@ -123,7 +123,7 @@ async fn main() -> Result<()> { ]; // Normalize and scale the vectors to ensure equal distances - let base_centers: Vec> = base_centers + let base_centers: Vec> = base_centers .into_iter() .map(|v| { let norm = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt(); @@ -140,9 +140,9 @@ async fn main() -> Result<()> { let base = &base_centers[i % base_centers.len()]; // Add very small jitter (1% of scale) to make it more natural let jitter = 0.01; - let x = base[0] + (rng.gen::() - 0.5) * jitter * scale; - let y = base[1] + (rng.gen::() - 0.5) * jitter * scale; - let z = base[2] + (rng.gen::() - 0.5) * jitter * scale; + let x = base[0] + (rng.gen::() - 0.5) * jitter * scale; + let y = base[1] + (rng.gen::() - 0.5) * jitter * scale; + let z = base[2] + (rng.gen::() - 0.5) * jitter * scale; cluster_centers.push(vec![x, y, z]); } @@ -164,7 +164,7 @@ async fn main() -> Result<()> { "CREATE TABLE item_clusters ( item_id INTEGER PRIMARY KEY, cluster_id INTEGER, - cluster_affinities FLOAT[] -- Array of affinities to archetypal items + cluster_affinities FLOAT4[] -- Array of affinities to archetypal items )", &[], ) @@ -177,7 +177,7 @@ async fn main() -> Result<()> { // Create cluster affinity vectors based on item co-occurrence patterns let mut cluster_affinities = - vec![vec![0.0; args.num_items as usize]; args.item_clusters as usize]; + vec![vec![0.0f32; args.num_items as usize]; args.item_clusters as usize]; // For each cluster, select a set of archetypal items that define the cluster for cluster_id in 0..args.item_clusters { @@ -214,7 +214,7 @@ async fn main() -> Result<()> { items_per_cluster[cluster_id as usize].push(item_id); // Store cluster assignment with affinity vector - let affinities: Vec = cluster_affinities[cluster_id as usize].clone(); + let affinities: Vec = cluster_affinities[cluster_id as usize].clone(); client .execute( "INSERT INTO item_clusters (item_id, cluster_id, cluster_affinities) @@ -251,7 +251,7 @@ async fn main() -> Result<()> { // Interact with most preferred items (1 - noise_level of items in preferred cluster) let num_interactions = - (preferred_items.len() as f64 * (1.0 - args.noise_level)) as usize; + (preferred_items.len() as f32 * (1.0 - args.noise_level)) as usize; let selected_items: Vec<_> = preferred_items.into_iter().collect(); let mut interactions: HashSet<_> = selected_items .choose_multiple(&mut rng, num_interactions) @@ -259,7 +259,7 @@ async fn main() -> Result<()> { .collect(); // Add very few random items from non-preferred clusters (noise) - let num_noise = (args.avg_interactions as f64 * args.noise_level) as i32; + let num_noise = (args.avg_interactions as f32 * args.noise_level) as i32; while interactions.len() < num_interactions + num_noise as usize { let item_id = rng.gen_range(0..args.num_items); if !interactions.contains(&item_id) { @@ -293,8 +293,8 @@ async fn main() -> Result<()> { "SELECT COUNT(*) as total_interactions, COUNT(DISTINCT user_id) as unique_users, COUNT(DISTINCT item_id) as unique_items, - COUNT(DISTINCT user_id)::float / {} as user_coverage, - COUNT(DISTINCT item_id)::float / {} as item_coverage + COUNT(DISTINCT user_id)::float4 / {} as user_coverage, + COUNT(DISTINCT item_id)::float4 / {} as item_coverage FROM {}", args.num_users, args.num_items, args.interactions_table ), diff --git a/src/bin/visualize_embeddings.rs b/src/bin/visualize_embeddings.rs index 0726c2f..472a07d 100644 --- a/src/bin/visualize_embeddings.rs +++ b/src/bin/visualize_embeddings.rs @@ -47,7 +47,7 @@ async fn create_pool() -> Result { Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?) } -fn perform_pca(data: &Array2, n_components: usize) -> Result> { +fn perform_pca(data: &Array2, n_components: usize) -> Result> { // Center the data let means = data.mean_axis(ndarray::Axis(0)).unwrap(); let centered = data.clone() - means.view().insert_axis(ndarray::Axis(0)); @@ -82,8 +82,8 @@ async fn main() -> Result<()> { let rows = client.query(&query, &[]).await?; let n_items = rows.len(); let (n_dims, affinity_dims) = if let Some(first_row) = rows.first() { - let embedding: Vec = first_row.get(1); - let affinities: Vec = first_row.get(3); + let embedding: Vec = first_row.get(1); + let affinities: Vec = first_row.get(3); (embedding.len(), affinities.len()) } else { return Ok(()); @@ -102,9 +102,9 @@ async fn main() -> Result<()> { for (i, row) in rows.iter().enumerate() { let item_id: i32 = row.get(0); - let embedding: Vec = row.get(1); + let embedding: Vec = row.get(1); let cluster_id: i32 = row.get(2); - let affinities: Vec = row.get(3); + let affinities: Vec = row.get(3); item_ids.push(item_id); cluster_ids.push(cluster_id); diff --git a/src/main.rs b/src/main.rs index 6b7a56c..493ad97 100644 --- a/src/main.rs +++ b/src/main.rs @@ -121,42 +121,45 @@ async fn load_data_batch( Ok(batch) } -// TODO - don't load all item IDs at once -async fn save_embeddings(pool: &Pool, args: &Args, model: &Model, item_ids: &[i32]) -> Result<()> { +async fn save_embeddings(pool: &Pool, args: &Args, model: &Model) -> Result<()> { let client = pool.get().await?; // Create the target table if it doesn't exist let create_table = format!( - "CREATE TABLE IF NOT EXISTS {} (item_id INTEGER PRIMARY KEY, embedding FLOAT[])", + "CREATE TABLE IF NOT EXISTS {} (item_id INTEGER PRIMARY KEY, embedding FLOAT4[])", args.target_table ); client.execute(&create_table, &[]).await?; - // Get all factors at once - let all_factors: Vec<_> = model.q_iter().collect(); - info!("Generated {} item embeddings", all_factors.len()); - let mut valid_embeddings = 0; let mut invalid_embeddings = 0; + let batch_size = 128; + let mut current_batch = Vec::with_capacity(batch_size); + let mut current_idx = 0; - for &item_id in item_ids { - let factors = &all_factors[item_id as usize]; - let factors_array: Vec = factors.iter().map(|&x| x as f64).collect(); - - // Check if the embedding contains any NaN values - if factors_array.iter().any(|x| x.is_nan()) { + // Process factors in chunks using the iterator directly + for factors in model.q_iter() { + // Skip invalid embeddings + if factors.iter().any(|&x| x.is_nan()) { invalid_embeddings += 1; + current_idx += 1; continue; } valid_embeddings += 1; - let query = format!( - "INSERT INTO {} (item_id, embedding) VALUES ($1, $2) - ON CONFLICT (item_id) DO UPDATE SET embedding = EXCLUDED.embedding", - args.target_table - ); + current_batch.push((current_idx, factors)); + current_idx += 1; - client.execute(&query, &[&item_id, &factors_array]).await?; + // When batch is full, save it + if current_batch.len() >= batch_size { + save_batch(&client, &args.target_table, ¤t_batch).await?; + current_batch.clear(); + } + } + + // Save any remaining items in the last batch + if !current_batch.is_empty() { + save_batch(&client, &args.target_table, ¤t_batch).await?; } info!( @@ -167,17 +170,33 @@ async fn save_embeddings(pool: &Pool, args: &Args, model: &Model, item_ids: &[i3 Ok(()) } -async fn get_unique_item_ids( +async fn save_batch( client: &deadpool_postgres::Client, - source_table: &str, - item_id_column: &str, -) -> Result> { + target_table: &str, + batch_values: &[(i32, &[f32])], +) -> Result<()> { + // Build the batch insert query + let placeholders: Vec = (0..batch_values.len()) + .map(|i| format!("(${}, ${})", i * 2 + 1, i * 2 + 2)) + .collect(); let query = format!( - "SELECT DISTINCT {} FROM {} ORDER BY {}", - item_id_column, source_table, item_id_column + r#" + INSERT INTO {target_table} (item_id, embedding) VALUES {placeholders} + ON CONFLICT (item_id) + DO UPDATE SET embedding = EXCLUDED.embedding + "#, + placeholders = placeholders.join(",") ); - let rows = client.query(&query, &[]).await?; - Ok(rows.iter().map(|row| row.get(0)).collect()) + + // Flatten parameters for the query + let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = Vec::new(); + for (item_id, factors) in batch_values { + params.push(item_id); + params.push(factors); + } + + client.execute(&query, ¶ms[..]).await?; + Ok(()) } #[tokio::main] @@ -243,12 +262,6 @@ async fn main() -> Result<()> { return Ok(()); } - // Get unique item IDs from database - let client = pool.get().await?; - let unique_item_ids = - get_unique_item_ids(&client, &args.source_table, &args.item_id_column).await?; - info!("Found {} unique items", unique_item_ids.len()); - // Set up training parameters let model = Model::params() .factors(args.factors) @@ -266,7 +279,7 @@ async fn main() -> Result<()> { .fit(&matrix)?; info!("Saving embeddings..."); - save_embeddings(&pool, &args, &model, &unique_item_ids).await?; + save_embeddings(&pool, &args, &model).await?; Ok(()) }