Enhance fit model functionality and argument handling
- Updated `fit_model_args.rs` to allow optional factors for matrix factorization and added an index name argument for index management. - Modified `fit_model.rs` to handle index creation and dropping during data upsert, improving database interaction. - Adjusted schema validation to infer vector dimensions and validate against specified factors. - Enhanced `generate_test_data.rs` to create an IVFFlat index on the embeddings column. These changes improve the flexibility and robustness of the fit model process, allowing for better management of database indices and more intuitive argument handling.
This commit is contained in:
@@ -26,6 +26,7 @@ async fn create_pool() -> Result<Pool> {
|
|||||||
config.dbname = Some(env::var("POSTGRES_DB").context("POSTGRES_DB not set")?);
|
config.dbname = Some(env::var("POSTGRES_DB").context("POSTGRES_DB not set")?);
|
||||||
config.user = Some(env::var("POSTGRES_USER").context("POSTGRES_USER not set")?);
|
config.user = Some(env::var("POSTGRES_USER").context("POSTGRES_USER not set")?);
|
||||||
config.password = Some(env::var("POSTGRES_PASSWORD").context("POSTGRES_PASSWORD not set")?);
|
config.password = Some(env::var("POSTGRES_PASSWORD").context("POSTGRES_PASSWORD not set")?);
|
||||||
|
config.application_name = Some("fit_model".to_string());
|
||||||
|
|
||||||
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
||||||
}
|
}
|
||||||
@@ -138,8 +139,9 @@ async fn save_embeddings(
|
|||||||
args: &Args,
|
args: &Args,
|
||||||
model: &Model,
|
model: &Model,
|
||||||
item_ids: &HashSet<i32>,
|
item_ids: &HashSet<i32>,
|
||||||
|
vector_dim: i32,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let client = pool.get().await?;
|
let mut client = pool.get().await?;
|
||||||
|
|
||||||
// Get the column types from the target table
|
// Get the column types from the target table
|
||||||
info!("Getting column types for target table...");
|
info!("Getting column types for target table...");
|
||||||
@@ -166,7 +168,7 @@ async fn save_embeddings(
|
|||||||
{} {},
|
{} {},
|
||||||
{} vector({})
|
{} vector({})
|
||||||
)",
|
)",
|
||||||
args.target_id_column, id_type_str, args.target_embedding_column, args.factors
|
args.target_id_column, id_type_str, args.target_embedding_column, vector_dim
|
||||||
);
|
);
|
||||||
info!("Temp table creation SQL: {}", create_temp);
|
info!("Temp table creation SQL: {}", create_temp);
|
||||||
client.execute(&create_temp, &[]).await?;
|
client.execute(&create_temp, &[]).await?;
|
||||||
@@ -215,6 +217,42 @@ async fn save_embeddings(
|
|||||||
writer.as_mut().finish().await?;
|
writer.as_mut().finish().await?;
|
||||||
info!("COPY operation completed");
|
info!("COPY operation completed");
|
||||||
|
|
||||||
|
// Start a transaction for index management and data upsert
|
||||||
|
info!("Starting transaction for index management and data upsert...");
|
||||||
|
let tx = client.transaction().await?;
|
||||||
|
|
||||||
|
// Get and drop index if specified
|
||||||
|
let index_sql = if let Some(index_name) = &args.index_name {
|
||||||
|
let query = r#"
|
||||||
|
SELECT pg_get_indexdef(i.indexrelid) ||
|
||||||
|
CASE WHEN t.spcname IS NOT NULL
|
||||||
|
THEN ' TABLESPACE ' || t.spcname
|
||||||
|
ELSE ''
|
||||||
|
END as index_def
|
||||||
|
FROM pg_index i
|
||||||
|
JOIN pg_class c ON i.indexrelid = c.oid
|
||||||
|
LEFT JOIN pg_tablespace t ON c.reltablespace = t.oid
|
||||||
|
WHERE c.relname = $1"#;
|
||||||
|
let row = tx.query_one(query, &[index_name]).await?;
|
||||||
|
let sql: String = row.get("index_def");
|
||||||
|
|
||||||
|
info!("Dropping index {}...", index_name);
|
||||||
|
tx.execute(&format!("DROP INDEX IF EXISTS {}", index_name), &[])
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Some(sql)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Set table as UNLOGGED
|
||||||
|
// info!("Setting table as UNLOGGED...");
|
||||||
|
// tx.execute(
|
||||||
|
// &format!("ALTER TABLE {} SET UNLOGGED", args.target_table),
|
||||||
|
// &[],
|
||||||
|
// )
|
||||||
|
// .await?;
|
||||||
|
|
||||||
// Insert from temp table with ON CONFLICT DO UPDATE
|
// Insert from temp table with ON CONFLICT DO UPDATE
|
||||||
info!("Upserting from temp table to target table...");
|
info!("Upserting from temp table to target table...");
|
||||||
let upsert_query = format!(
|
let upsert_query = format!(
|
||||||
@@ -227,9 +265,30 @@ async fn save_embeddings(
|
|||||||
embedding_col = args.target_embedding_column,
|
embedding_col = args.target_embedding_column,
|
||||||
);
|
);
|
||||||
info!("Upsert query: {}", upsert_query);
|
info!("Upsert query: {}", upsert_query);
|
||||||
client.execute(&upsert_query, &[]).await?;
|
tx.execute(&upsert_query, &[]).await?;
|
||||||
info!("Upsert completed");
|
info!("Upsert completed");
|
||||||
|
|
||||||
|
// Set table back to LOGGED
|
||||||
|
// info!("Setting table back to LOGGED...");
|
||||||
|
// tx.execute(
|
||||||
|
// &format!("ALTER TABLE {} SET LOGGED", args.target_table),
|
||||||
|
// &[],
|
||||||
|
// )
|
||||||
|
// .await?;
|
||||||
|
|
||||||
|
// Recreate index if we had one
|
||||||
|
if let Some(sql) = index_sql {
|
||||||
|
info!("Recreating index...");
|
||||||
|
info!("Index SQL: {}", sql);
|
||||||
|
tx.execute(&sql, &[]).await?;
|
||||||
|
info!("Index recreated");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit the transaction
|
||||||
|
info!("Committing transaction...");
|
||||||
|
tx.commit().await?;
|
||||||
|
info!("Transaction committed");
|
||||||
|
|
||||||
// Clean up temp table
|
// Clean up temp table
|
||||||
info!("Cleaning up temporary table...");
|
info!("Cleaning up temporary table...");
|
||||||
client.execute("DROP TABLE temp_embeddings", &[]).await?;
|
client.execute("DROP TABLE temp_embeddings", &[]).await?;
|
||||||
@@ -278,7 +337,7 @@ async fn validate_column_exists(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Result<()> {
|
async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Result<i32> {
|
||||||
// Validate source table exists
|
// Validate source table exists
|
||||||
validate_table_exists(client, &args.source_table)
|
validate_table_exists(client, &args.source_table)
|
||||||
.await
|
.await
|
||||||
@@ -305,7 +364,7 @@ async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Res
|
|||||||
.await
|
.await
|
||||||
.context("Failed to validate target embedding column")?;
|
.context("Failed to validate target embedding column")?;
|
||||||
|
|
||||||
// Validate vector dimension matches factors
|
// Get vector dimension from target column
|
||||||
let query = r#"
|
let query = r#"
|
||||||
SELECT a.atttypmod as vector_dim, c.data_type, c.udt_name
|
SELECT a.atttypmod as vector_dim, c.data_type, c.udt_name
|
||||||
FROM pg_attribute a
|
FROM pg_attribute a
|
||||||
@@ -323,29 +382,44 @@ async fn validate_schema(client: &deadpool_postgres::Client, args: &Args) -> Res
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let data_type: &str = row.get("data_type");
|
let data_type: &str = row.get("data_type");
|
||||||
if data_type == "USER-DEFINED" {
|
if data_type != "USER-DEFINED" {
|
||||||
let udt_name: &str = row.get("udt_name");
|
anyhow::bail!(
|
||||||
if udt_name == "vector" {
|
"Column '{}' is not a user-defined type (got {})",
|
||||||
let vector_dim: i32 = row.get("vector_dim");
|
args.target_embedding_column,
|
||||||
if vector_dim <= 0 {
|
data_type
|
||||||
anyhow::bail!(
|
);
|
||||||
"Invalid vector dimension {} for column '{}'",
|
}
|
||||||
vector_dim,
|
|
||||||
args.target_embedding_column,
|
let udt_name: &str = row.get("udt_name");
|
||||||
);
|
if udt_name != "vector" {
|
||||||
}
|
anyhow::bail!(
|
||||||
if vector_dim != args.factors as i32 {
|
"Column '{}' is not a vector type (got {})",
|
||||||
anyhow::bail!(
|
args.target_embedding_column,
|
||||||
"Vector dimension mismatch: column '{}' has dimension {}, but factors is {}",
|
udt_name
|
||||||
args.target_embedding_column,
|
);
|
||||||
vector_dim,
|
}
|
||||||
args.factors
|
|
||||||
);
|
let vector_dim: i32 = row.get("vector_dim");
|
||||||
}
|
if vector_dim <= 0 {
|
||||||
|
anyhow::bail!(
|
||||||
|
"Invalid vector dimension {} for column '{}'",
|
||||||
|
vector_dim,
|
||||||
|
args.target_embedding_column,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If factors is specified, validate it matches the column dimension
|
||||||
|
if let Some(factors) = args.factors {
|
||||||
|
if factors != vector_dim {
|
||||||
|
anyhow::bail!(
|
||||||
|
"Specified factors ({}) does not match column vector dimension ({})",
|
||||||
|
factors,
|
||||||
|
vector_dim
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(vector_dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@@ -362,10 +436,10 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
let pool = create_pool().await?;
|
let pool = create_pool().await?;
|
||||||
|
|
||||||
// Validate schema before proceeding
|
// Validate schema and get vector dimension
|
||||||
info!("Validating database schema...");
|
info!("Validating database schema...");
|
||||||
validate_schema(&pool.get().await?, &args).await?;
|
let factors = validate_schema(&pool.get().await?, &args).await?;
|
||||||
info!("Schema validation successful");
|
info!("Schema validation successful, inferred {} factors", factors);
|
||||||
|
|
||||||
let mut matrix = Matrix::new();
|
let mut matrix = Matrix::new();
|
||||||
let mut item_ids = HashSet::new();
|
let mut item_ids = HashSet::new();
|
||||||
@@ -401,7 +475,7 @@ async fn main() -> Result<()> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let model = Model::params()
|
let model = Model::params()
|
||||||
.factors(args.factors)
|
.factors(factors)
|
||||||
.lambda_p1(args.lambda1)
|
.lambda_p1(args.lambda1)
|
||||||
.lambda_q1(args.lambda1)
|
.lambda_q1(args.lambda1)
|
||||||
.lambda_p2(args.lambda2)
|
.lambda_p2(args.lambda2)
|
||||||
@@ -416,7 +490,7 @@ async fn main() -> Result<()> {
|
|||||||
.fit(&matrix)?;
|
.fit(&matrix)?;
|
||||||
|
|
||||||
info!("Saving embeddings...");
|
info!("Saving embeddings...");
|
||||||
save_embeddings(&pool, &args, &model, &item_ids).await?;
|
save_embeddings(&pool, &args, &model, &item_ids, factors).await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -117,6 +117,22 @@ async fn main() -> Result<()> {
|
|||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// Create IVFFlat index on the embeddings column
|
||||||
|
let index_name = format!("{}_embeddings_idx", args.embeddings_table);
|
||||||
|
info!(
|
||||||
|
"Creating IVFFlat index `{}' on embeddings column...",
|
||||||
|
index_name
|
||||||
|
);
|
||||||
|
client
|
||||||
|
.execute(
|
||||||
|
&format!(
|
||||||
|
"CREATE INDEX {} ON {} USING ivfflat (embedding)",
|
||||||
|
index_name, args.embeddings_table
|
||||||
|
),
|
||||||
|
&[],
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
// Generate cluster centers that are well-separated in 3D space
|
// Generate cluster centers that are well-separated in 3D space
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::thread_rng();
|
||||||
let mut cluster_centers = Vec::new();
|
let mut cluster_centers = Vec::new();
|
||||||
|
|||||||
@@ -39,9 +39,10 @@ pub struct Args {
|
|||||||
#[arg(long, default_value = "0.01")]
|
#[arg(long, default_value = "0.01")]
|
||||||
pub learning_rate: f32,
|
pub learning_rate: f32,
|
||||||
|
|
||||||
/// Number of factors for matrix factorization
|
/// Number of factors (dimensions) for matrix factorization
|
||||||
#[arg(long, default_value = "8")]
|
/// If not specified, will be inferred from the target column's vector dimension
|
||||||
pub factors: i32,
|
#[arg(long)]
|
||||||
|
pub factors: Option<i32>,
|
||||||
|
|
||||||
/// Lambda for regularization
|
/// Lambda for regularization
|
||||||
#[arg(long, default_value = "0.0")]
|
#[arg(long, default_value = "0.0")]
|
||||||
@@ -62,4 +63,8 @@ pub struct Args {
|
|||||||
/// Maximum number of interactions to load (optional)
|
/// Maximum number of interactions to load (optional)
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub max_interactions: Option<usize>,
|
pub max_interactions: Option<usize>,
|
||||||
|
|
||||||
|
/// Name of the index to drop before upserting and recreate after
|
||||||
|
#[arg(long)]
|
||||||
|
pub index_name: Option<String>,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user