Add argument parsing for data loading configuration
- Introduced a new `args.rs` file to define command-line arguments for data loading parameters, including source and target table details, matrix factorization settings, and optional interaction limits. - Refactored `main.rs` to utilize the new argument structure, enhancing code organization and readability. - Removed the previous inline argument definitions, streamlining the main application logic. These changes improve the configurability and maintainability of the data loading process.
This commit is contained in:
65
src/args.rs
Normal file
65
src/args.rs
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
pub struct Args {
|
||||||
|
/// Source table name
|
||||||
|
#[arg(long)]
|
||||||
|
pub source_table: String,
|
||||||
|
|
||||||
|
/// User ID column name in source table
|
||||||
|
#[arg(long)]
|
||||||
|
pub source_user_id_column: String,
|
||||||
|
|
||||||
|
/// Item ID column name in source table
|
||||||
|
#[arg(long)]
|
||||||
|
pub source_item_id_column: String,
|
||||||
|
|
||||||
|
/// Target table for item embeddings
|
||||||
|
#[arg(long)]
|
||||||
|
pub target_table: String,
|
||||||
|
|
||||||
|
/// Target ID column name in the target table
|
||||||
|
#[arg(long)]
|
||||||
|
pub target_id_column: String,
|
||||||
|
|
||||||
|
/// Target column name for embeddings array
|
||||||
|
#[arg(long)]
|
||||||
|
pub target_embedding_column: String,
|
||||||
|
|
||||||
|
/// Number of iterations for matrix factorization
|
||||||
|
#[arg(long, short = 'i', default_value = "100")]
|
||||||
|
pub iterations: i32,
|
||||||
|
|
||||||
|
/// Batch size for loading data
|
||||||
|
#[arg(long, default_value = "10000")]
|
||||||
|
pub batch_size: i32,
|
||||||
|
|
||||||
|
/// Learning rate
|
||||||
|
#[arg(long, default_value = "0.01")]
|
||||||
|
pub learning_rate: f32,
|
||||||
|
|
||||||
|
/// Number of factors for matrix factorization
|
||||||
|
#[arg(long, default_value = "8")]
|
||||||
|
pub factors: i32,
|
||||||
|
|
||||||
|
/// Lambda for regularization
|
||||||
|
#[arg(long, default_value = "0.0")]
|
||||||
|
pub lambda1: f32,
|
||||||
|
|
||||||
|
/// Lambda for regularization
|
||||||
|
#[arg(long, default_value = "0.1")]
|
||||||
|
pub lambda2: f32,
|
||||||
|
|
||||||
|
/// Number of threads for matrix factorization (defaults to number of CPU cores)
|
||||||
|
#[arg(long, default_value_t = num_cpus::get() as i32)]
|
||||||
|
pub threads: i32,
|
||||||
|
|
||||||
|
/// Number of bins to use for training
|
||||||
|
#[arg(long, default_value = "10")]
|
||||||
|
pub bins: i32,
|
||||||
|
|
||||||
|
/// Maximum number of interactions to load (optional)
|
||||||
|
#[arg(long)]
|
||||||
|
pub max_interactions: Option<usize>,
|
||||||
|
}
|
||||||
199
src/main.rs
199
src/main.rs
@@ -7,75 +7,10 @@ use libmf::{Loss, Matrix, Model};
|
|||||||
use log::info;
|
use log::info;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use tokio_postgres::{types::Type, NoTls};
|
||||||
use std::sync::Arc;
|
|
||||||
use tokio_postgres::NoTls;
|
|
||||||
|
|
||||||
static IMMEDIATE_EXIT: AtomicBool = AtomicBool::new(false);
|
mod args;
|
||||||
|
use args::Args;
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Source table name
|
|
||||||
#[arg(long)]
|
|
||||||
source_table: String,
|
|
||||||
|
|
||||||
/// User ID column name in source table
|
|
||||||
#[arg(long)]
|
|
||||||
source_user_id_column: String,
|
|
||||||
|
|
||||||
/// Item ID column name in source table
|
|
||||||
#[arg(long)]
|
|
||||||
source_item_id_column: String,
|
|
||||||
|
|
||||||
/// Target table for item embeddings
|
|
||||||
#[arg(long)]
|
|
||||||
target_table: String,
|
|
||||||
|
|
||||||
/// Target ID column name in the target table
|
|
||||||
#[arg(long)]
|
|
||||||
target_id_column: String,
|
|
||||||
|
|
||||||
/// Target column name for embeddings array
|
|
||||||
#[arg(long)]
|
|
||||||
target_embedding_column: String,
|
|
||||||
|
|
||||||
/// Number of iterations for matrix factorization
|
|
||||||
#[arg(long, short = 'i', default_value = "100")]
|
|
||||||
iterations: i32,
|
|
||||||
|
|
||||||
/// Batch size for loading data
|
|
||||||
#[arg(long, default_value = "10000")]
|
|
||||||
batch_size: i32,
|
|
||||||
|
|
||||||
/// Learning rate
|
|
||||||
#[arg(long, default_value = "0.01")]
|
|
||||||
learning_rate: f32,
|
|
||||||
|
|
||||||
/// Number of factors for matrix factorization
|
|
||||||
#[arg(long, default_value = "8")]
|
|
||||||
factors: i32,
|
|
||||||
|
|
||||||
/// Lambda for regularization
|
|
||||||
#[arg(long, default_value = "0.0")]
|
|
||||||
lambda1: f32,
|
|
||||||
|
|
||||||
/// Lambda for regularization
|
|
||||||
#[arg(long, default_value = "0.1")]
|
|
||||||
lambda2: f32,
|
|
||||||
|
|
||||||
/// Number of threads for matrix factorization (defaults to number of CPU cores)
|
|
||||||
#[arg(long, default_value_t = num_cpus::get() as i32)]
|
|
||||||
threads: i32,
|
|
||||||
|
|
||||||
/// Number of bins to use for training
|
|
||||||
#[arg(long, default_value = "10")]
|
|
||||||
bins: i32,
|
|
||||||
|
|
||||||
/// Maximum number of interactions to load (optional)
|
|
||||||
#[arg(long)]
|
|
||||||
max_interactions: Option<usize>,
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn create_pool() -> Result<Pool> {
|
async fn create_pool() -> Result<Pool> {
|
||||||
let mut config = Config::new();
|
let mut config = Config::new();
|
||||||
@@ -93,30 +28,90 @@ async fn create_pool() -> Result<Pool> {
|
|||||||
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_column_types(
|
||||||
|
client: &deadpool_postgres::Client,
|
||||||
|
table: &str,
|
||||||
|
columns: &[&str],
|
||||||
|
) -> Result<Vec<Type>> {
|
||||||
|
let column_list = columns.join("', '");
|
||||||
|
let query = format!(
|
||||||
|
"SELECT data_type \
|
||||||
|
FROM information_schema.columns \
|
||||||
|
WHERE table_schema = 'public' \
|
||||||
|
AND table_name = $1 \
|
||||||
|
AND column_name IN ('{}') \
|
||||||
|
ORDER BY column_name",
|
||||||
|
column_list
|
||||||
|
);
|
||||||
|
|
||||||
|
let rows = client.query(&query, &[&table]).await?;
|
||||||
|
let mut types = Vec::new();
|
||||||
|
for row in rows {
|
||||||
|
let data_type: &str = row.get(0);
|
||||||
|
let pg_type = match data_type {
|
||||||
|
"integer" => Type::INT4,
|
||||||
|
"bigint" => Type::INT8,
|
||||||
|
_ => anyhow::bail!("Unsupported column type: {}", data_type),
|
||||||
|
};
|
||||||
|
types.push(pg_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
if types.len() != columns.len() {
|
||||||
|
anyhow::bail!("Not all columns were found in the table");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(types)
|
||||||
|
}
|
||||||
|
|
||||||
async fn load_data(
|
async fn load_data(
|
||||||
client: &deadpool_postgres::Client,
|
client: &deadpool_postgres::Client,
|
||||||
source_table: &str,
|
source_table: &str,
|
||||||
source_user_id_column: &str,
|
source_user_id_column: &str,
|
||||||
source_item_id_column: &str,
|
source_item_id_column: &str,
|
||||||
|
max_interactions: Option<usize>,
|
||||||
) -> Result<Vec<(i32, i32)>> {
|
) -> Result<Vec<(i32, i32)>> {
|
||||||
|
let types = get_column_types(
|
||||||
|
client,
|
||||||
|
source_table,
|
||||||
|
&[source_user_id_column, source_item_id_column],
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let query = format!(
|
let query = format!(
|
||||||
"COPY {table} ({user}, {item}) TO STDOUT (FORMAT text, DELIMITER '\t')",
|
"COPY {table} ({user}, {item}) TO STDOUT (FORMAT binary)",
|
||||||
table = source_table,
|
table = source_table,
|
||||||
user = source_user_id_column,
|
user = source_user_id_column,
|
||||||
item = source_item_id_column,
|
item = source_item_id_column,
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut data = Vec::new();
|
let copy_out = client.copy_out(&query).await?;
|
||||||
let mut copy_out = Box::pin(client.copy_out(&query).await?);
|
let mut stream = Box::pin(tokio_postgres::binary_copy::BinaryCopyOutStream::new(
|
||||||
|
copy_out, &types,
|
||||||
|
));
|
||||||
|
let mut stream = stream.as_mut();
|
||||||
|
|
||||||
while let Some(bytes) = copy_out.as_mut().next().await {
|
let mut data = Vec::new();
|
||||||
let bytes = bytes?;
|
while let Some(row) = stream.as_mut().next().await {
|
||||||
let row = String::from_utf8(bytes.to_vec())?;
|
let row = row?;
|
||||||
let parts: Vec<&str> = row.trim().split('\t').collect();
|
|
||||||
if parts.len() == 2 {
|
// Handle both int4 and int8 types
|
||||||
let user_id: i32 = parts[0].parse()?;
|
let user_id = if types[0] == Type::INT4 {
|
||||||
let item_id: i32 = parts[1].parse()?;
|
row.try_get::<i32>(0)?
|
||||||
data.push((user_id, item_id));
|
} else {
|
||||||
|
row.try_get::<i64>(0)? as i32
|
||||||
|
};
|
||||||
|
|
||||||
|
let item_id = if types[1] == Type::INT4 {
|
||||||
|
row.try_get::<i32>(1)?
|
||||||
|
} else {
|
||||||
|
row.try_get::<i64>(1)? as i32
|
||||||
|
};
|
||||||
|
|
||||||
|
data.push((user_id, item_id));
|
||||||
|
if let Some(max) = max_interactions {
|
||||||
|
if data.len() >= max {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -309,25 +304,13 @@ async fn main() -> Result<()> {
|
|||||||
let mut matrix = Matrix::new();
|
let mut matrix = Matrix::new();
|
||||||
let mut item_ids = HashSet::new();
|
let mut item_ids = HashSet::new();
|
||||||
|
|
||||||
// Set up graceful shutdown for data loading
|
|
||||||
let running = Arc::new(AtomicBool::new(true));
|
|
||||||
let r = running.clone();
|
|
||||||
ctrlc::set_handler(move || {
|
|
||||||
if IMMEDIATE_EXIT.load(Ordering::SeqCst) {
|
|
||||||
info!("Received interrupt signal, exiting immediately...");
|
|
||||||
std::process::exit(1);
|
|
||||||
} else {
|
|
||||||
r.store(false, Ordering::SeqCst);
|
|
||||||
info!("Received interrupt signal, finishing current batch...");
|
|
||||||
}
|
|
||||||
})?;
|
|
||||||
|
|
||||||
info!("Starting data loading...");
|
info!("Starting data loading...");
|
||||||
let data = load_data(
|
let data = load_data(
|
||||||
&pool.get().await?,
|
&pool.get().await?,
|
||||||
&args.source_table,
|
&args.source_table,
|
||||||
&args.source_user_id_column,
|
&args.source_user_id_column,
|
||||||
&args.source_item_id_column,
|
&args.source_item_id_column,
|
||||||
|
args.max_interactions,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
@@ -337,38 +320,20 @@ async fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let data_len = data.len();
|
let data_len = data.len();
|
||||||
// Check if we would exceed max_interactions
|
info!("Loaded {} total rows", data_len);
|
||||||
let total_rows = if let Some(max) = args.max_interactions {
|
|
||||||
let max = max.min(data_len);
|
|
||||||
info!(
|
|
||||||
"Loading {} interactions (limited by --max-interactions)",
|
|
||||||
max
|
|
||||||
);
|
|
||||||
|
|
||||||
// Only process up to max_interactions
|
// Process all data
|
||||||
for (user_id, item_id) in data.into_iter().take(max) {
|
for (user_id, item_id) in data {
|
||||||
matrix.push(user_id, item_id, 1.0f32);
|
matrix.push(user_id, item_id, 1.0f32);
|
||||||
item_ids.insert(item_id);
|
item_ids.insert(item_id);
|
||||||
}
|
}
|
||||||
max
|
|
||||||
} else {
|
|
||||||
// Process all data
|
|
||||||
for (user_id, item_id) in data {
|
|
||||||
matrix.push(user_id, item_id, 1.0f32);
|
|
||||||
item_ids.insert(item_id);
|
|
||||||
}
|
|
||||||
data_len
|
|
||||||
};
|
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Loaded {} total rows with {} unique items",
|
"Loaded {} total rows with {} unique items",
|
||||||
total_rows,
|
data_len,
|
||||||
item_ids.len()
|
item_ids.len()
|
||||||
);
|
);
|
||||||
|
|
||||||
// Switch to immediate exit mode for training
|
|
||||||
IMMEDIATE_EXIT.store(true, Ordering::SeqCst);
|
|
||||||
|
|
||||||
let model = Model::params()
|
let model = Model::params()
|
||||||
.factors(args.factors)
|
.factors(args.factors)
|
||||||
.lambda_p1(args.lambda1)
|
.lambda_p1(args.lambda1)
|
||||||
|
|||||||
Reference in New Issue
Block a user