use COPY for importing data

This commit is contained in:
Dylan Knutson
2024-12-28 17:55:56 +00:00
parent 857cbf5d1f
commit 428ca89c92
3 changed files with 86 additions and 88 deletions

36
Cargo.lock generated
View File

@@ -703,6 +703,21 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]] [[package]]
name = "futures-channel" name = "futures-channel"
version = "0.3.31" version = "0.3.31"
@@ -719,6 +734,23 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "futures-executor"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
[[package]] [[package]]
name = "futures-macro" name = "futures-macro"
version = "0.3.31" version = "0.3.31"
@@ -748,10 +780,13 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [ dependencies = [
"futures-channel",
"futures-core", "futures-core",
"futures-io",
"futures-macro", "futures-macro",
"futures-sink", "futures-sink",
"futures-task", "futures-task",
"memchr",
"pin-project-lite", "pin-project-lite",
"pin-utils", "pin-utils",
"slab", "slab",
@@ -1179,6 +1214,7 @@ dependencies = [
"ctrlc", "ctrlc",
"deadpool-postgres", "deadpool-postgres",
"dotenv", "dotenv",
"futures",
"libmf", "libmf",
"log", "log",
"ndarray", "ndarray",

View File

@@ -9,6 +9,7 @@ clap = { version = "4.4", features = ["derive"] }
ctrlc = "3.4" ctrlc = "3.4"
deadpool-postgres = "0.11" deadpool-postgres = "0.11"
dotenv = "0.15" dotenv = "0.15"
futures = "0.3"
libmf = "0.3" libmf = "0.3"
log = "0.4" log = "0.4"
ndarray = { version = "0.15", features = ["blas"] } ndarray = { version = "0.15", features = ["blas"] }

View File

@@ -2,6 +2,7 @@ use anyhow::{Context, Result};
use clap::Parser; use clap::Parser;
use deadpool_postgres::{Config, Pool, Runtime}; use deadpool_postgres::{Config, Pool, Runtime};
use dotenv::dotenv; use dotenv::dotenv;
use futures::StreamExt;
use libmf::{Loss, Matrix, Model}; use libmf::{Loss, Matrix, Model};
use log::info; use log::info;
use std::collections::HashSet; use std::collections::HashSet;
@@ -92,52 +93,34 @@ async fn create_pool() -> Result<Pool> {
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?) Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
} }
async fn load_data_batch( async fn load_data(
client: &deadpool_postgres::Client, client: &deadpool_postgres::Client,
source_table: &str, source_table: &str,
source_user_id_column: &str, source_user_id_column: &str,
source_item_id_column: &str, source_item_id_column: &str,
batch_size: usize,
last_user_id: Option<i32>,
last_item_id: Option<i32>,
) -> Result<Vec<(i32, i32)>> { ) -> Result<Vec<(i32, i32)>> {
let rows = if let (Some(last_user), Some(last_item)) = (last_user_id, last_item_id) { let query = format!(
let query = format!( "COPY {table} ({user}, {item}) TO STDOUT (FORMAT text, DELIMITER '\t')",
"SELECT ({user})::int4, ({item})::int4 FROM {table} \ table = source_table,
WHERE ({user}, {item}) > ($1::int4, $2::int4) \ user = source_user_id_column,
ORDER BY {user}, {item} \ item = source_item_id_column,
LIMIT $3::bigint", );
user = source_user_id_column,
item = source_item_id_column,
table = source_table,
);
client
.query(&query, &[&last_user, &last_item, &(batch_size as i64)])
.await
.with_context(|| format!("Query failed at {}:{}", file!(), line!()))?
} else {
let query = format!(
"SELECT ({user})::int4, ({item})::int4 FROM {table} \
ORDER BY {user}, {item} \
LIMIT $1::bigint",
user = source_user_id_column,
item = source_item_id_column,
table = source_table,
);
client
.query(&query, &[&(batch_size as i64)])
.await
.with_context(|| format!("Query failed at {}:{}", file!(), line!()))?
};
let mut batch = Vec::with_capacity(rows.len()); let mut data = Vec::new();
for row in rows { let mut copy_out = Box::pin(client.copy_out(&query).await?);
let user_id: i32 = row.get(0);
let item_id: i32 = row.get(1); while let Some(bytes) = copy_out.as_mut().next().await {
batch.push((user_id, item_id)); let bytes = bytes?;
let row = String::from_utf8(bytes.to_vec())?;
let parts: Vec<&str> = row.trim().split('\t').collect();
if parts.len() == 2 {
let user_id: i32 = parts[0].parse()?;
let item_id: i32 = parts[1].parse()?;
data.push((user_id, item_id));
}
} }
Ok(batch) Ok(data)
} }
async fn save_embeddings( async fn save_embeddings(
@@ -324,9 +307,6 @@ async fn main() -> Result<()> {
info!("Schema validation successful"); info!("Schema validation successful");
let mut matrix = Matrix::new(); let mut matrix = Matrix::new();
let mut last_user_id: Option<i32> = None;
let mut last_item_id: Option<i32> = None;
let mut total_rows = 0;
let mut item_ids = HashSet::new(); let mut item_ids = HashSet::new();
// Set up graceful shutdown for data loading // Set up graceful shutdown for data loading
@@ -343,56 +323,42 @@ async fn main() -> Result<()> {
})?; })?;
info!("Starting data loading..."); info!("Starting data loading...");
while running.load(Ordering::SeqCst) { let data = load_data(
let batch = load_data_batch( &pool.get().await?,
&pool.get().await?, &args.source_table,
&args.source_table, &args.source_user_id_column,
&args.source_user_id_column, &args.source_item_id_column,
&args.source_item_id_column, )
args.batch_size as usize, .await?;
last_user_id,
last_item_id,
)
.await?;
if batch.is_empty() { if data.is_empty() {
break; info!("No data found in source table");
} return Ok(());
}
// Check if we would exceed max_interactions let data_len = data.len();
if let Some(max) = args.max_interactions { // Check if we would exceed max_interactions
if total_rows + batch.len() > max { let total_rows = if let Some(max) = args.max_interactions {
// Only process up to max_interactions let max = max.min(data_len);
let remaining = max - total_rows;
for (user_id, item_id) in batch.into_iter().take(remaining) {
matrix.push(user_id, item_id, 1.0f32);
item_ids.insert(item_id);
}
total_rows += remaining;
info!("Reached maximum interactions limit of {max}",);
break;
}
}
total_rows += batch.len();
info!( info!(
"Loaded batch of {} rows (total: {})", "Loading {} interactions (limited by --max-interactions)",
batch.len(), max
total_rows
); );
// Update last seen IDs for next batch // Only process up to max_interactions
if let Some((user_id, item_id)) = batch.last() { for (user_id, item_id) in data.into_iter().take(max) {
last_user_id = Some(*user_id);
last_item_id = Some(*item_id);
}
// Process batch
for (user_id, item_id) in batch {
matrix.push(user_id, item_id, 1.0f32); matrix.push(user_id, item_id, 1.0f32);
item_ids.insert(item_id); item_ids.insert(item_id);
} }
} max
} else {
// Process all data
for (user_id, item_id) in data {
matrix.push(user_id, item_id, 1.0f32);
item_ids.insert(item_id);
}
data_len
};
info!( info!(
"Loaded {} total rows with {} unique items", "Loaded {} total rows with {} unique items",
@@ -400,11 +366,6 @@ async fn main() -> Result<()> {
item_ids.len() item_ids.len()
); );
if total_rows == 0 {
info!("No data found in source table");
return Ok(());
}
// Switch to immediate exit mode for training // Switch to immediate exit mode for training
IMMEDIATE_EXIT.store(true, Ordering::SeqCst); IMMEDIATE_EXIT.store(true, Ordering::SeqCst);