write vector binary type to database
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -1210,6 +1210,7 @@ name = "mf-fitter"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"bytes",
|
||||||
"clap",
|
"clap",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
"deadpool-postgres",
|
"deadpool-postgres",
|
||||||
|
|||||||
@@ -21,3 +21,4 @@ rand = "0.8"
|
|||||||
rand_distr = "0.4"
|
rand_distr = "0.4"
|
||||||
tokio = { version = "1.35", features = ["full"] }
|
tokio = { version = "1.35", features = ["full"] }
|
||||||
tokio-postgres = "0.7"
|
tokio-postgres = "0.7"
|
||||||
|
bytes = "1.5.0"
|
||||||
|
|||||||
@@ -43,6 +43,10 @@ struct Args {
|
|||||||
/// Target table name for embeddings
|
/// Target table name for embeddings
|
||||||
#[arg(long, default_value = "item_embeddings")]
|
#[arg(long, default_value = "item_embeddings")]
|
||||||
embeddings_table: String,
|
embeddings_table: String,
|
||||||
|
|
||||||
|
/// Number of factors (dimensions) for the embeddings
|
||||||
|
#[arg(long, default_value = "16")]
|
||||||
|
factors: i32,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create_pool() -> Result<Pool> {
|
async fn create_pool() -> Result<Pool> {
|
||||||
@@ -70,6 +74,12 @@ async fn main() -> Result<()> {
|
|||||||
let pool = create_pool().await?;
|
let pool = create_pool().await?;
|
||||||
let client = pool.get().await?;
|
let client = pool.get().await?;
|
||||||
|
|
||||||
|
// Enable pgvector extension
|
||||||
|
info!("Enabling pgvector extension...");
|
||||||
|
client
|
||||||
|
.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])
|
||||||
|
.await?;
|
||||||
|
|
||||||
// Create tables
|
// Create tables
|
||||||
info!("Creating tables...");
|
info!("Creating tables...");
|
||||||
client
|
client
|
||||||
@@ -86,14 +96,22 @@ async fn main() -> Result<()> {
|
|||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// Drop and recreate embeddings table to ensure correct vector dimensions
|
||||||
|
client
|
||||||
|
.execute(
|
||||||
|
&format!("DROP TABLE IF EXISTS {}", args.embeddings_table),
|
||||||
|
&[],
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
client
|
client
|
||||||
.execute(
|
.execute(
|
||||||
&format!(
|
&format!(
|
||||||
"CREATE TABLE IF NOT EXISTS {} (
|
"CREATE TABLE {} (
|
||||||
item_id INTEGER PRIMARY KEY,
|
item_id INTEGER PRIMARY KEY,
|
||||||
embedding FLOAT4[]
|
embedding vector({})
|
||||||
)",
|
)",
|
||||||
args.embeddings_table
|
args.embeddings_table, args.factors
|
||||||
),
|
),
|
||||||
&[],
|
&[],
|
||||||
)
|
)
|
||||||
|
|||||||
78
src/main.rs
78
src/main.rs
@@ -1,4 +1,5 @@
|
|||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
|
use bytes::BytesMut;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use deadpool_postgres::{Config, Pool, Runtime};
|
use deadpool_postgres::{Config, Pool, Runtime};
|
||||||
use dotenv::dotenv;
|
use dotenv::dotenv;
|
||||||
@@ -57,6 +58,7 @@ async fn get_column_types(
|
|||||||
("bigint", _) => Type::INT8,
|
("bigint", _) => Type::INT8,
|
||||||
("ARRAY", "_float4") => Type::FLOAT4_ARRAY,
|
("ARRAY", "_float4") => Type::FLOAT4_ARRAY,
|
||||||
("ARRAY", "_float8") => Type::FLOAT8_ARRAY,
|
("ARRAY", "_float8") => Type::FLOAT8_ARRAY,
|
||||||
|
("USER-DEFINED", "vector") => Type::BYTEA, // pgvector type maps to bytea
|
||||||
_ => anyhow::bail!("Unsupported column type: {} ({})", data_type, udt_name),
|
_ => anyhow::bail!("Unsupported column type: {} ({})", data_type, udt_name),
|
||||||
};
|
};
|
||||||
types.push(pg_type);
|
types.push(pg_type);
|
||||||
@@ -133,31 +135,44 @@ async fn save_embeddings(
|
|||||||
let client = pool.get().await?;
|
let client = pool.get().await?;
|
||||||
|
|
||||||
// Get the column types from the target table
|
// Get the column types from the target table
|
||||||
|
info!("Getting column types for target table...");
|
||||||
let types = get_column_types(
|
let types = get_column_types(
|
||||||
&client,
|
&client,
|
||||||
&args.target_table,
|
&args.target_table,
|
||||||
&[&args.target_id_column, &args.target_embedding_column],
|
&[&args.target_id_column, &args.target_embedding_column],
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
info!("Target table column types: {:?}", types);
|
||||||
|
|
||||||
// Create a temporary table with the same structure
|
// Create a temporary table with the same structure
|
||||||
|
info!("Creating temporary table...");
|
||||||
|
let id_type = match types[0] {
|
||||||
|
Type::INT4 => "INTEGER",
|
||||||
|
Type::INT8 => "BIGINT",
|
||||||
|
_ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]),
|
||||||
|
};
|
||||||
let create_temp = format!(
|
let create_temp = format!(
|
||||||
"CREATE TEMP TABLE temp_embeddings AS SELECT * FROM {} WHERE 1=0",
|
"CREATE TEMP TABLE temp_embeddings (
|
||||||
args.target_table
|
{} {},
|
||||||
|
{} vector({})
|
||||||
|
)",
|
||||||
|
args.target_id_column, id_type, args.target_embedding_column, args.factors
|
||||||
);
|
);
|
||||||
|
info!("Temp table creation SQL: {}", create_temp);
|
||||||
client.execute(&create_temp, &[]).await?;
|
client.execute(&create_temp, &[]).await?;
|
||||||
|
info!("Temporary table created successfully");
|
||||||
|
|
||||||
// Copy data into temporary table
|
// Copy data into temporary table
|
||||||
let copy_query = format!(
|
info!("Starting COPY operation...");
|
||||||
"COPY temp_embeddings ({id_col}, {embedding_col}) FROM STDIN (FORMAT binary)",
|
let copy_query =
|
||||||
id_col = args.target_id_column,
|
format!("COPY temp_embeddings (item_id, embedding) FROM STDIN WITH (FORMAT BINARY)",);
|
||||||
embedding_col = args.target_embedding_column,
|
info!("COPY query: {}", copy_query);
|
||||||
);
|
|
||||||
|
|
||||||
let mut writer = Box::pin(BinaryCopyInWriter::new(
|
let mut writer = Box::pin(BinaryCopyInWriter::new(
|
||||||
client.copy_in(©_query).await?,
|
client.copy_in(©_query).await?,
|
||||||
&types,
|
&[Type::INT4, Type::BYTEA],
|
||||||
));
|
));
|
||||||
|
info!("Binary writer initialized");
|
||||||
|
|
||||||
let mut valid_embeddings = 0;
|
let mut valid_embeddings = 0;
|
||||||
let mut invalid_embeddings = 0;
|
let mut invalid_embeddings = 0;
|
||||||
@@ -176,33 +191,48 @@ async fn save_embeddings(
|
|||||||
}
|
}
|
||||||
|
|
||||||
valid_embeddings += 1;
|
valid_embeddings += 1;
|
||||||
let id_value: Box<dyn tokio_postgres::types::ToSql + Sync> = match types[0] {
|
let vector_data: Vec<f32> = factors.to_vec();
|
||||||
Type::INT4 => Box::new(item_id),
|
|
||||||
Type::INT8 => Box::new(item_id as i64),
|
// Create a buffer for the binary data
|
||||||
_ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]),
|
let mut buf = BytesMut::new();
|
||||||
};
|
buf.extend_from_slice(&(vector_data.len() as u16).to_be_bytes()); // Number of dimensions in big-endian
|
||||||
let factors: &[f32] = factors;
|
buf.extend_from_slice(&0u16.to_be_bytes()); // Unused padding in big-endian
|
||||||
writer.as_mut().write(&[&*id_value, &factors]).await?;
|
for &value in &vector_data {
|
||||||
|
buf.extend_from_slice(&value.to_be_bytes()); // Float values in big-endian
|
||||||
|
}
|
||||||
|
let binary_data: Vec<u8> = buf.freeze().to_vec();
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Writing embedding for item_id={}, factors.len()={}",
|
||||||
|
item_id,
|
||||||
|
factors.len()
|
||||||
|
);
|
||||||
|
writer.as_mut().write(&[&item_id, &binary_data]).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
info!("Finishing COPY operation...");
|
||||||
writer.as_mut().finish().await?;
|
writer.as_mut().finish().await?;
|
||||||
|
info!("COPY operation completed");
|
||||||
|
|
||||||
// Insert from temp table with ON CONFLICT DO UPDATE
|
// Insert from temp table with ON CONFLICT DO UPDATE
|
||||||
|
info!("Inserting from temp table to target table...");
|
||||||
let insert_query = format!(
|
let insert_query = format!(
|
||||||
r#"
|
"INSERT INTO {target_table} ({id_col}, {embedding_col})
|
||||||
INSERT INTO {target_table} ({id_col}, {embedding_col})
|
SELECT {id_col}, {embedding_col} FROM temp_embeddings
|
||||||
SELECT {id_col}, {embedding_col} FROM temp_embeddings
|
ON CONFLICT ({id_col}) DO UPDATE
|
||||||
ON CONFLICT ({id_col}) DO UPDATE
|
SET {embedding_col} = EXCLUDED.{embedding_col}",
|
||||||
SET {embedding_col} = EXCLUDED.{embedding_col}
|
|
||||||
"#,
|
|
||||||
target_table = args.target_table,
|
target_table = args.target_table,
|
||||||
id_col = args.target_id_column,
|
id_col = args.target_id_column,
|
||||||
embedding_col = args.target_embedding_column,
|
embedding_col = args.target_embedding_column,
|
||||||
);
|
);
|
||||||
|
info!("Insert query: {}", insert_query);
|
||||||
client.execute(&insert_query, &[]).await?;
|
client.execute(&insert_query, &[]).await?;
|
||||||
|
info!("Insert completed");
|
||||||
|
|
||||||
// Clean up temp table
|
// Clean up temp table
|
||||||
|
info!("Cleaning up temporary table...");
|
||||||
client.execute("DROP TABLE temp_embeddings", &[]).await?;
|
client.execute("DROP TABLE temp_embeddings", &[]).await?;
|
||||||
|
info!("Temporary table dropped");
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Saved {} valid embeddings, skipped {} invalid embeddings",
|
"Saved {} valid embeddings, skipped {} invalid embeddings",
|
||||||
@@ -283,6 +313,12 @@ async fn main() -> Result<()> {
|
|||||||
pretty_env_logger::init();
|
pretty_env_logger::init();
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
|
// Set up Ctrl+C handler for immediate exit
|
||||||
|
ctrlc::set_handler(|| {
|
||||||
|
info!("Received interrupt signal, exiting immediately...");
|
||||||
|
std::process::exit(1);
|
||||||
|
})?;
|
||||||
|
|
||||||
let pool = create_pool().await?;
|
let pool = create_pool().await?;
|
||||||
|
|
||||||
// Validate schema before proceeding
|
// Validate schema before proceeding
|
||||||
|
|||||||
Reference in New Issue
Block a user