use COPY for importing data
This commit is contained in:
36
Cargo.lock
generated
36
Cargo.lock
generated
@@ -703,6 +703,21 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-executor",
|
||||
"futures-io",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.31"
|
||||
@@ -719,6 +734,23 @@ version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
|
||||
|
||||
[[package]]
|
||||
name = "futures-executor"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.31"
|
||||
@@ -748,10 +780,13 @@ version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-macro",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"memchr",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"slab",
|
||||
@@ -1179,6 +1214,7 @@ dependencies = [
|
||||
"ctrlc",
|
||||
"deadpool-postgres",
|
||||
"dotenv",
|
||||
"futures",
|
||||
"libmf",
|
||||
"log",
|
||||
"ndarray",
|
||||
|
||||
@@ -9,6 +9,7 @@ clap = { version = "4.4", features = ["derive"] }
|
||||
ctrlc = "3.4"
|
||||
deadpool-postgres = "0.11"
|
||||
dotenv = "0.15"
|
||||
futures = "0.3"
|
||||
libmf = "0.3"
|
||||
log = "0.4"
|
||||
ndarray = { version = "0.15", features = ["blas"] }
|
||||
|
||||
111
src/main.rs
111
src/main.rs
@@ -2,6 +2,7 @@ use anyhow::{Context, Result};
|
||||
use clap::Parser;
|
||||
use deadpool_postgres::{Config, Pool, Runtime};
|
||||
use dotenv::dotenv;
|
||||
use futures::StreamExt;
|
||||
use libmf::{Loss, Matrix, Model};
|
||||
use log::info;
|
||||
use std::collections::HashSet;
|
||||
@@ -92,52 +93,34 @@ async fn create_pool() -> Result<Pool> {
|
||||
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
||||
}
|
||||
|
||||
async fn load_data_batch(
|
||||
async fn load_data(
|
||||
client: &deadpool_postgres::Client,
|
||||
source_table: &str,
|
||||
source_user_id_column: &str,
|
||||
source_item_id_column: &str,
|
||||
batch_size: usize,
|
||||
last_user_id: Option<i32>,
|
||||
last_item_id: Option<i32>,
|
||||
) -> Result<Vec<(i32, i32)>> {
|
||||
let rows = if let (Some(last_user), Some(last_item)) = (last_user_id, last_item_id) {
|
||||
let query = format!(
|
||||
"SELECT ({user})::int4, ({item})::int4 FROM {table} \
|
||||
WHERE ({user}, {item}) > ($1::int4, $2::int4) \
|
||||
ORDER BY {user}, {item} \
|
||||
LIMIT $3::bigint",
|
||||
"COPY {table} ({user}, {item}) TO STDOUT (FORMAT text, DELIMITER '\t')",
|
||||
table = source_table,
|
||||
user = source_user_id_column,
|
||||
item = source_item_id_column,
|
||||
table = source_table,
|
||||
);
|
||||
client
|
||||
.query(&query, &[&last_user, &last_item, &(batch_size as i64)])
|
||||
.await
|
||||
.with_context(|| format!("Query failed at {}:{}", file!(), line!()))?
|
||||
} else {
|
||||
let query = format!(
|
||||
"SELECT ({user})::int4, ({item})::int4 FROM {table} \
|
||||
ORDER BY {user}, {item} \
|
||||
LIMIT $1::bigint",
|
||||
user = source_user_id_column,
|
||||
item = source_item_id_column,
|
||||
table = source_table,
|
||||
);
|
||||
client
|
||||
.query(&query, &[&(batch_size as i64)])
|
||||
.await
|
||||
.with_context(|| format!("Query failed at {}:{}", file!(), line!()))?
|
||||
};
|
||||
|
||||
let mut batch = Vec::with_capacity(rows.len());
|
||||
for row in rows {
|
||||
let user_id: i32 = row.get(0);
|
||||
let item_id: i32 = row.get(1);
|
||||
batch.push((user_id, item_id));
|
||||
let mut data = Vec::new();
|
||||
let mut copy_out = Box::pin(client.copy_out(&query).await?);
|
||||
|
||||
while let Some(bytes) = copy_out.as_mut().next().await {
|
||||
let bytes = bytes?;
|
||||
let row = String::from_utf8(bytes.to_vec())?;
|
||||
let parts: Vec<&str> = row.trim().split('\t').collect();
|
||||
if parts.len() == 2 {
|
||||
let user_id: i32 = parts[0].parse()?;
|
||||
let item_id: i32 = parts[1].parse()?;
|
||||
data.push((user_id, item_id));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(batch)
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
async fn save_embeddings(
|
||||
@@ -324,9 +307,6 @@ async fn main() -> Result<()> {
|
||||
info!("Schema validation successful");
|
||||
|
||||
let mut matrix = Matrix::new();
|
||||
let mut last_user_id: Option<i32> = None;
|
||||
let mut last_item_id: Option<i32> = None;
|
||||
let mut total_rows = 0;
|
||||
let mut item_ids = HashSet::new();
|
||||
|
||||
// Set up graceful shutdown for data loading
|
||||
@@ -343,56 +323,42 @@ async fn main() -> Result<()> {
|
||||
})?;
|
||||
|
||||
info!("Starting data loading...");
|
||||
while running.load(Ordering::SeqCst) {
|
||||
let batch = load_data_batch(
|
||||
let data = load_data(
|
||||
&pool.get().await?,
|
||||
&args.source_table,
|
||||
&args.source_user_id_column,
|
||||
&args.source_item_id_column,
|
||||
args.batch_size as usize,
|
||||
last_user_id,
|
||||
last_item_id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if batch.is_empty() {
|
||||
break;
|
||||
if data.is_empty() {
|
||||
info!("No data found in source table");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let data_len = data.len();
|
||||
// Check if we would exceed max_interactions
|
||||
if let Some(max) = args.max_interactions {
|
||||
if total_rows + batch.len() > max {
|
||||
// Only process up to max_interactions
|
||||
let remaining = max - total_rows;
|
||||
for (user_id, item_id) in batch.into_iter().take(remaining) {
|
||||
matrix.push(user_id, item_id, 1.0f32);
|
||||
item_ids.insert(item_id);
|
||||
}
|
||||
total_rows += remaining;
|
||||
info!("Reached maximum interactions limit of {max}",);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
total_rows += batch.len();
|
||||
let total_rows = if let Some(max) = args.max_interactions {
|
||||
let max = max.min(data_len);
|
||||
info!(
|
||||
"Loaded batch of {} rows (total: {})",
|
||||
batch.len(),
|
||||
total_rows
|
||||
"Loading {} interactions (limited by --max-interactions)",
|
||||
max
|
||||
);
|
||||
|
||||
// Update last seen IDs for next batch
|
||||
if let Some((user_id, item_id)) = batch.last() {
|
||||
last_user_id = Some(*user_id);
|
||||
last_item_id = Some(*item_id);
|
||||
}
|
||||
|
||||
// Process batch
|
||||
for (user_id, item_id) in batch {
|
||||
// Only process up to max_interactions
|
||||
for (user_id, item_id) in data.into_iter().take(max) {
|
||||
matrix.push(user_id, item_id, 1.0f32);
|
||||
item_ids.insert(item_id);
|
||||
}
|
||||
max
|
||||
} else {
|
||||
// Process all data
|
||||
for (user_id, item_id) in data {
|
||||
matrix.push(user_id, item_id, 1.0f32);
|
||||
item_ids.insert(item_id);
|
||||
}
|
||||
data_len
|
||||
};
|
||||
|
||||
info!(
|
||||
"Loaded {} total rows with {} unique items",
|
||||
@@ -400,11 +366,6 @@ async fn main() -> Result<()> {
|
||||
item_ids.len()
|
||||
);
|
||||
|
||||
if total_rows == 0 {
|
||||
info!("No data found in source table");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Switch to immediate exit mode for training
|
||||
IMMEDIATE_EXIT.store(true, Ordering::SeqCst);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user