From 5430fdd501bb3c8cd01e0aa7c935490989ef789f Mon Sep 17 00:00:00 2001 From: Dylan Knutson Date: Sat, 28 Dec 2024 19:58:46 +0000 Subject: [PATCH] remove float[] array support, only use vector --- src/bin/fit_model.rs | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/bin/fit_model.rs b/src/bin/fit_model.rs index 7be9a2c..bce43ec 100644 --- a/src/bin/fit_model.rs +++ b/src/bin/fit_model.rs @@ -1,5 +1,4 @@ use anyhow::{Context, Result}; -use bytes::BytesMut; use clap::Parser; use deadpool_postgres::{Config, Pool, Runtime}; use dotenv::dotenv; @@ -57,8 +56,6 @@ async fn get_column_types( let pg_type = match (data_type, udt_name) { ("integer", _) => Type::INT4, ("bigint", _) => Type::INT8, - ("ARRAY", "_float4") => Type::FLOAT4_ARRAY, - ("ARRAY", "_float8") => Type::FLOAT8_ARRAY, ("USER-DEFINED", type_name) => pg_types::get_pg_type(client, type_name).await?, _ => anyhow::bail!("Unsupported column type: {} ({})", data_type, udt_name), }; @@ -152,11 +149,14 @@ async fn save_embeddings( &[&args.target_id_column, &args.target_embedding_column], ) .await?; - info!("Target table column types: {:?}", types); + + if types.len() != 2 { + anyhow::bail!("Failed to get both column types from target table"); + } // Create a temporary table with the same structure info!("Creating temporary table..."); - let id_type = match types[0] { + let id_type_str = match types[0] { Type::INT4 => "INTEGER", Type::INT8 => "BIGINT", _ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]), @@ -166,7 +166,7 @@ async fn save_embeddings( {} {}, {} vector({}) )", - args.target_id_column, id_type, args.target_embedding_column, args.factors + args.target_id_column, id_type_str, args.target_embedding_column, args.factors ); info!("Temp table creation SQL: {}", create_temp); client.execute(&create_temp, &[]).await?; @@ -327,6 +327,13 @@ async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Res let udt_name: &str = row.get("udt_name"); if udt_name == "vector" { let vector_dim: i32 = row.get("vector_dim"); + if vector_dim <= 0 { + anyhow::bail!( + "Invalid vector dimension {} for column '{}'", + vector_dim, + args.target_embedding_column, + ); + } if vector_dim != args.factors as i32 { anyhow::bail!( "Vector dimension mismatch: column '{}' has dimension {}, but factors is {}",