cluster validation

This commit is contained in:
Dylan Knutson
2024-12-28 01:46:48 +00:00
parent f7bb5b0cdd
commit 00b30ac285
3 changed files with 233 additions and 1 deletions

View File

@@ -119,10 +119,46 @@ async fn main() -> Result<()> {
)
.await?;
// Create table for cluster assignments
client
.execute(
"CREATE TABLE IF NOT EXISTS item_clusters (
item_id INTEGER PRIMARY KEY,
cluster_id INTEGER,
cluster_center FLOAT
)",
&[],
)
.await?;
// Generate cluster information
let user_clusters = generate_cluster_info(args.user_clusters);
let item_clusters = generate_cluster_info(args.item_clusters);
// Store item cluster assignments
let mut item_to_cluster = HashMap::new();
let mut rng = rand::thread_rng();
let user_cluster_dist = Uniform::new(0, args.user_clusters);
let item_cluster_dist = Uniform::new(0, args.item_clusters);
// Assign and store cluster information for each item
for item_id in 0..args.num_items {
let cluster_id = item_cluster_dist.sample(&mut rng);
let cluster = item_clusters[cluster_id as usize];
item_to_cluster.insert(item_id, (cluster_id, cluster));
client
.execute(
"INSERT INTO item_clusters (item_id, cluster_id, cluster_center)
VALUES ($1, $2, $3)
ON CONFLICT (item_id) DO UPDATE
SET cluster_id = EXCLUDED.cluster_id,
cluster_center = EXCLUDED.cluster_center",
&[&item_id, &cluster_id, &cluster.center],
)
.await?;
}
let mut rng = rand::thread_rng();
let user_cluster_dist = Uniform::new(0, args.user_clusters);
let item_cluster_dist = Uniform::new(0, args.item_clusters);

View File

