Update dependencies and enhance SQL formatting in fit_model

- Added new dependencies: `colored`, `lazy_static`, `sqlformat`, and `unicode_categories` to improve code readability and SQL formatting capabilities.
- Introduced a new module `format_sql` for better SQL query formatting in `fit_model.rs`.
- Updated `fit_model.rs` to utilize the new `format_sql` function for logging SQL commands, enhancing clarity in database operations.
- Adjusted the `Cargo.toml` and `Cargo.lock` files to reflect the new dependencies and their versions.

These changes improve the maintainability and readability of the code, particularly in the context of SQL operations.
This commit is contained in:
Dylan Knutson
2024-12-28 21:46:14 +00:00
parent 4651b96785
commit 0dadf2654c
5 changed files with 176 additions and 20 deletions

35
Cargo.lock generated
View File

@@ -359,6 +359,16 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990"
[[package]]
name = "colored"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c"
dependencies = [
"lazy_static",
"windows-sys 0.59.0",
]
[[package]]
name = "console"
version = "0.15.10"
@@ -1151,6 +1161,12 @@ dependencies = [
"thiserror 1.0.69",
]
[[package]]
name = "lazy_static"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "libc"
version = "0.2.169"
@@ -1244,6 +1260,7 @@ dependencies = [
"anyhow",
"bytes",
"clap",
"colored",
"ctrlc",
"deadpool-postgres",
"dotenv",
@@ -1258,6 +1275,8 @@ dependencies = [
"pretty_env_logger",
"rand",
"rand_distr",
"regex",
"sqlformat",
"tokio",
"tokio-postgres",
]
@@ -2021,6 +2040,16 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "sqlformat"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7bba3a93db0cc4f7bdece8bb09e77e2e785c20bfebf79eb8340ed80708048790"
dependencies = [
"nom",
"unicode_categories",
]
[[package]]
name = "stable_deref_trait"
version = "1.2.0"
@@ -2373,6 +2402,12 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd"
[[package]]
name = "unicode_categories"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "ureq"
version = "2.12.1"

View File

@@ -23,3 +23,6 @@ tokio = { version = "1.35", features = ["full"] }
tokio-postgres = "0.7"
bytes = "1.5.0"
indicatif = "0.17"
sqlformat = "0.2"
colored = "2.1"
regex = "1.10"

View File

@@ -12,6 +12,7 @@ use tokio_postgres::binary_copy::BinaryCopyInWriter;
use tokio_postgres::{types::Type, NoTls};
use mf_fitter::fit_model_args::Args;
use mf_fitter::format_sql::format_sql;
use mf_fitter::pg_types;
use mf_fitter::pgvector::PgVector;
@@ -29,6 +30,12 @@ async fn create_pool() -> Result<Pool> {
config.password = Some(env::var("POSTGRES_PASSWORD").context("POSTGRES_PASSWORD not set")?);
config.application_name = Some("fit_model".to_string());
// Set client_min_messages using connection parameters
let mut options = Vec::new();
options.push("-c".to_string());
options.push("client_min_messages=warning".to_string());
config.options = Some(options.join(" "));
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
}
@@ -172,6 +179,20 @@ async fn save_embeddings(
anyhow::bail!("Failed to get both column types from target table");
}
// Get the tablespace of the target table
let tablespace_query = r#"
SELECT t.spcname as tablespace
FROM pg_class c
LEFT JOIN pg_tablespace t ON c.reltablespace = t.oid
WHERE c.relname = $1
AND c.relkind = 'r'"#;
let tablespace = client
.query_one(tablespace_query, &[&args.target_table])
.await?
.get::<_, Option<String>>("tablespace")
.unwrap_or("pg_default".to_string());
// Create a new table for the embeddings
let new_table = format!("{}_new", args.target_table);
info!("Creating new table {}...", new_table);
@@ -182,27 +203,34 @@ async fn save_embeddings(
_ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]),
};
// Drop the new table if it exists from a previous failed run
let drop_new_sql = format!("DROP TABLE IF EXISTS {}", new_table);
info!("Cleaning up: {}", format_sql(&drop_new_sql));
client.execute(&drop_new_sql, &[]).await?;
let create_table = format!(
"CREATE UNLOGGED TABLE {} ({} {}, {} vector({}))",
new_table, args.target_id_column, id_type_str, args.target_embedding_column, vector_dim
"CREATE UNLOGGED TABLE {new_table} (
{id_column} {id_type_str},
{embedding_column} vector({vector_dim})
) TABLESPACE {tablespace}",
id_column = args.target_id_column,
embedding_column = args.target_embedding_column,
);
info!("Table creation SQL: {}", create_table);
info!("Creating table: {}", format_sql(&create_table));
client.execute(&create_table, &[]).await?;
info!("New table created successfully");
// Copy data into new table
info!("Starting COPY operation...");
let copy_query = format!(
"COPY {} ({}, {}) FROM STDIN WITH (FORMAT BINARY)",
new_table, args.target_id_column, args.target_embedding_column
);
info!("COPY query: {}", copy_query);
info!("COPY query: {}", format_sql(&copy_query));
let mut writer = Box::pin(BinaryCopyInWriter::new(
client.copy_in(&copy_query).await?,
&types,
));
info!("Binary writer initialized");
// Add progress bar for embedding processing
let pb = ProgressBar::new(model.q_iter().len() as u64);
@@ -288,30 +316,48 @@ async fn save_embeddings(
info!("Creating {} index {}...", index_type, index_name);
let new_index_sql = index_def.replace(&args.target_table, &new_table);
info!("Index SQL: {}", new_index_sql);
info!("Index definition: {}", format_sql(&new_index_sql));
client.execute(&new_index_sql, &[]).await?;
info!("Index created");
}
// Now set table to LOGGED after indexes are built
info!("Setting table as LOGGED...");
client
.execute(&format!("ALTER TABLE {} SET LOGGED", new_table), &[])
.await?;
let set_logged_sql = format!("ALTER TABLE {} SET LOGGED", new_table);
info!("Setting logged: {}", format_sql(&set_logged_sql));
client.execute(&set_logged_sql, &[]).await?;
// Start a transaction for the table swap
info!("Starting transaction for table swap...");
let tx = client.transaction().await?;
// Drop the old table and rename the new one
info!("Dropping old table and renaming new table...");
tx.execute(&format!("DROP TABLE IF EXISTS {}", args.target_table), &[])
.await?;
tx.execute(
&format!("ALTER TABLE {} RENAME TO {}", new_table, args.target_table),
&[],
)
.await?;
let drop_sql = format!("DROP TABLE IF EXISTS {}", args.target_table);
info!("Dropping table: {}", format_sql(&drop_sql));
tx.execute(&drop_sql, &[]).await?;
let rename_sql = format!("ALTER TABLE {} RENAME TO {}", new_table, args.target_table);
info!("Renaming table: {}", format_sql(&rename_sql));
tx.execute(&rename_sql, &[]).await?;
// Get and rename all indexes on the new table
let index_query = r#"
SELECT c.relname as index_name
FROM pg_index i
JOIN pg_class c ON i.indexrelid = c.oid
JOIN pg_class tc ON i.indrelid = tc.oid
WHERE tc.relname = $1
ORDER BY c.relname"#;
let rows = tx.query(index_query, &[&args.target_table]).await?;
for row in rows {
let index_name: String = row.get("index_name");
if index_name.contains("_new") {
let new_index_name = index_name.replace("_new", "");
let rename_index_sql =
format!("ALTER INDEX {} RENAME TO {}", index_name, new_index_name);
info!("Renaming index: {}", format_sql(&rename_index_sql));
tx.execute(&rename_index_sql, &[]).await?;
}
}
// Commit the transaction
info!("Committing transaction...");

