use COPY for importing data
This commit is contained in:
36
Cargo.lock
generated
36
Cargo.lock
generated
@@ -703,6 +703,21 @@ dependencies = [
|
|||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures"
|
||||||
|
version = "0.3.31"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
|
||||||
|
dependencies = [
|
||||||
|
"futures-channel",
|
||||||
|
"futures-core",
|
||||||
|
"futures-executor",
|
||||||
|
"futures-io",
|
||||||
|
"futures-sink",
|
||||||
|
"futures-task",
|
||||||
|
"futures-util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-channel"
|
name = "futures-channel"
|
||||||
version = "0.3.31"
|
version = "0.3.31"
|
||||||
@@ -719,6 +734,23 @@ version = "0.3.31"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
|
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-executor"
|
||||||
|
version = "0.3.31"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
|
||||||
|
dependencies = [
|
||||||
|
"futures-core",
|
||||||
|
"futures-task",
|
||||||
|
"futures-util",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-io"
|
||||||
|
version = "0.3.31"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-macro"
|
name = "futures-macro"
|
||||||
version = "0.3.31"
|
version = "0.3.31"
|
||||||
@@ -748,10 +780,13 @@ version = "0.3.31"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
|
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"futures-channel",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
"futures-io",
|
||||||
"futures-macro",
|
"futures-macro",
|
||||||
"futures-sink",
|
"futures-sink",
|
||||||
"futures-task",
|
"futures-task",
|
||||||
|
"memchr",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"pin-utils",
|
"pin-utils",
|
||||||
"slab",
|
"slab",
|
||||||
@@ -1179,6 +1214,7 @@ dependencies = [
|
|||||||
"ctrlc",
|
"ctrlc",
|
||||||
"deadpool-postgres",
|
"deadpool-postgres",
|
||||||
"dotenv",
|
"dotenv",
|
||||||
|
"futures",
|
||||||
"libmf",
|
"libmf",
|
||||||
"log",
|
"log",
|
||||||
"ndarray",
|
"ndarray",
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ clap = { version = "4.4", features = ["derive"] }
|
|||||||
ctrlc = "3.4"
|
ctrlc = "3.4"
|
||||||
deadpool-postgres = "0.11"
|
deadpool-postgres = "0.11"
|
||||||
dotenv = "0.15"
|
dotenv = "0.15"
|
||||||
|
futures = "0.3"
|
||||||
libmf = "0.3"
|
libmf = "0.3"
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
ndarray = { version = "0.15", features = ["blas"] }
|
ndarray = { version = "0.15", features = ["blas"] }
|
||||||
|
|||||||
137
src/main.rs
137
src/main.rs
@@ -2,6 +2,7 @@ use anyhow::{Context, Result};
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use deadpool_postgres::{Config, Pool, Runtime};
|
use deadpool_postgres::{Config, Pool, Runtime};
|
||||||
use dotenv::dotenv;
|
use dotenv::dotenv;
|
||||||
|
use futures::StreamExt;
|
||||||
use libmf::{Loss, Matrix, Model};
|
use libmf::{Loss, Matrix, Model};
|
||||||
use log::info;
|
use log::info;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
@@ -92,52 +93,34 @@ async fn create_pool() -> Result<Pool> {
|
|||||||
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn load_data_batch(
|
async fn load_data(
|
||||||
client: &deadpool_postgres::Client,
|
client: &deadpool_postgres::Client,
|
||||||
source_table: &str,
|
source_table: &str,
|
||||||
source_user_id_column: &str,
|
source_user_id_column: &str,
|
||||||
source_item_id_column: &str,
|
source_item_id_column: &str,
|
||||||
batch_size: usize,
|
|
||||||
last_user_id: Option<i32>,
|
|
||||||
last_item_id: Option<i32>,
|
|
||||||
) -> Result<Vec<(i32, i32)>> {
|
) -> Result<Vec<(i32, i32)>> {
|
||||||
let rows = if let (Some(last_user), Some(last_item)) = (last_user_id, last_item_id) {
|
let query = format!(
|
||||||
let query = format!(
|
"COPY {table} ({user}, {item}) TO STDOUT (FORMAT text, DELIMITER '\t')",
|
||||||
"SELECT ({user})::int4, ({item})::int4 FROM {table} \
|
table = source_table,
|
||||||
WHERE ({user}, {item}) > ($1::int4, $2::int4) \
|
user = source_user_id_column,
|
||||||
ORDER BY {user}, {item} \
|
item = source_item_id_column,
|
||||||
LIMIT $3::bigint",
|
);
|
||||||
user = source_user_id_column,
|
|
||||||
item = source_item_id_column,
|
|
||||||
table = source_table,
|
|
||||||
);
|
|
||||||
client
|
|
||||||
.query(&query, &[&last_user, &last_item, &(batch_size as i64)])
|
|
||||||
.await
|
|
||||||
.with_context(|| format!("Query failed at {}:{}", file!(), line!()))?
|
|
||||||
} else {
|
|
||||||
let query = format!(
|
|
||||||
"SELECT ({user})::int4, ({item})::int4 FROM {table} \
|
|
||||||
ORDER BY {user}, {item} \
|
|
||||||
LIMIT $1::bigint",
|
|
||||||
user = source_user_id_column,
|
|
||||||
item = source_item_id_column,
|
|
||||||
table = source_table,
|
|
||||||
);
|
|
||||||
client
|
|
||||||
.query(&query, &[&(batch_size as i64)])
|
|
||||||
.await
|
|
||||||
.with_context(|| format!("Query failed at {}:{}", file!(), line!()))?
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut batch = Vec::with_capacity(rows.len());
|
let mut data = Vec::new();
|
||||||
for row in rows {
|
let mut copy_out = Box::pin(client.copy_out(&query).await?);
|
||||||
let user_id: i32 = row.get(0);
|
|
||||||
let item_id: i32 = row.get(1);
|
while let Some(bytes) = copy_out.as_mut().next().await {
|
||||||
batch.push((user_id, item_id));
|
let bytes = bytes?;
|
||||||
|
let row = String::from_utf8(bytes.to_vec())?;
|
||||||
|
let parts: Vec<&str> = row.trim().split('\t').collect();
|
||||||
|
if parts.len() == 2 {
|
||||||
|
let user_id: i32 = parts[0].parse()?;
|
||||||
|
let item_id: i32 = parts[1].parse()?;
|
||||||
|
data.push((user_id, item_id));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(batch)
|
Ok(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_embeddings(
|
async fn save_embeddings(
|
||||||
@@ -324,9 +307,6 @@ async fn main() -> Result<()> {
|
|||||||
info!("Schema validation successful");
|
info!("Schema validation successful");
|
||||||
|
|
||||||
let mut matrix = Matrix::new();
|
let mut matrix = Matrix::new();
|
||||||
let mut last_user_id: Option<i32> = None;
|
|
||||||
let mut last_item_id: Option<i32> = None;
|
|
||||||
let mut total_rows = 0;
|
|
||||||
let mut item_ids = HashSet::new();
|
let mut item_ids = HashSet::new();
|
||||||
|
|
||||||
// Set up graceful shutdown for data loading
|
// Set up graceful shutdown for data loading
|
||||||
@@ -343,56 +323,42 @@ async fn main() -> Result<()> {
|
|||||||
})?;
|
})?;
|
||||||
|
|
||||||
info!("Starting data loading...");
|
info!("Starting data loading...");
|
||||||
while running.load(Ordering::SeqCst) {
|
let data = load_data(
|
||||||
let batch = load_data_batch(
|
&pool.get().await?,
|
||||||
&pool.get().await?,
|
&args.source_table,
|
||||||
&args.source_table,
|
&args.source_user_id_column,
|
||||||
&args.source_user_id_column,
|
&args.source_item_id_column,
|
||||||
&args.source_item_id_column,
|
)
|
||||||
args.batch_size as usize,
|
.await?;
|
||||||
last_user_id,
|
|
||||||
last_item_id,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if batch.is_empty() {
|
if data.is_empty() {
|
||||||
break;
|
info!("No data found in source table");
|
||||||
}
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
// Check if we would exceed max_interactions
|
let data_len = data.len();
|
||||||
if let Some(max) = args.max_interactions {
|
// Check if we would exceed max_interactions
|
||||||
if total_rows + batch.len() > max {
|
let total_rows = if let Some(max) = args.max_interactions {
|
||||||
// Only process up to max_interactions
|
let max = max.min(data_len);
|
||||||
let remaining = max - total_rows;
|
|
||||||
for (user_id, item_id) in batch.into_iter().take(remaining) {
|
|
||||||
matrix.push(user_id, item_id, 1.0f32);
|
|
||||||
item_ids.insert(item_id);
|
|
||||||
}
|
|
||||||
total_rows += remaining;
|
|
||||||
info!("Reached maximum interactions limit of {max}",);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
total_rows += batch.len();
|
|
||||||
info!(
|
info!(
|
||||||
"Loaded batch of {} rows (total: {})",
|
"Loading {} interactions (limited by --max-interactions)",
|
||||||
batch.len(),
|
max
|
||||||
total_rows
|
|
||||||
);
|
);
|
||||||
|
|
||||||
// Update last seen IDs for next batch
|
// Only process up to max_interactions
|
||||||
if let Some((user_id, item_id)) = batch.last() {
|
for (user_id, item_id) in data.into_iter().take(max) {
|
||||||
last_user_id = Some(*user_id);
|
|
||||||
last_item_id = Some(*item_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process batch
|
|
||||||
for (user_id, item_id) in batch {
|
|
||||||
matrix.push(user_id, item_id, 1.0f32);
|
matrix.push(user_id, item_id, 1.0f32);
|
||||||
item_ids.insert(item_id);
|
item_ids.insert(item_id);
|
||||||
}
|
}
|
||||||
}
|
max
|
||||||
|
} else {
|
||||||
|
// Process all data
|
||||||
|
for (user_id, item_id) in data {
|
||||||
|
matrix.push(user_id, item_id, 1.0f32);
|
||||||
|
item_ids.insert(item_id);
|
||||||
|
}
|
||||||
|
data_len
|
||||||
|
};
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Loaded {} total rows with {} unique items",
|
"Loaded {} total rows with {} unique items",
|
||||||
@@ -400,11 +366,6 @@ async fn main() -> Result<()> {
|
|||||||
item_ids.len()
|
item_ids.len()
|
||||||
);
|
);
|
||||||
|
|
||||||
if total_rows == 0 {
|
|
||||||
info!("No data found in source table");
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Switch to immediate exit mode for training
|
// Switch to immediate exit mode for training
|
||||||
IMMEDIATE_EXIT.store(true, Ordering::SeqCst);
|
IMMEDIATE_EXIT.store(true, Ordering::SeqCst);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user