From c4e79a36f94359f9866df7eee3f792296f9d501b Mon Sep 17 00:00:00 2001 From: Dylan Knutson Date: Sat, 28 Dec 2024 18:16:39 +0000 Subject: [PATCH] Add argument parsing for data loading configuration - Introduced a new `args.rs` file to define command-line arguments for data loading parameters, including source and target table details, matrix factorization settings, and optional interaction limits. - Refactored `main.rs` to utilize the new argument structure, enhancing code organization and readability. - Removed the previous inline argument definitions, streamlining the main application logic. These changes improve the configurability and maintainability of the data loading process. --- src/args.rs | 65 +++++++++++++++++ src/main.rs | 199 ++++++++++++++++++++++------------------------------ 2 files changed, 147 insertions(+), 117 deletions(-) create mode 100644 src/args.rs diff --git a/src/args.rs b/src/args.rs new file mode 100644 index 0000000..f8edbeb --- /dev/null +++ b/src/args.rs @@ -0,0 +1,65 @@ +use clap::Parser; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +pub struct Args { + /// Source table name + #[arg(long)] + pub source_table: String, + + /// User ID column name in source table + #[arg(long)] + pub source_user_id_column: String, + + /// Item ID column name in source table + #[arg(long)] + pub source_item_id_column: String, + + /// Target table for item embeddings + #[arg(long)] + pub target_table: String, + + /// Target ID column name in the target table + #[arg(long)] + pub target_id_column: String, + + /// Target column name for embeddings array + #[arg(long)] + pub target_embedding_column: String, + + /// Number of iterations for matrix factorization + #[arg(long, short = 'i', default_value = "100")] + pub iterations: i32, + + /// Batch size for loading data + #[arg(long, default_value = "10000")] + pub batch_size: i32, + + /// Learning rate + #[arg(long, default_value = "0.01")] + pub learning_rate: f32, + + /// Number of factors for matrix factorization + #[arg(long, default_value = "8")] + pub factors: i32, + + /// Lambda for regularization + #[arg(long, default_value = "0.0")] + pub lambda1: f32, + + /// Lambda for regularization + #[arg(long, default_value = "0.1")] + pub lambda2: f32, + + /// Number of threads for matrix factorization (defaults to number of CPU cores) + #[arg(long, default_value_t = num_cpus::get() as i32)] + pub threads: i32, + + /// Number of bins to use for training + #[arg(long, default_value = "10")] + pub bins: i32, + + /// Maximum number of interactions to load (optional) + #[arg(long)] + pub max_interactions: Option, +} diff --git a/src/main.rs b/src/main.rs index f1aca85..f3d99bb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,75 +7,10 @@ use libmf::{Loss, Matrix, Model}; use log::info; use std::collections::HashSet; use std::env; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; -use tokio_postgres::NoTls; +use tokio_postgres::{types::Type, NoTls}; -static IMMEDIATE_EXIT: AtomicBool = AtomicBool::new(false); - -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -struct Args { - /// Source table name - #[arg(long)] - source_table: String, - - /// User ID column name in source table - #[arg(long)] - source_user_id_column: String, - - /// Item ID column name in source table - #[arg(long)] - source_item_id_column: String, - - /// Target table for item embeddings - #[arg(long)] - target_table: String, - - /// Target ID column name in the target table - #[arg(long)] - target_id_column: String, - - /// Target column name for embeddings array - #[arg(long)] - target_embedding_column: String, - - /// Number of iterations for matrix factorization - #[arg(long, short = 'i', default_value = "100")] - iterations: i32, - - /// Batch size for loading data - #[arg(long, default_value = "10000")] - batch_size: i32, - - /// Learning rate - #[arg(long, default_value = "0.01")] - learning_rate: f32, - - /// Number of factors for matrix factorization - #[arg(long, default_value = "8")] - factors: i32, - - /// Lambda for regularization - #[arg(long, default_value = "0.0")] - lambda1: f32, - - /// Lambda for regularization - #[arg(long, default_value = "0.1")] - lambda2: f32, - - /// Number of threads for matrix factorization (defaults to number of CPU cores) - #[arg(long, default_value_t = num_cpus::get() as i32)] - threads: i32, - - /// Number of bins to use for training - #[arg(long, default_value = "10")] - bins: i32, - - /// Maximum number of interactions to load (optional) - #[arg(long)] - max_interactions: Option, -} +mod args; +use args::Args; async fn create_pool() -> Result { let mut config = Config::new(); @@ -93,30 +28,90 @@ async fn create_pool() -> Result { Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?) } +async fn get_column_types( + client: &deadpool_postgres::Client, + table: &str, + columns: &[&str], +) -> Result> { + let column_list = columns.join("', '"); + let query = format!( + "SELECT data_type \ + FROM information_schema.columns \ + WHERE table_schema = 'public' \ + AND table_name = $1 \ + AND column_name IN ('{}') \ + ORDER BY column_name", + column_list + ); + + let rows = client.query(&query, &[&table]).await?; + let mut types = Vec::new(); + for row in rows { + let data_type: &str = row.get(0); + let pg_type = match data_type { + "integer" => Type::INT4, + "bigint" => Type::INT8, + _ => anyhow::bail!("Unsupported column type: {}", data_type), + }; + types.push(pg_type); + } + + if types.len() != columns.len() { + anyhow::bail!("Not all columns were found in the table"); + } + + Ok(types) +} + async fn load_data( client: &deadpool_postgres::Client, source_table: &str, source_user_id_column: &str, source_item_id_column: &str, + max_interactions: Option, ) -> Result> { + let types = get_column_types( + client, + source_table, + &[source_user_id_column, source_item_id_column], + ) + .await?; + let query = format!( - "COPY {table} ({user}, {item}) TO STDOUT (FORMAT text, DELIMITER '\t')", + "COPY {table} ({user}, {item}) TO STDOUT (FORMAT binary)", table = source_table, user = source_user_id_column, item = source_item_id_column, ); - let mut data = Vec::new(); - let mut copy_out = Box::pin(client.copy_out(&query).await?); + let copy_out = client.copy_out(&query).await?; + let mut stream = Box::pin(tokio_postgres::binary_copy::BinaryCopyOutStream::new( + copy_out, &types, + )); + let mut stream = stream.as_mut(); - while let Some(bytes) = copy_out.as_mut().next().await { - let bytes = bytes?; - let row = String::from_utf8(bytes.to_vec())?; - let parts: Vec<&str> = row.trim().split('\t').collect(); - if parts.len() == 2 { - let user_id: i32 = parts[0].parse()?; - let item_id: i32 = parts[1].parse()?; - data.push((user_id, item_id)); + let mut data = Vec::new(); + while let Some(row) = stream.as_mut().next().await { + let row = row?; + + // Handle both int4 and int8 types + let user_id = if types[0] == Type::INT4 { + row.try_get::(0)? + } else { + row.try_get::(0)? as i32 + }; + + let item_id = if types[1] == Type::INT4 { + row.try_get::(1)? + } else { + row.try_get::(1)? as i32 + }; + + data.push((user_id, item_id)); + if let Some(max) = max_interactions { + if data.len() >= max { + break; + } } } @@ -309,25 +304,13 @@ async fn main() -> Result<()> { let mut matrix = Matrix::new(); let mut item_ids = HashSet::new(); - // Set up graceful shutdown for data loading - let running = Arc::new(AtomicBool::new(true)); - let r = running.clone(); - ctrlc::set_handler(move || { - if IMMEDIATE_EXIT.load(Ordering::SeqCst) { - info!("Received interrupt signal, exiting immediately..."); - std::process::exit(1); - } else { - r.store(false, Ordering::SeqCst); - info!("Received interrupt signal, finishing current batch..."); - } - })?; - info!("Starting data loading..."); let data = load_data( &pool.get().await?, &args.source_table, &args.source_user_id_column, &args.source_item_id_column, + args.max_interactions, ) .await?; @@ -337,38 +320,20 @@ async fn main() -> Result<()> { } let data_len = data.len(); - // Check if we would exceed max_interactions - let total_rows = if let Some(max) = args.max_interactions { - let max = max.min(data_len); - info!( - "Loading {} interactions (limited by --max-interactions)", - max - ); + info!("Loaded {} total rows", data_len); - // Only process up to max_interactions - for (user_id, item_id) in data.into_iter().take(max) { - matrix.push(user_id, item_id, 1.0f32); - item_ids.insert(item_id); - } - max - } else { - // Process all data - for (user_id, item_id) in data { - matrix.push(user_id, item_id, 1.0f32); - item_ids.insert(item_id); - } - data_len - }; + // Process all data + for (user_id, item_id) in data { + matrix.push(user_id, item_id, 1.0f32); + item_ids.insert(item_id); + } info!( "Loaded {} total rows with {} unique items", - total_rows, + data_len, item_ids.len() ); - // Switch to immediate exit mode for training - IMMEDIATE_EXIT.store(true, Ordering::SeqCst); - let model = Model::params() .factors(args.factors) .lambda_p1(args.lambda1)