@@ -0,0 +1,195 @@
use anyhow::{Context, Result};
use clap::Parser;
use deadpool_postgres::{Config, Pool, Runtime};
use dotenv::dotenv;
use log::info;
use std::env;
use tokio_postgres::NoTls;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Table containing user-item interactions
#[arg(long)]
interactions_table: String,
/// Table containing item embeddings
#[arg(long)]
embeddings_table: String,
/// Table containing item cluster information
#[arg(long)]
clusters_table: String,
}
async fn create_pool() -> Result<Pool> {
let mut config = Config::new();
config.host = Some(env::var("POSTGRES_HOST").context("POSTGRES_HOST not set")?);
config.port = Some(
env::var("POSTGRES_PORT")
.context("POSTGRES_PORT not set")?
.parse()
.context("Invalid POSTGRES_PORT")?,
);
config.dbname = Some(env::var("POSTGRES_DB").context("POSTGRES_DB not set")?);
config.user = Some(env::var("POSTGRES_USER").context("POSTGRES_USER not set")?);
config.password = Some(env::var("POSTGRES_PASSWORD").context("POSTGRES_PASSWORD not set")?);
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
}
async fn analyze_cluster_cohesion(client: &tokio_postgres::Client, args: &Args) -> Result<()> {
info!("Analyzing cluster cohesion...");
// Calculate within-cluster and between-cluster distances
let query = format!(
"WITH distances AS (
SELECT
a.cluster_id = b.cluster_id as same_cluster,
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance
FROM {} a
JOIN {} b ON a.item_id < b.item_id
JOIN {} ie1 ON a.item_id = ie1.item_id
JOIN {} ie2 ON b.item_id = ie2.item_id,
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
GROUP BY a.item_id, b.item_id, same_cluster
)
SELECT
CASE WHEN same_cluster THEN 'Within Cluster' ELSE 'Between Clusters' END as comparison,
AVG(distance)::float8 as avg_distance,
STDDEV(distance)::float8 as stddev_distance,
MIN(distance)::float8 as min_distance,
MAX(distance)::float8 as max_distance,
COUNT(*) as num_pairs
FROM distances
GROUP BY same_cluster
ORDER BY same_cluster DESC",
args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table
);
let rows = client.query(&query, &[]).await?;
for row in rows {
let comparison: &str = row.get(0);
let avg: f64 = row.get(1);
let stddev: f64 = row.get(2);
let min: f64 = row.get(3);
let max: f64 = row.get(4);
let count: i64 = row.get(5);
info!(
"{}: avg={:.3}, stddev={:.3}, min={:.3}, max={:.3}, pairs={}",
comparison, avg, stddev, min, max, count
);
}
Ok(())
}
async fn analyze_embedding_stats(client: &tokio_postgres::Client, args: &Args) -> Result<()> {
info!("Analyzing embedding statistics...");
// Calculate embedding norms and component statistics
let query = format!(
"WITH stats AS (
SELECT
ie.item_id,
c.cluster_id,
SQRT(SUM(x * x))::float8 as norm,
AVG(x)::float8 as avg_component,
STDDEV(x)::float8 as stddev_component,
MIN(x)::float8 as min_component,
MAX(x)::float8 as max_component
FROM {} ie
JOIN {} c ON ie.item_id = c.item_id,
UNNEST(embedding) x
GROUP BY ie.item_id, c.cluster_id
)
SELECT
cluster_id,
AVG(norm)::float8 as avg_norm,
STDDEV(norm)::float8 as stddev_norm,
AVG(avg_component)::float8 as avg_component,
AVG(stddev_component)::float8 as avg_component_spread,
COUNT(*) as num_items
FROM stats
GROUP BY cluster_id
ORDER BY cluster_id",
args.embeddings_table, args.clusters_table
);
let rows = client.query(&query, &[]).await?;
info!("Per-cluster embedding statistics:");
for row in rows {
let cluster_id: i32 = row.get(0);
let avg_norm: f64 = row.get(1);
let stddev_norm: f64 = row.get(2);
let avg_component: f64 = row.get(3);
let avg_spread: f64 = row.get(4);
let count: i64 = row.get(5);
info!(
"Cluster {}: {} items, norm={:.3}±{:.3}, components={:.3}±{:.3}",
cluster_id, count, avg_norm, stddev_norm, avg_component, avg_spread
);
}
Ok(())
}
async fn analyze_cluster_correlation(client: &tokio_postgres::Client, args: &Args) -> Result<()> {
info!("Analyzing correlation between cluster centers and embedding distances...");
// Calculate correlation between cluster center distances and embedding distances
let query = format!(
"WITH distances AS (
SELECT
a.cluster_id as cluster1,
b.cluster_id as cluster2,
ABS(a.cluster_center - b.cluster_center)::float8 as center_distance,
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as embedding_distance
FROM {} a
JOIN {} b ON a.item_id < b.item_id
JOIN {} ie1 ON a.item_id = ie1.item_id
JOIN {} ie2 ON b.item_id = ie2.item_id,
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
GROUP BY a.item_id, b.item_id, a.cluster_id, b.cluster_id, a.cluster_center, b.cluster_center
)
SELECT
CORR(center_distance, embedding_distance)::float8 as correlation,
COUNT(*) as num_pairs,
AVG(center_distance)::float8 as avg_center_dist,
AVG(embedding_distance)::float8 as avg_emb_dist
FROM distances",
args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table
);
let row = client.query_one(&query, &[]).await?;
info!(
"Correlation between cluster center distances and embedding distances: {:.3} (from {} pairs)",
row.get::<_, f64>(0),
row.get::<_, i64>(1)
);
info!(
"Average distances: cluster centers={:.3}, embeddings={:.3}",
row.get::<_, f64>(2),
row.get::<_, f64>(3)
);
Ok(())
}
#[tokio::main]
async fn main() -> Result<()> {
dotenv().ok();
pretty_env_logger::init();
let args = Args::parse();
let pool = create_pool().await?;
let client = pool.get().await?;
analyze_cluster_cohesion(&client, &args).await?;
analyze_embedding_stats(&client, &args).await?;
analyze_cluster_correlation(&client, &args).await?;
Ok(())
}

View File

@@ -188,7 +188,8 @@ async fn main() -> Result<()> {
.lambda_q2(args.lambda)
.learning_rate(0.01)
.iterations(100)
.loss(libmf::Loss::BinaryL2)
.loss(libmf::Loss::OneClassL2)
.c(0.00001)
.quiet(false)
.fit(&matrix)?;