diff --git a/Cargo.lock b/Cargo.lock index 77593db..8da6c01 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -359,6 +359,19 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +[[package]] +name = "console" +version = "0.15.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width", + "windows-sys 0.59.0", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -607,6 +620,12 @@ version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "env_logger" version = "0.10.2" @@ -1052,6 +1071,19 @@ dependencies = [ "serde", ] +[[package]] +name = "indicatif" +version = "0.17.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281" +dependencies = [ + "console", + "number_prefix", + "portable-atomic", + "unicode-width", + "web-time", +] + [[package]] name = "is-terminal" version = "0.4.13" @@ -1216,6 +1248,7 @@ dependencies = [ "deadpool-postgres", "dotenv", "futures", + "indicatif", "libmf", "log", "ndarray", @@ -1388,6 +1421,12 @@ dependencies = [ "libc", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "object" version = "0.36.7" @@ -1590,6 +1629,12 @@ dependencies = [ "zip", ] +[[package]] +name = "portable-atomic" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" + [[package]] name = "postgres-protocol" version = "0.6.7" @@ -2322,6 +2367,12 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" +[[package]] +name = "unicode-width" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" + [[package]] name = "ureq" version = "2.12.1" @@ -2460,6 +2511,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "whoami" version = "1.5.2" diff --git a/Cargo.toml b/Cargo.toml index 4f764c9..a18a502 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,3 +22,4 @@ rand_distr = "0.4" tokio = { version = "1.35", features = ["full"] } tokio-postgres = "0.7" bytes = "1.5.0" +indicatif = "0.17" diff --git a/src/bin/fit_model.rs b/src/bin/fit_model.rs index fb0a5e2..8bde7b9 100644 --- a/src/bin/fit_model.rs +++ b/src/bin/fit_model.rs @@ -3,6 +3,7 @@ use clap::Parser; use deadpool_postgres::{Config, Pool, Runtime}; use dotenv::dotenv; use futures::StreamExt; +use indicatif::{ProgressBar, ProgressStyle}; use libmf::{Loss, Matrix, Model}; use log::info; use std::collections::HashSet; @@ -77,6 +78,15 @@ async fn load_data( source_item_id_column: &str, max_interactions: Option, ) -> Result> { + // First, get the total count for progress bar + info!("Counting rows to load..."); + let count_query = if let Some(max) = max_interactions { + format!("SELECT LEAST(COUNT(*), {}) FROM {}", max, source_table) + } else { + format!("SELECT COUNT(*) FROM {}", source_table) + }; + let total_rows: i64 = client.query_one(&count_query, &[]).await?.get(0); + let types = get_column_types( client, source_table, @@ -84,6 +94,12 @@ async fn load_data( ) .await?; + let pb = ProgressBar::new(total_rows as u64); + pb.set_style(ProgressStyle::default_bar() + .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} rows ({per_sec}, {eta})")? + .progress_chars("#>-")); + pb.set_message("Loading data..."); + let query = if let Some(max) = max_interactions { format!( "COPY ( @@ -111,7 +127,7 @@ async fn load_data( )); let mut stream = stream.as_mut(); - let mut data = Vec::new(); + let mut data = Vec::with_capacity(total_rows as usize); while let Some(row) = stream.as_mut().next().await { let row = row?; @@ -129,8 +145,10 @@ async fn load_data( }; data.push((user_id, item_id)); + pb.inc(1); } + pb.finish_with_message("Data loading complete"); Ok(data) } @@ -188,11 +206,19 @@ async fn save_embeddings( )); info!("Binary writer initialized"); + // Add progress bar for embedding processing + let pb = ProgressBar::new(model.q_iter().len() as u64); + pb.set_style(ProgressStyle::default_bar() + .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} embeddings ({per_sec}, {eta})")? + .progress_chars("#>-")); + pb.set_message("Processing embeddings..."); + let mut valid_embeddings = 0; let mut invalid_embeddings = 0; // Process factors for items that appear in the source table for (idx, factors) in model.q_iter().enumerate() { + pb.inc(1); let item_id = idx as i32; if !item_ids.contains(&item_id) { continue; @@ -213,6 +239,11 @@ async fn save_embeddings( } } + pb.finish_with_message(format!( + "Processed {} valid embeddings, skipped {} invalid embeddings", + valid_embeddings, invalid_embeddings + )); + info!("Finishing COPY operation..."); writer.as_mut().finish().await?; info!("COPY operation completed"); @@ -246,12 +277,12 @@ async fn save_embeddings( }; // Set table as UNLOGGED - // info!("Setting table as UNLOGGED..."); - // tx.execute( - // &format!("ALTER TABLE {} SET UNLOGGED", args.target_table), - // &[], - // ) - // .await?; + info!("Setting table as UNLOGGED..."); + tx.execute( + &format!("ALTER TABLE {} SET UNLOGGED", args.target_table), + &[], + ) + .await?; // Insert from temp table with ON CONFLICT DO UPDATE info!("Upserting from temp table to target table..."); @@ -269,12 +300,12 @@ async fn save_embeddings( info!("Upsert completed"); // Set table back to LOGGED - // info!("Setting table back to LOGGED..."); - // tx.execute( - // &format!("ALTER TABLE {} SET LOGGED", args.target_table), - // &[], - // ) - // .await?; + info!("Setting table back to LOGGED..."); + tx.execute( + &format!("ALTER TABLE {} SET LOGGED", args.target_table), + &[], + ) + .await?; // Recreate index if we had one if let Some(sql) = index_sql {