Refactor data loading and embedding saving process
- Updated `.cargo/config.toml` to optimize compilation flags for performance. - Enhanced `main.rs` by: - Renaming user and item ID columns for clarity. - Adding validation functions to ensure the existence of tables and columns in the database schema. - Implementing immediate exit handling during data loading. - Modifying the `save_embeddings` function to accept item IDs for processing. - Improving error handling with context messages for database operations. These changes improve code readability, robustness, and performance during data processing.
This commit is contained in:
@@ -9,7 +9,7 @@ rustflags = [
|
||||
]
|
||||
|
||||
[env]
|
||||
CXXFLAGS = "-fopenmp -pthread -DUSEOMP=1"
|
||||
LDFLAGS = "-fopenmp -pthread -DUSEOMP=1"
|
||||
CXXFLAGS = "-O3 -march=native -fopenmp -pthread -DUSEOMP=1"
|
||||
LDFLAGS = "-O3 -march=native -fopenmp -pthread -DUSEOMP=1 -flto"
|
||||
CC = "gcc"
|
||||
CXX = "g++"
|
||||
|
||||
228
src/main.rs
228
src/main.rs
@@ -4,11 +4,14 @@ use deadpool_postgres::{Config, Pool, Runtime};
|
||||
use dotenv::dotenv;
|
||||
use libmf::{Loss, Matrix, Model};
|
||||
use log::info;
|
||||
use std::collections::HashSet;
|
||||
use std::env;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio_postgres::NoTls;
|
||||
|
||||
static IMMEDIATE_EXIT: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@@ -16,18 +19,26 @@ struct Args {
|
||||
#[arg(long)]
|
||||
source_table: String,
|
||||
|
||||
/// User ID column name
|
||||
/// User ID column name in source table
|
||||
#[arg(long)]
|
||||
user_id_column: String,
|
||||
source_user_id_column: String,
|
||||
|
||||
/// Item ID column name
|
||||
/// Item ID column name in source table
|
||||
#[arg(long)]
|
||||
item_id_column: String,
|
||||
source_item_id_column: String,
|
||||
|
||||
/// Target table for item embeddings
|
||||
#[arg(long)]
|
||||
target_table: String,
|
||||
|
||||
/// Target ID column name in the target table
|
||||
#[arg(long)]
|
||||
target_id_column: String,
|
||||
|
||||
/// Target column name for embeddings array
|
||||
#[arg(long)]
|
||||
target_embedding_column: String,
|
||||
|
||||
/// Number of iterations for matrix factorization
|
||||
#[arg(long, short = 'i', default_value = "100")]
|
||||
iterations: i32,
|
||||
@@ -80,35 +91,39 @@ async fn create_pool() -> Result<Pool> {
|
||||
async fn load_data_batch(
|
||||
client: &deadpool_postgres::Client,
|
||||
source_table: &str,
|
||||
user_id_column: &str,
|
||||
item_id_column: &str,
|
||||
source_user_id_column: &str,
|
||||
source_item_id_column: &str,
|
||||
batch_size: usize,
|
||||
last_user_id: Option<i32>,
|
||||
last_item_id: Option<i32>,
|
||||
) -> Result<Vec<(i32, i32)>> {
|
||||
let rows = if let (Some(last_user), Some(last_item)) = (last_user_id, last_item_id) {
|
||||
let query = format!(
|
||||
"SELECT {user}, {item} FROM {table} \
|
||||
WHERE ({user}, {item}) > ($1, $2) \
|
||||
"SELECT ({user})::int4, ({item})::int4 FROM {table} \
|
||||
WHERE ({user}, {item}) > ($1::int4, $2::int4) \
|
||||
ORDER BY {user}, {item} \
|
||||
LIMIT $3",
|
||||
user = user_id_column,
|
||||
item = item_id_column,
|
||||
LIMIT $3::bigint",
|
||||
user = source_user_id_column,
|
||||
item = source_item_id_column,
|
||||
table = source_table,
|
||||
);
|
||||
client
|
||||
.query(&query, &[&last_user, &last_item, &(batch_size as i64)])
|
||||
.await?
|
||||
.await
|
||||
.with_context(|| format!("Query failed at {}:{}", file!(), line!()))?
|
||||
} else {
|
||||
let query = format!(
|
||||
"SELECT {user}, {item} FROM {table} \
|
||||
"SELECT ({user})::int4, ({item})::int4 FROM {table} \
|
||||
ORDER BY {user}, {item} \
|
||||
LIMIT $1",
|
||||
user = user_id_column,
|
||||
item = item_id_column,
|
||||
LIMIT $1::bigint",
|
||||
user = source_user_id_column,
|
||||
item = source_item_id_column,
|
||||
table = source_table,
|
||||
);
|
||||
client.query(&query, &[&(batch_size as i64)]).await?
|
||||
client
|
||||
.query(&query, &[&(batch_size as i64)])
|
||||
.await
|
||||
.with_context(|| format!("Query failed at {}:{}", file!(), line!()))?
|
||||
};
|
||||
|
||||
let mut batch = Vec::with_capacity(rows.len());
|
||||
@@ -121,45 +136,59 @@ async fn load_data_batch(
|
||||
Ok(batch)
|
||||
}
|
||||
|
||||
async fn save_embeddings(pool: &Pool, args: &Args, model: &Model) -> Result<()> {
|
||||
async fn save_embeddings(
|
||||
pool: &Pool,
|
||||
args: &Args,
|
||||
model: &Model,
|
||||
item_ids: &HashSet<i32>,
|
||||
) -> Result<()> {
|
||||
let client = pool.get().await?;
|
||||
|
||||
// Create the target table if it doesn't exist
|
||||
let create_table = format!(
|
||||
"CREATE TABLE IF NOT EXISTS {} (item_id INTEGER PRIMARY KEY, embedding FLOAT4[])",
|
||||
args.target_table
|
||||
);
|
||||
client.execute(&create_table, &[]).await?;
|
||||
|
||||
let mut valid_embeddings = 0;
|
||||
let mut invalid_embeddings = 0;
|
||||
let batch_size = 128;
|
||||
let mut current_batch = Vec::with_capacity(batch_size);
|
||||
let mut current_idx = 0;
|
||||
|
||||
// Process factors in chunks using the iterator directly
|
||||
for factors in model.q_iter() {
|
||||
// Process factors for items that appear in the source table
|
||||
for (idx, factors) in model.q_iter().enumerate() {
|
||||
let item_id = idx as i32;
|
||||
if !item_ids.contains(&item_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip invalid embeddings
|
||||
if factors.iter().any(|&x| x.is_nan()) {
|
||||
invalid_embeddings += 1;
|
||||
current_idx += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
valid_embeddings += 1;
|
||||
current_batch.push((current_idx, factors));
|
||||
current_idx += 1;
|
||||
current_batch.push((item_id as i64, factors));
|
||||
|
||||
// When batch is full, save it
|
||||
if current_batch.len() >= batch_size {
|
||||
save_batch(&client, &args.target_table, ¤t_batch).await?;
|
||||
save_batch(
|
||||
&client,
|
||||
&args.target_table,
|
||||
&args.target_id_column,
|
||||
&args.target_embedding_column,
|
||||
¤t_batch,
|
||||
)
|
||||
.await?;
|
||||
current_batch.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Save any remaining items in the last batch
|
||||
if !current_batch.is_empty() {
|
||||
save_batch(&client, &args.target_table, ¤t_batch).await?;
|
||||
save_batch(
|
||||
&client,
|
||||
&args.target_table,
|
||||
&args.target_id_column,
|
||||
&args.target_embedding_column,
|
||||
¤t_batch,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
info!(
|
||||
@@ -173,21 +202,26 @@ async fn save_embeddings(pool: &Pool, args: &Args, model: &Model) -> Result<()>
|
||||
async fn save_batch(
|
||||
client: &deadpool_postgres::Client,
|
||||
target_table: &str,
|
||||
batch_values: &[(i32, &[f32])],
|
||||
target_id_column: &str,
|
||||
target_embedding_column: &str,
|
||||
batch_values: &[(i64, &[f32])],
|
||||
) -> Result<()> {
|
||||
// Build the batch insert query
|
||||
let placeholders: Vec<String> = (0..batch_values.len())
|
||||
.map(|i| format!("(${}, ${})", i * 2 + 1, i * 2 + 2))
|
||||
.map(|i| format!("(${}::int8, ${}::float4[])", i * 2 + 1, i * 2 + 2))
|
||||
.collect();
|
||||
let query = format!(
|
||||
r#"
|
||||
INSERT INTO {target_table} (item_id, embedding) VALUES {placeholders}
|
||||
ON CONFLICT (item_id)
|
||||
DO UPDATE SET embedding = EXCLUDED.embedding
|
||||
INSERT INTO {target_table} ({target_id_column}, {target_embedding_column})
|
||||
VALUES {placeholders}
|
||||
ON CONFLICT ({target_id_column})
|
||||
DO UPDATE SET {target_embedding_column} = EXCLUDED.{target_embedding_column}
|
||||
"#,
|
||||
placeholders = placeholders.join(",")
|
||||
);
|
||||
|
||||
// info!("Executing query: {}", query);
|
||||
|
||||
// Flatten parameters for the query
|
||||
let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = Vec::new();
|
||||
for (item_id, factors) in batch_values {
|
||||
@@ -195,7 +229,80 @@ async fn save_batch(
|
||||
params.push(factors);
|
||||
}
|
||||
|
||||
client.execute(&query, ¶ms[..]).await?;
|
||||
info!("Number of parameters: {}", params.len());
|
||||
client.execute(&query, ¶ms[..]).await.with_context(|| {
|
||||
format!(
|
||||
"Failed to execute batch insert at {}:{} with {} values",
|
||||
file!(),
|
||||
line!(),
|
||||
batch_values.len()
|
||||
)
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn validate_table_exists(client: &deadpool_postgres::Client, table: &str) -> Result<()> {
|
||||
let query = r#"SELECT EXISTS (
|
||||
SELECT FROM pg_tables
|
||||
WHERE schemaname = 'public'
|
||||
AND tablename = $1
|
||||
)"#;
|
||||
|
||||
let exists: bool = client.query_one(query, &[&table]).await?.get(0);
|
||||
|
||||
if !exists {
|
||||
anyhow::bail!("Table '{}' does not exist", table);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn validate_column_exists(
|
||||
client: &deadpool_postgres::Client,
|
||||
table: &str,
|
||||
column: &str,
|
||||
) -> Result<()> {
|
||||
let query = r#"SELECT EXISTS (
|
||||
SELECT FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = $1
|
||||
AND column_name = $2
|
||||
)"#;
|
||||
|
||||
let exists: bool = client.query_one(query, &[&table, &column]).await?.get(0);
|
||||
|
||||
if !exists {
|
||||
anyhow::bail!("Column '{}' does not exist in table '{}'", column, table);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Result<()> {
|
||||
// Validate source table exists
|
||||
validate_table_exists(client, &args.source_table)
|
||||
.await
|
||||
.context("Failed to validate source table")?;
|
||||
|
||||
// Validate source columns exist
|
||||
validate_column_exists(client, &args.source_table, &args.source_user_id_column)
|
||||
.await
|
||||
.context("Failed to validate source user ID column")?;
|
||||
validate_column_exists(client, &args.source_table, &args.source_item_id_column)
|
||||
.await
|
||||
.context("Failed to validate source item ID column")?;
|
||||
|
||||
// Validate target table exists
|
||||
validate_table_exists(client, &args.target_table)
|
||||
.await
|
||||
.context("Failed to validate target table")?;
|
||||
|
||||
// Validate target columns exist
|
||||
validate_column_exists(client, &args.target_table, &args.target_id_column)
|
||||
.await
|
||||
.context("Failed to validate target ID column")?;
|
||||
validate_column_exists(client, &args.target_table, &args.target_embedding_column)
|
||||
.await
|
||||
.context("Failed to validate target embedding column")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -205,27 +312,39 @@ async fn main() -> Result<()> {
|
||||
pretty_env_logger::init();
|
||||
let args = Args::parse();
|
||||
|
||||
// Set up graceful shutdown
|
||||
let pool = create_pool().await?;
|
||||
|
||||
// Validate schema before proceeding
|
||||
info!("Validating database schema...");
|
||||
validate_schema(&pool.get().await?, &args).await?;
|
||||
info!("Schema validation successful");
|
||||
|
||||
let mut matrix = Matrix::new();
|
||||
let mut last_user_id: Option<i32> = None;
|
||||
let mut last_item_id: Option<i32> = None;
|
||||
let mut total_rows = 0;
|
||||
let mut item_ids = HashSet::new();
|
||||
|
||||
// Set up graceful shutdown for data loading
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
let r = running.clone();
|
||||
ctrlc::set_handler(move || {
|
||||
if IMMEDIATE_EXIT.load(Ordering::SeqCst) {
|
||||
info!("Received interrupt signal, exiting immediately...");
|
||||
std::process::exit(1);
|
||||
} else {
|
||||
r.store(false, Ordering::SeqCst);
|
||||
info!("Received interrupt signal, finishing current batch...");
|
||||
}
|
||||
})?;
|
||||
|
||||
let pool = create_pool().await?;
|
||||
let mut matrix = Matrix::new();
|
||||
let mut last_user_id = None;
|
||||
let mut last_item_id = None;
|
||||
let mut total_rows = 0;
|
||||
|
||||
info!("Starting data loading...");
|
||||
while running.load(Ordering::SeqCst) {
|
||||
let batch = load_data_batch(
|
||||
&pool.get().await?,
|
||||
&args.source_table,
|
||||
&args.user_id_column,
|
||||
&args.item_id_column,
|
||||
&args.source_user_id_column,
|
||||
&args.source_item_id_column,
|
||||
args.batch_size as usize,
|
||||
last_user_id,
|
||||
last_item_id,
|
||||
@@ -252,17 +371,24 @@ async fn main() -> Result<()> {
|
||||
// Process batch
|
||||
for (user_id, item_id) in batch {
|
||||
matrix.push(user_id, item_id, 1.0f32);
|
||||
item_ids.insert(item_id);
|
||||
}
|
||||
}
|
||||
|
||||
info!("Loaded {} total rows", total_rows);
|
||||
info!(
|
||||
"Loaded {} total rows with {} unique items",
|
||||
total_rows,
|
||||
item_ids.len()
|
||||
);
|
||||
|
||||
if total_rows == 0 {
|
||||
info!("No data found in source table");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Set up training parameters
|
||||
// Switch to immediate exit mode for training
|
||||
IMMEDIATE_EXIT.store(true, Ordering::SeqCst);
|
||||
|
||||
let model = Model::params()
|
||||
.factors(args.factors)
|
||||
.lambda_p1(args.lambda1)
|
||||
@@ -279,7 +405,7 @@ async fn main() -> Result<()> {
|
||||
.fit(&matrix)?;
|
||||
|
||||
info!("Saving embeddings...");
|
||||
save_embeddings(&pool, &args, &model).await?;
|
||||
save_embeddings(&pool, &args, &model, &item_ids).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user