Enhance fit model functionality and argument handling

- Updated `fit_model_args.rs` to allow optional factors for matrix factorization and added an index name argument for index management.
- Modified `fit_model.rs` to handle index creation and dropping during data upsert, improving database interaction.
- Adjusted schema validation to infer vector dimensions and validate against specified factors.
- Enhanced `generate_test_data.rs` to create an IVFFlat index on the embeddings column.

These changes improve the flexibility and robustness of the fit model process, allowing for better management of database indices and more intuitive argument handling.
This commit is contained in:
Dylan Knutson
2024-12-28 20:37:12 +00:00
parent 5430fdd501
commit b3ba58723c
3 changed files with 128 additions and 33 deletions

View File

@@ -26,6 +26,7 @@ async fn create_pool() -> Result<Pool> {
config.dbname = Some(env::var("POSTGRES_DB").context("POSTGRES_DB not set")?);
config.user = Some(env::var("POSTGRES_USER").context("POSTGRES_USER not set")?);
config.password = Some(env::var("POSTGRES_PASSWORD").context("POSTGRES_PASSWORD not set")?);
config.application_name = Some("fit_model".to_string());
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
}
@@ -138,8 +139,9 @@ async fn save_embeddings(
args: &Args,
model: &Model,
item_ids: &HashSet<i32>,
vector_dim: i32,
) -> Result<()> {
let client = pool.get().await?;
let mut client = pool.get().await?;
// Get the column types from the target table
info!("Getting column types for target table...");
@@ -166,7 +168,7 @@ async fn save_embeddings(
{} {},
{} vector({})
)",
args.target_id_column, id_type_str, args.target_embedding_column, args.factors
args.target_id_column, id_type_str, args.target_embedding_column, vector_dim
);
info!("Temp table creation SQL: {}", create_temp);
client.execute(&create_temp, &[]).await?;
@@ -215,6 +217,42 @@ async fn save_embeddings(
writer.as_mut().finish().await?;
info!("COPY operation completed");
// Start a transaction for index management and data upsert
info!("Starting transaction for index management and data upsert...");
let tx = client.transaction().await?;
// Get and drop index if specified
let index_sql = if let Some(index_name) = &args.index_name {
let query = r#"
SELECT pg_get_indexdef(i.indexrelid) ||
CASE WHEN t.spcname IS NOT NULL
THEN ' TABLESPACE ' || t.spcname
ELSE ''
END as index_def
FROM pg_index i
JOIN pg_class c ON i.indexrelid = c.oid
LEFT JOIN pg_tablespace t ON c.reltablespace = t.oid
WHERE c.relname = $1"#;
let row = tx.query_one(query, &[index_name]).await?;
let sql: String = row.get("index_def");
info!("Dropping index {}...", index_name);
tx.execute(&format!("DROP INDEX IF EXISTS {}", index_name), &[])
.await?;
Some(sql)
} else {
None
};
// Set table as UNLOGGED
// info!("Setting table as UNLOGGED...");
// tx.execute(
// &format!("ALTER TABLE {} SET UNLOGGED", args.target_table),
// &[],
// )
// .await?;
// Insert from temp table with ON CONFLICT DO UPDATE
info!("Upserting from temp table to target table...");
let upsert_query = format!(
@@ -227,9 +265,30 @@ async fn save_embeddings(
embedding_col = args.target_embedding_column,
);
info!("Upsert query: {}", upsert_query);
client.execute(&upsert_query, &[]).await?;
tx.execute(&upsert_query, &[]).await?;
info!("Upsert completed");
// Set table back to LOGGED
// info!("Setting table back to LOGGED...");
// tx.execute(
// &format!("ALTER TABLE {} SET LOGGED", args.target_table),
// &[],
// )
// .await?;
// Recreate index if we had one
if let Some(sql) = index_sql {
info!("Recreating index...");
info!("Index SQL: {}", sql);
tx.execute(&sql, &[]).await?;
info!("Index recreated");
}
// Commit the transaction
info!("Committing transaction...");
tx.commit().await?;
info!("Transaction committed");
// Clean up temp table
info!("Cleaning up temporary table...");
client.execute("DROP TABLE temp_embeddings", &[]).await?;
@@ -278,7 +337,7 @@ async fn validate_column_exists(
Ok(())
}
async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Result<()> {
async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Result<i32> {
// Validate source table exists
validate_table_exists(client, &args.source_table)
.await
@@ -305,7 +364,7 @@ async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Res
.await
.context("Failed to validate target embedding column")?;
// Validate vector dimension matches factors
// Get vector dimension from target column
let query = r#"
SELECT a.atttypmod as vector_dim, c.data_type, c.udt_name
FROM pg_attribute a
@@ -323,29 +382,44 @@ async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Res
.await?;
let data_type: &str = row.get("data_type");
if data_type == "USER-DEFINED" {
let udt_name: &str = row.get("udt_name");
if udt_name == "vector" {
let vector_dim: i32 = row.get("vector_dim");
if vector_dim <= 0 {
anyhow::bail!(
"Invalid vector dimension {} for column '{}'",
vector_dim,
args.target_embedding_column,
);
}
if vector_dim != args.factors as i32 {
anyhow::bail!(
"Vector dimension mismatch: column '{}' has dimension {}, but factors is {}",
args.target_embedding_column,
vector_dim,
args.factors
);
}
if data_type != "USER-DEFINED" {
anyhow::bail!(
"Column '{}' is not a user-defined type (got {})",
args.target_embedding_column,
data_type
);
}
let udt_name: &str = row.get("udt_name");
if udt_name != "vector" {
anyhow::bail!(
"Column '{}' is not a vector type (got {})",
args.target_embedding_column,
udt_name
);
}
let vector_dim: i32 = row.get("vector_dim");
if vector_dim <= 0 {
anyhow::bail!(
"Invalid vector dimension {} for column '{}'",
vector_dim,
args.target_embedding_column,
);
}
// If factors is specified, validate it matches the column dimension
if let Some(factors) = args.factors {
if factors != vector_dim {
anyhow::bail!(
"Specified factors ({}) does not match column vector dimension ({})",
factors,
vector_dim
);
}
}
Ok(())
Ok(vector_dim)
}
#[tokio::main]
@@ -362,10 +436,10 @@ async fn main() -> Result<()> {
let pool = create_pool().await?;
// Validate schema before proceeding
// Validate schema and get vector dimension
info!("Validating database schema...");
validate_schema(&pool.get().await?, &args).await?;
info!("Schema validation successful");
let factors = validate_schema(&pool.get().await?, &args).await?;
info!("Schema validation successful, inferred {} factors", factors);
let mut matrix = Matrix::new();
let mut item_ids = HashSet::new();
@@ -401,7 +475,7 @@ async fn main() -> Result<()> {
);
let model = Model::params()
.factors(args.factors)
.factors(factors)
.lambda_p1(args.lambda1)
.lambda_q1(args.lambda1)
.lambda_p2(args.lambda2)
@@ -416,7 +490,7 @@ async fn main() -> Result<()> {
.fit(&matrix)?;
info!("Saving embeddings...");
save_embeddings(&pool, &args, &model, &item_ids).await?;
save_embeddings(&pool, &args, &model, &item_ids, factors).await?;
Ok(())
}

View File

@@ -117,6 +117,22 @@ async fn main() -> Result<()> {
)
.await?;
// Create IVFFlat index on the embeddings column
let index_name = format!("{}_embeddings_idx", args.embeddings_table);
info!(
"Creating IVFFlat index `{}' on embeddings column...",
index_name
);
client
.execute(
&format!(
"CREATE INDEX {} ON {} USING ivfflat (embedding)",
index_name, args.embeddings_table
),
&[],
)
.await?;
// Generate cluster centers that are well-separated in 3D space
let mut rng = rand::thread_rng();
let mut cluster_centers = Vec::new();

View File

@@ -39,9 +39,10 @@ pub struct Args {
#[arg(long, default_value = "0.01")]
pub learning_rate: f32,
/// Number of factors for matrix factorization
#[arg(long, default_value = "8")]
pub factors: i32,
/// Number of factors (dimensions) for matrix factorization
/// If not specified, will be inferred from the target column's vector dimension
#[arg(long)]
pub factors: Option<i32>,
/// Lambda for regularization
#[arg(long, default_value = "0.0")]
@@ -62,4 +63,8 @@ pub struct Args {
/// Maximum number of interactions to load (optional)
#[arg(long)]
pub max_interactions: Option<usize>,
/// Name of the index to drop before upserting and recreate after
#[arg(long)]
pub index_name: Option<String>,
}