use COPY for exporting data into temp table
This commit is contained in:
141
src/main.rs
141
src/main.rs
@@ -7,6 +7,7 @@ use libmf::{Loss, Matrix, Model};
|
|||||||
use log::info;
|
use log::info;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::env;
|
use std::env;
|
||||||
|
use tokio_postgres::binary_copy::BinaryCopyInWriter;
|
||||||
use tokio_postgres::{types::Type, NoTls};
|
use tokio_postgres::{types::Type, NoTls};
|
||||||
|
|
||||||
mod args;
|
mod args;
|
||||||
@@ -35,23 +36,28 @@ async fn get_column_types(
|
|||||||
) -> Result<Vec<Type>> {
|
) -> Result<Vec<Type>> {
|
||||||
let column_list = columns.join("', '");
|
let column_list = columns.join("', '");
|
||||||
let query = format!(
|
let query = format!(
|
||||||
"SELECT data_type \
|
"SELECT data_type, udt_name \
|
||||||
FROM information_schema.columns \
|
FROM information_schema.columns \
|
||||||
WHERE table_schema = 'public' \
|
WHERE table_schema = 'public' \
|
||||||
AND table_name = $1 \
|
AND table_name = $1 \
|
||||||
AND column_name IN ('{}') \
|
AND column_name IN ('{}') \
|
||||||
ORDER BY column_name",
|
ORDER BY array_position(ARRAY['{}'], column_name)",
|
||||||
column_list
|
column_list, column_list
|
||||||
);
|
);
|
||||||
|
|
||||||
|
info!("Executing type query: {}", query);
|
||||||
let rows = client.query(&query, &[&table]).await?;
|
let rows = client.query(&query, &[&table]).await?;
|
||||||
|
info!("Got types: {:?}", rows);
|
||||||
let mut types = Vec::new();
|
let mut types = Vec::new();
|
||||||
for row in rows {
|
for row in rows {
|
||||||
let data_type: &str = row.get(0);
|
let data_type: &str = row.get(0);
|
||||||
let pg_type = match data_type {
|
let udt_name: &str = row.get(1);
|
||||||
"integer" => Type::INT4,
|
let pg_type = match (data_type, udt_name) {
|
||||||
"bigint" => Type::INT8,
|
("integer", _) => Type::INT4,
|
||||||
_ => anyhow::bail!("Unsupported column type: {}", data_type),
|
("bigint", _) => Type::INT8,
|
||||||
|
("ARRAY", "_float4") => Type::FLOAT4_ARRAY,
|
||||||
|
("ARRAY", "_float8") => Type::FLOAT8_ARRAY,
|
||||||
|
_ => anyhow::bail!("Unsupported column type: {} ({})", data_type, udt_name),
|
||||||
};
|
};
|
||||||
types.push(pg_type);
|
types.push(pg_type);
|
||||||
}
|
}
|
||||||
@@ -126,10 +132,35 @@ async fn save_embeddings(
|
|||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let client = pool.get().await?;
|
let client = pool.get().await?;
|
||||||
|
|
||||||
|
// Get the column types from the target table
|
||||||
|
let types = get_column_types(
|
||||||
|
&client,
|
||||||
|
&args.target_table,
|
||||||
|
&[&args.target_id_column, &args.target_embedding_column],
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Create a temporary table with the same structure
|
||||||
|
let create_temp = format!(
|
||||||
|
"CREATE TEMP TABLE temp_embeddings AS SELECT * FROM {} WHERE 1=0",
|
||||||
|
args.target_table
|
||||||
|
);
|
||||||
|
client.execute(&create_temp, &[]).await?;
|
||||||
|
|
||||||
|
// Copy data into temporary table
|
||||||
|
let copy_query = format!(
|
||||||
|
"COPY temp_embeddings ({id_col}, {embedding_col}) FROM STDIN (FORMAT binary)",
|
||||||
|
id_col = args.target_id_column,
|
||||||
|
embedding_col = args.target_embedding_column,
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut writer = Box::pin(BinaryCopyInWriter::new(
|
||||||
|
client.copy_in(©_query).await?,
|
||||||
|
&types,
|
||||||
|
));
|
||||||
|
|
||||||
let mut valid_embeddings = 0;
|
let mut valid_embeddings = 0;
|
||||||
let mut invalid_embeddings = 0;
|
let mut invalid_embeddings = 0;
|
||||||
let batch_size = 128;
|
|
||||||
let mut current_batch = Vec::with_capacity(batch_size);
|
|
||||||
|
|
||||||
// Process factors for items that appear in the source table
|
// Process factors for items that appear in the source table
|
||||||
for (idx, factors) in model.q_iter().enumerate() {
|
for (idx, factors) in model.q_iter().enumerate() {
|
||||||
@@ -145,33 +176,33 @@ async fn save_embeddings(
|
|||||||
}
|
}
|
||||||
|
|
||||||
valid_embeddings += 1;
|
valid_embeddings += 1;
|
||||||
current_batch.push((item_id as i64, factors));
|
let id_value: Box<dyn tokio_postgres::types::ToSql + Sync> = match types[0] {
|
||||||
|
Type::INT4 => Box::new(item_id),
|
||||||
// When batch is full, save it
|
Type::INT8 => Box::new(item_id as i64),
|
||||||
if current_batch.len() >= batch_size {
|
_ => anyhow::bail!("Unexpected type for ID column: {:?}", types[0]),
|
||||||
save_batch(
|
};
|
||||||
&client,
|
let factors: &[f32] = factors;
|
||||||
&args.target_table,
|
writer.as_mut().write(&[&*id_value, &factors]).await?;
|
||||||
&args.target_id_column,
|
|
||||||
&args.target_embedding_column,
|
|
||||||
¤t_batch,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
current_batch.clear();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save any remaining items in the last batch
|
writer.as_mut().finish().await?;
|
||||||
if !current_batch.is_empty() {
|
|
||||||
save_batch(
|
// Insert from temp table with ON CONFLICT DO UPDATE
|
||||||
&client,
|
let insert_query = format!(
|
||||||
&args.target_table,
|
r#"
|
||||||
&args.target_id_column,
|
INSERT INTO {target_table} ({id_col}, {embedding_col})
|
||||||
&args.target_embedding_column,
|
SELECT {id_col}, {embedding_col} FROM temp_embeddings
|
||||||
¤t_batch,
|
ON CONFLICT ({id_col}) DO UPDATE
|
||||||
)
|
SET {embedding_col} = EXCLUDED.{embedding_col}
|
||||||
.await?;
|
"#,
|
||||||
}
|
target_table = args.target_table,
|
||||||
|
id_col = args.target_id_column,
|
||||||
|
embedding_col = args.target_embedding_column,
|
||||||
|
);
|
||||||
|
client.execute(&insert_query, &[]).await?;
|
||||||
|
|
||||||
|
// Clean up temp table
|
||||||
|
client.execute("DROP TABLE temp_embeddings", &[]).await?;
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Saved {} valid embeddings, skipped {} invalid embeddings",
|
"Saved {} valid embeddings, skipped {} invalid embeddings",
|
||||||
@@ -181,48 +212,6 @@ async fn save_embeddings(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_batch(
|
|
||||||
client: &deadpool_postgres::Client,
|
|
||||||
target_table: &str,
|
|
||||||
target_id_column: &str,
|
|
||||||
target_embedding_column: &str,
|
|
||||||
batch_values: &[(i64, &[f32])],
|
|
||||||
) -> Result<()> {
|
|
||||||
// Build the batch insert query
|
|
||||||
let placeholders: Vec<String> = (0..batch_values.len())
|
|
||||||
.map(|i| format!("(${}::int8, ${}::float4[])", i * 2 + 1, i * 2 + 2))
|
|
||||||
.collect();
|
|
||||||
let query = format!(
|
|
||||||
r#"
|
|
||||||
INSERT INTO {target_table} ({target_id_column}, {target_embedding_column})
|
|
||||||
VALUES {placeholders}
|
|
||||||
ON CONFLICT ({target_id_column})
|
|
||||||
DO UPDATE SET {target_embedding_column} = EXCLUDED.{target_embedding_column}
|
|
||||||
"#,
|
|
||||||
placeholders = placeholders.join(",")
|
|
||||||
);
|
|
||||||
|
|
||||||
// info!("Executing query: {}", query);
|
|
||||||
|
|
||||||
// Flatten parameters for the query
|
|
||||||
let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = Vec::new();
|
|
||||||
for (item_id, factors) in batch_values {
|
|
||||||
params.push(item_id);
|
|
||||||
params.push(factors);
|
|
||||||
}
|
|
||||||
|
|
||||||
info!("Number of parameters: {}", params.len());
|
|
||||||
client.execute(&query, ¶ms[..]).await.with_context(|| {
|
|
||||||
format!(
|
|
||||||
"Failed to execute batch insert at {}:{} with {} values",
|
|
||||||
file!(),
|
|
||||||
line!(),
|
|
||||||
batch_values.len()
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn validate_table_exists(client: &deadpool_postgres::Client, table: &str) -> Result<()> {
|
async fn validate_table_exists(client: &deadpool_postgres::Client, table: &str) -> Result<()> {
|
||||||
let query = r#"SELECT EXISTS (
|
let query = r#"SELECT EXISTS (
|
||||||
SELECT FROM pg_tables
|
SELECT FROM pg_tables
|
||||||
|
|||||||
Reference in New Issue
Block a user