Add argument parsing for data loading configuration

- Introduced a new `args.rs` file to define command-line arguments for data loading parameters, including source and target table details, matrix factorization settings, and optional interaction limits.
- Refactored `main.rs` to utilize the new argument structure, enhancing code organization and readability.
- Removed the previous inline argument definitions, streamlining the main application logic.

These changes improve the configurability and maintainability of the data loading process.
This commit is contained in:
Dylan Knutson
2024-12-28 18:16:39 +00:00
parent 428ca89c92
commit c4e79a36f9
2 changed files with 147 additions and 117 deletions

65
src/args.rs Normal file
View File

@@ -0,0 +1,65 @@
use clap::Parser;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// Source table name
#[arg(long)]
pub source_table: String,
/// User ID column name in source table
#[arg(long)]
pub source_user_id_column: String,
/// Item ID column name in source table
#[arg(long)]
pub source_item_id_column: String,
/// Target table for item embeddings
#[arg(long)]
pub target_table: String,
/// Target ID column name in the target table
#[arg(long)]
pub target_id_column: String,
/// Target column name for embeddings array
#[arg(long)]
pub target_embedding_column: String,
/// Number of iterations for matrix factorization
#[arg(long, short = 'i', default_value = "100")]
pub iterations: i32,
/// Batch size for loading data
#[arg(long, default_value = "10000")]
pub batch_size: i32,
/// Learning rate
#[arg(long, default_value = "0.01")]
pub learning_rate: f32,
/// Number of factors for matrix factorization
#[arg(long, default_value = "8")]
pub factors: i32,
/// Lambda for regularization
#[arg(long, default_value = "0.0")]
pub lambda1: f32,
/// Lambda for regularization
#[arg(long, default_value = "0.1")]
pub lambda2: f32,
/// Number of threads for matrix factorization (defaults to number of CPU cores)
#[arg(long, default_value_t = num_cpus::get() as i32)]
pub threads: i32,
/// Number of bins to use for training
#[arg(long, default_value = "10")]
pub bins: i32,
/// Maximum number of interactions to load (optional)
#[arg(long)]
pub max_interactions: Option<usize>,
}

View File

