write vector binary type to database

This commit is contained in:
Dylan Knutson
2024-12-28 19:04:43 +00:00
parent 2b1865f3d4
commit bc88c54cb0
4 changed files with 80 additions and 24 deletions

1
Cargo.lock generated
View File

@@ -1210,6 +1210,7 @@ name = "mf-fitter"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bytes",
"clap", "clap",
"ctrlc", "ctrlc",
"deadpool-postgres", "deadpool-postgres",

View File

@@ -21,3 +21,4 @@ rand = "0.8"
rand_distr = "0.4" rand_distr = "0.4"
tokio = { version = "1.35", features = ["full"] } tokio = { version = "1.35", features = ["full"] }
tokio-postgres = "0.7" tokio-postgres = "0.7"
bytes = "1.5.0"

View File

@@ -43,6 +43,10 @@ struct Args {
/// Target table name for embeddings /// Target table name for embeddings
#[arg(long, default_value = "item_embeddings")] #[arg(long, default_value = "item_embeddings")]
embeddings_table: String, embeddings_table: String,
/// Number of factors (dimensions) for the embeddings
#[arg(long, default_value = "16")]
factors: i32,
} }
async fn create_pool() -> Result<Pool> { async fn create_pool() -> Result<Pool> {
@@ -70,6 +74,12 @@ async fn main() -> Result<()> {
let pool = create_pool().await?; let pool = create_pool().await?;
let client = pool.get().await?; let client = pool.get().await?;
// Enable pgvector extension
info!("Enabling pgvector extension...");
client
.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])
.await?;
// Create tables // Create tables
info!("Creating tables..."); info!("Creating tables...");
client client
@@ -86,14 +96,22 @@ async fn main() -> Result<()> {
) )
.await?; .await?;
// Drop and recreate embeddings table to ensure correct vector dimensions
client
.execute(
&format!("DROP TABLE IF EXISTS {}", args.embeddings_table),
&[],
)
.await?;
client client
.execute( .execute(
&format!( &format!(
"CREATE TABLE IF NOT EXISTS {} ( "CREATE TABLE {} (
item_id INTEGER PRIMARY KEY, item_id INTEGER PRIMARY KEY,
embedding FLOAT4[] embedding vector({})
)", )",
args.embeddings_table args.embeddings_table, args.factors
), ),
&[], &[],
) )

View File

