write vector binary type to database
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -1210,6 +1210,7 @@ name = "mf-fitter"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
"clap",
|
||||
"ctrlc",
|
||||
"deadpool-postgres",
|
||||
|
||||
@@ -21,3 +21,4 @@ rand = "0.8"
|
||||
rand_distr = "0.4"
|
||||
tokio = { version = "1.35", features = ["full"] }
|
||||
tokio-postgres = "0.7"
|
||||
bytes = "1.5.0"
|
||||
|
||||
@@ -43,6 +43,10 @@ struct Args {
|
||||
/// Target table name for embeddings
|
||||
#[arg(long, default_value = "item_embeddings")]
|
||||
embeddings_table: String,
|
||||
|
||||
/// Number of factors (dimensions) for the embeddings
|
||||
#[arg(long, default_value = "16")]
|
||||
factors: i32,
|
||||
}
|
||||
|
||||
async fn create_pool() -> Result<Pool> {
|
||||
@@ -70,6 +74,12 @@ async fn main() -> Result<()> {
|
||||
let pool = create_pool().await?;
|
||||
let client = pool.get().await?;
|
||||
|
||||
// Enable pgvector extension
|
||||
info!("Enabling pgvector extension...");
|
||||
client
|
||||
.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])
|
||||
.await?;
|
||||
|
||||
// Create tables
|
||||
info!("Creating tables...");
|
||||
client
|
||||
@@ -86,14 +96,22 @@ async fn main() -> Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Drop and recreate embeddings table to ensure correct vector dimensions
|
||||
client
|
||||
.execute(
|
||||
&format!("DROP TABLE IF EXISTS {}", args.embeddings_table),
|
||||
&[],
|
||||
)
|
||||
.await?;
|
||||
|
||||
client
|
||||
.execute(
|
||||
&format!(
|
||||
"CREATE TABLE IF NOT EXISTS {} (
|
||||
"CREATE TABLE {} (
|
||||
item_id INTEGER PRIMARY KEY,
|
||||
embedding FLOAT4[]
|
||||
embedding vector({})
|
||||
)",
|
||||
args.embeddings_table
|
||||
args.embeddings_table, args.factors
|
||||
),
|
||||
&[],
|
||||
)
|
||||
|
||||
74
src/main.rs
74
src/main.rs
@@ -1,4 +1,5 @@
|
||||
use anyhow::{Context, Result};
|
||||
use bytes::BytesMut;
|
||||
use clap::Parser;
|
||||
use deadpool_postgres::{Config, Pool, Runtime};
|
||||
use dotenv::dotenv;
|
||||
@@ -57,6 +58,7 @@ async fn get_column_types(
|
||||
("bigint", _) => Type::INT8,
|
||||
("ARRAY", "_float4") => Type::FLOAT4_ARRAY,
|
||||
("ARRAY", "_float8") => Type::FLOAT8_ARRAY,
|
||||
("USER-DEFINED", "vector") => Type::BYTEA, // pgvector type maps to bytea
|
||||
_ => anyhow::bail!("Unsupported column type: {} ({})", data_type, udt_name),
|
||||
};
|
||||
types.push(pg_type);
|
||||
@@ -133,31 +135,44 @@ async fn save_embeddings(
|
||||
let client = pool.get().await?;
|
||||
|
||||
// Get the column types from the target table
|
||||
info!("Getting column types for target table...");
|
||||
let types = get_column_types(
|
||||
&client,
|
||||
&args.target_table,
|
||||
&[&args.target_id_column, &args.target_embedding_column],
|
||||
)
|
||||
.await?;
|
||||
info!("Target table column types: {:?}", types);
|
||||
|
||||
// Create a temporary table with the same structure
|
||||
info!("Creating temporary table...");
|
||||
let id_type = match types[0] {
|
||||
Type::INT4 => "INTEGER",
|
||||
Type::INT8 => "BIGINT",
|
||||
_ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]),
|
||||
};
|
||||
let create_temp = format!(
|
||||
"CREATE TEMP TABLE temp_embeddings AS SELECT * FROM {} WHERE 1=0",
|
||||
args.target_table
|
||||
"CREATE TEMP TABLE temp_embeddings (
|
||||
{} {},
|
||||
{} vector({})
|
||||
)",
|
||||
args.target_id_column, id_type, args.target_embedding_column, args.factors
|
||||
);
|
||||
info!("Temp table creation SQL: {}", create_temp);
|
||||
client.execute(&create_temp, &[]).await?;
|
||||
info!("Temporary table created successfully");
|
||||
|
||||
// Copy data into temporary table
|
||||
let copy_query = format!(
|
||||
"COPY temp_embeddings ({id_col}, {embedding_col}) FROM STDIN (FORMAT binary)",
|
||||
id_col = args.target_id_column,
|
||||
embedding_col = args.target_embedding_column,
|
||||
);
|
||||
info!("Starting COPY operation...");
|
||||
let copy_query =
|
||||
format!("COPY temp_embeddings (item_id, embedding) FROM STDIN WITH (FORMAT BINARY)",);
|
||||
info!("COPY query: {}", copy_query);
|
||||
|
||||
let mut writer = Box::pin(BinaryCopyInWriter::new(
|
||||
client.copy_in(©_query).await?,
|
||||
&types,
|
||||
&[Type::INT4, Type::BYTEA],
|
||||
));
|
||||
info!("Binary writer initialized");
|
||||
|
||||
let mut valid_embeddings = 0;
|
||||
let mut invalid_embeddings = 0;
|
||||
@@ -176,33 +191,48 @@ async fn save_embeddings(
|
||||
}
|
||||
|
||||
valid_embeddings += 1;
|
||||
let id_value: Box<dyn tokio_postgres::types::ToSql + Sync> = match types[0] {
|
||||
Type::INT4 => Box::new(item_id),
|
||||
Type::INT8 => Box::new(item_id as i64),
|
||||
_ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]),
|
||||
};
|
||||
let factors: &[f32] = factors;
|
||||
writer.as_mut().write(&[&*id_value, &factors]).await?;
|
||||
let vector_data: Vec<f32> = factors.to_vec();
|
||||
|
||||
// Create a buffer for the binary data
|
||||
let mut buf = BytesMut::new();
|
||||
buf.extend_from_slice(&(vector_data.len() as u16).to_be_bytes()); // Number of dimensions in big-endian
|
||||
buf.extend_from_slice(&0u16.to_be_bytes()); // Unused padding in big-endian
|
||||
for &value in &vector_data {
|
||||
buf.extend_from_slice(&value.to_be_bytes()); // Float values in big-endian
|
||||
}
|
||||
let binary_data: Vec<u8> = buf.freeze().to_vec();
|
||||
|
||||
info!(
|
||||
"Writing embedding for item_id={}, factors.len()={}",
|
||||
item_id,
|
||||
factors.len()
|
||||
);
|
||||
writer.as_mut().write(&[&item_id, &binary_data]).await?;
|
||||
}
|
||||
|
||||
info!("Finishing COPY operation...");
|
||||
writer.as_mut().finish().await?;
|
||||
info!("COPY operation completed");
|
||||
|
||||
// Insert from temp table with ON CONFLICT DO UPDATE
|
||||
info!("Inserting from temp table to target table...");
|
||||
let insert_query = format!(
|
||||
r#"
|
||||
INSERT INTO {target_table} ({id_col}, {embedding_col})
|
||||
"INSERT INTO {target_table} ({id_col}, {embedding_col})
|
||||
SELECT {id_col}, {embedding_col} FROM temp_embeddings
|
||||
ON CONFLICT ({id_col}) DO UPDATE
|
||||
SET {embedding_col} = EXCLUDED.{embedding_col}
|
||||
"#,
|
||||
SET {embedding_col} = EXCLUDED.{embedding_col}",
|
||||
target_table = args.target_table,
|
||||
id_col = args.target_id_column,
|
||||
embedding_col = args.target_embedding_column,
|
||||
);
|
||||
info!("Insert query: {}", insert_query);
|
||||
client.execute(&insert_query, &[]).await?;
|
||||
info!("Insert completed");
|
||||
|
||||
// Clean up temp table
|
||||
info!("Cleaning up temporary table...");
|
||||
client.execute("DROP TABLE temp_embeddings", &[]).await?;
|
||||
info!("Temporary table dropped");
|
||||
|
||||
info!(
|
||||
"Saved {} valid embeddings, skipped {} invalid embeddings",
|
||||
@@ -283,6 +313,12 @@ async fn main() -> Result<()> {
|
||||
pretty_env_logger::init();
|
||||
let args = Args::parse();
|
||||
|
||||
// Set up Ctrl+C handler for immediate exit
|
||||
ctrlc::set_handler(|| {
|
||||
info!("Received interrupt signal, exiting immediately...");
|
||||
std::process::exit(1);
|
||||
})?;
|
||||
|
||||
let pool = create_pool().await?;
|
||||
|
||||
// Validate schema before proceeding
|
||||
|
||||
Reference in New Issue
Block a user