diff --git a/.cargo/config.toml b/.cargo/config.toml index 4b4d39b..2303fff 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -9,7 +9,7 @@ rustflags = [ ] [env] -CXXFLAGS = "-fopenmp -pthread -DUSEOMP=1" -LDFLAGS = "-fopenmp -pthread -DUSEOMP=1" +CXXFLAGS = "-O3 -march=native -fopenmp -pthread -DUSEOMP=1" +LDFLAGS = "-O3 -march=native -fopenmp -pthread -DUSEOMP=1 -flto" CC = "gcc" CXX = "g++" diff --git a/src/main.rs b/src/main.rs index 493ad97..090645e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,11 +4,14 @@ use deadpool_postgres::{Config, Pool, Runtime}; use dotenv::dotenv; use libmf::{Loss, Matrix, Model}; use log::info; +use std::collections::HashSet; use std::env; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio_postgres::NoTls; +static IMMEDIATE_EXIT: AtomicBool = AtomicBool::new(false); + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -16,18 +19,26 @@ struct Args { #[arg(long)] source_table: String, - /// User ID column name + /// User ID column name in source table #[arg(long)] - user_id_column: String, + source_user_id_column: String, - /// Item ID column name + /// Item ID column name in source table #[arg(long)] - item_id_column: String, + source_item_id_column: String, /// Target table for item embeddings #[arg(long)] target_table: String, + /// Target ID column name in the target table + #[arg(long)] + target_id_column: String, + + /// Target column name for embeddings array + #[arg(long)] + target_embedding_column: String, + /// Number of iterations for matrix factorization #[arg(long, short = 'i', default_value = "100")] iterations: i32, @@ -80,35 +91,39 @@ async fn create_pool() -> Result { async fn load_data_batch( client: &deadpool_postgres::Client, source_table: &str, - user_id_column: &str, - item_id_column: &str, + source_user_id_column: &str, + source_item_id_column: &str, batch_size: usize, last_user_id: Option, last_item_id: Option, ) -> Result> { let rows = if let (Some(last_user), Some(last_item)) = (last_user_id, last_item_id) { let query = format!( - "SELECT {user}, {item} FROM {table} \ - WHERE ({user}, {item}) > ($1, $2) \ + "SELECT ({user})::int4, ({item})::int4 FROM {table} \ + WHERE ({user}, {item}) > ($1::int4, $2::int4) \ ORDER BY {user}, {item} \ - LIMIT $3", - user = user_id_column, - item = item_id_column, + LIMIT $3::bigint", + user = source_user_id_column, + item = source_item_id_column, table = source_table, ); client .query(&query, &[&last_user, &last_item, &(batch_size as i64)]) - .await? + .await + .with_context(|| format!("Query failed at {}:{}", file!(), line!()))? } else { let query = format!( - "SELECT {user}, {item} FROM {table} \ + "SELECT ({user})::int4, ({item})::int4 FROM {table} \ ORDER BY {user}, {item} \ - LIMIT $1", - user = user_id_column, - item = item_id_column, + LIMIT $1::bigint", + user = source_user_id_column, + item = source_item_id_column, table = source_table, ); - client.query(&query, &[&(batch_size as i64)]).await? + client + .query(&query, &[&(batch_size as i64)]) + .await + .with_context(|| format!("Query failed at {}:{}", file!(), line!()))? }; let mut batch = Vec::with_capacity(rows.len()); @@ -121,45 +136,59 @@ async fn load_data_batch( Ok(batch) } -async fn save_embeddings(pool: &Pool, args: &Args, model: &Model) -> Result<()> { +async fn save_embeddings( + pool: &Pool, + args: &Args, + model: &Model, + item_ids: &HashSet, +) -> Result<()> { let client = pool.get().await?; - // Create the target table if it doesn't exist - let create_table = format!( - "CREATE TABLE IF NOT EXISTS {} (item_id INTEGER PRIMARY KEY, embedding FLOAT4[])", - args.target_table - ); - client.execute(&create_table, &[]).await?; - let mut valid_embeddings = 0; let mut invalid_embeddings = 0; let batch_size = 128; let mut current_batch = Vec::with_capacity(batch_size); - let mut current_idx = 0; - // Process factors in chunks using the iterator directly - for factors in model.q_iter() { + // Process factors for items that appear in the source table + for (idx, factors) in model.q_iter().enumerate() { + let item_id = idx as i32; + if !item_ids.contains(&item_id) { + continue; + } + // Skip invalid embeddings if factors.iter().any(|&x| x.is_nan()) { invalid_embeddings += 1; - current_idx += 1; continue; } valid_embeddings += 1; - current_batch.push((current_idx, factors)); - current_idx += 1; + current_batch.push((item_id as i64, factors)); // When batch is full, save it if current_batch.len() >= batch_size { - save_batch(&client, &args.target_table, ¤t_batch).await?; + save_batch( + &client, + &args.target_table, + &args.target_id_column, + &args.target_embedding_column, + ¤t_batch, + ) + .await?; current_batch.clear(); } } // Save any remaining items in the last batch if !current_batch.is_empty() { - save_batch(&client, &args.target_table, ¤t_batch).await?; + save_batch( + &client, + &args.target_table, + &args.target_id_column, + &args.target_embedding_column, + ¤t_batch, + ) + .await?; } info!( @@ -173,21 +202,26 @@ async fn save_embeddings(pool: &Pool, args: &Args, model: &Model) -> Result<()> async fn save_batch( client: &deadpool_postgres::Client, target_table: &str, - batch_values: &[(i32, &[f32])], + target_id_column: &str, + target_embedding_column: &str, + batch_values: &[(i64, &[f32])], ) -> Result<()> { // Build the batch insert query let placeholders: Vec = (0..batch_values.len()) - .map(|i| format!("(${}, ${})", i * 2 + 1, i * 2 + 2)) + .map(|i| format!("(${}::int8, ${}::float4[])", i * 2 + 1, i * 2 + 2)) .collect(); let query = format!( r#" - INSERT INTO {target_table} (item_id, embedding) VALUES {placeholders} - ON CONFLICT (item_id) - DO UPDATE SET embedding = EXCLUDED.embedding + INSERT INTO {target_table} ({target_id_column}, {target_embedding_column}) + VALUES {placeholders} + ON CONFLICT ({target_id_column}) + DO UPDATE SET {target_embedding_column} = EXCLUDED.{target_embedding_column} "#, placeholders = placeholders.join(",") ); + // info!("Executing query: {}", query); + // Flatten parameters for the query let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = Vec::new(); for (item_id, factors) in batch_values { @@ -195,7 +229,80 @@ async fn save_batch( params.push(factors); } - client.execute(&query, ¶ms[..]).await?; + info!("Number of parameters: {}", params.len()); + client.execute(&query, ¶ms[..]).await.with_context(|| { + format!( + "Failed to execute batch insert at {}:{} with {} values", + file!(), + line!(), + batch_values.len() + ) + })?; + Ok(()) +} + +async fn validate_table_exists(client: &deadpool_postgres::Client, table: &str) -> Result<()> { + let query = r#"SELECT EXISTS ( + SELECT FROM pg_tables + WHERE schemaname = 'public' + AND tablename = $1 + )"#; + + let exists: bool = client.query_one(query, &[&table]).await?.get(0); + + if !exists { + anyhow::bail!("Table '{}' does not exist", table); + } + Ok(()) +} + +async fn validate_column_exists( + client: &deadpool_postgres::Client, + table: &str, + column: &str, +) -> Result<()> { + let query = r#"SELECT EXISTS ( + SELECT FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = $1 + AND column_name = $2 + )"#; + + let exists: bool = client.query_one(query, &[&table, &column]).await?.get(0); + + if !exists { + anyhow::bail!("Column '{}' does not exist in table '{}'", column, table); + } + Ok(()) +} + +async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Result<()> { + // Validate source table exists + validate_table_exists(client, &args.source_table) + .await + .context("Failed to validate source table")?; + + // Validate source columns exist + validate_column_exists(client, &args.source_table, &args.source_user_id_column) + .await + .context("Failed to validate source user ID column")?; + validate_column_exists(client, &args.source_table, &args.source_item_id_column) + .await + .context("Failed to validate source item ID column")?; + + // Validate target table exists + validate_table_exists(client, &args.target_table) + .await + .context("Failed to validate target table")?; + + // Validate target columns exist + validate_column_exists(client, &args.target_table, &args.target_id_column) + .await + .context("Failed to validate target ID column")?; + validate_column_exists(client, &args.target_table, &args.target_embedding_column) + .await + .context("Failed to validate target embedding column")?; + Ok(()) } @@ -205,27 +312,39 @@ async fn main() -> Result<()> { pretty_env_logger::init(); let args = Args::parse(); - // Set up graceful shutdown + let pool = create_pool().await?; + + // Validate schema before proceeding + info!("Validating database schema..."); + validate_schema(&pool.get().await?, &args).await?; + info!("Schema validation successful"); + + let mut matrix = Matrix::new(); + let mut last_user_id: Option = None; + let mut last_item_id: Option = None; + let mut total_rows = 0; + let mut item_ids = HashSet::new(); + + // Set up graceful shutdown for data loading let running = Arc::new(AtomicBool::new(true)); let r = running.clone(); ctrlc::set_handler(move || { - r.store(false, Ordering::SeqCst); - info!("Received interrupt signal, finishing current batch..."); + if IMMEDIATE_EXIT.load(Ordering::SeqCst) { + info!("Received interrupt signal, exiting immediately..."); + std::process::exit(1); + } else { + r.store(false, Ordering::SeqCst); + info!("Received interrupt signal, finishing current batch..."); + } })?; - let pool = create_pool().await?; - let mut matrix = Matrix::new(); - let mut last_user_id = None; - let mut last_item_id = None; - let mut total_rows = 0; - info!("Starting data loading..."); while running.load(Ordering::SeqCst) { let batch = load_data_batch( &pool.get().await?, &args.source_table, - &args.user_id_column, - &args.item_id_column, + &args.source_user_id_column, + &args.source_item_id_column, args.batch_size as usize, last_user_id, last_item_id, @@ -252,17 +371,24 @@ async fn main() -> Result<()> { // Process batch for (user_id, item_id) in batch { matrix.push(user_id, item_id, 1.0f32); + item_ids.insert(item_id); } } - info!("Loaded {} total rows", total_rows); + info!( + "Loaded {} total rows with {} unique items", + total_rows, + item_ids.len() + ); if total_rows == 0 { info!("No data found in source table"); return Ok(()); } - // Set up training parameters + // Switch to immediate exit mode for training + IMMEDIATE_EXIT.store(true, Ordering::SeqCst); + let model = Model::params() .factors(args.factors) .lambda_p1(args.lambda1) @@ -279,7 +405,7 @@ async fn main() -> Result<()> { .fit(&matrix)?; info!("Saving embeddings..."); - save_embeddings(&pool, &args, &model).await?; + save_embeddings(&pool, &args, &model, &item_ids).await?; Ok(()) }