batch loading for computed rows

This commit is contained in:
Dylan Knutson
2024-12-28 04:40:09 +00:00
parent 9aece9c740
commit 66165a7eee
4 changed files with 66 additions and 55 deletions

View File

@@ -3,8 +3,6 @@ rustflags = [
"-C", "-C",
"link-arg=-fuse-ld=lld", "link-arg=-fuse-ld=lld",
"-C", "-C",
"link-arg=-L/usr/lib/x86_64-linux-gnu",
"-C",
"link-arg=-lgomp", "link-arg=-lgomp",
"-C", "-C",
"link-arg=-fopenmp", "link-arg=-fopenmp",

View File

@@ -34,7 +34,7 @@ struct Args {
/// Noise level (0.0 to 1.0) - fraction of interactions that are random /// Noise level (0.0 to 1.0) - fraction of interactions that are random
#[arg(long, default_value = "0.05")] #[arg(long, default_value = "0.05")]
noise_level: f64, noise_level: f32,
/// Source table name for interactions /// Source table name for interactions
#[arg(long, default_value = "user_interactions")] #[arg(long, default_value = "user_interactions")]
@@ -91,7 +91,7 @@ async fn main() -> Result<()> {
&format!( &format!(
"CREATE TABLE IF NOT EXISTS {} ( "CREATE TABLE IF NOT EXISTS {} (
item_id INTEGER PRIMARY KEY, item_id INTEGER PRIMARY KEY,
embedding FLOAT[] embedding FLOAT4[]
)", )",
args.embeddings_table args.embeddings_table
), ),
@@ -105,7 +105,7 @@ async fn main() -> Result<()> {
// Base vertices of a regular octahedron // Base vertices of a regular octahedron
// These vertices form three orthogonal golden rectangles // These vertices form three orthogonal golden rectangles
let phi = (1.0 + 5.0_f64.sqrt()) / 2.0; // golden ratio let phi = (1.0 + 5.0_f32.sqrt()) / 2.0; // golden ratio
let scale = 2.0; let scale = 2.0;
let base_centers = vec![ let base_centers = vec![
vec![1.0, phi, 0.0], // front top vec![1.0, phi, 0.0], // front top
@@ -123,7 +123,7 @@ async fn main() -> Result<()> {
]; ];
// Normalize and scale the vectors to ensure equal distances // Normalize and scale the vectors to ensure equal distances
let base_centers: Vec<Vec<f64>> = base_centers let base_centers: Vec<Vec<f32>> = base_centers
.into_iter() .into_iter()
.map(|v| { .map(|v| {
let norm = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt(); let norm = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
@@ -140,9 +140,9 @@ async fn main() -> Result<()> {
let base = &base_centers[i % base_centers.len()]; let base = &base_centers[i % base_centers.len()];
// Add very small jitter (1% of scale) to make it more natural // Add very small jitter (1% of scale) to make it more natural
let jitter = 0.01; let jitter = 0.01;
let x = base[0] + (rng.gen::<f64>() - 0.5) * jitter * scale; let x = base[0] + (rng.gen::<f32>() - 0.5) * jitter * scale;
let y = base[1] + (rng.gen::<f64>() - 0.5) * jitter * scale; let y = base[1] + (rng.gen::<f32>() - 0.5) * jitter * scale;
let z = base[2] + (rng.gen::<f64>() - 0.5) * jitter * scale; let z = base[2] + (rng.gen::<f32>() - 0.5) * jitter * scale;
cluster_centers.push(vec![x, y, z]); cluster_centers.push(vec![x, y, z]);
} }
@@ -164,7 +164,7 @@ async fn main() -> Result<()> {
"CREATE TABLE item_clusters ( "CREATE TABLE item_clusters (
item_id INTEGER PRIMARY KEY, item_id INTEGER PRIMARY KEY,
cluster_id INTEGER, cluster_id INTEGER,
cluster_affinities FLOAT[] -- Array of affinities to archetypal items cluster_affinities FLOAT4[] -- Array of affinities to archetypal items
)", )",
&[], &[],
) )
@@ -177,7 +177,7 @@ async fn main() -> Result<()> {
// Create cluster affinity vectors based on item co-occurrence patterns // Create cluster affinity vectors based on item co-occurrence patterns
let mut cluster_affinities = let mut cluster_affinities =
vec![vec![0.0; args.num_items as usize]; args.item_clusters as usize]; vec![vec![0.0f32; args.num_items as usize]; args.item_clusters as usize];
// For each cluster, select a set of archetypal items that define the cluster // For each cluster, select a set of archetypal items that define the cluster
for cluster_id in 0..args.item_clusters { for cluster_id in 0..args.item_clusters {
@@ -214,7 +214,7 @@ async fn main() -> Result<()> {
items_per_cluster[cluster_id as usize].push(item_id); items_per_cluster[cluster_id as usize].push(item_id);
// Store cluster assignment with affinity vector // Store cluster assignment with affinity vector
let affinities: Vec<f64> = cluster_affinities[cluster_id as usize].clone(); let affinities: Vec<f32> = cluster_affinities[cluster_id as usize].clone();
client client
.execute( .execute(
"INSERT INTO item_clusters (item_id, cluster_id, cluster_affinities) "INSERT INTO item_clusters (item_id, cluster_id, cluster_affinities)
@@ -251,7 +251,7 @@ async fn main() -> Result<()> {
// Interact with most preferred items (1 - noise_level of items in preferred cluster) // Interact with most preferred items (1 - noise_level of items in preferred cluster)
let num_interactions = let num_interactions =
(preferred_items.len() as f64 * (1.0 - args.noise_level)) as usize; (preferred_items.len() as f32 * (1.0 - args.noise_level)) as usize;
let selected_items: Vec<_> = preferred_items.into_iter().collect(); let selected_items: Vec<_> = preferred_items.into_iter().collect();
let mut interactions: HashSet<_> = selected_items let mut interactions: HashSet<_> = selected_items
.choose_multiple(&mut rng, num_interactions) .choose_multiple(&mut rng, num_interactions)
@@ -259,7 +259,7 @@ async fn main() -> Result<()> {
.collect(); .collect();
// Add very few random items from non-preferred clusters (noise) // Add very few random items from non-preferred clusters (noise)
let num_noise = (args.avg_interactions as f64 * args.noise_level) as i32; let num_noise = (args.avg_interactions as f32 * args.noise_level) as i32;
while interactions.len() < num_interactions + num_noise as usize { while interactions.len() < num_interactions + num_noise as usize {
let item_id = rng.gen_range(0..args.num_items); let item_id = rng.gen_range(0..args.num_items);
if !interactions.contains(&item_id) { if !interactions.contains(&item_id) {
@@ -293,8 +293,8 @@ async fn main() -> Result<()> {
"SELECT COUNT(*) as total_interactions, "SELECT COUNT(*) as total_interactions,
COUNT(DISTINCT user_id) as unique_users, COUNT(DISTINCT user_id) as unique_users,
COUNT(DISTINCT item_id) as unique_items, COUNT(DISTINCT item_id) as unique_items,
COUNT(DISTINCT user_id)::float / {} as user_coverage, COUNT(DISTINCT user_id)::float4 / {} as user_coverage,
COUNT(DISTINCT item_id)::float / {} as item_coverage COUNT(DISTINCT item_id)::float4 / {} as item_coverage
FROM {}", FROM {}",
args.num_users, args.num_items, args.interactions_table args.num_users, args.num_items, args.interactions_table
), ),

View File

@@ -47,7 +47,7 @@ async fn create_pool() -> Result<Pool> {
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?) Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
} }
fn perform_pca(data: &Array2<f64>, n_components: usize) -> Result<Array2<f64>> { fn perform_pca(data: &Array2<f32>, n_components: usize) -> Result<Array2<f32>> {
// Center the data // Center the data
let means = data.mean_axis(ndarray::Axis(0)).unwrap(); let means = data.mean_axis(ndarray::Axis(0)).unwrap();
let centered = data.clone() - means.view().insert_axis(ndarray::Axis(0)); let centered = data.clone() - means.view().insert_axis(ndarray::Axis(0));
@@ -82,8 +82,8 @@ async fn main() -> Result<()> {
let rows = client.query(&query, &[]).await?; let rows = client.query(&query, &[]).await?;
let n_items = rows.len(); let n_items = rows.len();
let (n_dims, affinity_dims) = if let Some(first_row) = rows.first() { let (n_dims, affinity_dims) = if let Some(first_row) = rows.first() {
let embedding: Vec<f64> = first_row.get(1); let embedding: Vec<f32> = first_row.get(1);
let affinities: Vec<f64> = first_row.get(3); let affinities: Vec<f32> = first_row.get(3);
(embedding.len(), affinities.len()) (embedding.len(), affinities.len())
} else { } else {
return Ok(()); return Ok(());
@@ -102,9 +102,9 @@ async fn main() -> Result<()> {
for (i, row) in rows.iter().enumerate() { for (i, row) in rows.iter().enumerate() {
let item_id: i32 = row.get(0); let item_id: i32 = row.get(0);
let embedding: Vec<f64> = row.get(1); let embedding: Vec<f32> = row.get(1);
let cluster_id: i32 = row.get(2); let cluster_id: i32 = row.get(2);
let affinities: Vec<f64> = row.get(3); let affinities: Vec<f32> = row.get(3);
item_ids.push(item_id); item_ids.push(item_id);
cluster_ids.push(cluster_id); cluster_ids.push(cluster_id);

View File

@@ -121,42 +121,45 @@ async fn load_data_batch(
Ok(batch) Ok(batch)
} }
// TODO - don't load all item IDs at once async fn save_embeddings(pool: &Pool, args: &Args, model: &Model) -> Result<()> {
async fn save_embeddings(pool: &Pool, args: &Args, model: &Model, item_ids: &[i32]) -> Result<()> {
let client = pool.get().await?; let client = pool.get().await?;
// Create the target table if it doesn't exist // Create the target table if it doesn't exist
let create_table = format!( let create_table = format!(
"CREATE TABLE IF NOT EXISTS {} (item_id INTEGER PRIMARY KEY, embedding FLOAT[])", "CREATE TABLE IF NOT EXISTS {} (item_id INTEGER PRIMARY KEY, embedding FLOAT4[])",
args.target_table args.target_table
); );
client.execute(&create_table, &[]).await?; client.execute(&create_table, &[]).await?;
// Get all factors at once
let all_factors: Vec<_> = model.q_iter().collect();
info!("Generated {} item embeddings", all_factors.len());
let mut valid_embeddings = 0; let mut valid_embeddings = 0;
let mut invalid_embeddings = 0; let mut invalid_embeddings = 0;
let batch_size = 128;
let mut current_batch = Vec::with_capacity(batch_size);
let mut current_idx = 0;
for &item_id in item_ids { // Process factors in chunks using the iterator directly
let factors = &all_factors[item_id as usize]; for factors in model.q_iter() {
let factors_array: Vec<f64> = factors.iter().map(|&x| x as f64).collect(); // Skip invalid embeddings
if factors.iter().any(|&x| x.is_nan()) {
// Check if the embedding contains any NaN values
if factors_array.iter().any(|x| x.is_nan()) {
invalid_embeddings += 1; invalid_embeddings += 1;
current_idx += 1;
continue; continue;
} }
valid_embeddings += 1; valid_embeddings += 1;
let query = format!( current_batch.push((current_idx, factors));
"INSERT INTO {} (item_id, embedding) VALUES ($1, $2) current_idx += 1;
ON CONFLICT (item_id) DO UPDATE SET embedding = EXCLUDED.embedding",
args.target_table
);
client.execute(&query, &[&item_id, &factors_array]).await?; // When batch is full, save it
if current_batch.len() >= batch_size {
save_batch(&client, &args.target_table, &current_batch).await?;
current_batch.clear();
}
}
// Save any remaining items in the last batch
if !current_batch.is_empty() {
save_batch(&client, &args.target_table, &current_batch).await?;
} }
info!( info!(
@@ -167,17 +170,33 @@ async fn save_embeddings(pool: &Pool, args: &Args, model: &Model, item_ids: &[i3
Ok(()) Ok(())
} }
async fn get_unique_item_ids( async fn save_batch(
client: &deadpool_postgres::Client, client: &deadpool_postgres::Client,
source_table: &str, target_table: &str,
item_id_column: &str, batch_values: &[(i32, &[f32])],
) -> Result<Vec<i32>> { ) -> Result<()> {
// Build the batch insert query
let placeholders: Vec<String> = (0..batch_values.len())
.map(|i| format!("(${}, ${})", i * 2 + 1, i * 2 + 2))
.collect();
let query = format!( let query = format!(
"SELECT DISTINCT {} FROM {} ORDER BY {}", r#"
item_id_column, source_table, item_id_column INSERT INTO {target_table} (item_id, embedding) VALUES {placeholders}
ON CONFLICT (item_id)
DO UPDATE SET embedding = EXCLUDED.embedding
"#,
placeholders = placeholders.join(",")
); );
let rows = client.query(&query, &[]).await?;
Ok(rows.iter().map(|row| row.get(0)).collect()) // Flatten parameters for the query
let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = Vec::new();
for (item_id, factors) in batch_values {
params.push(item_id);
params.push(factors);
}
client.execute(&query, &params[..]).await?;
Ok(())
} }
#[tokio::main] #[tokio::main]
@@ -243,12 +262,6 @@ async fn main() -> Result<()> {
return Ok(()); return Ok(());
} }
// Get unique item IDs from database
let client = pool.get().await?;
let unique_item_ids =
get_unique_item_ids(&client, &args.source_table, &args.item_id_column).await?;
info!("Found {} unique items", unique_item_ids.len());
// Set up training parameters // Set up training parameters
let model = Model::params() let model = Model::params()
.factors(args.factors) .factors(args.factors)
@@ -266,7 +279,7 @@ async fn main() -> Result<()> {
.fit(&matrix)?; .fit(&matrix)?;
info!("Saving embeddings..."); info!("Saving embeddings...");
save_embeddings(&pool, &args, &model, &unique_item_ids).await?; save_embeddings(&pool, &args, &model).await?;
Ok(()) Ok(())
} }