Refactor data loading and embedding saving process
- Updated `.cargo/config.toml` to optimize compilation flags for performance. - Enhanced `main.rs` by: - Renaming user and item ID columns for clarity. - Adding validation functions to ensure the existence of tables and columns in the database schema. - Implementing immediate exit handling during data loading. - Modifying the `save_embeddings` function to accept item IDs for processing. - Improving error handling with context messages for database operations. These changes improve code readability, robustness, and performance during data processing.
This commit is contained in:
@@ -9,7 +9,7 @@ rustflags = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[env]
|
[env]
|
||||||
CXXFLAGS = "-fopenmp -pthread -DUSEOMP=1"
|
CXXFLAGS = "-O3 -march=native -fopenmp -pthread -DUSEOMP=1"
|
||||||
LDFLAGS = "-fopenmp -pthread -DUSEOMP=1"
|
LDFLAGS = "-O3 -march=native -fopenmp -pthread -DUSEOMP=1 -flto"
|
||||||
CC = "gcc"
|
CC = "gcc"
|
||||||
CXX = "g++"
|
CXX = "g++"
|
||||||
|
|||||||
232
src/main.rs
232
src/main.rs
@@ -4,11 +4,14 @@ use deadpool_postgres::{Config, Pool, Runtime};
|
|||||||
use dotenv::dotenv;
|
use dotenv::dotenv;
|
||||||
use libmf::{Loss, Matrix, Model};
|
use libmf::{Loss, Matrix, Model};
|
||||||
use log::info;
|
use log::info;
|
||||||
|
use std::collections::HashSet;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio_postgres::NoTls;
|
use tokio_postgres::NoTls;
|
||||||
|
|
||||||
|
static IMMEDIATE_EXIT: AtomicBool = AtomicBool::new(false);
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@@ -16,18 +19,26 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
source_table: String,
|
source_table: String,
|
||||||
|
|
||||||
/// User ID column name
|
/// User ID column name in source table
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
user_id_column: String,
|
source_user_id_column: String,
|
||||||
|
|
||||||
/// Item ID column name
|
/// Item ID column name in source table
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
item_id_column: String,
|
source_item_id_column: String,
|
||||||
|
|
||||||
/// Target table for item embeddings
|
/// Target table for item embeddings
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
target_table: String,
|
target_table: String,
|
||||||
|
|
||||||
|
/// Target ID column name in the target table
|
||||||
|
#[arg(long)]
|
||||||
|
target_id_column: String,
|
||||||
|
|
||||||
|
/// Target column name for embeddings array
|
||||||
|
#[arg(long)]
|
||||||
|
target_embedding_column: String,
|
||||||
|
|
||||||
/// Number of iterations for matrix factorization
|
/// Number of iterations for matrix factorization
|
||||||
#[arg(long, short = 'i', default_value = "100")]
|
#[arg(long, short = 'i', default_value = "100")]
|
||||||
iterations: i32,
|
iterations: i32,
|
||||||
@@ -80,35 +91,39 @@ async fn create_pool() -> Result<Pool> {
|
|||||||
async fn load_data_batch(
|
async fn load_data_batch(
|
||||||
client: &deadpool_postgres::Client,
|
client: &deadpool_postgres::Client,
|
||||||
source_table: &str,
|
source_table: &str,
|
||||||
user_id_column: &str,
|
source_user_id_column: &str,
|
||||||
item_id_column: &str,
|
source_item_id_column: &str,
|
||||||
batch_size: usize,
|
batch_size: usize,
|
||||||
last_user_id: Option<i32>,
|
last_user_id: Option<i32>,
|
||||||
last_item_id: Option<i32>,
|
last_item_id: Option<i32>,
|
||||||
) -> Result<Vec<(i32, i32)>> {
|
) -> Result<Vec<(i32, i32)>> {
|
||||||
let rows = if let (Some(last_user), Some(last_item)) = (last_user_id, last_item_id) {
|
let rows = if let (Some(last_user), Some(last_item)) = (last_user_id, last_item_id) {
|
||||||
let query = format!(
|
let query = format!(
|
||||||
"SELECT {user}, {item} FROM {table} \
|
"SELECT ({user})::int4, ({item})::int4 FROM {table} \
|
||||||
WHERE ({user}, {item}) > ($1, $2) \
|
WHERE ({user}, {item}) > ($1::int4, $2::int4) \
|
||||||
ORDER BY {user}, {item} \
|
ORDER BY {user}, {item} \
|
||||||
LIMIT $3",
|
LIMIT $3::bigint",
|
||||||
user = user_id_column,
|
user = source_user_id_column,
|
||||||
item = item_id_column,
|
item = source_item_id_column,
|
||||||
table = source_table,
|
table = source_table,
|
||||||
);
|
);
|
||||||
client
|
client
|
||||||
.query(&query, &[&last_user, &last_item, &(batch_size as i64)])
|
.query(&query, &[&last_user, &last_item, &(batch_size as i64)])
|
||||||
.await?
|
.await
|
||||||
|
.with_context(|| format!("Query failed at {}:{}", file!(), line!()))?
|
||||||
} else {
|
} else {
|
||||||
let query = format!(
|
let query = format!(
|
||||||
"SELECT {user}, {item} FROM {table} \
|
"SELECT ({user})::int4, ({item})::int4 FROM {table} \
|
||||||
ORDER BY {user}, {item} \
|
ORDER BY {user}, {item} \
|
||||||
LIMIT $1",
|
LIMIT $1::bigint",
|
||||||
user = user_id_column,
|
user = source_user_id_column,
|
||||||
item = item_id_column,
|
item = source_item_id_column,
|
||||||
table = source_table,
|
table = source_table,
|
||||||
);
|
);
|
||||||
client.query(&query, &[&(batch_size as i64)]).await?
|
client
|
||||||
|
.query(&query, &[&(batch_size as i64)])
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("Query failed at {}:{}", file!(), line!()))?
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut batch = Vec::with_capacity(rows.len());
|
let mut batch = Vec::with_capacity(rows.len());
|
||||||
@@ -121,45 +136,59 @@ async fn load_data_batch(
|
|||||||
Ok(batch)
|
Ok(batch)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_embeddings(pool: &Pool, args: &Args, model: &Model) -> Result<()> {
|
async fn save_embeddings(
|
||||||
|
pool: &Pool,
|
||||||
|
args: &Args,
|
||||||
|
model: &Model,
|
||||||
|
item_ids: &HashSet<i32>,
|
||||||
|
) -> Result<()> {
|
||||||
let client = pool.get().await?;
|
let client = pool.get().await?;
|
||||||
|
|
||||||
// Create the target table if it doesn't exist
|
|
||||||
let create_table = format!(
|
|
||||||
"CREATE TABLE IF NOT EXISTS {} (item_id INTEGER PRIMARY KEY, embedding FLOAT4[])",
|
|
||||||
args.target_table
|
|
||||||
);
|
|
||||||
client.execute(&create_table, &[]).await?;
|
|
||||||
|
|
||||||
let mut valid_embeddings = 0;
|
let mut valid_embeddings = 0;
|
||||||
let mut invalid_embeddings = 0;
|
let mut invalid_embeddings = 0;
|
||||||
let batch_size = 128;
|
let batch_size = 128;
|
||||||
let mut current_batch = Vec::with_capacity(batch_size);
|
let mut current_batch = Vec::with_capacity(batch_size);
|
||||||
let mut current_idx = 0;
|
|
||||||
|
|
||||||
// Process factors in chunks using the iterator directly
|
// Process factors for items that appear in the source table
|
||||||
for factors in model.q_iter() {
|
for (idx, factors) in model.q_iter().enumerate() {
|
||||||
|
let item_id = idx as i32;
|
||||||
|
if !item_ids.contains(&item_id) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// Skip invalid embeddings
|
// Skip invalid embeddings
|
||||||
if factors.iter().any(|&x| x.is_nan()) {
|
if factors.iter().any(|&x| x.is_nan()) {
|
||||||
invalid_embeddings += 1;
|
invalid_embeddings += 1;
|
||||||
current_idx += 1;
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
valid_embeddings += 1;
|
valid_embeddings += 1;
|
||||||
current_batch.push((current_idx, factors));
|
current_batch.push((item_id as i64, factors));
|
||||||
current_idx += 1;
|
|
||||||
|
|
||||||
// When batch is full, save it
|
// When batch is full, save it
|
||||||
if current_batch.len() >= batch_size {
|
if current_batch.len() >= batch_size {
|
||||||
save_batch(&client, &args.target_table, ¤t_batch).await?;
|
save_batch(
|
||||||
|
&client,
|
||||||
|
&args.target_table,
|
||||||
|
&args.target_id_column,
|
||||||
|
&args.target_embedding_column,
|
||||||
|
¤t_batch,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
current_batch.clear();
|
current_batch.clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save any remaining items in the last batch
|
// Save any remaining items in the last batch
|
||||||
if !current_batch.is_empty() {
|
if !current_batch.is_empty() {
|
||||||
save_batch(&client, &args.target_table, ¤t_batch).await?;
|
save_batch(
|
||||||
|
&client,
|
||||||
|
&args.target_table,
|
||||||
|
&args.target_id_column,
|
||||||
|
&args.target_embedding_column,
|
||||||
|
¤t_batch,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
@@ -173,21 +202,26 @@ async fn save_embeddings(pool: &Pool, args: &Args, model: &Model) -> Result<()>
|
|||||||
async fn save_batch(
|
async fn save_batch(
|
||||||
client: &deadpool_postgres::Client,
|
client: &deadpool_postgres::Client,
|
||||||
target_table: &str,
|
target_table: &str,
|
||||||
batch_values: &[(i32, &[f32])],
|
target_id_column: &str,
|
||||||
|
target_embedding_column: &str,
|
||||||
|
batch_values: &[(i64, &[f32])],
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
// Build the batch insert query
|
// Build the batch insert query
|
||||||
let placeholders: Vec<String> = (0..batch_values.len())
|
let placeholders: Vec<String> = (0..batch_values.len())
|
||||||
.map(|i| format!("(${}, ${})", i * 2 + 1, i * 2 + 2))
|
.map(|i| format!("(${}::int8, ${}::float4[])", i * 2 + 1, i * 2 + 2))
|
||||||
.collect();
|
.collect();
|
||||||
let query = format!(
|
let query = format!(
|
||||||
r#"
|
r#"
|
||||||
INSERT INTO {target_table} (item_id, embedding) VALUES {placeholders}
|
INSERT INTO {target_table} ({target_id_column}, {target_embedding_column})
|
||||||
ON CONFLICT (item_id)
|
VALUES {placeholders}
|
||||||
DO UPDATE SET embedding = EXCLUDED.embedding
|
ON CONFLICT ({target_id_column})
|
||||||
|
DO UPDATE SET {target_embedding_column} = EXCLUDED.{target_embedding_column}
|
||||||
"#,
|
"#,
|
||||||
placeholders = placeholders.join(",")
|
placeholders = placeholders.join(",")
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// info!("Executing query: {}", query);
|
||||||
|
|
||||||
// Flatten parameters for the query
|
// Flatten parameters for the query
|
||||||
let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = Vec::new();
|
let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = Vec::new();
|
||||||
for (item_id, factors) in batch_values {
|
for (item_id, factors) in batch_values {
|
||||||
@@ -195,7 +229,80 @@ async fn save_batch(
|
|||||||
params.push(factors);
|
params.push(factors);
|
||||||
}
|
}
|
||||||
|
|
||||||
client.execute(&query, ¶ms[..]).await?;
|
info!("Number of parameters: {}", params.len());
|
||||||
|
client.execute(&query, ¶ms[..]).await.with_context(|| {
|
||||||
|
format!(
|
||||||
|
"Failed to execute batch insert at {}:{} with {} values",
|
||||||
|
file!(),
|
||||||
|
line!(),
|
||||||
|
batch_values.len()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn validate_table_exists(client: &deadpool_postgres::Client, table: &str) -> Result<()> {
|
||||||
|
let query = r#"SELECT EXISTS (
|
||||||
|
SELECT FROM pg_tables
|
||||||
|
WHERE schemaname = 'public'
|
||||||
|
AND tablename = $1
|
||||||
|
)"#;
|
||||||
|
|
||||||
|
let exists: bool = client.query_one(query, &[&table]).await?.get(0);
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
anyhow::bail!("Table '{}' does not exist", table);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn validate_column_exists(
|
||||||
|
client: &deadpool_postgres::Client,
|
||||||
|
table: &str,
|
||||||
|
column: &str,
|
||||||
|
) -> Result<()> {
|
||||||
|
let query = r#"SELECT EXISTS (
|
||||||
|
SELECT FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = $1
|
||||||
|
AND column_name = $2
|
||||||
|
)"#;
|
||||||
|
|
||||||
|
let exists: bool = client.query_one(query, &[&table, &column]).await?.get(0);
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
anyhow::bail!("Column '{}' does not exist in table '{}'", column, table);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Result<()> {
|
||||||
|
// Validate source table exists
|
||||||
|
validate_table_exists(client, &args.source_table)
|
||||||
|
.await
|
||||||
|
.context("Failed to validate source table")?;
|
||||||
|
|
||||||
|
// Validate source columns exist
|
||||||
|
validate_column_exists(client, &args.source_table, &args.source_user_id_column)
|
||||||
|
.await
|
||||||
|
.context("Failed to validate source user ID column")?;
|
||||||
|
validate_column_exists(client, &args.source_table, &args.source_item_id_column)
|
||||||
|
.await
|
||||||
|
.context("Failed to validate source item ID column")?;
|
||||||
|
|
||||||
|
// Validate target table exists
|
||||||
|
validate_table_exists(client, &args.target_table)
|
||||||
|
.await
|
||||||
|
.context("Failed to validate target table")?;
|
||||||
|
|
||||||
|
// Validate target columns exist
|
||||||
|
validate_column_exists(client, &args.target_table, &args.target_id_column)
|
||||||
|
.await
|
||||||
|
.context("Failed to validate target ID column")?;
|
||||||
|
validate_column_exists(client, &args.target_table, &args.target_embedding_column)
|
||||||
|
.await
|
||||||
|
.context("Failed to validate target embedding column")?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,27 +312,39 @@ async fn main() -> Result<()> {
|
|||||||
pretty_env_logger::init();
|
pretty_env_logger::init();
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
// Set up graceful shutdown
|
let pool = create_pool().await?;
|
||||||
|
|
||||||
|
// Validate schema before proceeding
|
||||||
|
info!("Validating database schema...");
|
||||||
|
validate_schema(&pool.get().await?, &args).await?;
|
||||||
|
info!("Schema validation successful");
|
||||||
|
|
||||||
|
let mut matrix = Matrix::new();
|
||||||
|
let mut last_user_id: Option<i32> = None;
|
||||||
|
let mut last_item_id: Option<i32> = None;
|
||||||
|
let mut total_rows = 0;
|
||||||
|
let mut item_ids = HashSet::new();
|
||||||
|
|
||||||
|
// Set up graceful shutdown for data loading
|
||||||
let running = Arc::new(AtomicBool::new(true));
|
let running = Arc::new(AtomicBool::new(true));
|
||||||
let r = running.clone();
|
let r = running.clone();
|
||||||
ctrlc::set_handler(move || {
|
ctrlc::set_handler(move || {
|
||||||
r.store(false, Ordering::SeqCst);
|
if IMMEDIATE_EXIT.load(Ordering::SeqCst) {
|
||||||
info!("Received interrupt signal, finishing current batch...");
|
info!("Received interrupt signal, exiting immediately...");
|
||||||
|
std::process::exit(1);
|
||||||
|
} else {
|
||||||
|
r.store(false, Ordering::SeqCst);
|
||||||
|
info!("Received interrupt signal, finishing current batch...");
|
||||||
|
}
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let pool = create_pool().await?;
|
|
||||||
let mut matrix = Matrix::new();
|
|
||||||
let mut last_user_id = None;
|
|
||||||
let mut last_item_id = None;
|
|
||||||
let mut total_rows = 0;
|
|
||||||
|
|
||||||
info!("Starting data loading...");
|
info!("Starting data loading...");
|
||||||
while running.load(Ordering::SeqCst) {
|
while running.load(Ordering::SeqCst) {
|
||||||
let batch = load_data_batch(
|
let batch = load_data_batch(
|
||||||
&pool.get().await?,
|
&pool.get().await?,
|
||||||
&args.source_table,
|
&args.source_table,
|
||||||
&args.user_id_column,
|
&args.source_user_id_column,
|
||||||
&args.item_id_column,
|
&args.source_item_id_column,
|
||||||
args.batch_size as usize,
|
args.batch_size as usize,
|
||||||
last_user_id,
|
last_user_id,
|
||||||
last_item_id,
|
last_item_id,
|
||||||
@@ -252,17 +371,24 @@ async fn main() -> Result<()> {
|
|||||||
// Process batch
|
// Process batch
|
||||||
for (user_id, item_id) in batch {
|
for (user_id, item_id) in batch {
|
||||||
matrix.push(user_id, item_id, 1.0f32);
|
matrix.push(user_id, item_id, 1.0f32);
|
||||||
|
item_ids.insert(item_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Loaded {} total rows", total_rows);
|
info!(
|
||||||
|
"Loaded {} total rows with {} unique items",
|
||||||
|
total_rows,
|
||||||
|
item_ids.len()
|
||||||
|
);
|
||||||
|
|
||||||
if total_rows == 0 {
|
if total_rows == 0 {
|
||||||
info!("No data found in source table");
|
info!("No data found in source table");
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up training parameters
|
// Switch to immediate exit mode for training
|
||||||
|
IMMEDIATE_EXIT.store(true, Ordering::SeqCst);
|
||||||
|
|
||||||
let model = Model::params()
|
let model = Model::params()
|
||||||
.factors(args.factors)
|
.factors(args.factors)
|
||||||
.lambda_p1(args.lambda1)
|
.lambda_p1(args.lambda1)
|
||||||
@@ -279,7 +405,7 @@ async fn main() -> Result<()> {
|
|||||||
.fit(&matrix)?;
|
.fit(&matrix)?;
|
||||||
|
|
||||||
info!("Saving embeddings...");
|
info!("Saving embeddings...");
|
||||||
save_embeddings(&pool, &args, &model).await?;
|
save_embeddings(&pool, &args, &model, &item_ids).await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user