diff --git a/src/main.rs b/src/main.rs index f3d99bb..c6ce4ff 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,6 +7,7 @@ use libmf::{Loss, Matrix, Model}; use log::info; use std::collections::HashSet; use std::env; +use tokio_postgres::binary_copy::BinaryCopyInWriter; use tokio_postgres::{types::Type, NoTls}; mod args; @@ -35,23 +36,28 @@ async fn get_column_types( ) -> Result> { let column_list = columns.join("', '"); let query = format!( - "SELECT data_type \ + "SELECT data_type, udt_name \ FROM information_schema.columns \ WHERE table_schema = 'public' \ AND table_name = $1 \ AND column_name IN ('{}') \ - ORDER BY column_name", - column_list + ORDER BY array_position(ARRAY['{}'], column_name)", + column_list, column_list ); + info!("Executing type query: {}", query); let rows = client.query(&query, &[&table]).await?; + info!("Got types: {:?}", rows); let mut types = Vec::new(); for row in rows { let data_type: &str = row.get(0); - let pg_type = match data_type { - "integer" => Type::INT4, - "bigint" => Type::INT8, - _ => anyhow::bail!("Unsupported column type: {}", data_type), + let udt_name: &str = row.get(1); + let pg_type = match (data_type, udt_name) { + ("integer", _) => Type::INT4, + ("bigint", _) => Type::INT8, + ("ARRAY", "_float4") => Type::FLOAT4_ARRAY, + ("ARRAY", "_float8") => Type::FLOAT8_ARRAY, + _ => anyhow::bail!("Unsupported column type: {} ({})", data_type, udt_name), }; types.push(pg_type); } @@ -126,10 +132,35 @@ async fn save_embeddings( ) -> Result<()> { let client = pool.get().await?; + // Get the column types from the target table + let types = get_column_types( + &client, + &args.target_table, + &[&args.target_id_column, &args.target_embedding_column], + ) + .await?; + + // Create a temporary table with the same structure + let create_temp = format!( + "CREATE TEMP TABLE temp_embeddings AS SELECT * FROM {} WHERE 1=0", + args.target_table + ); + client.execute(&create_temp, &[]).await?; + + // Copy data into temporary table + let copy_query = format!( + "COPY temp_embeddings ({id_col}, {embedding_col}) FROM STDIN (FORMAT binary)", + id_col = args.target_id_column, + embedding_col = args.target_embedding_column, + ); + + let mut writer = Box::pin(BinaryCopyInWriter::new( + client.copy_in(©_query).await?, + &types, + )); + let mut valid_embeddings = 0; let mut invalid_embeddings = 0; - let batch_size = 128; - let mut current_batch = Vec::with_capacity(batch_size); // Process factors for items that appear in the source table for (idx, factors) in model.q_iter().enumerate() { @@ -145,33 +176,33 @@ async fn save_embeddings( } valid_embeddings += 1; - current_batch.push((item_id as i64, factors)); - - // When batch is full, save it - if current_batch.len() >= batch_size { - save_batch( - &client, - &args.target_table, - &args.target_id_column, - &args.target_embedding_column, - ¤t_batch, - ) - .await?; - current_batch.clear(); - } + let id_value: Box = match types[0] { + Type::INT4 => Box::new(item_id), + Type::INT8 => Box::new(item_id as i64), + _ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]), + }; + let factors: &[f32] = factors; + writer.as_mut().write(&[&*id_value, &factors]).await?; } - // Save any remaining items in the last batch - if !current_batch.is_empty() { - save_batch( - &client, - &args.target_table, - &args.target_id_column, - &args.target_embedding_column, - ¤t_batch, - ) - .await?; - } + writer.as_mut().finish().await?; + + // Insert from temp table with ON CONFLICT DO UPDATE + let insert_query = format!( + r#" + INSERT INTO {target_table} ({id_col}, {embedding_col}) + SELECT {id_col}, {embedding_col} FROM temp_embeddings + ON CONFLICT ({id_col}) DO UPDATE + SET {embedding_col} = EXCLUDED.{embedding_col} + "#, + target_table = args.target_table, + id_col = args.target_id_column, + embedding_col = args.target_embedding_column, + ); + client.execute(&insert_query, &[]).await?; + + // Clean up temp table + client.execute("DROP TABLE temp_embeddings", &[]).await?; info!( "Saved {} valid embeddings, skipped {} invalid embeddings", @@ -181,48 +212,6 @@ async fn save_embeddings( Ok(()) } -async fn save_batch( - client: &deadpool_postgres::Client, - target_table: &str, - target_id_column: &str, - target_embedding_column: &str, - batch_values: &[(i64, &[f32])], -) -> Result<()> { - // Build the batch insert query - let placeholders: Vec = (0..batch_values.len()) - .map(|i| format!("(${}::int8, ${}::float4[])", i * 2 + 1, i * 2 + 2)) - .collect(); - let query = format!( - r#" - INSERT INTO {target_table} ({target_id_column}, {target_embedding_column}) - VALUES {placeholders} - ON CONFLICT ({target_id_column}) - DO UPDATE SET {target_embedding_column} = EXCLUDED.{target_embedding_column} - "#, - placeholders = placeholders.join(",") - ); - - // info!("Executing query: {}", query); - - // Flatten parameters for the query - let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = Vec::new(); - for (item_id, factors) in batch_values { - params.push(item_id); - params.push(factors); - } - - info!("Number of parameters: {}", params.len()); - client.execute(&query, ¶ms[..]).await.with_context(|| { - format!( - "Failed to execute batch insert at {}:{} with {} values", - file!(), - line!(), - batch_values.len() - ) - })?; - Ok(()) -} - async fn validate_table_exists(client: &deadpool_postgres::Client, table: &str) -> Result<()> { let query = r#"SELECT EXISTS ( SELECT FROM pg_tables