add progress bar

This commit is contained in:
Dylan Knutson
2024-12-28 20:51:11 +00:00
parent b3ba58723c
commit b255f40ac7
3 changed files with 106 additions and 13 deletions

61
Cargo.lock generated
View File

@@ -359,6 +359,19 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990"
[[package]]
name = "console"
version = "0.15.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b"
dependencies = [
"encode_unicode",
"libc",
"once_cell",
"unicode-width",
"windows-sys 0.59.0",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
@@ -607,6 +620,12 @@ version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125"
[[package]]
name = "encode_unicode"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0"
[[package]]
name = "env_logger"
version = "0.10.2"
@@ -1052,6 +1071,19 @@ dependencies = [
"serde",
]
[[package]]
name = "indicatif"
version = "0.17.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281"
dependencies = [
"console",
"number_prefix",
"portable-atomic",
"unicode-width",
"web-time",
]
[[package]]
name = "is-terminal"
version = "0.4.13"
@@ -1216,6 +1248,7 @@ dependencies = [
"deadpool-postgres",
"dotenv",
"futures",
"indicatif",
"libmf",
"log",
"ndarray",
@@ -1388,6 +1421,12 @@ dependencies = [
"libc",
]
[[package]]
name = "number_prefix"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
[[package]]
name = "object"
version = "0.36.7"
@@ -1590,6 +1629,12 @@ dependencies = [
"zip",
]
[[package]]
name = "portable-atomic"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6"
[[package]]
name = "postgres-protocol"
version = "0.6.7"
@@ -2322,6 +2367,12 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0"
[[package]]
name = "unicode-width"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd"
[[package]]
name = "ureq"
version = "2.12.1"
@@ -2460,6 +2511,16 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "whoami"
version = "1.5.2"

View File

@@ -22,3 +22,4 @@ rand_distr = "0.4"
tokio = { version = "1.35", features = ["full"] }
tokio-postgres = "0.7"
bytes = "1.5.0"
indicatif = "0.17"

View File

@@ -3,6 +3,7 @@ use clap::Parser;
use deadpool_postgres::{Config, Pool, Runtime};
use dotenv::dotenv;
use futures::StreamExt;
use indicatif::{ProgressBar, ProgressStyle};
use libmf::{Loss, Matrix, Model};
use log::info;
use std::collections::HashSet;
@@ -77,6 +78,15 @@ async fn load_data(
source_item_id_column: &str,
max_interactions: Option<usize>,
) -> Result<Vec<(i32, i32)>> {
// First, get the total count for progress bar
info!("Counting rows to load...");
let count_query = if let Some(max) = max_interactions {
format!("SELECT LEAST(COUNT(*), {}) FROM {}", max, source_table)
} else {
format!("SELECT COUNT(*) FROM {}", source_table)
};
let total_rows: i64 = client.query_one(&count_query, &[]).await?.get(0);
let types = get_column_types(
client,
source_table,
@@ -84,6 +94,12 @@ async fn load_data(
)
.await?;
let pb = ProgressBar::new(total_rows as u64);
pb.set_style(ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} rows ({per_sec}, {eta})")?
.progress_chars("#>-"));
pb.set_message("Loading data...");
let query = if let Some(max) = max_interactions {
format!(
"COPY (
@@ -111,7 +127,7 @@ async fn load_data(
));
let mut stream = stream.as_mut();
let mut data = Vec::new();
let mut data = Vec::with_capacity(total_rows as usize);
while let Some(row) = stream.as_mut().next().await {
let row = row?;
@@ -129,8 +145,10 @@ async fn load_data(
};
data.push((user_id, item_id));
pb.inc(1);
}
pb.finish_with_message("Data loading complete");
Ok(data)
}
@@ -188,11 +206,19 @@ async fn save_embeddings(
));
info!("Binary writer initialized");
// Add progress bar for embedding processing
let pb = ProgressBar::new(model.q_iter().len() as u64);
pb.set_style(ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} embeddings ({per_sec}, {eta})")?
.progress_chars("#>-"));
pb.set_message("Processing embeddings...");
let mut valid_embeddings = 0;
let mut invalid_embeddings = 0;
// Process factors for items that appear in the source table
for (idx, factors) in model.q_iter().enumerate() {
pb.inc(1);
let item_id = idx as i32;
if !item_ids.contains(&item_id) {
continue;
@@ -213,6 +239,11 @@ async fn save_embeddings(
}
}
pb.finish_with_message(format!(
"Processed {} valid embeddings, skipped {} invalid embeddings",
valid_embeddings, invalid_embeddings
));
info!("Finishing COPY operation...");
writer.as_mut().finish().await?;
info!("COPY operation completed");
@@ -246,12 +277,12 @@ async fn save_embeddings(
};
// Set table as UNLOGGED
// info!("Setting table as UNLOGGED...");
// tx.execute(
// &format!("ALTER TABLE {} SET UNLOGGED", args.target_table),
// &[],
// )
// .await?;
info!("Setting table as UNLOGGED...");
tx.execute(
&format!("ALTER TABLE {} SET UNLOGGED", args.target_table),
&[],
)
.await?;
// Insert from temp table with ON CONFLICT DO UPDATE
info!("Upserting from temp table to target table...");
@@ -269,12 +300,12 @@ async fn save_embeddings(
info!("Upsert completed");
// Set table back to LOGGED
// info!("Setting table back to LOGGED...");
// tx.execute(
// &format!("ALTER TABLE {} SET LOGGED", args.target_table),
// &[],
// )
// .await?;
info!("Setting table back to LOGGED...");
tx.execute(
&format!("ALTER TABLE {} SET LOGGED", args.target_table),
&[],
)
.await?;
// Recreate index if we had one
if let Some(sql) = index_sql {