use COPY for exporting data into temp table

This commit is contained in:
Dylan Knutson
2024-12-28 18:32:18 +00:00
parent c4e79a36f9
commit 2b1865f3d4

View File

@@ -7,6 +7,7 @@ use libmf::{Loss, Matrix, Model};
use log::info; use log::info;
use std::collections::HashSet; use std::collections::HashSet;
use std::env; use std::env;
use tokio_postgres::binary_copy::BinaryCopyInWriter;
use tokio_postgres::{types::Type, NoTls}; use tokio_postgres::{types::Type, NoTls};
mod args; mod args;
@@ -35,23 +36,28 @@ async fn get_column_types(
) -> Result<Vec<Type>> { ) -> Result<Vec<Type>> {
let column_list = columns.join("', '"); let column_list = columns.join("', '");
let query = format!( let query = format!(
"SELECT data_type \ "SELECT data_type, udt_name \
FROM information_schema.columns \ FROM information_schema.columns \
WHERE table_schema = 'public' \ WHERE table_schema = 'public' \
AND table_name = $1 \ AND table_name = $1 \
AND column_name IN ('{}') \ AND column_name IN ('{}') \
ORDER BY column_name", ORDER BY array_position(ARRAY['{}'], column_name)",
column_list column_list, column_list
); );
info!("Executing type query: {}", query);
let rows = client.query(&query, &[&table]).await?; let rows = client.query(&query, &[&table]).await?;
info!("Got types: {:?}", rows);
let mut types = Vec::new(); let mut types = Vec::new();
for row in rows { for row in rows {
let data_type: &str = row.get(0); let data_type: &str = row.get(0);
let pg_type = match data_type { let udt_name: &str = row.get(1);
"integer" => Type::INT4, let pg_type = match (data_type, udt_name) {
"bigint" => Type::INT8, ("integer", _) => Type::INT4,
_ => anyhow::bail!("Unsupported column type: {}", data_type), ("bigint", _) => Type::INT8,
("ARRAY", "_float4") => Type::FLOAT4_ARRAY,
("ARRAY", "_float8") => Type::FLOAT8_ARRAY,
_ => anyhow::bail!("Unsupported column type: {} ({})", data_type, udt_name),
}; };
types.push(pg_type); types.push(pg_type);
} }
@@ -126,10 +132,35 @@ async fn save_embeddings(
) -> Result<()> { ) -> Result<()> {
let client = pool.get().await?; let client = pool.get().await?;
// Get the column types from the target table
let types = get_column_types(
&client,
&args.target_table,
&[&args.target_id_column, &args.target_embedding_column],
)
.await?;
// Create a temporary table with the same structure
let create_temp = format!(
"CREATE TEMP TABLE temp_embeddings AS SELECT * FROM {} WHERE 1=0",
args.target_table
);
client.execute(&create_temp, &[]).await?;
// Copy data into temporary table
let copy_query = format!(
"COPY temp_embeddings ({id_col}, {embedding_col}) FROM STDIN (FORMAT binary)",
id_col = args.target_id_column,
embedding_col = args.target_embedding_column,
);
let mut writer = Box::pin(BinaryCopyInWriter::new(
client.copy_in(&copy_query).await?,
&types,
));
let mut valid_embeddings = 0; let mut valid_embeddings = 0;
let mut invalid_embeddings = 0; let mut invalid_embeddings = 0;
let batch_size = 128;
let mut current_batch = Vec::with_capacity(batch_size);
// Process factors for items that appear in the source table // Process factors for items that appear in the source table
for (idx, factors) in model.q_iter().enumerate() { for (idx, factors) in model.q_iter().enumerate() {
@@ -145,33 +176,33 @@ async fn save_embeddings(
} }
valid_embeddings += 1; valid_embeddings += 1;
current_batch.push((item_id as i64, factors)); let id_value: Box<dyn tokio_postgres::types::ToSql + Sync> = match types[0] {
Type::INT4 => Box::new(item_id),
// When batch is full, save it Type::INT8 => Box::new(item_id as i64),
if current_batch.len() >= batch_size { _ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]),
save_batch( };
&client, let factors: &[f32] = factors;
&args.target_table, writer.as_mut().write(&[&*id_value, &factors]).await?;
&args.target_id_column,
&args.target_embedding_column,
&current_batch,
)
.await?;
current_batch.clear();
}
} }
// Save any remaining items in the last batch writer.as_mut().finish().await?;
if !current_batch.is_empty() {
save_batch( // Insert from temp table with ON CONFLICT DO UPDATE
&client, let insert_query = format!(
&args.target_table, r#"
&args.target_id_column, INSERT INTO {target_table} ({id_col}, {embedding_col})
&args.target_embedding_column, SELECT {id_col}, {embedding_col} FROM temp_embeddings
&current_batch, ON CONFLICT ({id_col}) DO UPDATE
) SET {embedding_col} = EXCLUDED.{embedding_col}
.await?; "#,
} target_table = args.target_table,
id_col = args.target_id_column,
embedding_col = args.target_embedding_column,
);
client.execute(&insert_query, &[]).await?;
// Clean up temp table
client.execute("DROP TABLE temp_embeddings", &[]).await?;
info!( info!(
"Saved {} valid embeddings, skipped {} invalid embeddings", "Saved {} valid embeddings, skipped {} invalid embeddings",
@@ -181,48 +212,6 @@ async fn save_embeddings(
Ok(()) Ok(())
} }
async fn save_batch(
client: &deadpool_postgres::Client,
target_table: &str,
target_id_column: &str,
target_embedding_column: &str,
batch_values: &[(i64, &[f32])],
) -> Result<()> {
// Build the batch insert query
let placeholders: Vec<String> = (0..batch_values.len())
.map(|i| format!("(${}::int8, ${}::float4[])", i * 2 + 1, i * 2 + 2))
.collect();
let query = format!(
r#"
INSERT INTO {target_table} ({target_id_column}, {target_embedding_column})
VALUES {placeholders}
ON CONFLICT ({target_id_column})
DO UPDATE SET {target_embedding_column} = EXCLUDED.{target_embedding_column}
"#,
placeholders = placeholders.join(",")
);
// info!("Executing query: {}", query);
// Flatten parameters for the query
let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = Vec::new();
for (item_id, factors) in batch_values {
params.push(item_id);
params.push(factors);
}
info!("Number of parameters: {}", params.len());
client.execute(&query, &params[..]).await.with_context(|| {
format!(
"Failed to execute batch insert at {}:{} with {} values",
file!(),
line!(),
batch_values.len()
)
})?;
Ok(())
}
async fn validate_table_exists(client: &deadpool_postgres::Client, table: &str) -> Result<()> { async fn validate_table_exists(client: &deadpool_postgres::Client, table: &str) -> Result<()> {
let query = r#"SELECT EXISTS ( let query = r#"SELECT EXISTS (
SELECT FROM pg_tables SELECT FROM pg_tables