@@ -1,4 +1,5 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use bytes::BytesMut;
use clap::Parser; use clap::Parser;
use deadpool_postgres::{Config, Pool, Runtime}; use deadpool_postgres::{Config, Pool, Runtime};
use dotenv::dotenv; use dotenv::dotenv;
@@ -57,6 +58,7 @@ async fn get_column_types(
("bigint", _) => Type::INT8, ("bigint", _) => Type::INT8,
("ARRAY", "_float4") => Type::FLOAT4_ARRAY, ("ARRAY", "_float4") => Type::FLOAT4_ARRAY,
("ARRAY", "_float8") => Type::FLOAT8_ARRAY, ("ARRAY", "_float8") => Type::FLOAT8_ARRAY,
("USER-DEFINED", "vector") => Type::BYTEA, // pgvector type maps to bytea
_ => anyhow::bail!("Unsupported column type: {} ({})", data_type, udt_name), _ => anyhow::bail!("Unsupported column type: {} ({})", data_type, udt_name),
}; };
types.push(pg_type); types.push(pg_type);
@@ -133,31 +135,44 @@ async fn save_embeddings(
let client = pool.get().await?; let client = pool.get().await?;
// Get the column types from the target table // Get the column types from the target table
info!("Getting column types for target table...");
let types = get_column_types( let types = get_column_types(
&client, &client,
&args.target_table, &args.target_table,
&[&args.target_id_column, &args.target_embedding_column], &[&args.target_id_column, &args.target_embedding_column],
) )
.await?; .await?;
info!("Target table column types: {:?}", types);
// Create a temporary table with the same structure // Create a temporary table with the same structure
info!("Creating temporary table...");
let id_type = match types[0] {
Type::INT4 => "INTEGER",
Type::INT8 => "BIGINT",
_ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]),
};
let create_temp = format!( let create_temp = format!(
"CREATE TEMP TABLE temp_embeddings AS SELECT * FROM {} WHERE 1=0", "CREATE TEMP TABLE temp_embeddings (
args.target_table {} {},
{} vector({})
)",
args.target_id_column, id_type, args.target_embedding_column, args.factors
); );
info!("Temp table creation SQL: {}", create_temp);
client.execute(&create_temp, &[]).await?; client.execute(&create_temp, &[]).await?;
info!("Temporary table created successfully");
// Copy data into temporary table // Copy data into temporary table
let copy_query = format!( info!("Starting COPY operation...");
"COPY temp_embeddings ({id_col}, {embedding_col}) FROM STDIN (FORMAT binary)", let copy_query =
id_col = args.target_id_column, format!("COPY temp_embeddings (item_id, embedding) FROM STDIN WITH (FORMAT BINARY)",);
embedding_col = args.target_embedding_column, info!("COPY query: {}", copy_query);
);
let mut writer = Box::pin(BinaryCopyInWriter::new( let mut writer = Box::pin(BinaryCopyInWriter::new(
client.copy_in(&copy_query).await?, client.copy_in(&copy_query).await?,
&types, &[Type::INT4, Type::BYTEA],
)); ));
info!("Binary writer initialized");
let mut valid_embeddings = 0; let mut valid_embeddings = 0;
let mut invalid_embeddings = 0; let mut invalid_embeddings = 0;
@@ -176,33 +191,48 @@ async fn save_embeddings(
} }
valid_embeddings += 1; valid_embeddings += 1;
let id_value: Box<dyn tokio_postgres::types::ToSql + Sync> = match types[0] { let vector_data: Vec<f32> = factors.to_vec();
Type::INT4 => Box::new(item_id),
Type::INT8 => Box::new(item_id as i64), // Create a buffer for the binary data
_ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]), let mut buf = BytesMut::new();
}; buf.extend_from_slice(&(vector_data.len() as u16).to_be_bytes()); // Number of dimensions in big-endian
let factors: &[f32] = factors; buf.extend_from_slice(&0u16.to_be_bytes()); // Unused padding in big-endian
writer.as_mut().write(&[&*id_value, &factors]).await?; for &value in &vector_data {
buf.extend_from_slice(&value.to_be_bytes()); // Float values in big-endian
}
let binary_data: Vec<u8> = buf.freeze().to_vec();
info!(
"Writing embedding for item_id={}, factors.len()={}",
item_id,
factors.len()
);
writer.as_mut().write(&[&item_id, &binary_data]).await?;
} }
info!("Finishing COPY operation...");
writer.as_mut().finish().await?; writer.as_mut().finish().await?;
info!("COPY operation completed");
// Insert from temp table with ON CONFLICT DO UPDATE // Insert from temp table with ON CONFLICT DO UPDATE
info!("Inserting from temp table to target table...");
let insert_query = format!( let insert_query = format!(
r#" "INSERT INTO {target_table} ({id_col}, {embedding_col})
INSERT INTO {target_table} ({id_col}, {embedding_col}) SELECT {id_col}, {embedding_col} FROM temp_embeddings
SELECT {id_col}, {embedding_col} FROM temp_embeddings ON CONFLICT ({id_col}) DO UPDATE
ON CONFLICT ({id_col}) DO UPDATE SET {embedding_col} = EXCLUDED.{embedding_col}",
SET {embedding_col} = EXCLUDED.{embedding_col}
"#,
target_table = args.target_table, target_table = args.target_table,
id_col = args.target_id_column, id_col = args.target_id_column,
embedding_col = args.target_embedding_column, embedding_col = args.target_embedding_column,
); );
info!("Insert query: {}", insert_query);
client.execute(&insert_query, &[]).await?; client.execute(&insert_query, &[]).await?;
info!("Insert completed");
// Clean up temp table // Clean up temp table
info!("Cleaning up temporary table...");
client.execute("DROP TABLE temp_embeddings", &[]).await?; client.execute("DROP TABLE temp_embeddings", &[]).await?;
info!("Temporary table dropped");
info!( info!(
"Saved {} valid embeddings, skipped {} invalid embeddings", "Saved {} valid embeddings, skipped {} invalid embeddings",
@@ -283,6 +313,12 @@ async fn main() -> Result<()> {
pretty_env_logger::init(); pretty_env_logger::init();
let args = Args::parse(); let args = Args::parse();
// Set up Ctrl+C handler for immediate exit
ctrlc::set_handler(|| {
info!("Received interrupt signal, exiting immediately...");
std::process::exit(1);
})?;
let pool = create_pool().await?; let pool = create_pool().await?;
// Validate schema before proceeding // Validate schema before proceeding