Update dependencies and enhance SQL formatting in fit_model
- Added new dependencies: `colored`, `lazy_static`, `sqlformat`, and `unicode_categories` to improve code readability and SQL formatting capabilities. - Introduced a new module `format_sql` for better SQL query formatting in `fit_model.rs`. - Updated `fit_model.rs` to utilize the new `format_sql` function for logging SQL commands, enhancing clarity in database operations. - Adjusted the `Cargo.toml` and `Cargo.lock` files to reflect the new dependencies and their versions. These changes improve the maintainability and readability of the code, particularly in the context of SQL operations.
This commit is contained in:
35
Cargo.lock
generated
35
Cargo.lock
generated
@@ -359,6 +359,16 @@ version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990"
|
||||
|
||||
[[package]]
|
||||
name = "colored"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "console"
|
||||
version = "0.15.10"
|
||||
@@ -1151,6 +1161,12 @@ dependencies = [
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.169"
|
||||
@@ -1244,6 +1260,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
"clap",
|
||||
"colored",
|
||||
"ctrlc",
|
||||
"deadpool-postgres",
|
||||
"dotenv",
|
||||
@@ -1258,6 +1275,8 @@ dependencies = [
|
||||
"pretty_env_logger",
|
||||
"rand",
|
||||
"rand_distr",
|
||||
"regex",
|
||||
"sqlformat",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
]
|
||||
@@ -2021,6 +2040,16 @@ dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlformat"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7bba3a93db0cc4f7bdece8bb09e77e2e785c20bfebf79eb8340ed80708048790"
|
||||
dependencies = [
|
||||
"nom",
|
||||
"unicode_categories",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "stable_deref_trait"
|
||||
version = "1.2.0"
|
||||
@@ -2373,6 +2402,12 @@ version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd"
|
||||
|
||||
[[package]]
|
||||
name = "unicode_categories"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
|
||||
|
||||
[[package]]
|
||||
name = "ureq"
|
||||
version = "2.12.1"
|
||||
|
||||
@@ -23,3 +23,6 @@ tokio = { version = "1.35", features = ["full"] }
|
||||
tokio-postgres = "0.7"
|
||||
bytes = "1.5.0"
|
||||
indicatif = "0.17"
|
||||
sqlformat = "0.2"
|
||||
colored = "2.1"
|
||||
regex = "1.10"
|
||||
|
||||
@@ -12,6 +12,7 @@ use tokio_postgres::binary_copy::BinaryCopyInWriter;
|
||||
use tokio_postgres::{types::Type, NoTls};
|
||||
|
||||
use mf_fitter::fit_model_args::Args;
|
||||
use mf_fitter::format_sql::format_sql;
|
||||
use mf_fitter::pg_types;
|
||||
use mf_fitter::pgvector::PgVector;
|
||||
|
||||
@@ -29,6 +30,12 @@ async fn create_pool() -> Result<Pool> {
|
||||
config.password = Some(env::var("POSTGRES_PASSWORD").context("POSTGRES_PASSWORD not set")?);
|
||||
config.application_name = Some("fit_model".to_string());
|
||||
|
||||
// Set client_min_messages using connection parameters
|
||||
let mut options = Vec::new();
|
||||
options.push("-c".to_string());
|
||||
options.push("client_min_messages=warning".to_string());
|
||||
config.options = Some(options.join(" "));
|
||||
|
||||
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
||||
}
|
||||
|
||||
@@ -172,6 +179,20 @@ async fn save_embeddings(
|
||||
anyhow::bail!("Failed to get both column types from target table");
|
||||
}
|
||||
|
||||
// Get the tablespace of the target table
|
||||
let tablespace_query = r#"
|
||||
SELECT t.spcname as tablespace
|
||||
FROM pg_class c
|
||||
LEFT JOIN pg_tablespace t ON c.reltablespace = t.oid
|
||||
WHERE c.relname = $1
|
||||
AND c.relkind = 'r'"#;
|
||||
|
||||
let tablespace = client
|
||||
.query_one(tablespace_query, &[&args.target_table])
|
||||
.await?
|
||||
.get::<_, Option<String>>("tablespace")
|
||||
.unwrap_or("pg_default".to_string());
|
||||
|
||||
// Create a new table for the embeddings
|
||||
let new_table = format!("{}_new", args.target_table);
|
||||
info!("Creating new table {}...", new_table);
|
||||
@@ -182,27 +203,34 @@ async fn save_embeddings(
|
||||
_ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]),
|
||||
};
|
||||
|
||||
// Drop the new table if it exists from a previous failed run
|
||||
let drop_new_sql = format!("DROP TABLE IF EXISTS {}", new_table);
|
||||
info!("Cleaning up: {}", format_sql(&drop_new_sql));
|
||||
client.execute(&drop_new_sql, &[]).await?;
|
||||
|
||||
let create_table = format!(
|
||||
"CREATE UNLOGGED TABLE {} ({} {}, {} vector({}))",
|
||||
new_table, args.target_id_column, id_type_str, args.target_embedding_column, vector_dim
|
||||
"CREATE UNLOGGED TABLE {new_table} (
|
||||
{id_column} {id_type_str},
|
||||
{embedding_column} vector({vector_dim})
|
||||
) TABLESPACE {tablespace}",
|
||||
id_column = args.target_id_column,
|
||||
embedding_column = args.target_embedding_column,
|
||||
);
|
||||
info!("Table creation SQL: {}", create_table);
|
||||
info!("Creating table: {}", format_sql(&create_table));
|
||||
client.execute(&create_table, &[]).await?;
|
||||
info!("New table created successfully");
|
||||
|
||||
// Copy data into new table
|
||||
info!("Starting COPY operation...");
|
||||
let copy_query = format!(
|
||||
"COPY {} ({}, {}) FROM STDIN WITH (FORMAT BINARY)",
|
||||
new_table, args.target_id_column, args.target_embedding_column
|
||||
);
|
||||
info!("COPY query: {}", copy_query);
|
||||
info!("COPY query: {}", format_sql(©_query));
|
||||
|
||||
let mut writer = Box::pin(BinaryCopyInWriter::new(
|
||||
client.copy_in(©_query).await?,
|
||||
&types,
|
||||
));
|
||||
info!("Binary writer initialized");
|
||||
|
||||
// Add progress bar for embedding processing
|
||||
let pb = ProgressBar::new(model.q_iter().len() as u64);
|
||||
@@ -288,30 +316,48 @@ async fn save_embeddings(
|
||||
|
||||
info!("Creating {} index {}...", index_type, index_name);
|
||||
let new_index_sql = index_def.replace(&args.target_table, &new_table);
|
||||
info!("Index SQL: {}", new_index_sql);
|
||||
info!("Index definition: {}", format_sql(&new_index_sql));
|
||||
client.execute(&new_index_sql, &[]).await?;
|
||||
info!("Index created");
|
||||
}
|
||||
|
||||
// Now set table to LOGGED after indexes are built
|
||||
info!("Setting table as LOGGED...");
|
||||
client
|
||||
.execute(&format!("ALTER TABLE {} SET LOGGED", new_table), &[])
|
||||
.await?;
|
||||
let set_logged_sql = format!("ALTER TABLE {} SET LOGGED", new_table);
|
||||
info!("Setting logged: {}", format_sql(&set_logged_sql));
|
||||
client.execute(&set_logged_sql, &[]).await?;
|
||||
|
||||
// Start a transaction for the table swap
|
||||
info!("Starting transaction for table swap...");
|
||||
let tx = client.transaction().await?;
|
||||
|
||||
// Drop the old table and rename the new one
|
||||
info!("Dropping old table and renaming new table...");
|
||||
tx.execute(&format!("DROP TABLE IF EXISTS {}", args.target_table), &[])
|
||||
.await?;
|
||||
tx.execute(
|
||||
&format!("ALTER TABLE {} RENAME TO {}", new_table, args.target_table),
|
||||
&[],
|
||||
)
|
||||
.await?;
|
||||
let drop_sql = format!("DROP TABLE IF EXISTS {}", args.target_table);
|
||||
info!("Dropping table: {}", format_sql(&drop_sql));
|
||||
tx.execute(&drop_sql, &[]).await?;
|
||||
|
||||
let rename_sql = format!("ALTER TABLE {} RENAME TO {}", new_table, args.target_table);
|
||||
info!("Renaming table: {}", format_sql(&rename_sql));
|
||||
tx.execute(&rename_sql, &[]).await?;
|
||||
|
||||
// Get and rename all indexes on the new table
|
||||
let index_query = r#"
|
||||
SELECT c.relname as index_name
|
||||
FROM pg_index i
|
||||
JOIN pg_class c ON i.indexrelid = c.oid
|
||||
JOIN pg_class tc ON i.indrelid = tc.oid
|
||||
WHERE tc.relname = $1
|
||||
ORDER BY c.relname"#;
|
||||
|
||||
let rows = tx.query(index_query, &[&args.target_table]).await?;
|
||||
for row in rows {
|
||||
let index_name: String = row.get("index_name");
|
||||
if index_name.contains("_new") {
|
||||
let new_index_name = index_name.replace("_new", "");
|
||||
let rename_index_sql =
|
||||
format!("ALTER INDEX {} RENAME TO {}", index_name, new_index_name);
|
||||
info!("Renaming index: {}", format_sql(&rename_index_sql));
|
||||
tx.execute(&rename_index_sql, &[]).await?;
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
info!("Committing transaction...");
|
||||
|
||||
71
src/format_sql.rs
Normal file
71
src/format_sql.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
use colored::Colorize;
|
||||
use regex::Regex;
|
||||
use sqlformat::{format as sql_format, FormatOptions, QueryParams};
|
||||
|
||||
// Define SQL keywords by color group
|
||||
static COLORIZERS: &[(fn(&str) -> colored::ColoredString, &[&str])] = &[
|
||||
// Query keywords (blue)
|
||||
(
|
||||
|s: &str| s.bright_blue(),
|
||||
&["SELECT", "FROM", "WHERE", "ORDER BY"],
|
||||
),
|
||||
// DDL keywords (yellow)
|
||||
(
|
||||
|s: &str| s.bright_yellow(),
|
||||
&[
|
||||
"TABLE",
|
||||
"TABLESPACE",
|
||||
"UNIQUE",
|
||||
"INDEX",
|
||||
"UNLOGGED",
|
||||
"LOGGED",
|
||||
"ALTER",
|
||||
"RENAME",
|
||||
"TO",
|
||||
"USING",
|
||||
],
|
||||
),
|
||||
// Destructive operations (red)
|
||||
(|s: &str| s.bright_red(), &["DROP"]),
|
||||
// Data modification (green)
|
||||
(
|
||||
|s: &str| s.bright_green(),
|
||||
&["COPY", "INSERT", "UPDATE", "SET", "CREATE"],
|
||||
),
|
||||
];
|
||||
|
||||
// Helper function to format SQL for logging
|
||||
pub fn format_sql(sql: &str) -> String {
|
||||
// First format with sqlformat for consistent spacing
|
||||
let formatted = sql_format(
|
||||
sql,
|
||||
&QueryParams::None,
|
||||
FormatOptions {
|
||||
uppercase: true,
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
// Replace newlines with spaces and collapse multiple spaces
|
||||
let single_line = formatted
|
||||
.lines()
|
||||
.map(|line| line.trim())
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ")
|
||||
.split_whitespace()
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
|
||||
// Apply colors using the static array
|
||||
let mut result = single_line;
|
||||
for (color_fn, words) in COLORIZERS.iter() {
|
||||
for word in *words {
|
||||
let re = Regex::new(&format!(r"\b{}\b", word)).unwrap();
|
||||
result = re
|
||||
.replace_all(&result, color_fn(word).to_string())
|
||||
.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod fit_model_args;
|
||||
pub mod format_sql;
|
||||
pub mod pg_types;
|
||||
pub mod pgvector;
|
||||
|
||||
Reference in New Issue
Block a user