Enhance fit model functionality and argument handling
- Updated `fit_model_args.rs` to allow optional factors for matrix factorization and added an index name argument for index management. - Modified `fit_model.rs` to handle index creation and dropping during data upsert, improving database interaction. - Adjusted schema validation to infer vector dimensions and validate against specified factors. - Enhanced `generate_test_data.rs` to create an IVFFlat index on the embeddings column. These changes improve the flexibility and robustness of the fit model process, allowing for better management of database indices and more intuitive argument handling.
This commit is contained in:
@@ -26,6 +26,7 @@ async fn create_pool() -> Result<Pool> {
|
||||
config.dbname = Some(env::var("POSTGRES_DB").context("POSTGRES_DB not set")?);
|
||||
config.user = Some(env::var("POSTGRES_USER").context("POSTGRES_USER not set")?);
|
||||
config.password = Some(env::var("POSTGRES_PASSWORD").context("POSTGRES_PASSWORD not set")?);
|
||||
config.application_name = Some("fit_model".to_string());
|
||||
|
||||
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
||||
}
|
||||
@@ -138,8 +139,9 @@ async fn save_embeddings(
|
||||
args: &Args,
|
||||
model: &Model,
|
||||
item_ids: &HashSet<i32>,
|
||||
vector_dim: i32,
|
||||
) -> Result<()> {
|
||||
let client = pool.get().await?;
|
||||
let mut client = pool.get().await?;
|
||||
|
||||
// Get the column types from the target table
|
||||
info!("Getting column types for target table...");
|
||||
@@ -166,7 +168,7 @@ async fn save_embeddings(
|
||||
{} {},
|
||||
{} vector({})
|
||||
)",
|
||||
args.target_id_column, id_type_str, args.target_embedding_column, args.factors
|
||||
args.target_id_column, id_type_str, args.target_embedding_column, vector_dim
|
||||
);
|
||||
info!("Temp table creation SQL: {}", create_temp);
|
||||
client.execute(&create_temp, &[]).await?;
|
||||
@@ -215,6 +217,42 @@ async fn save_embeddings(
|
||||
writer.as_mut().finish().await?;
|
||||
info!("COPY operation completed");
|
||||
|
||||
// Start a transaction for index management and data upsert
|
||||
info!("Starting transaction for index management and data upsert...");
|
||||
let tx = client.transaction().await?;
|
||||
|
||||
// Get and drop index if specified
|
||||
let index_sql = if let Some(index_name) = &args.index_name {
|
||||
let query = r#"
|
||||
SELECT pg_get_indexdef(i.indexrelid) ||
|
||||
CASE WHEN t.spcname IS NOT NULL
|
||||
THEN ' TABLESPACE ' || t.spcname
|
||||
ELSE ''
|
||||
END as index_def
|
||||
FROM pg_index i
|
||||
JOIN pg_class c ON i.indexrelid = c.oid
|
||||
LEFT JOIN pg_tablespace t ON c.reltablespace = t.oid
|
||||
WHERE c.relname = $1"#;
|
||||
let row = tx.query_one(query, &[index_name]).await?;
|
||||
let sql: String = row.get("index_def");
|
||||
|
||||
info!("Dropping index {}...", index_name);
|
||||
tx.execute(&format!("DROP INDEX IF EXISTS {}", index_name), &[])
|
||||
.await?;
|
||||
|
||||
Some(sql)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Set table as UNLOGGED
|
||||
// info!("Setting table as UNLOGGED...");
|
||||
// tx.execute(
|
||||
// &format!("ALTER TABLE {} SET UNLOGGED", args.target_table),
|
||||
// &[],
|
||||
// )
|
||||
// .await?;
|
||||
|
||||
// Insert from temp table with ON CONFLICT DO UPDATE
|
||||
info!("Upserting from temp table to target table...");
|
||||
let upsert_query = format!(
|
||||
@@ -227,9 +265,30 @@ async fn save_embeddings(
|
||||
embedding_col = args.target_embedding_column,
|
||||
);
|
||||
info!("Upsert query: {}", upsert_query);
|
||||
client.execute(&upsert_query, &[]).await?;
|
||||
tx.execute(&upsert_query, &[]).await?;
|
||||
info!("Upsert completed");
|
||||
|
||||
// Set table back to LOGGED
|
||||
// info!("Setting table back to LOGGED...");
|
||||
// tx.execute(
|
||||
// &format!("ALTER TABLE {} SET LOGGED", args.target_table),
|
||||
// &[],
|
||||
// )
|
||||
// .await?;
|
||||
|
||||
// Recreate index if we had one
|
||||
if let Some(sql) = index_sql {
|
||||
info!("Recreating index...");
|
||||
info!("Index SQL: {}", sql);
|
||||
tx.execute(&sql, &[]).await?;
|
||||
info!("Index recreated");
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
info!("Committing transaction...");
|
||||
tx.commit().await?;
|
||||
info!("Transaction committed");
|
||||
|
||||
// Clean up temp table
|
||||
info!("Cleaning up temporary table...");
|
||||
client.execute("DROP TABLE temp_embeddings", &[]).await?;
|
||||
@@ -278,7 +337,7 @@ async fn validate_column_exists(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Result<()> {
|
||||
async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Result<i32> {
|
||||
// Validate source table exists
|
||||
validate_table_exists(client, &args.source_table)
|
||||
.await
|
||||
@@ -305,7 +364,7 @@ async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Res
|
||||
.await
|
||||
.context("Failed to validate target embedding column")?;
|
||||
|
||||
// Validate vector dimension matches factors
|
||||
// Get vector dimension from target column
|
||||
let query = r#"
|
||||
SELECT a.atttypmod as vector_dim, c.data_type, c.udt_name
|
||||
FROM pg_attribute a
|
||||
@@ -323,9 +382,23 @@ async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Res
|
||||
.await?;
|
||||
|
||||
let data_type: &str = row.get("data_type");
|
||||
if data_type == "USER-DEFINED" {
|
||||
if data_type != "USER-DEFINED" {
|
||||
anyhow::bail!(
|
||||
"Column '{}' is not a user-defined type (got {})",
|
||||
args.target_embedding_column,
|
||||
data_type
|
||||
);
|
||||
}
|
||||
|
||||
let udt_name: &str = row.get("udt_name");
|
||||
if udt_name == "vector" {
|
||||
if udt_name != "vector" {
|
||||
anyhow::bail!(
|
||||
"Column '{}' is not a vector type (got {})",
|
||||
args.target_embedding_column,
|
||||
udt_name
|
||||
);
|
||||
}
|
||||
|
||||
let vector_dim: i32 = row.get("vector_dim");
|
||||
if vector_dim <= 0 {
|
||||
anyhow::bail!(
|
||||
@@ -334,18 +407,19 @@ async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Res
|
||||
args.target_embedding_column,
|
||||
);
|
||||
}
|
||||
if vector_dim != args.factors as i32 {
|
||||
|
||||
// If factors is specified, validate it matches the column dimension
|
||||
if let Some(factors) = args.factors {
|
||||
if factors != vector_dim {
|
||||
anyhow::bail!(
|
||||
"Vector dimension mismatch: column '{}' has dimension {}, but factors is {}",
|
||||
args.target_embedding_column,
|
||||
vector_dim,
|
||||
args.factors
|
||||
"Specified factors ({}) does not match column vector dimension ({})",
|
||||
factors,
|
||||
vector_dim
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(vector_dim)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@@ -362,10 +436,10 @@ async fn main() -> Result<()> {
|
||||
|
||||
let pool = create_pool().await?;
|
||||
|
||||
// Validate schema before proceeding
|
||||
// Validate schema and get vector dimension
|
||||
info!("Validating database schema...");
|
||||
validate_schema(&pool.get().await?, &args).await?;
|
||||
info!("Schema validation successful");
|
||||
let factors = validate_schema(&pool.get().await?, &args).await?;
|
||||
info!("Schema validation successful, inferred {} factors", factors);
|
||||
|
||||
let mut matrix = Matrix::new();
|
||||
let mut item_ids = HashSet::new();
|
||||
@@ -401,7 +475,7 @@ async fn main() -> Result<()> {
|
||||
);
|
||||
|
||||
let model = Model::params()
|
||||
.factors(args.factors)
|
||||
.factors(factors)
|
||||
.lambda_p1(args.lambda1)
|
||||
.lambda_q1(args.lambda1)
|
||||
.lambda_p2(args.lambda2)
|
||||
@@ -416,7 +490,7 @@ async fn main() -> Result<()> {
|
||||
.fit(&matrix)?;
|
||||
|
||||
info!("Saving embeddings...");
|
||||
save_embeddings(&pool, &args, &model, &item_ids).await?;
|
||||
save_embeddings(&pool, &args, &model, &item_ids, factors).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -117,6 +117,22 @@ async fn main() -> Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Create IVFFlat index on the embeddings column
|
||||
let index_name = format!("{}_embeddings_idx", args.embeddings_table);
|
||||
info!(
|
||||
"Creating IVFFlat index `{}' on embeddings column...",
|
||||
index_name
|
||||
);
|
||||
client
|
||||
.execute(
|
||||
&format!(
|
||||
"CREATE INDEX {} ON {} USING ivfflat (embedding)",
|
||||
index_name, args.embeddings_table
|
||||
),
|
||||
&[],
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Generate cluster centers that are well-separated in 3D space
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut cluster_centers = Vec::new();
|
||||
|
||||
@@ -39,9 +39,10 @@ pub struct Args {
|
||||
#[arg(long, default_value = "0.01")]
|
||||
pub learning_rate: f32,
|
||||
|
||||
/// Number of factors for matrix factorization
|
||||
#[arg(long, default_value = "8")]
|
||||
pub factors: i32,
|
||||
/// Number of factors (dimensions) for matrix factorization
|
||||
/// If not specified, will be inferred from the target column's vector dimension
|
||||
#[arg(long)]
|
||||
pub factors: Option<i32>,
|
||||
|
||||
/// Lambda for regularization
|
||||
#[arg(long, default_value = "0.0")]
|
||||
@@ -62,4 +63,8 @@ pub struct Args {
|
||||
/// Maximum number of interactions to load (optional)
|
||||
#[arg(long)]
|
||||
pub max_interactions: Option<usize>,
|
||||
|
||||
/// Name of the index to drop before upserting and recreate after
|
||||
#[arg(long)]
|
||||
pub index_name: Option<String>,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user