Add argument parsing for data loading configuration
- Introduced a new `args.rs` file to define command-line arguments for data loading parameters, including source and target table details, matrix factorization settings, and optional interaction limits. - Refactored `main.rs` to utilize the new argument structure, enhancing code organization and readability. - Removed the previous inline argument definitions, streamlining the main application logic. These changes improve the configurability and maintainability of the data loading process.
This commit is contained in:
65
src/args.rs
Normal file
65
src/args.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use clap::Parser;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
/// Source table name
|
||||
#[arg(long)]
|
||||
pub source_table: String,
|
||||
|
||||
/// User ID column name in source table
|
||||
#[arg(long)]
|
||||
pub source_user_id_column: String,
|
||||
|
||||
/// Item ID column name in source table
|
||||
#[arg(long)]
|
||||
pub source_item_id_column: String,
|
||||
|
||||
/// Target table for item embeddings
|
||||
#[arg(long)]
|
||||
pub target_table: String,
|
||||
|
||||
/// Target ID column name in the target table
|
||||
#[arg(long)]
|
||||
pub target_id_column: String,
|
||||
|
||||
/// Target column name for embeddings array
|
||||
#[arg(long)]
|
||||
pub target_embedding_column: String,
|
||||
|
||||
/// Number of iterations for matrix factorization
|
||||
#[arg(long, short = 'i', default_value = "100")]
|
||||
pub iterations: i32,
|
||||
|
||||
/// Batch size for loading data
|
||||
#[arg(long, default_value = "10000")]
|
||||
pub batch_size: i32,
|
||||
|
||||
/// Learning rate
|
||||
#[arg(long, default_value = "0.01")]
|
||||
pub learning_rate: f32,
|
||||
|
||||
/// Number of factors for matrix factorization
|
||||
#[arg(long, default_value = "8")]
|
||||
pub factors: i32,
|
||||
|
||||
/// Lambda for regularization
|
||||
#[arg(long, default_value = "0.0")]
|
||||
pub lambda1: f32,
|
||||
|
||||
/// Lambda for regularization
|
||||
#[arg(long, default_value = "0.1")]
|
||||
pub lambda2: f32,
|
||||
|
||||
/// Number of threads for matrix factorization (defaults to number of CPU cores)
|
||||
#[arg(long, default_value_t = num_cpus::get() as i32)]
|
||||
pub threads: i32,
|
||||
|
||||
/// Number of bins to use for training
|
||||
#[arg(long, default_value = "10")]
|
||||
pub bins: i32,
|
||||
|
||||
/// Maximum number of interactions to load (optional)
|
||||
#[arg(long)]
|
||||
pub max_interactions: Option<usize>,
|
||||
}
|
||||
199
src/main.rs
199
src/main.rs
@@ -7,75 +7,10 @@ use libmf::{Loss, Matrix, Model};
|
||||
use log::info;
|
||||
use std::collections::HashSet;
|
||||
use std::env;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio_postgres::NoTls;
|
||||
use tokio_postgres::{types::Type, NoTls};
|
||||
|
||||
static IMMEDIATE_EXIT: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Source table name
|
||||
#[arg(long)]
|
||||
source_table: String,
|
||||
|
||||
/// User ID column name in source table
|
||||
#[arg(long)]
|
||||
source_user_id_column: String,
|
||||
|
||||
/// Item ID column name in source table
|
||||
#[arg(long)]
|
||||
source_item_id_column: String,
|
||||
|
||||
/// Target table for item embeddings
|
||||
#[arg(long)]
|
||||
target_table: String,
|
||||
|
||||
/// Target ID column name in the target table
|
||||
#[arg(long)]
|
||||
target_id_column: String,
|
||||
|
||||
/// Target column name for embeddings array
|
||||
#[arg(long)]
|
||||
target_embedding_column: String,
|
||||
|
||||
/// Number of iterations for matrix factorization
|
||||
#[arg(long, short = 'i', default_value = "100")]
|
||||
iterations: i32,
|
||||
|
||||
/// Batch size for loading data
|
||||
#[arg(long, default_value = "10000")]
|
||||
batch_size: i32,
|
||||
|
||||
/// Learning rate
|
||||
#[arg(long, default_value = "0.01")]
|
||||
learning_rate: f32,
|
||||
|
||||
/// Number of factors for matrix factorization
|
||||
#[arg(long, default_value = "8")]
|
||||
factors: i32,
|
||||
|
||||
/// Lambda for regularization
|
||||
#[arg(long, default_value = "0.0")]
|
||||
lambda1: f32,
|
||||
|
||||
/// Lambda for regularization
|
||||
#[arg(long, default_value = "0.1")]
|
||||
lambda2: f32,
|
||||
|
||||
/// Number of threads for matrix factorization (defaults to number of CPU cores)
|
||||
#[arg(long, default_value_t = num_cpus::get() as i32)]
|
||||
threads: i32,
|
||||
|
||||
/// Number of bins to use for training
|
||||
#[arg(long, default_value = "10")]
|
||||
bins: i32,
|
||||
|
||||
/// Maximum number of interactions to load (optional)
|
||||
#[arg(long)]
|
||||
max_interactions: Option<usize>,
|
||||
}
|
||||
mod args;
|
||||
use args::Args;
|
||||
|
||||
async fn create_pool() -> Result<Pool> {
|
||||
let mut config = Config::new();
|
||||
@@ -93,30 +28,90 @@ async fn create_pool() -> Result<Pool> {
|
||||
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
||||
}
|
||||
|
||||
async fn get_column_types(
|
||||
client: &deadpool_postgres::Client,
|
||||
table: &str,
|
||||
columns: &[&str],
|
||||
) -> Result<Vec<Type>> {
|
||||
let column_list = columns.join("', '");
|
||||
let query = format!(
|
||||
"SELECT data_type \
|
||||
FROM information_schema.columns \
|
||||
WHERE table_schema = 'public' \
|
||||
AND table_name = $1 \
|
||||
AND column_name IN ('{}') \
|
||||
ORDER BY column_name",
|
||||
column_list
|
||||
);
|
||||
|
||||
let rows = client.query(&query, &[&table]).await?;
|
||||
let mut types = Vec::new();
|
||||
for row in rows {
|
||||
let data_type: &str = row.get(0);
|
||||
let pg_type = match data_type {
|
||||
"integer" => Type::INT4,
|
||||
"bigint" => Type::INT8,
|
||||
_ => anyhow::bail!("Unsupported column type: {}", data_type),
|
||||
};
|
||||
types.push(pg_type);
|
||||
}
|
||||
|
||||
if types.len() != columns.len() {
|
||||
anyhow::bail!("Not all columns were found in the table");
|
||||
}
|
||||
|
||||
Ok(types)
|
||||
}
|
||||
|
||||
async fn load_data(
|
||||
client: &deadpool_postgres::Client,
|
||||
source_table: &str,
|
||||
source_user_id_column: &str,
|
||||
source_item_id_column: &str,
|
||||
max_interactions: Option<usize>,
|
||||
) -> Result<Vec<(i32, i32)>> {
|
||||
let types = get_column_types(
|
||||
client,
|
||||
source_table,
|
||||
&[source_user_id_column, source_item_id_column],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let query = format!(
|
||||
"COPY {table} ({user}, {item}) TO STDOUT (FORMAT text, DELIMITER '\t')",
|
||||
"COPY {table} ({user}, {item}) TO STDOUT (FORMAT binary)",
|
||||
table = source_table,
|
||||
user = source_user_id_column,
|
||||
item = source_item_id_column,
|
||||
);
|
||||
|
||||
let mut data = Vec::new();
|
||||
let mut copy_out = Box::pin(client.copy_out(&query).await?);
|
||||
let copy_out = client.copy_out(&query).await?;
|
||||
let mut stream = Box::pin(tokio_postgres::binary_copy::BinaryCopyOutStream::new(
|
||||
copy_out, &types,
|
||||
));
|
||||
let mut stream = stream.as_mut();
|
||||
|
||||
while let Some(bytes) = copy_out.as_mut().next().await {
|
||||
let bytes = bytes?;
|
||||
let row = String::from_utf8(bytes.to_vec())?;
|
||||
let parts: Vec<&str> = row.trim().split('\t').collect();
|
||||
if parts.len() == 2 {
|
||||
let user_id: i32 = parts[0].parse()?;
|
||||
let item_id: i32 = parts[1].parse()?;
|
||||
data.push((user_id, item_id));
|
||||
let mut data = Vec::new();
|
||||
while let Some(row) = stream.as_mut().next().await {
|
||||
let row = row?;
|
||||
|
||||
// Handle both int4 and int8 types
|
||||
let user_id = if types[0] == Type::INT4 {
|
||||
row.try_get::<i32>(0)?
|
||||
} else {
|
||||
row.try_get::<i64>(0)? as i32
|
||||
};
|
||||
|
||||
let item_id = if types[1] == Type::INT4 {
|
||||
row.try_get::<i32>(1)?
|
||||
} else {
|
||||
row.try_get::<i64>(1)? as i32
|
||||
};
|
||||
|
||||
data.push((user_id, item_id));
|
||||
if let Some(max) = max_interactions {
|
||||
if data.len() >= max {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -309,25 +304,13 @@ async fn main() -> Result<()> {
|
||||
let mut matrix = Matrix::new();
|
||||
let mut item_ids = HashSet::new();
|
||||
|
||||
// Set up graceful shutdown for data loading
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
let r = running.clone();
|
||||
ctrlc::set_handler(move || {
|
||||
if IMMEDIATE_EXIT.load(Ordering::SeqCst) {
|
||||
info!("Received interrupt signal, exiting immediately...");
|
||||
std::process::exit(1);
|
||||
} else {
|
||||
r.store(false, Ordering::SeqCst);
|
||||
info!("Received interrupt signal, finishing current batch...");
|
||||
}
|
||||
})?;
|
||||
|
||||
info!("Starting data loading...");
|
||||
let data = load_data(
|
||||
&pool.get().await?,
|
||||
&args.source_table,
|
||||
&args.source_user_id_column,
|
||||
&args.source_item_id_column,
|
||||
args.max_interactions,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -337,38 +320,20 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
let data_len = data.len();
|
||||
// Check if we would exceed max_interactions
|
||||
let total_rows = if let Some(max) = args.max_interactions {
|
||||
let max = max.min(data_len);
|
||||
info!(
|
||||
"Loading {} interactions (limited by --max-interactions)",
|
||||
max
|
||||
);
|
||||
info!("Loaded {} total rows", data_len);
|
||||
|
||||
// Only process up to max_interactions
|
||||
for (user_id, item_id) in data.into_iter().take(max) {
|
||||
matrix.push(user_id, item_id, 1.0f32);
|
||||
item_ids.insert(item_id);
|
||||
}
|
||||
max
|
||||
} else {
|
||||
// Process all data
|
||||
for (user_id, item_id) in data {
|
||||
matrix.push(user_id, item_id, 1.0f32);
|
||||
item_ids.insert(item_id);
|
||||
}
|
||||
data_len
|
||||
};
|
||||
// Process all data
|
||||
for (user_id, item_id) in data {
|
||||
matrix.push(user_id, item_id, 1.0f32);
|
||||
item_ids.insert(item_id);
|
||||
}
|
||||
|
||||
info!(
|
||||
"Loaded {} total rows with {} unique items",
|
||||
total_rows,
|
||||
data_len,
|
||||
item_ids.len()
|
||||
);
|
||||
|
||||
// Switch to immediate exit mode for training
|
||||
IMMEDIATE_EXIT.store(true, Ordering::SeqCst);
|
||||
|
||||
let model = Model::params()
|
||||
.factors(args.factors)
|
||||
.lambda_p1(args.lambda1)
|
||||
|
||||
Reference in New Issue
Block a user