use COPY for importing data

This commit is contained in:
Dylan Knutson
2024-12-28 17:55:56 +00:00
parent 857cbf5d1f
commit 428ca89c92
3 changed files with 86 additions and 88 deletions

36
Cargo.lock generated
View File

@@ -703,6 +703,21 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.31"
@@ -719,6 +734,23 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "futures-executor"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
[[package]]
name = "futures-macro"
version = "0.3.31"
@@ -748,10 +780,13 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
@@ -1179,6 +1214,7 @@ dependencies = [
"ctrlc",
"deadpool-postgres",
"dotenv",
"futures",
"libmf",
"log",
"ndarray",

View File

@@ -9,6 +9,7 @@ clap = { version = "4.4", features = ["derive"] }
ctrlc = "3.4"
deadpool-postgres = "0.11"
dotenv = "0.15"
futures = "0.3"
libmf = "0.3"
log = "0.4"
ndarray = { version = "0.15", features = ["blas"] }

View File

@@ -2,6 +2,7 @@ use anyhow::{Context, Result};
use clap::Parser;
use deadpool_postgres::{Config, Pool, Runtime};
use dotenv::dotenv;
use futures::StreamExt;
use libmf::{Loss, Matrix, Model};
use log::info;
use std::collections::HashSet;
@@ -92,52 +93,34 @@ async fn create_pool() -> Result<Pool> {
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
}
async fn load_data_batch(
async fn load_data(
client: &deadpool_postgres::Client,
source_table: &str,
source_user_id_column: &str,
source_item_id_column: &str,
batch_size: usize,
last_user_id: Option<i32>,
last_item_id: Option<i32>,
) -> Result<Vec<(i32, i32)>> {
let rows = if let (Some(last_user), Some(last_item)) = (last_user_id, last_item_id) {
let query = format!(
"SELECT ({user})::int4, ({item})::int4 FROM {table} \
WHERE ({user}, {item}) > ($1::int4, $2::int4) \
ORDER BY {user}, {item} \
LIMIT $3::bigint",
"COPY {table} ({user}, {item}) TO STDOUT (FORMAT text, DELIMITER '\t')",
table = source_table,
user = source_user_id_column,
item = source_item_id_column,
table = source_table,
);
client
.query(&query, &[&last_user, &last_item, &(batch_size as i64)])
.await
.with_context(|| format!("Query failed at {}:{}", file!(), line!()))?
} else {
let query = format!(
"SELECT ({user})::int4, ({item})::int4 FROM {table} \
ORDER BY {user}, {item} \
LIMIT $1::bigint",
user = source_user_id_column,
item = source_item_id_column,
table = source_table,
);
client
.query(&query, &[&(batch_size as i64)])
.await
.with_context(|| format!("Query failed at {}:{}", file!(), line!()))?
};
let mut batch = Vec::with_capacity(rows.len());
for row in rows {
let user_id: i32 = row.get(0);
let item_id: i32 = row.get(1);
batch.push((user_id, item_id));
let mut data = Vec::new();
let mut copy_out = Box::pin(client.copy_out(&query).await?);
while let Some(bytes) = copy_out.as_mut().next().await {
let bytes = bytes?;
let row = String::from_utf8(bytes.to_vec())?;
let parts: Vec<&str> = row.trim().split('\t').collect();
if parts.len() == 2 {
let user_id: i32 = parts[0].parse()?;
let item_id: i32 = parts[1].parse()?;
data.push((user_id, item_id));
}
}
Ok(batch)
Ok(data)
}
async fn save_embeddings(
@@ -324,9 +307,6 @@ async fn main() -> Result<()> {
info!("Schema validation successful");
let mut matrix = Matrix::new();
let mut last_user_id: Option<i32> = None;
let mut last_item_id: Option<i32> = None;
let mut total_rows = 0;
let mut item_ids = HashSet::new();
// Set up graceful shutdown for data loading
@@ -343,56 +323,42 @@ async fn main() -> Result<()> {
})?;
info!("Starting data loading...");
while running.load(Ordering::SeqCst) {
let batch = load_data_batch(
let data = load_data(
&pool.get().await?,
&args.source_table,
&args.source_user_id_column,
&args.source_item_id_column,
args.batch_size as usize,
last_user_id,
last_item_id,
)
.await?;
if batch.is_empty() {
break;
if data.is_empty() {
info!("No data found in source table");
return Ok(());
}
let data_len = data.len();
// Check if we would exceed max_interactions
if let Some(max) = args.max_interactions {
if total_rows + batch.len() > max {
// Only process up to max_interactions
let remaining = max - total_rows;
for (user_id, item_id) in batch.into_iter().take(remaining) {
matrix.push(user_id, item_id, 1.0f32);
item_ids.insert(item_id);
}
total_rows += remaining;
info!("Reached maximum interactions limit of {max}",);
break;
}
}
total_rows += batch.len();
let total_rows = if let Some(max) = args.max_interactions {
let max = max.min(data_len);
info!(
"Loaded batch of {} rows (total: {})",
batch.len(),
total_rows
"Loading {} interactions (limited by --max-interactions)",
max
);
// Update last seen IDs for next batch
if let Some((user_id, item_id)) = batch.last() {
last_user_id = Some(*user_id);
last_item_id = Some(*item_id);
}
// Process batch
for (user_id, item_id) in batch {
// Only process up to max_interactions
for (user_id, item_id) in data.into_iter().take(max) {
matrix.push(user_id, item_id, 1.0f32);
item_ids.insert(item_id);
}
max
} else {
// Process all data
for (user_id, item_id) in data {
matrix.push(user_id, item_id, 1.0f32);
item_ids.insert(item_id);
}
data_len
};
info!(
"Loaded {} total rows with {} unique items",
@@ -400,11 +366,6 @@ async fn main() -> Result<()> {
item_ids.len()
);
if total_rows == 0 {
info!("No data found in source table");
return Ok(());
}
// Switch to immediate exit mode for training
IMMEDIATE_EXIT.store(true, Ordering::SeqCst);