@@ -7,75 +7,10 @@ use libmf::{Loss, Matrix, Model};
use log::info;
use std::collections::HashSet;
use std::env;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio_postgres::NoTls;
use tokio_postgres::{types::Type, NoTls};
static IMMEDIATE_EXIT: AtomicBool = AtomicBool::new(false);
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Source table name
#[arg(long)]
source_table: String,
/// User ID column name in source table
#[arg(long)]
source_user_id_column: String,
/// Item ID column name in source table
#[arg(long)]
source_item_id_column: String,
/// Target table for item embeddings
#[arg(long)]
target_table: String,
/// Target ID column name in the target table
#[arg(long)]
target_id_column: String,
/// Target column name for embeddings array
#[arg(long)]
target_embedding_column: String,
/// Number of iterations for matrix factorization
#[arg(long, short = 'i', default_value = "100")]
iterations: i32,
/// Batch size for loading data
#[arg(long, default_value = "10000")]
batch_size: i32,
/// Learning rate
#[arg(long, default_value = "0.01")]
learning_rate: f32,
/// Number of factors for matrix factorization
#[arg(long, default_value = "8")]
factors: i32,
/// Lambda for regularization
#[arg(long, default_value = "0.0")]
lambda1: f32,
/// Lambda for regularization
#[arg(long, default_value = "0.1")]
lambda2: f32,
/// Number of threads for matrix factorization (defaults to number of CPU cores)
#[arg(long, default_value_t = num_cpus::get() as i32)]
threads: i32,
/// Number of bins to use for training
#[arg(long, default_value = "10")]
bins: i32,
/// Maximum number of interactions to load (optional)
#[arg(long)]
max_interactions: Option<usize>,
}
mod args;
use args::Args;
async fn create_pool() -> Result<Pool> {
let mut config = Config::new();
@@ -93,30 +28,90 @@ async fn create_pool() -> Result<Pool> {
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
}
async fn get_column_types(
client: &deadpool_postgres::Client,
table: &str,
columns: &[&str],
) -> Result<Vec<Type>> {
let column_list = columns.join("', '");
let query = format!(
"SELECT data_type \
FROM information_schema.columns \
WHERE table_schema = 'public' \
AND table_name = $1 \
AND column_name IN ('{}') \
ORDER BY column_name",
column_list
);
let rows = client.query(&query, &[&table]).await?;
let mut types = Vec::new();
for row in rows {
let data_type: &str = row.get(0);
let pg_type = match data_type {
"integer" => Type::INT4,
"bigint" => Type::INT8,
_ => anyhow::bail!("Unsupported column type: {}", data_type),
};
types.push(pg_type);
}
if types.len() != columns.len() {
anyhow::bail!("Not all columns were found in the table");
}
Ok(types)
}
async fn load_data(
client: &deadpool_postgres::Client,
source_table: &str,
source_user_id_column: &str,
source_item_id_column: &str,
max_interactions: Option<usize>,
) -> Result<Vec<(i32, i32)>> {
let types = get_column_types(
client,
source_table,
&[source_user_id_column, source_item_id_column],
)
.await?;
let query = format!(
"COPY {table} ({user}, {item}) TO STDOUT (FORMAT text, DELIMITER '\t')",
"COPY {table} ({user}, {item}) TO STDOUT (FORMAT binary)",
table = source_table,
user = source_user_id_column,
item = source_item_id_column,
);
let mut data = Vec::new();
let mut copy_out = Box::pin(client.copy_out(&query).await?);
let copy_out = client.copy_out(&query).await?;
let mut stream = Box::pin(tokio_postgres::binary_copy::BinaryCopyOutStream::new(
copy_out, &types,
));
let mut stream = stream.as_mut();
while let Some(bytes) = copy_out.as_mut().next().await {
let bytes = bytes?;
let row = String::from_utf8(bytes.to_vec())?;
let parts: Vec<&str> = row.trim().split('\t').collect();
if parts.len() == 2 {
let user_id: i32 = parts[0].parse()?;
let item_id: i32 = parts[1].parse()?;
data.push((user_id, item_id));
let mut data = Vec::new();
while let Some(row) = stream.as_mut().next().await {
let row = row?;
// Handle both int4 and int8 types
let user_id = if types[0] == Type::INT4 {
row.try_get::<i32>(0)?
} else {
row.try_get::<i64>(0)? as i32
};
let item_id = if types[1] == Type::INT4 {
row.try_get::<i32>(1)?
} else {
row.try_get::<i64>(1)? as i32
};
data.push((user_id, item_id));
if let Some(max) = max_interactions {
if data.len() >= max {
break;
}
}
}
@@ -309,25 +304,13 @@ async fn main() -> Result<()> {
let mut matrix = Matrix::new();
let mut item_ids = HashSet::new();
// Set up graceful shutdown for data loading
let running = Arc::new(AtomicBool::new(true));
let r = running.clone();
ctrlc::set_handler(move || {
if IMMEDIATE_EXIT.load(Ordering::SeqCst) {
info!("Received interrupt signal, exiting immediately...");
std::process::exit(1);
} else {
r.store(false, Ordering::SeqCst);
info!("Received interrupt signal, finishing current batch...");
}
})?;
info!("Starting data loading...");
let data = load_data(
&pool.get().await?,
&args.source_table,
&args.source_user_id_column,
&args.source_item_id_column,
args.max_interactions,
)
.await?;
@@ -337,38 +320,20 @@ async fn main() -> Result<()> {
}
let data_len = data.len();
// Check if we would exceed max_interactions
let total_rows = if let Some(max) = args.max_interactions {
let max = max.min(data_len);
info!(
"Loading {} interactions (limited by --max-interactions)",
max
);
info!("Loaded {} total rows", data_len);
// Only process up to max_interactions
for (user_id, item_id) in data.into_iter().take(max) {
matrix.push(user_id, item_id, 1.0f32);
item_ids.insert(item_id);
}
max
} else {
// Process all data
for (user_id, item_id) in data {
matrix.push(user_id, item_id, 1.0f32);
item_ids.insert(item_id);
}
data_len
};
// Process all data
for (user_id, item_id) in data {
matrix.push(user_id, item_id, 1.0f32);
item_ids.insert(item_id);
}
info!(
"Loaded {} total rows with {} unique items",
total_rows,
data_len,
item_ids.len()
);
// Switch to immediate exit mode for training
IMMEDIATE_EXIT.store(true, Ordering::SeqCst);
let model = Model::params()
.factors(args.factors)
.lambda_p1(args.lambda1)