From 0dadf2654c521d496d43e30d182db53936291132 Mon Sep 17 00:00:00 2001 From: Dylan Knutson Date: Sat, 28 Dec 2024 21:46:14 +0000 Subject: [PATCH] Update dependencies and enhance SQL formatting in fit_model - Added new dependencies: `colored`, `lazy_static`, `sqlformat`, and `unicode_categories` to improve code readability and SQL formatting capabilities. - Introduced a new module `format_sql` for better SQL query formatting in `fit_model.rs`. - Updated `fit_model.rs` to utilize the new `format_sql` function for logging SQL commands, enhancing clarity in database operations. - Adjusted the `Cargo.toml` and `Cargo.lock` files to reflect the new dependencies and their versions. These changes improve the maintainability and readability of the code, particularly in the context of SQL operations. --- Cargo.lock | 35 ++++++++++++++++++ Cargo.toml | 3 ++ src/bin/fit_model.rs | 86 +++++++++++++++++++++++++++++++++----------- src/format_sql.rs | 71 ++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + 5 files changed, 176 insertions(+), 20 deletions(-) create mode 100644 src/format_sql.rs diff --git a/Cargo.lock b/Cargo.lock index 8da6c01..8c764cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -359,6 +359,16 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +[[package]] +name = "colored" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c" +dependencies = [ + "lazy_static", + "windows-sys 0.59.0", +] + [[package]] name = "console" version = "0.15.10" @@ -1151,6 +1161,12 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "libc" version = "0.2.169" @@ -1244,6 +1260,7 @@ dependencies = [ "anyhow", "bytes", "clap", + "colored", "ctrlc", "deadpool-postgres", "dotenv", @@ -1258,6 +1275,8 @@ dependencies = [ "pretty_env_logger", "rand", "rand_distr", + "regex", + "sqlformat", "tokio", "tokio-postgres", ] @@ -2021,6 +2040,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "sqlformat" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bba3a93db0cc4f7bdece8bb09e77e2e785c20bfebf79eb8340ed80708048790" +dependencies = [ + "nom", + "unicode_categories", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -2373,6 +2402,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + [[package]] name = "ureq" version = "2.12.1" diff --git a/Cargo.toml b/Cargo.toml index a18a502..8791b09 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,3 +23,6 @@ tokio = { version = "1.35", features = ["full"] } tokio-postgres = "0.7" bytes = "1.5.0" indicatif = "0.17" +sqlformat = "0.2" +colored = "2.1" +regex = "1.10" diff --git a/src/bin/fit_model.rs b/src/bin/fit_model.rs index 4ff2a5c..c954f8c 100644 --- a/src/bin/fit_model.rs +++ b/src/bin/fit_model.rs @@ -12,6 +12,7 @@ use tokio_postgres::binary_copy::BinaryCopyInWriter; use tokio_postgres::{types::Type, NoTls}; use mf_fitter::fit_model_args::Args; +use mf_fitter::format_sql::format_sql; use mf_fitter::pg_types; use mf_fitter::pgvector::PgVector; @@ -29,6 +30,12 @@ async fn create_pool() -> Result { config.password = Some(env::var("POSTGRES_PASSWORD").context("POSTGRES_PASSWORD not set")?); config.application_name = Some("fit_model".to_string()); + // Set client_min_messages using connection parameters + let mut options = Vec::new(); + options.push("-c".to_string()); + options.push("client_min_messages=warning".to_string()); + config.options = Some(options.join(" ")); + Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?) } @@ -172,6 +179,20 @@ async fn save_embeddings( anyhow::bail!("Failed to get both column types from target table"); } + // Get the tablespace of the target table + let tablespace_query = r#" + SELECT t.spcname as tablespace + FROM pg_class c + LEFT JOIN pg_tablespace t ON c.reltablespace = t.oid + WHERE c.relname = $1 + AND c.relkind = 'r'"#; + + let tablespace = client + .query_one(tablespace_query, &[&args.target_table]) + .await? + .get::<_, Option>("tablespace") + .unwrap_or("pg_default".to_string()); + // Create a new table for the embeddings let new_table = format!("{}_new", args.target_table); info!("Creating new table {}...", new_table); @@ -182,27 +203,34 @@ async fn save_embeddings( _ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]), }; + // Drop the new table if it exists from a previous failed run + let drop_new_sql = format!("DROP TABLE IF EXISTS {}", new_table); + info!("Cleaning up: {}", format_sql(&drop_new_sql)); + client.execute(&drop_new_sql, &[]).await?; + let create_table = format!( - "CREATE UNLOGGED TABLE {} ({} {}, {} vector({}))", - new_table, args.target_id_column, id_type_str, args.target_embedding_column, vector_dim + "CREATE UNLOGGED TABLE {new_table} ( + {id_column} {id_type_str}, + {embedding_column} vector({vector_dim}) + ) TABLESPACE {tablespace}", + id_column = args.target_id_column, + embedding_column = args.target_embedding_column, ); - info!("Table creation SQL: {}", create_table); + info!("Creating table: {}", format_sql(&create_table)); client.execute(&create_table, &[]).await?; info!("New table created successfully"); // Copy data into new table - info!("Starting COPY operation..."); let copy_query = format!( "COPY {} ({}, {}) FROM STDIN WITH (FORMAT BINARY)", new_table, args.target_id_column, args.target_embedding_column ); - info!("COPY query: {}", copy_query); + info!("COPY query: {}", format_sql(©_query)); let mut writer = Box::pin(BinaryCopyInWriter::new( client.copy_in(©_query).await?, &types, )); - info!("Binary writer initialized"); // Add progress bar for embedding processing let pb = ProgressBar::new(model.q_iter().len() as u64); @@ -288,30 +316,48 @@ async fn save_embeddings( info!("Creating {} index {}...", index_type, index_name); let new_index_sql = index_def.replace(&args.target_table, &new_table); - info!("Index SQL: {}", new_index_sql); + info!("Index definition: {}", format_sql(&new_index_sql)); client.execute(&new_index_sql, &[]).await?; - info!("Index created"); } // Now set table to LOGGED after indexes are built - info!("Setting table as LOGGED..."); - client - .execute(&format!("ALTER TABLE {} SET LOGGED", new_table), &[]) - .await?; + let set_logged_sql = format!("ALTER TABLE {} SET LOGGED", new_table); + info!("Setting logged: {}", format_sql(&set_logged_sql)); + client.execute(&set_logged_sql, &[]).await?; // Start a transaction for the table swap info!("Starting transaction for table swap..."); let tx = client.transaction().await?; // Drop the old table and rename the new one - info!("Dropping old table and renaming new table..."); - tx.execute(&format!("DROP TABLE IF EXISTS {}", args.target_table), &[]) - .await?; - tx.execute( - &format!("ALTER TABLE {} RENAME TO {}", new_table, args.target_table), - &[], - ) - .await?; + let drop_sql = format!("DROP TABLE IF EXISTS {}", args.target_table); + info!("Dropping table: {}", format_sql(&drop_sql)); + tx.execute(&drop_sql, &[]).await?; + + let rename_sql = format!("ALTER TABLE {} RENAME TO {}", new_table, args.target_table); + info!("Renaming table: {}", format_sql(&rename_sql)); + tx.execute(&rename_sql, &[]).await?; + + // Get and rename all indexes on the new table + let index_query = r#" + SELECT c.relname as index_name + FROM pg_index i + JOIN pg_class c ON i.indexrelid = c.oid + JOIN pg_class tc ON i.indrelid = tc.oid + WHERE tc.relname = $1 + ORDER BY c.relname"#; + + let rows = tx.query(index_query, &[&args.target_table]).await?; + for row in rows { + let index_name: String = row.get("index_name"); + if index_name.contains("_new") { + let new_index_name = index_name.replace("_new", ""); + let rename_index_sql = + format!("ALTER INDEX {} RENAME TO {}", index_name, new_index_name); + info!("Renaming index: {}", format_sql(&rename_index_sql)); + tx.execute(&rename_index_sql, &[]).await?; + } + } // Commit the transaction info!("Committing transaction..."); diff --git a/src/format_sql.rs b/src/format_sql.rs new file mode 100644 index 0000000..a92f3ab --- /dev/null +++ b/src/format_sql.rs @@ -0,0 +1,71 @@ +use colored::Colorize; +use regex::Regex; +use sqlformat::{format as sql_format, FormatOptions, QueryParams}; + +// Define SQL keywords by color group +static COLORIZERS: &[(fn(&str) -> colored::ColoredString, &[&str])] = &[ + // Query keywords (blue) + ( + |s: &str| s.bright_blue(), + &["SELECT", "FROM", "WHERE", "ORDER BY"], + ), + // DDL keywords (yellow) + ( + |s: &str| s.bright_yellow(), + &[ + "TABLE", + "TABLESPACE", + "UNIQUE", + "INDEX", + "UNLOGGED", + "LOGGED", + "ALTER", + "RENAME", + "TO", + "USING", + ], + ), + // Destructive operations (red) + (|s: &str| s.bright_red(), &["DROP"]), + // Data modification (green) + ( + |s: &str| s.bright_green(), + &["COPY", "INSERT", "UPDATE", "SET", "CREATE"], + ), +]; + +// Helper function to format SQL for logging +pub fn format_sql(sql: &str) -> String { + // First format with sqlformat for consistent spacing + let formatted = sql_format( + sql, + &QueryParams::None, + FormatOptions { + uppercase: true, + ..Default::default() + }, + ); + + // Replace newlines with spaces and collapse multiple spaces + let single_line = formatted + .lines() + .map(|line| line.trim()) + .collect::>() + .join(" ") + .split_whitespace() + .collect::>() + .join(" "); + + // Apply colors using the static array + let mut result = single_line; + for (color_fn, words) in COLORIZERS.iter() { + for word in *words { + let re = Regex::new(&format!(r"\b{}\b", word)).unwrap(); + result = re + .replace_all(&result, color_fn(word).to_string()) + .to_string(); + } + } + + result +} diff --git a/src/lib.rs b/src/lib.rs index ca0b654..c723f9e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ pub mod fit_model_args; +pub mod format_sql; pub mod pg_types; pub mod pgvector;