diff --git a/src/bin/fit_model.rs b/src/bin/fit_model.rs index bce43ec..fb0a5e2 100644 --- a/src/bin/fit_model.rs +++ b/src/bin/fit_model.rs @@ -26,6 +26,7 @@ async fn create_pool() -> Result { config.dbname = Some(env::var("POSTGRES_DB").context("POSTGRES_DB not set")?); config.user = Some(env::var("POSTGRES_USER").context("POSTGRES_USER not set")?); config.password = Some(env::var("POSTGRES_PASSWORD").context("POSTGRES_PASSWORD not set")?); + config.application_name = Some("fit_model".to_string()); Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?) } @@ -138,8 +139,9 @@ async fn save_embeddings( args: &Args, model: &Model, item_ids: &HashSet, + vector_dim: i32, ) -> Result<()> { - let client = pool.get().await?; + let mut client = pool.get().await?; // Get the column types from the target table info!("Getting column types for target table..."); @@ -166,7 +168,7 @@ async fn save_embeddings( {} {}, {} vector({}) )", - args.target_id_column, id_type_str, args.target_embedding_column, args.factors + args.target_id_column, id_type_str, args.target_embedding_column, vector_dim ); info!("Temp table creation SQL: {}", create_temp); client.execute(&create_temp, &[]).await?; @@ -215,6 +217,42 @@ async fn save_embeddings( writer.as_mut().finish().await?; info!("COPY operation completed"); + // Start a transaction for index management and data upsert + info!("Starting transaction for index management and data upsert..."); + let tx = client.transaction().await?; + + // Get and drop index if specified + let index_sql = if let Some(index_name) = &args.index_name { + let query = r#" + SELECT pg_get_indexdef(i.indexrelid) || + CASE WHEN t.spcname IS NOT NULL + THEN ' TABLESPACE ' || t.spcname + ELSE '' + END as index_def + FROM pg_index i + JOIN pg_class c ON i.indexrelid = c.oid + LEFT JOIN pg_tablespace t ON c.reltablespace = t.oid + WHERE c.relname = $1"#; + let row = tx.query_one(query, &[index_name]).await?; + let sql: String = row.get("index_def"); + + info!("Dropping index {}...", index_name); + tx.execute(&format!("DROP INDEX IF EXISTS {}", index_name), &[]) + .await?; + + Some(sql) + } else { + None + }; + + // Set table as UNLOGGED + // info!("Setting table as UNLOGGED..."); + // tx.execute( + // &format!("ALTER TABLE {} SET UNLOGGED", args.target_table), + // &[], + // ) + // .await?; + // Insert from temp table with ON CONFLICT DO UPDATE info!("Upserting from temp table to target table..."); let upsert_query = format!( @@ -227,9 +265,30 @@ async fn save_embeddings( embedding_col = args.target_embedding_column, ); info!("Upsert query: {}", upsert_query); - client.execute(&upsert_query, &[]).await?; + tx.execute(&upsert_query, &[]).await?; info!("Upsert completed"); + // Set table back to LOGGED + // info!("Setting table back to LOGGED..."); + // tx.execute( + // &format!("ALTER TABLE {} SET LOGGED", args.target_table), + // &[], + // ) + // .await?; + + // Recreate index if we had one + if let Some(sql) = index_sql { + info!("Recreating index..."); + info!("Index SQL: {}", sql); + tx.execute(&sql, &[]).await?; + info!("Index recreated"); + } + + // Commit the transaction + info!("Committing transaction..."); + tx.commit().await?; + info!("Transaction committed"); + // Clean up temp table info!("Cleaning up temporary table..."); client.execute("DROP TABLE temp_embeddings", &[]).await?; @@ -278,7 +337,7 @@ async fn validate_column_exists( Ok(()) } -async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Result<()> { +async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Result { // Validate source table exists validate_table_exists(client, &args.source_table) .await @@ -305,7 +364,7 @@ async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Res .await .context("Failed to validate target embedding column")?; - // Validate vector dimension matches factors + // Get vector dimension from target column let query = r#" SELECT a.atttypmod as vector_dim, c.data_type, c.udt_name FROM pg_attribute a @@ -323,29 +382,44 @@ async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Res .await?; let data_type: &str = row.get("data_type"); - if data_type == "USER-DEFINED" { - let udt_name: &str = row.get("udt_name"); - if udt_name == "vector" { - let vector_dim: i32 = row.get("vector_dim"); - if vector_dim <= 0 { - anyhow::bail!( - "Invalid vector dimension {} for column '{}'", - vector_dim, - args.target_embedding_column, - ); - } - if vector_dim != args.factors as i32 { - anyhow::bail!( - "Vector dimension mismatch: column '{}' has dimension {}, but factors is {}", - args.target_embedding_column, - vector_dim, - args.factors - ); - } + if data_type != "USER-DEFINED" { + anyhow::bail!( + "Column '{}' is not a user-defined type (got {})", + args.target_embedding_column, + data_type + ); + } + + let udt_name: &str = row.get("udt_name"); + if udt_name != "vector" { + anyhow::bail!( + "Column '{}' is not a vector type (got {})", + args.target_embedding_column, + udt_name + ); + } + + let vector_dim: i32 = row.get("vector_dim"); + if vector_dim <= 0 { + anyhow::bail!( + "Invalid vector dimension {} for column '{}'", + vector_dim, + args.target_embedding_column, + ); + } + + // If factors is specified, validate it matches the column dimension + if let Some(factors) = args.factors { + if factors != vector_dim { + anyhow::bail!( + "Specified factors ({}) does not match column vector dimension ({})", + factors, + vector_dim + ); } } - Ok(()) + Ok(vector_dim) } #[tokio::main] @@ -362,10 +436,10 @@ async fn main() -> Result<()> { let pool = create_pool().await?; - // Validate schema before proceeding + // Validate schema and get vector dimension info!("Validating database schema..."); - validate_schema(&pool.get().await?, &args).await?; - info!("Schema validation successful"); + let factors = validate_schema(&pool.get().await?, &args).await?; + info!("Schema validation successful, inferred {} factors", factors); let mut matrix = Matrix::new(); let mut item_ids = HashSet::new(); @@ -401,7 +475,7 @@ async fn main() -> Result<()> { ); let model = Model::params() - .factors(args.factors) + .factors(factors) .lambda_p1(args.lambda1) .lambda_q1(args.lambda1) .lambda_p2(args.lambda2) @@ -416,7 +490,7 @@ async fn main() -> Result<()> { .fit(&matrix)?; info!("Saving embeddings..."); - save_embeddings(&pool, &args, &model, &item_ids).await?; + save_embeddings(&pool, &args, &model, &item_ids, factors).await?; Ok(()) } diff --git a/src/bin/generate_test_data.rs b/src/bin/generate_test_data.rs index d10ce21..5339d32 100644 --- a/src/bin/generate_test_data.rs +++ b/src/bin/generate_test_data.rs @@ -117,6 +117,22 @@ async fn main() -> Result<()> { ) .await?; + // Create IVFFlat index on the embeddings column + let index_name = format!("{}_embeddings_idx", args.embeddings_table); + info!( + "Creating IVFFlat index `{}' on embeddings column...", + index_name + ); + client + .execute( + &format!( + "CREATE INDEX {} ON {} USING ivfflat (embedding)", + index_name, args.embeddings_table + ), + &[], + ) + .await?; + // Generate cluster centers that are well-separated in 3D space let mut rng = rand::thread_rng(); let mut cluster_centers = Vec::new(); diff --git a/src/fit_model_args.rs b/src/fit_model_args.rs index f8edbeb..9754e5e 100644 --- a/src/fit_model_args.rs +++ b/src/fit_model_args.rs @@ -39,9 +39,10 @@ pub struct Args { #[arg(long, default_value = "0.01")] pub learning_rate: f32, - /// Number of factors for matrix factorization - #[arg(long, default_value = "8")] - pub factors: i32, + /// Number of factors (dimensions) for matrix factorization + /// If not specified, will be inferred from the target column's vector dimension + #[arg(long)] + pub factors: Option, /// Lambda for regularization #[arg(long, default_value = "0.0")] @@ -62,4 +63,8 @@ pub struct Args { /// Maximum number of interactions to load (optional) #[arg(long)] pub max_interactions: Option, + + /// Name of the index to drop before upserting and recreate after + #[arg(long)] + pub index_name: Option, }