71
src/format_sql.rs Normal file
View File

@@ -0,0 +1,71 @@
use colored::Colorize;
use regex::Regex;
use sqlformat::{format as sql_format, FormatOptions, QueryParams};
// Define SQL keywords by color group
static COLORIZERS: &[(fn(&str) -> colored::ColoredString, &[&str])] = &[
// Query keywords (blue)
(
|s: &str| s.bright_blue(),
&["SELECT", "FROM", "WHERE", "ORDER BY"],
),
// DDL keywords (yellow)
(
|s: &str| s.bright_yellow(),
&[
"TABLE",
"TABLESPACE",
"UNIQUE",
"INDEX",
"UNLOGGED",
"LOGGED",
"ALTER",
"RENAME",
"TO",
"USING",
],
),
// Destructive operations (red)
(|s: &str| s.bright_red(), &["DROP"]),
// Data modification (green)
(
|s: &str| s.bright_green(),
&["COPY", "INSERT", "UPDATE", "SET", "CREATE"],
),
];
// Helper function to format SQL for logging
pub fn format_sql(sql: &str) -> String {
// First format with sqlformat for consistent spacing
let formatted = sql_format(
sql,
&QueryParams::None,
FormatOptions {
uppercase: true,
..Default::default()
},
);
// Replace newlines with spaces and collapse multiple spaces
let single_line = formatted
.lines()
.map(|line| line.trim())
.collect::<Vec<_>>()
.join(" ")
.split_whitespace()
.collect::<Vec<_>>()
.join(" ");
// Apply colors using the static array
let mut result = single_line;
for (color_fn, words) in COLORIZERS.iter() {
for word in *words {
let re = Regex::new(&format!(r"\b{}\b", word)).unwrap();
result = re
.replace_all(&result, color_fn(word).to_string())
.to_string();
}
}
result
}

View File

@@ -1,3 +1,4 @@
pub mod fit_model_args;
pub mod format_sql;
pub mod pg_types;
pub mod pgvector;