more fixes

This commit is contained in:
Dylan Knutson
2024-12-28 03:04:50 +00:00
parent 56b6604142
commit 32a7292481
5 changed files with 282 additions and 192 deletions

1
Cargo.lock generated
View File

@@ -1183,6 +1183,7 @@ dependencies = [
"log",
"ndarray",
"ndarray-linalg",
"num_cpus",
"plotly",
"pretty_env_logger",
"rand",

View File

@@ -13,6 +13,7 @@ libmf = "0.3"
log = "0.4"
ndarray = { version = "0.15", features = ["blas"] }
ndarray-linalg = { version = "0.16", features = ["openblas-system"] }
num_cpus = "1.16"
plotly = { version = "0.8", features = ["kaleido"] }
pretty_env_logger = "0.5"
rand = "0.8"

View File

@@ -2,11 +2,9 @@ use anyhow::{Context, Result};
use clap::Parser;
use deadpool_postgres::{Config, Pool, Runtime};
use dotenv::dotenv;
use log::{info, warn};
use rand::distributions::{Distribution, Uniform};
use log::info;
use rand::seq::SliceRandom;
use rand::Rng;
use rand_distr::Normal;
use std::collections::{HashMap, HashSet};
use std::env;
use tokio_postgres::NoTls;
@@ -101,26 +99,33 @@ async fn main() -> Result<()> {
)
.await?;
// Create table for cluster assignments
client
.execute(
"CREATE TABLE IF NOT EXISTS item_clusters (
item_id INTEGER PRIMARY KEY,
cluster_id INTEGER,
cluster_center FLOAT
)",
&[],
)
.await?;
// Generate cluster centers that are well-separated
// Generate cluster centers that are well-separated in 3D space
let mut rng = rand::thread_rng();
let mut cluster_centers = Vec::new();
for i in 0..args.item_clusters {
// Place centers evenly around a circle in 2D space, then project to 1D
let angle = 2.0 * std::f64::consts::PI * (i as f64) / (args.item_clusters as f64);
let center = angle.cos(); // Project to 1D by taking cosine
cluster_centers.push(center);
// Base vertices of a regular octahedron
let base_centers = vec![
vec![1.0, 0.0, 0.0], // +x
vec![-1.0, 0.0, 0.0], // -x
vec![0.0, 1.0, 0.0], // +y
vec![0.0, -1.0, 0.0], // -y
vec![0.0, 0.0, 1.0], // +z
vec![0.0, 0.0, -1.0], // -z
];
// Scale factor to control separation
let scale = 2.0;
// Add jittered versions of the base centers
for i in 0..args.item_clusters as usize {
let base = &base_centers[i % base_centers.len()];
// Add controlled random jitter (up to 10% of the scale)
let jitter = 0.1;
let x = base[0] * scale + (rng.gen::<f64>() - 0.5) * jitter * scale;
let y = base[1] * scale + (rng.gen::<f64>() - 0.5) * jitter * scale;
let z = base[2] * scale + (rng.gen::<f64>() - 0.5) * jitter * scale;
cluster_centers.push(vec![x, y, z]);
}
// Create shuffled list of items
@@ -130,35 +135,54 @@ async fn main() -> Result<()> {
// Assign items to clusters in contiguous blocks
let mut items_per_cluster = vec![Vec::new(); args.item_clusters as usize];
let items_per_cluster_count = args.num_items / args.item_clusters;
// Drop existing item_clusters table and recreate with 3D centers
client
.execute("DROP TABLE IF EXISTS item_clusters", &[])
.await?;
client
.execute(
"CREATE TABLE item_clusters (
item_id INTEGER PRIMARY KEY,
cluster_id INTEGER,
center_x FLOAT,
center_y FLOAT,
center_z FLOAT
)",
&[],
)
.await?;
// Clear existing interactions
client
.execute(&format!("TRUNCATE TABLE {}", args.interactions_table), &[])
.await?;
for (i, &item_id) in all_items.iter().enumerate() {
let cluster_id = (i as i32) / items_per_cluster_count;
if cluster_id < args.item_clusters {
items_per_cluster[cluster_id as usize].push(item_id);
// Store cluster assignment with meaningful center
// Store cluster assignment with 3D center
let center = &cluster_centers[cluster_id as usize];
client
.execute(
"INSERT INTO item_clusters (item_id, cluster_id, cluster_center)
VALUES ($1, $2, $3)
ON CONFLICT (item_id) DO UPDATE
SET cluster_id = EXCLUDED.cluster_id,
cluster_center = EXCLUDED.cluster_center",
&[&item_id, &cluster_id, &cluster_centers[cluster_id as usize]],
"INSERT INTO item_clusters (item_id, cluster_id, center_x, center_y, center_z)
VALUES ($1, $2, $3, $4, $5)",
&[&item_id, &cluster_id, &center[0], &center[1], &center[2]],
)
.await?;
}
}
// Generate interactions with strong cluster affinity
// Generate interactions with very strong cluster affinity
info!("Generating interactions...");
let mut item_interactions = HashMap::new();
// For each user cluster
for user_cluster in 0..args.user_clusters {
// Each user cluster strongly prefers 1-2 item clusters
let preferred_item_clusters: HashSet<i32> = (0..args.item_clusters)
.filter(|&i| i % args.user_clusters == user_cluster)
.collect();
// Each user cluster strongly prefers exactly one item cluster with wraparound
let preferred_item_cluster = user_cluster % args.item_clusters;
// For each user in this cluster
let cluster_start = user_cluster * (args.num_users / args.user_clusters);
@@ -169,13 +193,13 @@ async fn main() -> Result<()> {
};
for user_id in cluster_start..cluster_end {
// Get all items from preferred clusters
let mut preferred_items = HashSet::new();
for &item_cluster in &preferred_item_clusters {
preferred_items.extend(items_per_cluster[item_cluster as usize].iter());
}
// Get all items from preferred cluster
let preferred_items: HashSet<_> = items_per_cluster[preferred_item_cluster as usize]
.iter()
.copied()
.collect();
// Interact with most preferred items (1 - noise_level of items in preferred clusters)
// Interact with most preferred items (1 - noise_level of items in preferred cluster)
let num_interactions =
(preferred_items.len() as f64 * (1.0 - args.noise_level)) as usize;
let selected_items: Vec<_> = preferred_items.into_iter().collect();
@@ -184,7 +208,7 @@ async fn main() -> Result<()> {
.copied()
.collect();
// Add random items from non-preferred clusters (noise)
// Add very few random items from non-preferred clusters (noise)
let num_noise = (args.avg_interactions as f64 * args.noise_level) as i32;
while interactions.len() < num_interactions + num_noise as usize {
let item_id = rng.gen_range(0..args.num_items);

View File

@@ -41,85 +41,96 @@ async fn create_pool() -> Result<Pool> {
async fn analyze_cluster_cohesion(client: &tokio_postgres::Client, args: &Args) -> Result<()> {
info!("Analyzing cluster cohesion...");
// Calculate within-cluster and between-cluster distances
let query = format!(
"WITH distances AS (
// Analyze cluster cohesion
info!("Analyzing cluster cohesion...");
let cohesion_stats = client
.query_one(
"WITH embedding_distances AS (
SELECT
a.item_id as item1,
b.item_id as item2,
a.cluster_id as cluster1,
b.cluster_id as cluster2,
SQRT(
POW(a.center_x - b.center_x, 2) +
POW(a.center_y - b.center_y, 2) +
POW(a.center_z - b.center_z, 2)
) as distance,
CASE WHEN a.cluster_id = b.cluster_id THEN 'within' ELSE 'between' END as distance_type
FROM item_clusters a
CROSS JOIN item_clusters b
WHERE a.item_id < b.item_id
)
SELECT
a.cluster_id = b.cluster_id as same_cluster,
a.cluster_id as cluster1,
b.cluster_id as cluster2,
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance
FROM {} a
JOIN {} b ON a.item_id < b.item_id
JOIN {} ie1 ON a.item_id = ie1.item_id
JOIN {} ie2 ON b.item_id = ie2.item_id,
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
GROUP BY a.item_id, b.item_id, same_cluster, a.cluster_id, b.cluster_id
AVG(CASE WHEN distance_type = 'within' THEN distance END) as avg_within,
STDDEV(CASE WHEN distance_type = 'within' THEN distance END) as stddev_within,
MIN(CASE WHEN distance_type = 'within' THEN distance END) as min_within,
MAX(CASE WHEN distance_type = 'within' THEN distance END) as max_within,
COUNT(CASE WHEN distance_type = 'within' THEN 1 END) as count_within,
AVG(CASE WHEN distance_type = 'between' THEN distance END) as avg_between,
STDDEV(CASE WHEN distance_type = 'between' THEN distance END) as stddev_between,
MIN(CASE WHEN distance_type = 'between' THEN distance END) as min_between,
MAX(CASE WHEN distance_type = 'between' THEN distance END) as max_between,
COUNT(CASE WHEN distance_type = 'between' THEN 1 END) as count_between
FROM embedding_distances",
&[],
)
SELECT
CASE WHEN same_cluster THEN 'Within Cluster' ELSE 'Between Clusters' END as comparison,
AVG(distance)::float8 as avg_distance,
STDDEV(distance)::float8 as stddev_distance,
MIN(distance)::float8 as min_distance,
MAX(distance)::float8 as max_distance,
COUNT(*) as num_pairs
FROM distances
GROUP BY same_cluster
ORDER BY same_cluster DESC",
args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table
.await?;
// Print cohesion statistics
info!(
"Within Cluster: avg={:.3}, stddev={:.3}, min={:.3}, max={:.3}, pairs={}",
cohesion_stats.get::<_, f64>("avg_within"),
cohesion_stats.get::<_, f64>("stddev_within"),
cohesion_stats.get::<_, f64>("min_within"),
cohesion_stats.get::<_, f64>("max_within"),
cohesion_stats.get::<_, i64>("count_within")
);
let rows = client.query(&query, &[]).await?;
for row in rows {
let comparison: &str = row.get(0);
let avg: f64 = row.get(1);
let stddev: f64 = row.get(2);
let min: f64 = row.get(3);
let max: f64 = row.get(4);
let count: i64 = row.get(5);
info!(
"{}: avg={:.3}, stddev={:.3}, min={:.3}, max={:.3}, pairs={}",
comparison, avg, stddev, min, max, count
);
}
// Calculate per-cluster statistics
let query = format!(
"WITH distances AS (
SELECT
a.cluster_id,
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance
FROM {} a
JOIN {} b ON a.item_id < b.item_id
JOIN {} ie1 ON a.item_id = ie1.item_id
JOIN {} ie2 ON b.item_id = ie2.item_id,
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
WHERE a.cluster_id = b.cluster_id
GROUP BY a.item_id, b.item_id, a.cluster_id
)
SELECT
cluster_id,
AVG(distance)::float8 as avg_distance,
STDDEV(distance)::float8 as stddev_distance,
COUNT(*) as num_pairs
FROM distances
GROUP BY cluster_id
ORDER BY cluster_id",
args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table
info!(
"Between Clusters: avg={:.3}, stddev={:.3}, min={:.3}, max={:.3}, pairs={}",
cohesion_stats.get::<_, f64>("avg_between"),
cohesion_stats.get::<_, f64>("stddev_between"),
cohesion_stats.get::<_, f64>("min_between"),
cohesion_stats.get::<_, f64>("max_between"),
cohesion_stats.get::<_, i64>("count_between")
);
// Print per-cluster statistics
info!("\nPer-cluster cohesion:");
let rows = client.query(&query, &[]).await?;
for row in rows {
let cluster: i32 = row.get(0);
let avg: f64 = row.get(1);
let stddev: f64 = row.get(2);
let count: i64 = row.get(3);
let cluster_stats = client
.query(
"WITH cluster_distances AS (
SELECT
a.cluster_id,
SQRT(
POW(a.center_x - b.center_x, 2) +
POW(a.center_y - b.center_y, 2) +
POW(a.center_z - b.center_z, 2)
) as distance
FROM item_clusters a
JOIN item_clusters b ON a.cluster_id = b.cluster_id AND a.item_id < b.item_id
)
SELECT
cluster_id,
AVG(distance) as avg_distance,
STDDEV(distance) as stddev_distance,
COUNT(*) as num_pairs
FROM cluster_distances
GROUP BY cluster_id
ORDER BY cluster_id",
&[],
)
.await?;
for row in cluster_stats {
let cluster_id: i32 = row.get("cluster_id");
let avg_distance: f64 = row.get("avg_distance");
let stddev_distance: f64 = row.get("stddev_distance");
let num_pairs: i64 = row.get("num_pairs");
info!(
"Cluster {}: avg_distance={:.3}±{:.3} ({} pairs)",
cluster, avg, stddev, count
"Cluster {}: avg={:.3}, stddev={:.3}, pairs={}",
cluster_id, avg_distance, stddev_distance, num_pairs
);
}
@@ -223,39 +234,40 @@ async fn analyze_cluster_correlation(client: &tokio_postgres::Client, args: &Arg
info!("Analyzing correlation between cluster centers and embedding distances...");
// Calculate correlation between cluster center distances and embedding distances
let query = format!(
"WITH distances AS (
let correlation = client
.query_one(
"WITH distances AS (
SELECT
a.cluster_id as cluster1,
b.cluster_id as cluster2,
SQRT(
POW(a.center_x - b.center_x, 2) +
POW(a.center_y - b.center_y, 2) +
POW(a.center_z - b.center_z, 2)
) as center_distance,
SQRT(
POW(ae.embedding[1] - be.embedding[1], 2) +
POW(ae.embedding[2] - be.embedding[2], 2) +
POW(ae.embedding[3] - be.embedding[3], 2)
) as embedding_distance
FROM item_clusters a
JOIN item_clusters b ON a.cluster_id < b.cluster_id
JOIN item_embeddings ae ON a.item_id = ae.item_id
JOIN item_embeddings be ON b.item_id = be.item_id
)
SELECT
a.cluster_id as cluster1,
b.cluster_id as cluster2,
ABS(a.cluster_center - b.cluster_center)::float8 as center_distance,
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as embedding_distance
FROM {} a
JOIN {} b ON a.item_id < b.item_id
JOIN {} ie1 ON a.item_id = ie1.item_id
JOIN {} ie2 ON b.item_id = ie2.item_id,
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
GROUP BY a.item_id, b.item_id, a.cluster_id, b.cluster_id, a.cluster_center, b.cluster_center
corr(center_distance, embedding_distance) as correlation,
COUNT(*) as num_pairs
FROM distances",
&[],
)
SELECT
CORR(center_distance, embedding_distance)::float8 as correlation,
COUNT(*) as num_pairs,
AVG(center_distance)::float8 as avg_center_dist,
AVG(embedding_distance)::float8 as avg_emb_dist
FROM distances",
args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table
);
.await?;
let row = client.query_one(&query, &[]).await?;
let correlation_value: f64 = correlation.get("correlation");
let num_pairs: i64 = correlation.get("num_pairs");
info!(
"Correlation between cluster center distances and embedding distances: {:.3} (from {} pairs)",
row.get::<_, f64>(0),
row.get::<_, i64>(1)
);
info!(
"Average distances: cluster centers={:.3}, embeddings={:.3}",
row.get::<_, f64>(2),
row.get::<_, f64>(3)
"Correlation between cluster center distances and embedding distances: {:.3} ({} pairs)",
correlation_value, num_pairs
);
Ok(())

View File

@@ -2,8 +2,9 @@ use anyhow::{Context, Result};
use clap::Parser;
use deadpool_postgres::{Config, Pool, Runtime};
use dotenv::dotenv;
use libmf::Matrix;
use log::{info, warn};
use libmf::{Loss, Matrix, Model};
use log::info;
use num_cpus;
use std::env;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
@@ -39,6 +40,10 @@ struct Args {
/// Lambda for regularization
#[arg(long, default_value = "0.1")]
lambda: f32,
/// Number of threads for matrix factorization (defaults to number of CPU cores)
#[arg(long, default_value_t = num_cpus::get() as i32)]
threads: i32,
}
async fn create_pool() -> Result<Pool> {
@@ -57,33 +62,51 @@ async fn create_pool() -> Result<Pool> {
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
}
async fn load_data_batch(pool: &Pool, args: &Args, offset: i64) -> Result<Vec<(i32, i32, f32)>> {
let client = pool.get().await?;
let query = format!(
"SELECT {}, {} FROM {} OFFSET $1 LIMIT $2",
args.user_id_column, args.item_id_column, args.source_table
);
async fn load_data_batch(
client: &deadpool_postgres::Client,
source_table: &str,
user_id_column: &str,
item_id_column: &str,
batch_size: usize,
last_user_id: Option<i32>,
last_item_id: Option<i32>,
) -> Result<Vec<(i32, i32)>> {
let rows = if let (Some(last_user), Some(last_item)) = (last_user_id, last_item_id) {
let query = format!(
"SELECT {user}, {item} FROM {table} \
WHERE ({user}, {item}) > ($1, $2) \
ORDER BY {user}, {item} \
LIMIT $3",
user = user_id_column,
item = item_id_column,
table = source_table,
);
client
.query(&query, &[&last_user, &last_item, &(batch_size as i64)])
.await?
} else {
let query = format!(
"SELECT {user}, {item} FROM {table} \
ORDER BY {user}, {item} \
LIMIT $1",
user = user_id_column,
item = item_id_column,
table = source_table,
);
client.query(&query, &[&(batch_size as i64)]).await?
};
let rows = client
.query(&query, &[&offset, &(args.batch_size as i64)])
.await?;
let mut batch = Vec::with_capacity(rows.len());
for row in rows {
let user_id: i32 = row.get(0);
let item_id: i32 = row.get(1);
batch.push((user_id, item_id));
}
Ok(rows
.into_iter()
.map(|row| {
let user_id: i32 = row.get(0);
let item_id: i32 = row.get(1);
(user_id, item_id, 1.0) // Using 1.0 as interaction strength
})
.collect())
Ok(batch)
}
async fn save_embeddings(
pool: &Pool,
args: &Args,
model: &libmf::Model,
item_ids: &[i32],
) -> Result<()> {
async fn save_embeddings(pool: &Pool, args: &Args, model: &Model, item_ids: &[i32]) -> Result<()> {
let client = pool.get().await?;
// Create the target table if it doesn't exist
@@ -128,6 +151,19 @@ async fn save_embeddings(
Ok(())
}
async fn get_unique_item_ids(
client: &deadpool_postgres::Client,
source_table: &str,
item_id_column: &str,
) -> Result<Vec<i32>> {
let query = format!(
"SELECT DISTINCT {} FROM {} ORDER BY {}",
item_id_column, source_table, item_id_column
);
let rows = client.query(&query, &[]).await?;
Ok(rows.iter().map(|row| row.get(0)).collect())
}
#[tokio::main]
async fn main() -> Result<()> {
dotenv().ok();
@@ -143,59 +179,75 @@ async fn main() -> Result<()> {
})?;
let pool = create_pool().await?;
let mut offset = 0i64;
let mut all_data = Vec::new();
let mut unique_item_ids = Vec::new();
let mut matrix = Matrix::new();
let mut last_user_id = None;
let mut last_item_id = None;
let mut total_rows = 0;
info!("Starting data loading...");
while running.load(Ordering::SeqCst) {
let batch = load_data_batch(&pool, &args, offset).await?;
let batch = load_data_batch(
&pool.get().await?,
&args.source_table,
&args.user_id_column,
&args.item_id_column,
args.batch_size as usize,
last_user_id,
last_item_id,
)
.await?;
if batch.is_empty() {
break;
}
// Track unique item IDs
for &(_, item_id, _) in &batch {
if !unique_item_ids.contains(&item_id) {
unique_item_ids.push(item_id);
}
total_rows += batch.len();
info!(
"Loaded batch of {} rows (total: {})",
batch.len(),
total_rows
);
// Update last seen IDs for next batch
if let Some((user_id, item_id)) = batch.last() {
last_user_id = Some(*user_id);
last_item_id = Some(*item_id);
}
all_data.extend(batch);
offset += args.batch_size as i64;
info!("Loaded {} rows so far", all_data.len());
// Process batch
for (user_id, item_id) in batch {
matrix.push(user_id as i32, item_id as i32, 1.0f32);
}
}
if all_data.is_empty() {
warn!("No data was loaded. Exiting.");
info!("Loaded {} total rows", total_rows);
if total_rows == 0 {
info!("No data found in source table");
return Ok(());
}
info!("Creating libmf problem...");
let mut matrix = Matrix::new();
for (user_id, item_id, value) in all_data {
matrix.push(
user_id.try_into().unwrap(),
item_id.try_into().unwrap(),
value,
);
}
// Get unique item IDs from database
let client = pool.get().await?;
let unique_item_ids =
get_unique_item_ids(&client, &args.source_table, &args.item_id_column).await?;
info!("Found {} unique items", unique_item_ids.len());
info!("Training model with {} factors...", args.factors);
let model = libmf::Model::params()
.factors(args.factors.try_into().unwrap())
// Set up training parameters
let model = Model::params()
.factors(args.factors as i32)
.lambda_p2(args.lambda)
.lambda_q2(args.lambda)
.learning_rate(0.01)
.iterations(100)
.loss(libmf::Loss::OneClassL2)
.loss(Loss::OneClassL2)
.c(0.00001)
.quiet(false)
.threads(args.threads)
.fit(&matrix)?;
info!("Saving embeddings...");
save_embeddings(&pool, &args, &model, &unique_item_ids).await?;
info!("Done!");
Ok(())
}