diff --git a/src/bin/fit_model.rs b/src/bin/fit_model.rs index 8bde7b9..4ff2a5c 100644 --- a/src/bin/fit_model.rs +++ b/src/bin/fit_model.rs @@ -48,9 +48,7 @@ async fn get_column_types( column_list, column_list ); - info!("Executing type query: {}", query); let rows = client.query(&query, &[&table]).await?; - info!("Got types: {:?}", rows); let mut types = Vec::new(); for row in rows { let data_type: &str = row.get(0); @@ -174,29 +172,29 @@ async fn save_embeddings( anyhow::bail!("Failed to get both column types from target table"); } - // Create a temporary table with the same structure - info!("Creating temporary table..."); + // Create a new table for the embeddings + let new_table = format!("{}_new", args.target_table); + info!("Creating new table {}...", new_table); + let id_type_str = match types[0] { Type::INT4 => "INTEGER", Type::INT8 => "BIGINT", _ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]), }; - let create_temp = format!( - "CREATE TEMP TABLE temp_embeddings ( - {} {}, - {} vector({}) - )", - args.target_id_column, id_type_str, args.target_embedding_column, vector_dim - ); - info!("Temp table creation SQL: {}", create_temp); - client.execute(&create_temp, &[]).await?; - info!("Temporary table created successfully"); - // Copy data into temporary table + let create_table = format!( + "CREATE UNLOGGED TABLE {} ({} {}, {} vector({}))", + new_table, args.target_id_column, id_type_str, args.target_embedding_column, vector_dim + ); + info!("Table creation SQL: {}", create_table); + client.execute(&create_table, &[]).await?; + info!("New table created successfully"); + + // Copy data into new table info!("Starting COPY operation..."); let copy_query = format!( - "COPY temp_embeddings ({}, {}) FROM STDIN WITH (FORMAT BINARY)", - args.target_id_column, args.target_embedding_column + "COPY {} ({}, {}) FROM STDIN WITH (FORMAT BINARY)", + new_table, args.target_id_column, args.target_embedding_column ); info!("COPY query: {}", copy_query); @@ -248,83 +246,78 @@ async fn save_embeddings( writer.as_mut().finish().await?; info!("COPY operation completed"); - // Start a transaction for index management and data upsert - info!("Starting transaction for index management and data upsert..."); + // Get creation SQL for all indexes + let query = r#" + SELECT c.relname as index_name, + pg_get_indexdef(i.indexrelid) || + CASE WHEN t.spcname IS NOT NULL + THEN ' TABLESPACE ' || t.spcname + ELSE '' + END as index_def, + i.indisprimary, + i.indisunique + FROM pg_index i + JOIN pg_class c ON i.indexrelid = c.oid + JOIN pg_class tc ON i.indrelid = tc.oid + LEFT JOIN pg_tablespace t ON c.reltablespace = t.oid + WHERE tc.relname = $1 + ORDER BY i.indisprimary DESC, i.indisunique DESC, c.relname"#; + + // Get all indexes + let rows = client.query(query, &[&args.target_table]).await?; + let mut indexes = Vec::new(); + + for row in rows { + let index_name: String = row.get("index_name"); + let index_def: String = row.get("index_def"); + let is_primary: bool = row.get("indisprimary"); + let is_unique: bool = row.get("indisunique"); + + indexes.push((index_name, index_def, is_primary, is_unique)); + } + + // Create indexes while table is still UNLOGGED + for (index_name, index_def, is_primary, is_unique) in indexes { + let index_type = if is_primary { + "primary key" + } else if is_unique { + "unique" + } else { + "non-unique" + }; + + info!("Creating {} index {}...", index_type, index_name); + let new_index_sql = index_def.replace(&args.target_table, &new_table); + info!("Index SQL: {}", new_index_sql); + client.execute(&new_index_sql, &[]).await?; + info!("Index created"); + } + + // Now set table to LOGGED after indexes are built + info!("Setting table as LOGGED..."); + client + .execute(&format!("ALTER TABLE {} SET LOGGED", new_table), &[]) + .await?; + + // Start a transaction for the table swap + info!("Starting transaction for table swap..."); let tx = client.transaction().await?; - // Get and drop index if specified - let index_sql = if let Some(index_name) = &args.index_name { - let query = r#" - SELECT pg_get_indexdef(i.indexrelid) || - CASE WHEN t.spcname IS NOT NULL - THEN ' TABLESPACE ' || t.spcname - ELSE '' - END as index_def - FROM pg_index i - JOIN pg_class c ON i.indexrelid = c.oid - LEFT JOIN pg_tablespace t ON c.reltablespace = t.oid - WHERE c.relname = $1"#; - let row = tx.query_one(query, &[index_name]).await?; - let sql: String = row.get("index_def"); - - info!("Dropping index {}...", index_name); - tx.execute(&format!("DROP INDEX IF EXISTS {}", index_name), &[]) - .await?; - - Some(sql) - } else { - None - }; - - // Set table as UNLOGGED - info!("Setting table as UNLOGGED..."); + // Drop the old table and rename the new one + info!("Dropping old table and renaming new table..."); + tx.execute(&format!("DROP TABLE IF EXISTS {}", args.target_table), &[]) + .await?; tx.execute( - &format!("ALTER TABLE {} SET UNLOGGED", args.target_table), + &format!("ALTER TABLE {} RENAME TO {}", new_table, args.target_table), &[], ) .await?; - // Insert from temp table with ON CONFLICT DO UPDATE - info!("Upserting from temp table to target table..."); - let upsert_query = format!( - "INSERT INTO {target_table} ({id_col}, {embedding_col}) - SELECT {id_col}, {embedding_col} FROM temp_embeddings - ON CONFLICT ({id_col}) DO UPDATE - SET {embedding_col} = EXCLUDED.{embedding_col}", - target_table = args.target_table, - id_col = args.target_id_column, - embedding_col = args.target_embedding_column, - ); - info!("Upsert query: {}", upsert_query); - tx.execute(&upsert_query, &[]).await?; - info!("Upsert completed"); - - // Set table back to LOGGED - info!("Setting table back to LOGGED..."); - tx.execute( - &format!("ALTER TABLE {} SET LOGGED", args.target_table), - &[], - ) - .await?; - - // Recreate index if we had one - if let Some(sql) = index_sql { - info!("Recreating index..."); - info!("Index SQL: {}", sql); - tx.execute(&sql, &[]).await?; - info!("Index recreated"); - } - // Commit the transaction info!("Committing transaction..."); tx.commit().await?; info!("Transaction committed"); - // Clean up temp table - info!("Cleaning up temporary table..."); - client.execute("DROP TABLE temp_embeddings", &[]).await?; - info!("Temporary table dropped"); - info!( "Saved {} valid embeddings, skipped {} invalid embeddings", valid_embeddings, invalid_embeddings diff --git a/src/fit_model_args.rs b/src/fit_model_args.rs index 9754e5e..ef70b59 100644 --- a/src/fit_model_args.rs +++ b/src/fit_model_args.rs @@ -63,8 +63,4 @@ pub struct Args { /// Maximum number of interactions to load (optional) #[arg(long)] pub max_interactions: Option, - - /// Name of the index to drop before upserting and recreate after - #[arg(long)] - pub index_name: Option, }