add progress bar
This commit is contained in:
61
Cargo.lock
generated
61
Cargo.lock
generated
@@ -359,6 +359,19 @@ version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990"
|
||||
|
||||
[[package]]
|
||||
name = "console"
|
||||
version = "0.15.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b"
|
||||
dependencies = [
|
||||
"encode_unicode",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"unicode-width",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.4"
|
||||
@@ -607,6 +620,12 @@ version = "1.0.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125"
|
||||
|
||||
[[package]]
|
||||
name = "encode_unicode"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0"
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.10.2"
|
||||
@@ -1052,6 +1071,19 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indicatif"
|
||||
version = "0.17.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281"
|
||||
dependencies = [
|
||||
"console",
|
||||
"number_prefix",
|
||||
"portable-atomic",
|
||||
"unicode-width",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "is-terminal"
|
||||
version = "0.4.13"
|
||||
@@ -1216,6 +1248,7 @@ dependencies = [
|
||||
"deadpool-postgres",
|
||||
"dotenv",
|
||||
"futures",
|
||||
"indicatif",
|
||||
"libmf",
|
||||
"log",
|
||||
"ndarray",
|
||||
@@ -1388,6 +1421,12 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "number_prefix"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
|
||||
|
||||
[[package]]
|
||||
name = "object"
|
||||
version = "0.36.7"
|
||||
@@ -1590,6 +1629,12 @@ dependencies = [
|
||||
"zip",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "portable-atomic"
|
||||
version = "1.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6"
|
||||
|
||||
[[package]]
|
||||
name = "postgres-protocol"
|
||||
version = "0.6.7"
|
||||
@@ -2322,6 +2367,12 @@ version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-width"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd"
|
||||
|
||||
[[package]]
|
||||
name = "ureq"
|
||||
version = "2.12.1"
|
||||
@@ -2460,6 +2511,16 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "web-time"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
|
||||
dependencies = [
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "whoami"
|
||||
version = "1.5.2"
|
||||
|
||||
@@ -22,3 +22,4 @@ rand_distr = "0.4"
|
||||
tokio = { version = "1.35", features = ["full"] }
|
||||
tokio-postgres = "0.7"
|
||||
bytes = "1.5.0"
|
||||
indicatif = "0.17"
|
||||
|
||||
@@ -3,6 +3,7 @@ use clap::Parser;
|
||||
use deadpool_postgres::{Config, Pool, Runtime};
|
||||
use dotenv::dotenv;
|
||||
use futures::StreamExt;
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use libmf::{Loss, Matrix, Model};
|
||||
use log::info;
|
||||
use std::collections::HashSet;
|
||||
@@ -77,6 +78,15 @@ async fn load_data(
|
||||
source_item_id_column: &str,
|
||||
max_interactions: Option<usize>,
|
||||
) -> Result<Vec<(i32, i32)>> {
|
||||
// First, get the total count for progress bar
|
||||
info!("Counting rows to load...");
|
||||
let count_query = if let Some(max) = max_interactions {
|
||||
format!("SELECT LEAST(COUNT(*), {}) FROM {}", max, source_table)
|
||||
} else {
|
||||
format!("SELECT COUNT(*) FROM {}", source_table)
|
||||
};
|
||||
let total_rows: i64 = client.query_one(&count_query, &[]).await?.get(0);
|
||||
|
||||
let types = get_column_types(
|
||||
client,
|
||||
source_table,
|
||||
@@ -84,6 +94,12 @@ async fn load_data(
|
||||
)
|
||||
.await?;
|
||||
|
||||
let pb = ProgressBar::new(total_rows as u64);
|
||||
pb.set_style(ProgressStyle::default_bar()
|
||||
.template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} rows ({per_sec}, {eta})")?
|
||||
.progress_chars("#>-"));
|
||||
pb.set_message("Loading data...");
|
||||
|
||||
let query = if let Some(max) = max_interactions {
|
||||
format!(
|
||||
"COPY (
|
||||
@@ -111,7 +127,7 @@ async fn load_data(
|
||||
));
|
||||
let mut stream = stream.as_mut();
|
||||
|
||||
let mut data = Vec::new();
|
||||
let mut data = Vec::with_capacity(total_rows as usize);
|
||||
while let Some(row) = stream.as_mut().next().await {
|
||||
let row = row?;
|
||||
|
||||
@@ -129,8 +145,10 @@ async fn load_data(
|
||||
};
|
||||
|
||||
data.push((user_id, item_id));
|
||||
pb.inc(1);
|
||||
}
|
||||
|
||||
pb.finish_with_message("Data loading complete");
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
@@ -188,11 +206,19 @@ async fn save_embeddings(
|
||||
));
|
||||
info!("Binary writer initialized");
|
||||
|
||||
// Add progress bar for embedding processing
|
||||
let pb = ProgressBar::new(model.q_iter().len() as u64);
|
||||
pb.set_style(ProgressStyle::default_bar()
|
||||
.template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} embeddings ({per_sec}, {eta})")?
|
||||
.progress_chars("#>-"));
|
||||
pb.set_message("Processing embeddings...");
|
||||
|
||||
let mut valid_embeddings = 0;
|
||||
let mut invalid_embeddings = 0;
|
||||
|
||||
// Process factors for items that appear in the source table
|
||||
for (idx, factors) in model.q_iter().enumerate() {
|
||||
pb.inc(1);
|
||||
let item_id = idx as i32;
|
||||
if !item_ids.contains(&item_id) {
|
||||
continue;
|
||||
@@ -213,6 +239,11 @@ async fn save_embeddings(
|
||||
}
|
||||
}
|
||||
|
||||
pb.finish_with_message(format!(
|
||||
"Processed {} valid embeddings, skipped {} invalid embeddings",
|
||||
valid_embeddings, invalid_embeddings
|
||||
));
|
||||
|
||||
info!("Finishing COPY operation...");
|
||||
writer.as_mut().finish().await?;
|
||||
info!("COPY operation completed");
|
||||
@@ -246,12 +277,12 @@ async fn save_embeddings(
|
||||
};
|
||||
|
||||
// Set table as UNLOGGED
|
||||
// info!("Setting table as UNLOGGED...");
|
||||
// tx.execute(
|
||||
// &format!("ALTER TABLE {} SET UNLOGGED", args.target_table),
|
||||
// &[],
|
||||
// )
|
||||
// .await?;
|
||||
info!("Setting table as UNLOGGED...");
|
||||
tx.execute(
|
||||
&format!("ALTER TABLE {} SET UNLOGGED", args.target_table),
|
||||
&[],
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Insert from temp table with ON CONFLICT DO UPDATE
|
||||
info!("Upserting from temp table to target table...");
|
||||
@@ -269,12 +300,12 @@ async fn save_embeddings(
|
||||
info!("Upsert completed");
|
||||
|
||||
// Set table back to LOGGED
|
||||
// info!("Setting table back to LOGGED...");
|
||||
// tx.execute(
|
||||
// &format!("ALTER TABLE {} SET LOGGED", args.target_table),
|
||||
// &[],
|
||||
// )
|
||||
// .await?;
|
||||
info!("Setting table back to LOGGED...");
|
||||
tx.execute(
|
||||
&format!("ALTER TABLE {} SET LOGGED", args.target_table),
|
||||
&[],
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Recreate index if we had one
|
||||
if let Some(sql) = index_sql {
|
||||
|
||||
Reference in New Issue
Block a user