Refactor data loading and embedding saving process

- Updated `.cargo/config.toml` to optimize compilation flags for performance.
- Enhanced `main.rs` by:
  - Renaming user and item ID columns for clarity.
  - Adding validation functions to ensure the existence of tables and columns in the database schema.
  - Implementing immediate exit handling during data loading.
  - Modifying the `save_embeddings` function to accept item IDs for processing.
  - Improving error handling with context messages for database operations.

These changes improve code readability, robustness, and performance during data processing.
This commit is contained in:
Dylan Knutson
2024-12-28 06:42:28 +00:00
parent c791203d1c
commit 350c61c313
2 changed files with 181 additions and 55 deletions

View File

@@ -9,7 +9,7 @@ rustflags = [
] ]
[env] [env]
CXXFLAGS = "-fopenmp -pthread -DUSEOMP=1" CXXFLAGS = "-O3 -march=native -fopenmp -pthread -DUSEOMP=1"
LDFLAGS = "-fopenmp -pthread -DUSEOMP=1" LDFLAGS = "-O3 -march=native -fopenmp -pthread -DUSEOMP=1 -flto"
CC = "gcc" CC = "gcc"
CXX = "g++" CXX = "g++"

View File

@@ -4,11 +4,14 @@ use deadpool_postgres::{Config, Pool, Runtime};
use dotenv::dotenv; use dotenv::dotenv;
use libmf::{Loss, Matrix, Model}; use libmf::{Loss, Matrix, Model};
use log::info; use log::info;
use std::collections::HashSet;
use std::env; use std::env;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use tokio_postgres::NoTls; use tokio_postgres::NoTls;
static IMMEDIATE_EXIT: AtomicBool = AtomicBool::new(false);
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
struct Args { struct Args {
@@ -16,18 +19,26 @@ struct Args {
#[arg(long)] #[arg(long)]
source_table: String, source_table: String,
/// User ID column name /// User ID column name in source table
#[arg(long)] #[arg(long)]
user_id_column: String, source_user_id_column: String,
/// Item ID column name /// Item ID column name in source table
#[arg(long)] #[arg(long)]
item_id_column: String, source_item_id_column: String,
/// Target table for item embeddings /// Target table for item embeddings
#[arg(long)] #[arg(long)]
target_table: String, target_table: String,
/// Target ID column name in the target table
#[arg(long)]
target_id_column: String,
/// Target column name for embeddings array
#[arg(long)]
target_embedding_column: String,
/// Number of iterations for matrix factorization /// Number of iterations for matrix factorization
#[arg(long, short = 'i', default_value = "100")] #[arg(long, short = 'i', default_value = "100")]
iterations: i32, iterations: i32,
@@ -80,35 +91,39 @@ async fn create_pool() -> Result<Pool> {
async fn load_data_batch( async fn load_data_batch(
client: &deadpool_postgres::Client, client: &deadpool_postgres::Client,
source_table: &str, source_table: &str,
user_id_column: &str, source_user_id_column: &str,
item_id_column: &str, source_item_id_column: &str,
batch_size: usize, batch_size: usize,
last_user_id: Option<i32>, last_user_id: Option<i32>,
last_item_id: Option<i32>, last_item_id: Option<i32>,
) -> Result<Vec<(i32, i32)>> { ) -> Result<Vec<(i32, i32)>> {
let rows = if let (Some(last_user), Some(last_item)) = (last_user_id, last_item_id) { let rows = if let (Some(last_user), Some(last_item)) = (last_user_id, last_item_id) {
let query = format!( let query = format!(
"SELECT {user}, {item} FROM {table} \ "SELECT ({user})::int4, ({item})::int4 FROM {table} \
WHERE ({user}, {item}) > ($1, $2) \ WHERE ({user}, {item}) > ($1::int4, $2::int4) \
ORDER BY {user}, {item} \ ORDER BY {user}, {item} \
LIMIT $3", LIMIT $3::bigint",
user = user_id_column, user = source_user_id_column,
item = item_id_column, item = source_item_id_column,
table = source_table, table = source_table,
); );
client client
.query(&query, &[&last_user, &last_item, &(batch_size as i64)]) .query(&query, &[&last_user, &last_item, &(batch_size as i64)])
.await? .await
.with_context(|| format!("Query failed at {}:{}", file!(), line!()))?
} else { } else {
let query = format!( let query = format!(
"SELECT {user}, {item} FROM {table} \ "SELECT ({user})::int4, ({item})::int4 FROM {table} \
ORDER BY {user}, {item} \ ORDER BY {user}, {item} \
LIMIT $1", LIMIT $1::bigint",
user = user_id_column, user = source_user_id_column,
item = item_id_column, item = source_item_id_column,
table = source_table, table = source_table,
); );
client.query(&query, &[&(batch_size as i64)]).await? client
.query(&query, &[&(batch_size as i64)])
.await
.with_context(|| format!("Query failed at {}:{}", file!(), line!()))?
}; };
let mut batch = Vec::with_capacity(rows.len()); let mut batch = Vec::with_capacity(rows.len());
@@ -121,45 +136,59 @@ async fn load_data_batch(
Ok(batch) Ok(batch)
} }
async fn save_embeddings(pool: &Pool, args: &Args, model: &Model) -> Result<()> { async fn save_embeddings(
pool: &Pool,
args: &Args,
model: &Model,
item_ids: &HashSet<i32>,
) -> Result<()> {
let client = pool.get().await?; let client = pool.get().await?;
// Create the target table if it doesn't exist
let create_table = format!(
"CREATE TABLE IF NOT EXISTS {} (item_id INTEGER PRIMARY KEY, embedding FLOAT4[])",
args.target_table
);
client.execute(&create_table, &[]).await?;
let mut valid_embeddings = 0; let mut valid_embeddings = 0;
let mut invalid_embeddings = 0; let mut invalid_embeddings = 0;
let batch_size = 128; let batch_size = 128;
let mut current_batch = Vec::with_capacity(batch_size); let mut current_batch = Vec::with_capacity(batch_size);
let mut current_idx = 0;
// Process factors in chunks using the iterator directly // Process factors for items that appear in the source table
for factors in model.q_iter() { for (idx, factors) in model.q_iter().enumerate() {
let item_id = idx as i32;
if !item_ids.contains(&item_id) {
continue;
}
// Skip invalid embeddings // Skip invalid embeddings
if factors.iter().any(|&x| x.is_nan()) { if factors.iter().any(|&x| x.is_nan()) {
invalid_embeddings += 1; invalid_embeddings += 1;
current_idx += 1;
continue; continue;
} }
valid_embeddings += 1; valid_embeddings += 1;
current_batch.push((current_idx, factors)); current_batch.push((item_id as i64, factors));
current_idx += 1;
// When batch is full, save it // When batch is full, save it
if current_batch.len() >= batch_size { if current_batch.len() >= batch_size {
save_batch(&client, &args.target_table, &current_batch).await?; save_batch(
&client,
&args.target_table,
&args.target_id_column,
&args.target_embedding_column,
&current_batch,
)
.await?;
current_batch.clear(); current_batch.clear();
} }
} }
// Save any remaining items in the last batch // Save any remaining items in the last batch
if !current_batch.is_empty() { if !current_batch.is_empty() {
save_batch(&client, &args.target_table, &current_batch).await?; save_batch(
&client,
&args.target_table,
&args.target_id_column,
&args.target_embedding_column,
&current_batch,
)
.await?;
} }
info!( info!(
@@ -173,21 +202,26 @@ async fn save_embeddings(pool: &Pool, args: &Args, model: &Model) -> Result<()>
async fn save_batch( async fn save_batch(
client: &deadpool_postgres::Client, client: &deadpool_postgres::Client,
target_table: &str, target_table: &str,
batch_values: &[(i32, &[f32])], target_id_column: &str,
target_embedding_column: &str,
batch_values: &[(i64, &[f32])],
) -> Result<()> { ) -> Result<()> {
// Build the batch insert query // Build the batch insert query
let placeholders: Vec<String> = (0..batch_values.len()) let placeholders: Vec<String> = (0..batch_values.len())
.map(|i| format!("(${}, ${})", i * 2 + 1, i * 2 + 2)) .map(|i| format!("(${}::int8, ${}::float4[])", i * 2 + 1, i * 2 + 2))
.collect(); .collect();
let query = format!( let query = format!(
r#" r#"
INSERT INTO {target_table} (item_id, embedding) VALUES {placeholders} INSERT INTO {target_table} ({target_id_column}, {target_embedding_column})
ON CONFLICT (item_id) VALUES {placeholders}
DO UPDATE SET embedding = EXCLUDED.embedding ON CONFLICT ({target_id_column})
DO UPDATE SET {target_embedding_column} = EXCLUDED.{target_embedding_column}
"#, "#,
placeholders = placeholders.join(",") placeholders = placeholders.join(",")
); );
// info!("Executing query: {}", query);
// Flatten parameters for the query // Flatten parameters for the query
let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = Vec::new(); let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = Vec::new();
for (item_id, factors) in batch_values { for (item_id, factors) in batch_values {
@@ -195,7 +229,80 @@ async fn save_batch(
params.push(factors); params.push(factors);
} }
client.execute(&query, &params[..]).await?; info!("Number of parameters: {}", params.len());
client.execute(&query, &params[..]).await.with_context(|| {
format!(
"Failed to execute batch insert at {}:{} with {} values",
file!(),
line!(),
batch_values.len()
)
})?;
Ok(())
}
async fn validate_table_exists(client: &deadpool_postgres::Client, table: &str) -> Result<()> {
let query = r#"SELECT EXISTS (
SELECT FROM pg_tables
WHERE schemaname = 'public'
AND tablename = $1
)"#;
let exists: bool = client.query_one(query, &[&table]).await?.get(0);
if !exists {
anyhow::bail!("Table '{}' does not exist", table);
}
Ok(())
}
async fn validate_column_exists(
client: &deadpool_postgres::Client,
table: &str,
column: &str,
) -> Result<()> {
let query = r#"SELECT EXISTS (
SELECT FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = $1
AND column_name = $2
)"#;
let exists: bool = client.query_one(query, &[&table, &column]).await?.get(0);
if !exists {
anyhow::bail!("Column '{}' does not exist in table '{}'", column, table);
}
Ok(())
}
async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Result<()> {
// Validate source table exists
validate_table_exists(client, &args.source_table)
.await
.context("Failed to validate source table")?;
// Validate source columns exist
validate_column_exists(client, &args.source_table, &args.source_user_id_column)
.await
.context("Failed to validate source user ID column")?;
validate_column_exists(client, &args.source_table, &args.source_item_id_column)
.await
.context("Failed to validate source item ID column")?;
// Validate target table exists
validate_table_exists(client, &args.target_table)
.await
.context("Failed to validate target table")?;
// Validate target columns exist
validate_column_exists(client, &args.target_table, &args.target_id_column)
.await
.context("Failed to validate target ID column")?;
validate_column_exists(client, &args.target_table, &args.target_embedding_column)
.await
.context("Failed to validate target embedding column")?;
Ok(()) Ok(())
} }
@@ -205,27 +312,39 @@ async fn main() -> Result<()> {
pretty_env_logger::init(); pretty_env_logger::init();
let args = Args::parse(); let args = Args::parse();
// Set up graceful shutdown let pool = create_pool().await?;
// Validate schema before proceeding
info!("Validating database schema...");
validate_schema(&pool.get().await?, &args).await?;
info!("Schema validation successful");
let mut matrix = Matrix::new();
let mut last_user_id: Option<i32> = None;
let mut last_item_id: Option<i32> = None;
let mut total_rows = 0;
let mut item_ids = HashSet::new();
// Set up graceful shutdown for data loading
let running = Arc::new(AtomicBool::new(true)); let running = Arc::new(AtomicBool::new(true));
let r = running.clone(); let r = running.clone();
ctrlc::set_handler(move || { ctrlc::set_handler(move || {
if IMMEDIATE_EXIT.load(Ordering::SeqCst) {
info!("Received interrupt signal, exiting immediately...");
std::process::exit(1);
} else {
r.store(false, Ordering::SeqCst); r.store(false, Ordering::SeqCst);
info!("Received interrupt signal, finishing current batch..."); info!("Received interrupt signal, finishing current batch...");
}
})?; })?;
let pool = create_pool().await?;
let mut matrix = Matrix::new();
let mut last_user_id = None;
let mut last_item_id = None;
let mut total_rows = 0;
info!("Starting data loading..."); info!("Starting data loading...");
while running.load(Ordering::SeqCst) { while running.load(Ordering::SeqCst) {
let batch = load_data_batch( let batch = load_data_batch(
&pool.get().await?, &pool.get().await?,
&args.source_table, &args.source_table,
&args.user_id_column, &args.source_user_id_column,
&args.item_id_column, &args.source_item_id_column,
args.batch_size as usize, args.batch_size as usize,
last_user_id, last_user_id,
last_item_id, last_item_id,
@@ -252,17 +371,24 @@ async fn main() -> Result<()> {
// Process batch // Process batch
for (user_id, item_id) in batch { for (user_id, item_id) in batch {
matrix.push(user_id, item_id, 1.0f32); matrix.push(user_id, item_id, 1.0f32);
item_ids.insert(item_id);
} }
} }
info!("Loaded {} total rows", total_rows); info!(
"Loaded {} total rows with {} unique items",
total_rows,
item_ids.len()
);
if total_rows == 0 { if total_rows == 0 {
info!("No data found in source table"); info!("No data found in source table");
return Ok(()); return Ok(());
} }
// Set up training parameters // Switch to immediate exit mode for training
IMMEDIATE_EXIT.store(true, Ordering::SeqCst);
let model = Model::params() let model = Model::params()
.factors(args.factors) .factors(args.factors)
.lambda_p1(args.lambda1) .lambda_p1(args.lambda1)
@@ -279,7 +405,7 @@ async fn main() -> Result<()> {
.fit(&matrix)?; .fit(&matrix)?;
info!("Saving embeddings..."); info!("Saving embeddings...");
save_embeddings(&pool, &args, &model).await?; save_embeddings(&pool, &args, &model, &item_ids).await?;
Ok(()) Ok(())
} }