Files
mf-fitter/src/bin/validate_embeddings.rs
2024-12-28 03:32:38 +00:00

280 lines
10 KiB
Rust

use anyhow::{Context, Result};
use clap::Parser;
use deadpool_postgres::{Config, Pool, Runtime};
use dotenv::dotenv;
use log::info;
use std::env;
use tokio_postgres::NoTls;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Table containing user-item interactions
#[arg(long)]
interactions_table: String,
/// Table containing item embeddings
#[arg(long)]
embeddings_table: String,
/// Table containing item cluster information
#[arg(long)]
clusters_table: String,
}
async fn create_pool() -> Result<Pool> {
let mut config = Config::new();
config.host = Some(env::var("POSTGRES_HOST").context("POSTGRES_HOST not set")?);
config.port = Some(
env::var("POSTGRES_PORT")
.context("POSTGRES_PORT not set")?
.parse()
.context("Invalid POSTGRES_PORT")?,
);
config.dbname = Some(env::var("POSTGRES_DB").context("POSTGRES_DB not set")?);
config.user = Some(env::var("POSTGRES_USER").context("POSTGRES_USER not set")?);
config.password = Some(env::var("POSTGRES_PASSWORD").context("POSTGRES_PASSWORD not set")?);
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
}
async fn analyze_cluster_cohesion(client: &tokio_postgres::Client, args: &Args) -> Result<()> {
info!("Analyzing cluster cohesion...");
// Calculate cosine similarity between affinity vectors
let cohesion_stats = client
.query_one(
"WITH affinity_similarities AS (
SELECT
a.item_id as item1,
b.item_id as item2,
a.cluster_id as cluster1,
b.cluster_id as cluster2,
-- Compute cosine similarity between affinity vectors
SUM(a1 * b1) / (SQRT(SUM(a1 * a1)) * SQRT(SUM(b1 * b1))) as similarity,
CASE WHEN a.cluster_id = b.cluster_id THEN 'within' ELSE 'between' END as similarity_type
FROM item_clusters a
CROSS JOIN item_clusters b
CROSS JOIN UNNEST(a.cluster_affinities, b.cluster_affinities) AS t(a1, b1)
WHERE a.item_id < b.item_id
GROUP BY a.item_id, b.item_id, a.cluster_id, b.cluster_id
)
SELECT
AVG(CASE WHEN similarity_type = 'within' THEN similarity END) as avg_within,
STDDEV(CASE WHEN similarity_type = 'within' THEN similarity END) as stddev_within,
MIN(CASE WHEN similarity_type = 'within' THEN similarity END) as min_within,
MAX(CASE WHEN similarity_type = 'within' THEN similarity END) as max_within,
COUNT(CASE WHEN similarity_type = 'within' THEN 1 END) as count_within,
AVG(CASE WHEN similarity_type = 'between' THEN similarity END) as avg_between,
STDDEV(CASE WHEN similarity_type = 'between' THEN similarity END) as stddev_between,
MIN(CASE WHEN similarity_type = 'between' THEN similarity END) as min_between,
MAX(CASE WHEN similarity_type = 'between' THEN similarity END) as max_between,
COUNT(CASE WHEN similarity_type = 'between' THEN 1 END) as count_between
FROM affinity_similarities",
&[],
)
.await?;
// Print cohesion statistics
info!(
"Within Cluster Similarity: avg={:.3}, stddev={:.3}, min={:.3}, max={:.3}, pairs={}",
cohesion_stats.get::<_, f64>("avg_within"),
cohesion_stats.get::<_, f64>("stddev_within"),
cohesion_stats.get::<_, f64>("min_within"),
cohesion_stats.get::<_, f64>("max_within"),
cohesion_stats.get::<_, i64>("count_within")
);
info!(
"Between Clusters Similarity: avg={:.3}, stddev={:.3}, min={:.3}, max={:.3}, pairs={}",
cohesion_stats.get::<_, f64>("avg_between"),
cohesion_stats.get::<_, f64>("stddev_between"),
cohesion_stats.get::<_, f64>("min_between"),
cohesion_stats.get::<_, f64>("max_between"),
cohesion_stats.get::<_, i64>("count_between")
);
// Print per-cluster statistics
info!("\nPer-cluster cohesion:");
let cluster_stats = client
.query(
"WITH cluster_similarities AS (
SELECT
a.cluster_id,
SUM(a1 * b1) / (SQRT(SUM(a1 * a1)) * SQRT(SUM(b1 * b1))) as similarity
FROM item_clusters a
JOIN item_clusters b ON a.cluster_id = b.cluster_id AND a.item_id < b.item_id
CROSS JOIN UNNEST(a.cluster_affinities, b.cluster_affinities) AS t(a1, b1)
GROUP BY a.cluster_id, a.item_id, b.item_id
)
SELECT
cluster_id,
AVG(similarity) as avg_similarity,
STDDEV(similarity) as stddev_similarity,
COUNT(*) as num_pairs
FROM cluster_similarities
GROUP BY cluster_id
ORDER BY cluster_id",
&[],
)
.await?;
for row in cluster_stats {
let cluster_id: i32 = row.get("cluster_id");
let avg_similarity: f64 = row.get("avg_similarity");
let stddev_similarity: f64 = row.get("stddev_similarity");
let num_pairs: i64 = row.get("num_pairs");
info!(
"Cluster {}: avg_similarity={:.3}, stddev={:.3}, pairs={}",
cluster_id, avg_similarity, stddev_similarity, num_pairs
);
}
// Calculate separation between specific cluster pairs
let query = format!(
"WITH similarities AS (
SELECT
a.cluster_id as cluster1,
b.cluster_id as cluster2,
SUM(a1 * b1) / (SQRT(SUM(a1 * a1)) * SQRT(SUM(b1 * b1))) as similarity
FROM item_clusters a
JOIN item_clusters b ON a.cluster_id < b.cluster_id
CROSS JOIN UNNEST(a.cluster_affinities, b.cluster_affinities) AS t(a1, b1)
GROUP BY a.cluster_id, b.cluster_id, a.item_id, b.item_id
)
SELECT
cluster1,
cluster2,
AVG(similarity) as avg_similarity,
STDDEV(similarity) as stddev_similarity,
COUNT(*) as num_pairs
FROM similarities
GROUP BY cluster1, cluster2
ORDER BY cluster1, cluster2",
);
info!("\nBetween-cluster separation:");
let rows = client.query(&query, &[]).await?;
for row in rows {
let cluster1: i32 = row.get(0);
let cluster2: i32 = row.get(1);
let avg: f64 = row.get(2);
let stddev: f64 = row.get(3);
let count: i64 = row.get(4);
info!(
"Clusters {} <-> {}: avg_similarity={:.3}±{:.3} ({} pairs)",
cluster1, cluster2, avg, stddev, count
);
}
Ok(())
}
async fn analyze_embedding_stats(client: &tokio_postgres::Client, args: &Args) -> Result<()> {
info!("Analyzing embedding statistics...");
// Calculate embedding norms and component statistics
let query = format!(
"WITH stats AS (
SELECT
ie.item_id,
c.cluster_id,
SQRT(SUM(x * x))::float8 as norm,
AVG(x)::float8 as avg_component,
STDDEV(x)::float8 as stddev_component,
MIN(x)::float8 as min_component,
MAX(x)::float8 as max_component
FROM {} ie
JOIN {} c ON ie.item_id = c.item_id,
UNNEST(embedding) x
GROUP BY ie.item_id, c.cluster_id
)
SELECT
cluster_id,
AVG(norm)::float8 as avg_norm,
STDDEV(norm)::float8 as stddev_norm,
AVG(avg_component)::float8 as avg_component,
AVG(stddev_component)::float8 as avg_component_spread,
COUNT(*) as num_items
FROM stats
GROUP BY cluster_id
ORDER BY cluster_id",
args.embeddings_table, args.clusters_table
);
let rows = client.query(&query, &[]).await?;
info!("Per-cluster embedding statistics:");
for row in rows {
let cluster_id: i32 = row.get(0);
let avg_norm: f64 = row.get(1);
let stddev_norm: f64 = row.get(2);
let avg_component: f64 = row.get(3);
let avg_spread: f64 = row.get(4);
let count: i64 = row.get(5);
info!(
"Cluster {}: {} items, norm={:.3}±{:.3}, components={:.3}±{:.3}",
cluster_id, count, avg_norm, stddev_norm, avg_component, avg_spread
);
}
Ok(())
}
async fn analyze_cluster_correlation(client: &tokio_postgres::Client, args: &Args) -> Result<()> {
info!("Analyzing correlation between cluster affinities and embedding similarities...");
// Calculate correlation between affinity similarities and embedding similarities
let correlation = client
.query_one(
"WITH distances AS (
SELECT
a.cluster_id as cluster1,
b.cluster_id as cluster2,
-- Compute affinity similarity
SUM(a1 * b1) / (SQRT(SUM(a1 * a1)) * SQRT(SUM(b1 * b1))) as affinity_similarity,
-- Compute embedding similarity
SUM(e1 * e2) / (SQRT(SUM(e1 * e1)) * SQRT(SUM(e2 * e2))) as embedding_similarity
FROM item_clusters a
JOIN item_clusters b ON a.cluster_id < b.cluster_id
JOIN item_embeddings ae ON a.item_id = ae.item_id
JOIN item_embeddings be ON b.item_id = be.item_id
CROSS JOIN UNNEST(a.cluster_affinities, b.cluster_affinities) AS t1(a1, b1)
CROSS JOIN UNNEST(ae.embedding, be.embedding) AS t2(e1, e2)
GROUP BY a.cluster_id, b.cluster_id, a.item_id, b.item_id
)
SELECT
corr(affinity_similarity, embedding_similarity) as correlation,
COUNT(*) as num_pairs
FROM distances",
&[],
)
.await?;
let correlation_value: f64 = correlation.get("correlation");
let num_pairs: i64 = correlation.get("num_pairs");
info!(
"Correlation between affinity similarities and embedding similarities: {:.3} ({} pairs)",
correlation_value, num_pairs
);
Ok(())
}
#[tokio::main]
async fn main() -> Result<()> {
dotenv().ok();
pretty_env_logger::init();
let args = Args::parse();
let pool = create_pool().await?;
let client = pool.get().await?;
analyze_cluster_cohesion(&client, &args).await?;
analyze_embedding_stats(&client, &args).await?;
analyze_cluster_correlation(&client, &args).await?;
Ok(())
}