From 428ca89c928632185dfcd51eb2a9a746b7b73a8b Mon Sep 17 00:00:00 2001 From: Dylan Knutson Date: Sat, 28 Dec 2024 17:55:56 +0000 Subject: [PATCH] use COPY for importing data --- Cargo.lock | 36 ++++++++++++++ Cargo.toml | 1 + src/main.rs | 137 +++++++++++++++++++--------------------------------- 3 files changed, 86 insertions(+), 88 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 63e0ae4..4e4ca1b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -703,6 +703,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -719,6 +734,23 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + [[package]] name = "futures-macro" version = "0.3.31" @@ -748,10 +780,13 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", + "futures-io", "futures-macro", "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", "slab", @@ -1179,6 +1214,7 @@ dependencies = [ "ctrlc", "deadpool-postgres", "dotenv", + "futures", "libmf", "log", "ndarray", diff --git a/Cargo.toml b/Cargo.toml index 3113f76..30de10a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ clap = { version = "4.4", features = ["derive"] } ctrlc = "3.4" deadpool-postgres = "0.11" dotenv = "0.15" +futures = "0.3" libmf = "0.3" log = "0.4" ndarray = { version = "0.15", features = ["blas"] } diff --git a/src/main.rs b/src/main.rs index f2ee141..f1aca85 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ use anyhow::{Context, Result}; use clap::Parser; use deadpool_postgres::{Config, Pool, Runtime}; use dotenv::dotenv; +use futures::StreamExt; use libmf::{Loss, Matrix, Model}; use log::info; use std::collections::HashSet; @@ -92,52 +93,34 @@ async fn create_pool() -> Result { Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?) } -async fn load_data_batch( +async fn load_data( client: &deadpool_postgres::Client, source_table: &str, source_user_id_column: &str, source_item_id_column: &str, - batch_size: usize, - last_user_id: Option, - last_item_id: Option, ) -> Result> { - let rows = if let (Some(last_user), Some(last_item)) = (last_user_id, last_item_id) { - let query = format!( - "SELECT ({user})::int4, ({item})::int4 FROM {table} \ - WHERE ({user}, {item}) > ($1::int4, $2::int4) \ - ORDER BY {user}, {item} \ - LIMIT $3::bigint", - user = source_user_id_column, - item = source_item_id_column, - table = source_table, - ); - client - .query(&query, &[&last_user, &last_item, &(batch_size as i64)]) - .await - .with_context(|| format!("Query failed at {}:{}", file!(), line!()))? - } else { - let query = format!( - "SELECT ({user})::int4, ({item})::int4 FROM {table} \ - ORDER BY {user}, {item} \ - LIMIT $1::bigint", - user = source_user_id_column, - item = source_item_id_column, - table = source_table, - ); - client - .query(&query, &[&(batch_size as i64)]) - .await - .with_context(|| format!("Query failed at {}:{}", file!(), line!()))? - }; + let query = format!( + "COPY {table} ({user}, {item}) TO STDOUT (FORMAT text, DELIMITER '\t')", + table = source_table, + user = source_user_id_column, + item = source_item_id_column, + ); - let mut batch = Vec::with_capacity(rows.len()); - for row in rows { - let user_id: i32 = row.get(0); - let item_id: i32 = row.get(1); - batch.push((user_id, item_id)); + let mut data = Vec::new(); + let mut copy_out = Box::pin(client.copy_out(&query).await?); + + while let Some(bytes) = copy_out.as_mut().next().await { + let bytes = bytes?; + let row = String::from_utf8(bytes.to_vec())?; + let parts: Vec<&str> = row.trim().split('\t').collect(); + if parts.len() == 2 { + let user_id: i32 = parts[0].parse()?; + let item_id: i32 = parts[1].parse()?; + data.push((user_id, item_id)); + } } - Ok(batch) + Ok(data) } async fn save_embeddings( @@ -324,9 +307,6 @@ async fn main() -> Result<()> { info!("Schema validation successful"); let mut matrix = Matrix::new(); - let mut last_user_id: Option = None; - let mut last_item_id: Option = None; - let mut total_rows = 0; let mut item_ids = HashSet::new(); // Set up graceful shutdown for data loading @@ -343,56 +323,42 @@ async fn main() -> Result<()> { })?; info!("Starting data loading..."); - while running.load(Ordering::SeqCst) { - let batch = load_data_batch( - &pool.get().await?, - &args.source_table, - &args.source_user_id_column, - &args.source_item_id_column, - args.batch_size as usize, - last_user_id, - last_item_id, - ) - .await?; + let data = load_data( + &pool.get().await?, + &args.source_table, + &args.source_user_id_column, + &args.source_item_id_column, + ) + .await?; - if batch.is_empty() { - break; - } + if data.is_empty() { + info!("No data found in source table"); + return Ok(()); + } - // Check if we would exceed max_interactions - if let Some(max) = args.max_interactions { - if total_rows + batch.len() > max { - // Only process up to max_interactions - let remaining = max - total_rows; - for (user_id, item_id) in batch.into_iter().take(remaining) { - matrix.push(user_id, item_id, 1.0f32); - item_ids.insert(item_id); - } - total_rows += remaining; - info!("Reached maximum interactions limit of {max}",); - break; - } - } - - total_rows += batch.len(); + let data_len = data.len(); + // Check if we would exceed max_interactions + let total_rows = if let Some(max) = args.max_interactions { + let max = max.min(data_len); info!( - "Loaded batch of {} rows (total: {})", - batch.len(), - total_rows + "Loading {} interactions (limited by --max-interactions)", + max ); - // Update last seen IDs for next batch - if let Some((user_id, item_id)) = batch.last() { - last_user_id = Some(*user_id); - last_item_id = Some(*item_id); - } - - // Process batch - for (user_id, item_id) in batch { + // Only process up to max_interactions + for (user_id, item_id) in data.into_iter().take(max) { matrix.push(user_id, item_id, 1.0f32); item_ids.insert(item_id); } - } + max + } else { + // Process all data + for (user_id, item_id) in data { + matrix.push(user_id, item_id, 1.0f32); + item_ids.insert(item_id); + } + data_len + }; info!( "Loaded {} total rows with {} unique items", @@ -400,11 +366,6 @@ async fn main() -> Result<()> { item_ids.len() ); - if total_rows == 0 { - info!("No data found in source table"); - return Ok(()); - } - // Switch to immediate exit mode for training IMMEDIATE_EXIT.store(true, Ordering::SeqCst);