write vector binary type to database

This commit is contained in:
Dylan Knutson
2024-12-28 19:04:43 +00:00
parent 2b1865f3d4
commit bc88c54cb0
4 changed files with 80 additions and 24 deletions

1
Cargo.lock generated
View File

@@ -1210,6 +1210,7 @@ name = "mf-fitter"
version = "0.1.0"
dependencies = [
"anyhow",
"bytes",
"clap",
"ctrlc",
"deadpool-postgres",

View File

@@ -21,3 +21,4 @@ rand = "0.8"
rand_distr = "0.4"
tokio = { version = "1.35", features = ["full"] }
tokio-postgres = "0.7"
bytes = "1.5.0"

View File

@@ -43,6 +43,10 @@ struct Args {
/// Target table name for embeddings
#[arg(long, default_value = "item_embeddings")]
embeddings_table: String,
/// Number of factors (dimensions) for the embeddings
#[arg(long, default_value = "16")]
factors: i32,
}
async fn create_pool() -> Result<Pool> {
@@ -70,6 +74,12 @@ async fn main() -> Result<()> {
let pool = create_pool().await?;
let client = pool.get().await?;
// Enable pgvector extension
info!("Enabling pgvector extension...");
client
.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])
.await?;
// Create tables
info!("Creating tables...");
client
@@ -86,14 +96,22 @@ async fn main() -> Result<()> {
)
.await?;
// Drop and recreate embeddings table to ensure correct vector dimensions
client
.execute(
&format!("DROP TABLE IF EXISTS {}", args.embeddings_table),
&[],
)
.await?;
client
.execute(
&format!(
"CREATE TABLE IF NOT EXISTS {} (
"CREATE TABLE {} (
item_id INTEGER PRIMARY KEY,
embedding FLOAT4[]
embedding vector({})
)",
args.embeddings_table
args.embeddings_table, args.factors
),
&[],
)

View File

@@ -1,4 +1,5 @@
use anyhow::{Context, Result};
use bytes::BytesMut;
use clap::Parser;
use deadpool_postgres::{Config, Pool, Runtime};
use dotenv::dotenv;
@@ -57,6 +58,7 @@ async fn get_column_types(
("bigint", _) => Type::INT8,
("ARRAY", "_float4") => Type::FLOAT4_ARRAY,
("ARRAY", "_float8") => Type::FLOAT8_ARRAY,
("USER-DEFINED", "vector") => Type::BYTEA, // pgvector type maps to bytea
_ => anyhow::bail!("Unsupported column type: {} ({})", data_type, udt_name),
};
types.push(pg_type);
@@ -133,31 +135,44 @@ async fn save_embeddings(
let client = pool.get().await?;
// Get the column types from the target table
info!("Getting column types for target table...");
let types = get_column_types(
&client,
&args.target_table,
&[&args.target_id_column, &args.target_embedding_column],
)
.await?;
info!("Target table column types: {:?}", types);
// Create a temporary table with the same structure
info!("Creating temporary table...");
let id_type = match types[0] {
Type::INT4 => "INTEGER",
Type::INT8 => "BIGINT",
_ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]),
};
let create_temp = format!(
"CREATE TEMP TABLE temp_embeddings AS SELECT * FROM {} WHERE 1=0",
args.target_table
"CREATE TEMP TABLE temp_embeddings (
{} {},
{} vector({})
)",
args.target_id_column, id_type, args.target_embedding_column, args.factors
);
info!("Temp table creation SQL: {}", create_temp);
client.execute(&create_temp, &[]).await?;
info!("Temporary table created successfully");
// Copy data into temporary table
let copy_query = format!(
"COPY temp_embeddings ({id_col}, {embedding_col}) FROM STDIN (FORMAT binary)",
id_col = args.target_id_column,
embedding_col = args.target_embedding_column,
);
info!("Starting COPY operation...");
let copy_query =
format!("COPY temp_embeddings (item_id, embedding) FROM STDIN WITH (FORMAT BINARY)",);
info!("COPY query: {}", copy_query);
let mut writer = Box::pin(BinaryCopyInWriter::new(
client.copy_in(&copy_query).await?,
&types,
&[Type::INT4, Type::BYTEA],
));
info!("Binary writer initialized");
let mut valid_embeddings = 0;
let mut invalid_embeddings = 0;
@@ -176,33 +191,48 @@ async fn save_embeddings(
}
valid_embeddings += 1;
let id_value: Box<dyn tokio_postgres::types::ToSql + Sync> = match types[0] {
Type::INT4 => Box::new(item_id),
Type::INT8 => Box::new(item_id as i64),
_ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]),
};
let factors: &[f32] = factors;
writer.as_mut().write(&[&*id_value, &factors]).await?;
let vector_data: Vec<f32> = factors.to_vec();
// Create a buffer for the binary data
let mut buf = BytesMut::new();
buf.extend_from_slice(&(vector_data.len() as u16).to_be_bytes()); // Number of dimensions in big-endian
buf.extend_from_slice(&0u16.to_be_bytes()); // Unused padding in big-endian
for &value in &vector_data {
buf.extend_from_slice(&value.to_be_bytes()); // Float values in big-endian
}
let binary_data: Vec<u8> = buf.freeze().to_vec();
info!(
"Writing embedding for item_id={}, factors.len()={}",
item_id,
factors.len()
);
writer.as_mut().write(&[&item_id, &binary_data]).await?;
}
info!("Finishing COPY operation...");
writer.as_mut().finish().await?;
info!("COPY operation completed");
// Insert from temp table with ON CONFLICT DO UPDATE
info!("Inserting from temp table to target table...");
let insert_query = format!(
r#"
INSERT INTO {target_table} ({id_col}, {embedding_col})
"INSERT INTO {target_table} ({id_col}, {embedding_col})
SELECT {id_col}, {embedding_col} FROM temp_embeddings
ON CONFLICT ({id_col}) DO UPDATE
SET {embedding_col} = EXCLUDED.{embedding_col}
"#,
SET {embedding_col} = EXCLUDED.{embedding_col}",
target_table = args.target_table,
id_col = args.target_id_column,
embedding_col = args.target_embedding_column,
);
info!("Insert query: {}", insert_query);
client.execute(&insert_query, &[]).await?;
info!("Insert completed");
// Clean up temp table
info!("Cleaning up temporary table...");
client.execute("DROP TABLE temp_embeddings", &[]).await?;
info!("Temporary table dropped");
info!(
"Saved {} valid embeddings, skipped {} invalid embeddings",
@@ -283,6 +313,12 @@ async fn main() -> Result<()> {
pretty_env_logger::init();
let args = Args::parse();
// Set up Ctrl+C handler for immediate exit
ctrlc::set_handler(|| {
info!("Received interrupt signal, exiting immediately...");
std::process::exit(1);
})?;
let pool = create_pool().await?;
// Validate schema before proceeding