diff --git a/Cargo.lock b/Cargo.lock index 4e4ca1b..77593db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1210,6 +1210,7 @@ name = "mf-fitter" version = "0.1.0" dependencies = [ "anyhow", + "bytes", "clap", "ctrlc", "deadpool-postgres", diff --git a/Cargo.toml b/Cargo.toml index 30de10a..4f764c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,3 +21,4 @@ rand = "0.8" rand_distr = "0.4" tokio = { version = "1.35", features = ["full"] } tokio-postgres = "0.7" +bytes = "1.5.0" diff --git a/src/bin/generate_test_data.rs b/src/bin/generate_test_data.rs index b891056..d10ce21 100644 --- a/src/bin/generate_test_data.rs +++ b/src/bin/generate_test_data.rs @@ -43,6 +43,10 @@ struct Args { /// Target table name for embeddings #[arg(long, default_value = "item_embeddings")] embeddings_table: String, + + /// Number of factors (dimensions) for the embeddings + #[arg(long, default_value = "16")] + factors: i32, } async fn create_pool() -> Result { @@ -70,6 +74,12 @@ async fn main() -> Result<()> { let pool = create_pool().await?; let client = pool.get().await?; + // Enable pgvector extension + info!("Enabling pgvector extension..."); + client + .execute("CREATE EXTENSION IF NOT EXISTS vector", &[]) + .await?; + // Create tables info!("Creating tables..."); client @@ -86,14 +96,22 @@ async fn main() -> Result<()> { ) .await?; + // Drop and recreate embeddings table to ensure correct vector dimensions + client + .execute( + &format!("DROP TABLE IF EXISTS {}", args.embeddings_table), + &[], + ) + .await?; + client .execute( &format!( - "CREATE TABLE IF NOT EXISTS {} ( + "CREATE TABLE {} ( item_id INTEGER PRIMARY KEY, - embedding FLOAT4[] + embedding vector({}) )", - args.embeddings_table + args.embeddings_table, args.factors ), &[], ) diff --git a/src/main.rs b/src/main.rs index c6ce4ff..b6c6529 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ use anyhow::{Context, Result}; +use bytes::BytesMut; use clap::Parser; use deadpool_postgres::{Config, Pool, Runtime}; use dotenv::dotenv; @@ -57,6 +58,7 @@ async fn get_column_types( ("bigint", _) => Type::INT8, ("ARRAY", "_float4") => Type::FLOAT4_ARRAY, ("ARRAY", "_float8") => Type::FLOAT8_ARRAY, + ("USER-DEFINED", "vector") => Type::BYTEA, // pgvector type maps to bytea _ => anyhow::bail!("Unsupported column type: {} ({})", data_type, udt_name), }; types.push(pg_type); @@ -133,31 +135,44 @@ async fn save_embeddings( let client = pool.get().await?; // Get the column types from the target table + info!("Getting column types for target table..."); let types = get_column_types( &client, &args.target_table, &[&args.target_id_column, &args.target_embedding_column], ) .await?; + info!("Target table column types: {:?}", types); // Create a temporary table with the same structure + info!("Creating temporary table..."); + let id_type = match types[0] { + Type::INT4 => "INTEGER", + Type::INT8 => "BIGINT", + _ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]), + }; let create_temp = format!( - "CREATE TEMP TABLE temp_embeddings AS SELECT * FROM {} WHERE 1=0", - args.target_table + "CREATE TEMP TABLE temp_embeddings ( + {} {}, + {} vector({}) + )", + args.target_id_column, id_type, args.target_embedding_column, args.factors ); + info!("Temp table creation SQL: {}", create_temp); client.execute(&create_temp, &[]).await?; + info!("Temporary table created successfully"); // Copy data into temporary table - let copy_query = format!( - "COPY temp_embeddings ({id_col}, {embedding_col}) FROM STDIN (FORMAT binary)", - id_col = args.target_id_column, - embedding_col = args.target_embedding_column, - ); + info!("Starting COPY operation..."); + let copy_query = + format!("COPY temp_embeddings (item_id, embedding) FROM STDIN WITH (FORMAT BINARY)",); + info!("COPY query: {}", copy_query); let mut writer = Box::pin(BinaryCopyInWriter::new( client.copy_in(©_query).await?, - &types, + &[Type::INT4, Type::BYTEA], )); + info!("Binary writer initialized"); let mut valid_embeddings = 0; let mut invalid_embeddings = 0; @@ -176,33 +191,48 @@ async fn save_embeddings( } valid_embeddings += 1; - let id_value: Box = match types[0] { - Type::INT4 => Box::new(item_id), - Type::INT8 => Box::new(item_id as i64), - _ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]), - }; - let factors: &[f32] = factors; - writer.as_mut().write(&[&*id_value, &factors]).await?; + let vector_data: Vec = factors.to_vec(); + + // Create a buffer for the binary data + let mut buf = BytesMut::new(); + buf.extend_from_slice(&(vector_data.len() as u16).to_be_bytes()); // Number of dimensions in big-endian + buf.extend_from_slice(&0u16.to_be_bytes()); // Unused padding in big-endian + for &value in &vector_data { + buf.extend_from_slice(&value.to_be_bytes()); // Float values in big-endian + } + let binary_data: Vec = buf.freeze().to_vec(); + + info!( + "Writing embedding for item_id={}, factors.len()={}", + item_id, + factors.len() + ); + writer.as_mut().write(&[&item_id, &binary_data]).await?; } + info!("Finishing COPY operation..."); writer.as_mut().finish().await?; + info!("COPY operation completed"); // Insert from temp table with ON CONFLICT DO UPDATE + info!("Inserting from temp table to target table..."); let insert_query = format!( - r#" - INSERT INTO {target_table} ({id_col}, {embedding_col}) - SELECT {id_col}, {embedding_col} FROM temp_embeddings - ON CONFLICT ({id_col}) DO UPDATE - SET {embedding_col} = EXCLUDED.{embedding_col} - "#, + "INSERT INTO {target_table} ({id_col}, {embedding_col}) + SELECT {id_col}, {embedding_col} FROM temp_embeddings + ON CONFLICT ({id_col}) DO UPDATE + SET {embedding_col} = EXCLUDED.{embedding_col}", target_table = args.target_table, id_col = args.target_id_column, embedding_col = args.target_embedding_column, ); + info!("Insert query: {}", insert_query); client.execute(&insert_query, &[]).await?; + info!("Insert completed"); // Clean up temp table + info!("Cleaning up temporary table..."); client.execute("DROP TABLE temp_embeddings", &[]).await?; + info!("Temporary table dropped"); info!( "Saved {} valid embeddings, skipped {} invalid embeddings", @@ -283,6 +313,12 @@ async fn main() -> Result<()> { pretty_env_logger::init(); let args = Args::parse(); + // Set up Ctrl+C handler for immediate exit + ctrlc::set_handler(|| { + info!("Received interrupt signal, exiting immediately..."); + std::process::exit(1); + })?; + let pool = create_pool().await?; // Validate schema before proceeding