remove float[] array support, only use vector

This commit is contained in:
Dylan Knutson
2024-12-28 19:58:46 +00:00
parent 75e7a4538d
commit 5430fdd501

View File

@@ -1,5 +1,4 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use bytes::BytesMut;
use clap::Parser; use clap::Parser;
use deadpool_postgres::{Config, Pool, Runtime}; use deadpool_postgres::{Config, Pool, Runtime};
use dotenv::dotenv; use dotenv::dotenv;
@@ -57,8 +56,6 @@ async fn get_column_types(
let pg_type = match (data_type, udt_name) { let pg_type = match (data_type, udt_name) {
("integer", _) => Type::INT4, ("integer", _) => Type::INT4,
("bigint", _) => Type::INT8, ("bigint", _) => Type::INT8,
("ARRAY", "_float4") => Type::FLOAT4_ARRAY,
("ARRAY", "_float8") => Type::FLOAT8_ARRAY,
("USER-DEFINED", type_name) => pg_types::get_pg_type(client, type_name).await?, ("USER-DEFINED", type_name) => pg_types::get_pg_type(client, type_name).await?,
_ => anyhow::bail!("Unsupported column type: {} ({})", data_type, udt_name), _ => anyhow::bail!("Unsupported column type: {} ({})", data_type, udt_name),
}; };
@@ -152,11 +149,14 @@ async fn save_embeddings(
&[&args.target_id_column, &args.target_embedding_column], &[&args.target_id_column, &args.target_embedding_column],
) )
.await?; .await?;
info!("Target table column types: {:?}", types);
if types.len() != 2 {
anyhow::bail!("Failed to get both column types from target table");
}
// Create a temporary table with the same structure // Create a temporary table with the same structure
info!("Creating temporary table..."); info!("Creating temporary table...");
let id_type = match types[0] { let id_type_str = match types[0] {
Type::INT4 => "INTEGER", Type::INT4 => "INTEGER",
Type::INT8 => "BIGINT", Type::INT8 => "BIGINT",
_ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]), _ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]),
@@ -166,7 +166,7 @@ async fn save_embeddings(
{} {}, {} {},
{} vector({}) {} vector({})
)", )",
args.target_id_column, id_type, args.target_embedding_column, args.factors args.target_id_column, id_type_str, args.target_embedding_column, args.factors
); );
info!("Temp table creation SQL: {}", create_temp); info!("Temp table creation SQL: {}", create_temp);
client.execute(&create_temp, &[]).await?; client.execute(&create_temp, &[]).await?;
@@ -327,6 +327,13 @@ async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Res
let udt_name: &str = row.get("udt_name"); let udt_name: &str = row.get("udt_name");
if udt_name == "vector" { if udt_name == "vector" {
let vector_dim: i32 = row.get("vector_dim"); let vector_dim: i32 = row.get("vector_dim");
if vector_dim <= 0 {
anyhow::bail!(
"Invalid vector dimension {} for column '{}'",
vector_dim,
args.target_embedding_column,
);
}
if vector_dim != args.factors as i32 { if vector_dim != args.factors as i32 {
anyhow::bail!( anyhow::bail!(
"Vector dimension mismatch: column '{}' has dimension {}, but factors is {}", "Vector dimension mismatch: column '{}' has dimension {}, but factors is {}",