batch loading for computed rows
This commit is contained in:
@@ -3,8 +3,6 @@ rustflags = [
|
||||
"-C",
|
||||
"link-arg=-fuse-ld=lld",
|
||||
"-C",
|
||||
"link-arg=-L/usr/lib/x86_64-linux-gnu",
|
||||
"-C",
|
||||
"link-arg=-lgomp",
|
||||
"-C",
|
||||
"link-arg=-fopenmp",
|
||||
|
||||
@@ -34,7 +34,7 @@ struct Args {
|
||||
|
||||
/// Noise level (0.0 to 1.0) - fraction of interactions that are random
|
||||
#[arg(long, default_value = "0.05")]
|
||||
noise_level: f64,
|
||||
noise_level: f32,
|
||||
|
||||
/// Source table name for interactions
|
||||
#[arg(long, default_value = "user_interactions")]
|
||||
@@ -91,7 +91,7 @@ async fn main() -> Result<()> {
|
||||
&format!(
|
||||
"CREATE TABLE IF NOT EXISTS {} (
|
||||
item_id INTEGER PRIMARY KEY,
|
||||
embedding FLOAT[]
|
||||
embedding FLOAT4[]
|
||||
)",
|
||||
args.embeddings_table
|
||||
),
|
||||
@@ -105,7 +105,7 @@ async fn main() -> Result<()> {
|
||||
|
||||
// Base vertices of a regular octahedron
|
||||
// These vertices form three orthogonal golden rectangles
|
||||
let phi = (1.0 + 5.0_f64.sqrt()) / 2.0; // golden ratio
|
||||
let phi = (1.0 + 5.0_f32.sqrt()) / 2.0; // golden ratio
|
||||
let scale = 2.0;
|
||||
let base_centers = vec![
|
||||
vec![1.0, phi, 0.0], // front top
|
||||
@@ -123,7 +123,7 @@ async fn main() -> Result<()> {
|
||||
];
|
||||
|
||||
// Normalize and scale the vectors to ensure equal distances
|
||||
let base_centers: Vec<Vec<f64>> = base_centers
|
||||
let base_centers: Vec<Vec<f32>> = base_centers
|
||||
.into_iter()
|
||||
.map(|v| {
|
||||
let norm = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
|
||||
@@ -140,9 +140,9 @@ async fn main() -> Result<()> {
|
||||
let base = &base_centers[i % base_centers.len()];
|
||||
// Add very small jitter (1% of scale) to make it more natural
|
||||
let jitter = 0.01;
|
||||
let x = base[0] + (rng.gen::<f64>() - 0.5) * jitter * scale;
|
||||
let y = base[1] + (rng.gen::<f64>() - 0.5) * jitter * scale;
|
||||
let z = base[2] + (rng.gen::<f64>() - 0.5) * jitter * scale;
|
||||
let x = base[0] + (rng.gen::<f32>() - 0.5) * jitter * scale;
|
||||
let y = base[1] + (rng.gen::<f32>() - 0.5) * jitter * scale;
|
||||
let z = base[2] + (rng.gen::<f32>() - 0.5) * jitter * scale;
|
||||
|
||||
cluster_centers.push(vec![x, y, z]);
|
||||
}
|
||||
@@ -164,7 +164,7 @@ async fn main() -> Result<()> {
|
||||
"CREATE TABLE item_clusters (
|
||||
item_id INTEGER PRIMARY KEY,
|
||||
cluster_id INTEGER,
|
||||
cluster_affinities FLOAT[] -- Array of affinities to archetypal items
|
||||
cluster_affinities FLOAT4[] -- Array of affinities to archetypal items
|
||||
)",
|
||||
&[],
|
||||
)
|
||||
@@ -177,7 +177,7 @@ async fn main() -> Result<()> {
|
||||
|
||||
// Create cluster affinity vectors based on item co-occurrence patterns
|
||||
let mut cluster_affinities =
|
||||
vec![vec![0.0; args.num_items as usize]; args.item_clusters as usize];
|
||||
vec![vec![0.0f32; args.num_items as usize]; args.item_clusters as usize];
|
||||
|
||||
// For each cluster, select a set of archetypal items that define the cluster
|
||||
for cluster_id in 0..args.item_clusters {
|
||||
@@ -214,7 +214,7 @@ async fn main() -> Result<()> {
|
||||
items_per_cluster[cluster_id as usize].push(item_id);
|
||||
|
||||
// Store cluster assignment with affinity vector
|
||||
let affinities: Vec<f64> = cluster_affinities[cluster_id as usize].clone();
|
||||
let affinities: Vec<f32> = cluster_affinities[cluster_id as usize].clone();
|
||||
client
|
||||
.execute(
|
||||
"INSERT INTO item_clusters (item_id, cluster_id, cluster_affinities)
|
||||
@@ -251,7 +251,7 @@ async fn main() -> Result<()> {
|
||||
|
||||
// Interact with most preferred items (1 - noise_level of items in preferred cluster)
|
||||
let num_interactions =
|
||||
(preferred_items.len() as f64 * (1.0 - args.noise_level)) as usize;
|
||||
(preferred_items.len() as f32 * (1.0 - args.noise_level)) as usize;
|
||||
let selected_items: Vec<_> = preferred_items.into_iter().collect();
|
||||
let mut interactions: HashSet<_> = selected_items
|
||||
.choose_multiple(&mut rng, num_interactions)
|
||||
@@ -259,7 +259,7 @@ async fn main() -> Result<()> {
|
||||
.collect();
|
||||
|
||||
// Add very few random items from non-preferred clusters (noise)
|
||||
let num_noise = (args.avg_interactions as f64 * args.noise_level) as i32;
|
||||
let num_noise = (args.avg_interactions as f32 * args.noise_level) as i32;
|
||||
while interactions.len() < num_interactions + num_noise as usize {
|
||||
let item_id = rng.gen_range(0..args.num_items);
|
||||
if !interactions.contains(&item_id) {
|
||||
@@ -293,8 +293,8 @@ async fn main() -> Result<()> {
|
||||
"SELECT COUNT(*) as total_interactions,
|
||||
COUNT(DISTINCT user_id) as unique_users,
|
||||
COUNT(DISTINCT item_id) as unique_items,
|
||||
COUNT(DISTINCT user_id)::float / {} as user_coverage,
|
||||
COUNT(DISTINCT item_id)::float / {} as item_coverage
|
||||
COUNT(DISTINCT user_id)::float4 / {} as user_coverage,
|
||||
COUNT(DISTINCT item_id)::float4 / {} as item_coverage
|
||||
FROM {}",
|
||||
args.num_users, args.num_items, args.interactions_table
|
||||
),
|
||||
|
||||
@@ -47,7 +47,7 @@ async fn create_pool() -> Result<Pool> {
|
||||
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
||||
}
|
||||
|
||||
fn perform_pca(data: &Array2<f64>, n_components: usize) -> Result<Array2<f64>> {
|
||||
fn perform_pca(data: &Array2<f32>, n_components: usize) -> Result<Array2<f32>> {
|
||||
// Center the data
|
||||
let means = data.mean_axis(ndarray::Axis(0)).unwrap();
|
||||
let centered = data.clone() - means.view().insert_axis(ndarray::Axis(0));
|
||||
@@ -82,8 +82,8 @@ async fn main() -> Result<()> {
|
||||
let rows = client.query(&query, &[]).await?;
|
||||
let n_items = rows.len();
|
||||
let (n_dims, affinity_dims) = if let Some(first_row) = rows.first() {
|
||||
let embedding: Vec<f64> = first_row.get(1);
|
||||
let affinities: Vec<f64> = first_row.get(3);
|
||||
let embedding: Vec<f32> = first_row.get(1);
|
||||
let affinities: Vec<f32> = first_row.get(3);
|
||||
(embedding.len(), affinities.len())
|
||||
} else {
|
||||
return Ok(());
|
||||
@@ -102,9 +102,9 @@ async fn main() -> Result<()> {
|
||||
|
||||
for (i, row) in rows.iter().enumerate() {
|
||||
let item_id: i32 = row.get(0);
|
||||
let embedding: Vec<f64> = row.get(1);
|
||||
let embedding: Vec<f32> = row.get(1);
|
||||
let cluster_id: i32 = row.get(2);
|
||||
let affinities: Vec<f64> = row.get(3);
|
||||
let affinities: Vec<f32> = row.get(3);
|
||||
|
||||
item_ids.push(item_id);
|
||||
cluster_ids.push(cluster_id);
|
||||
|
||||
81
src/main.rs
81
src/main.rs
@@ -121,42 +121,45 @@ async fn load_data_batch(
|
||||
Ok(batch)
|
||||
}
|
||||
|
||||
// TODO - don't load all item IDs at once
|
||||
async fn save_embeddings(pool: &Pool, args: &Args, model: &Model, item_ids: &[i32]) -> Result<()> {
|
||||
async fn save_embeddings(pool: &Pool, args: &Args, model: &Model) -> Result<()> {
|
||||
let client = pool.get().await?;
|
||||
|
||||
// Create the target table if it doesn't exist
|
||||
let create_table = format!(
|
||||
"CREATE TABLE IF NOT EXISTS {} (item_id INTEGER PRIMARY KEY, embedding FLOAT[])",
|
||||
"CREATE TABLE IF NOT EXISTS {} (item_id INTEGER PRIMARY KEY, embedding FLOAT4[])",
|
||||
args.target_table
|
||||
);
|
||||
client.execute(&create_table, &[]).await?;
|
||||
|
||||
// Get all factors at once
|
||||
let all_factors: Vec<_> = model.q_iter().collect();
|
||||
info!("Generated {} item embeddings", all_factors.len());
|
||||
|
||||
let mut valid_embeddings = 0;
|
||||
let mut invalid_embeddings = 0;
|
||||
let batch_size = 128;
|
||||
let mut current_batch = Vec::with_capacity(batch_size);
|
||||
let mut current_idx = 0;
|
||||
|
||||
for &item_id in item_ids {
|
||||
let factors = &all_factors[item_id as usize];
|
||||
let factors_array: Vec<f64> = factors.iter().map(|&x| x as f64).collect();
|
||||
|
||||
// Check if the embedding contains any NaN values
|
||||
if factors_array.iter().any(|x| x.is_nan()) {
|
||||
// Process factors in chunks using the iterator directly
|
||||
for factors in model.q_iter() {
|
||||
// Skip invalid embeddings
|
||||
if factors.iter().any(|&x| x.is_nan()) {
|
||||
invalid_embeddings += 1;
|
||||
current_idx += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
valid_embeddings += 1;
|
||||
let query = format!(
|
||||
"INSERT INTO {} (item_id, embedding) VALUES ($1, $2)
|
||||
ON CONFLICT (item_id) DO UPDATE SET embedding = EXCLUDED.embedding",
|
||||
args.target_table
|
||||
);
|
||||
current_batch.push((current_idx, factors));
|
||||
current_idx += 1;
|
||||
|
||||
client.execute(&query, &[&item_id, &factors_array]).await?;
|
||||
// When batch is full, save it
|
||||
if current_batch.len() >= batch_size {
|
||||
save_batch(&client, &args.target_table, ¤t_batch).await?;
|
||||
current_batch.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Save any remaining items in the last batch
|
||||
if !current_batch.is_empty() {
|
||||
save_batch(&client, &args.target_table, ¤t_batch).await?;
|
||||
}
|
||||
|
||||
info!(
|
||||
@@ -167,17 +170,33 @@ async fn save_embeddings(pool: &Pool, args: &Args, model: &Model, item_ids: &[i3
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_unique_item_ids(
|
||||
async fn save_batch(
|
||||
client: &deadpool_postgres::Client,
|
||||
source_table: &str,
|
||||
item_id_column: &str,
|
||||
) -> Result<Vec<i32>> {
|
||||
target_table: &str,
|
||||
batch_values: &[(i32, &[f32])],
|
||||
) -> Result<()> {
|
||||
// Build the batch insert query
|
||||
let placeholders: Vec<String> = (0..batch_values.len())
|
||||
.map(|i| format!("(${}, ${})", i * 2 + 1, i * 2 + 2))
|
||||
.collect();
|
||||
let query = format!(
|
||||
"SELECT DISTINCT {} FROM {} ORDER BY {}",
|
||||
item_id_column, source_table, item_id_column
|
||||
r#"
|
||||
INSERT INTO {target_table} (item_id, embedding) VALUES {placeholders}
|
||||
ON CONFLICT (item_id)
|
||||
DO UPDATE SET embedding = EXCLUDED.embedding
|
||||
"#,
|
||||
placeholders = placeholders.join(",")
|
||||
);
|
||||
let rows = client.query(&query, &[]).await?;
|
||||
Ok(rows.iter().map(|row| row.get(0)).collect())
|
||||
|
||||
// Flatten parameters for the query
|
||||
let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = Vec::new();
|
||||
for (item_id, factors) in batch_values {
|
||||
params.push(item_id);
|
||||
params.push(factors);
|
||||
}
|
||||
|
||||
client.execute(&query, ¶ms[..]).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@@ -243,12 +262,6 @@ async fn main() -> Result<()> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Get unique item IDs from database
|
||||
let client = pool.get().await?;
|
||||
let unique_item_ids =
|
||||
get_unique_item_ids(&client, &args.source_table, &args.item_id_column).await?;
|
||||
info!("Found {} unique items", unique_item_ids.len());
|
||||
|
||||
// Set up training parameters
|
||||
let model = Model::params()
|
||||
.factors(args.factors)
|
||||
@@ -266,7 +279,7 @@ async fn main() -> Result<()> {
|
||||
.fit(&matrix)?;
|
||||
|
||||
info!("Saving embeddings...");
|
||||
save_embeddings(&pool, &args, &model, &unique_item_ids).await?;
|
||||
save_embeddings(&pool, &args, &model).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user