280 lines
10 KiB
Rust
280 lines
10 KiB
Rust
use anyhow::{Context, Result};
|
|
use clap::Parser;
|
|
use deadpool_postgres::{Config, Pool, Runtime};
|
|
use dotenv::dotenv;
|
|
use log::info;
|
|
use std::env;
|
|
use tokio_postgres::NoTls;
|
|
|
|
#[derive(Parser, Debug)]
|
|
#[command(author, version, about, long_about = None)]
|
|
struct Args {
|
|
/// Table containing user-item interactions
|
|
#[arg(long)]
|
|
interactions_table: String,
|
|
|
|
/// Table containing item embeddings
|
|
#[arg(long)]
|
|
embeddings_table: String,
|
|
|
|
/// Table containing item cluster information
|
|
#[arg(long)]
|
|
clusters_table: String,
|
|
}
|
|
|
|
async fn create_pool() -> Result<Pool> {
|
|
let mut config = Config::new();
|
|
config.host = Some(env::var("POSTGRES_HOST").context("POSTGRES_HOST not set")?);
|
|
config.port = Some(
|
|
env::var("POSTGRES_PORT")
|
|
.context("POSTGRES_PORT not set")?
|
|
.parse()
|
|
.context("Invalid POSTGRES_PORT")?,
|
|
);
|
|
config.dbname = Some(env::var("POSTGRES_DB").context("POSTGRES_DB not set")?);
|
|
config.user = Some(env::var("POSTGRES_USER").context("POSTGRES_USER not set")?);
|
|
config.password = Some(env::var("POSTGRES_PASSWORD").context("POSTGRES_PASSWORD not set")?);
|
|
|
|
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
|
}
|
|
|
|
async fn analyze_cluster_cohesion(client: &tokio_postgres::Client, args: &Args) -> Result<()> {
|
|
info!("Analyzing cluster cohesion...");
|
|
|
|
// Calculate cosine similarity between affinity vectors
|
|
let cohesion_stats = client
|
|
.query_one(
|
|
"WITH affinity_similarities AS (
|
|
SELECT
|
|
a.item_id as item1,
|
|
b.item_id as item2,
|
|
a.cluster_id as cluster1,
|
|
b.cluster_id as cluster2,
|
|
-- Compute cosine similarity between affinity vectors
|
|
SUM(a1 * b1) / (SQRT(SUM(a1 * a1)) * SQRT(SUM(b1 * b1))) as similarity,
|
|
CASE WHEN a.cluster_id = b.cluster_id THEN 'within' ELSE 'between' END as similarity_type
|
|
FROM item_clusters a
|
|
CROSS JOIN item_clusters b
|
|
CROSS JOIN UNNEST(a.cluster_affinities, b.cluster_affinities) AS t(a1, b1)
|
|
WHERE a.item_id < b.item_id
|
|
GROUP BY a.item_id, b.item_id, a.cluster_id, b.cluster_id
|
|
)
|
|
SELECT
|
|
AVG(CASE WHEN similarity_type = 'within' THEN similarity END) as avg_within,
|
|
STDDEV(CASE WHEN similarity_type = 'within' THEN similarity END) as stddev_within,
|
|
MIN(CASE WHEN similarity_type = 'within' THEN similarity END) as min_within,
|
|
MAX(CASE WHEN similarity_type = 'within' THEN similarity END) as max_within,
|
|
COUNT(CASE WHEN similarity_type = 'within' THEN 1 END) as count_within,
|
|
AVG(CASE WHEN similarity_type = 'between' THEN similarity END) as avg_between,
|
|
STDDEV(CASE WHEN similarity_type = 'between' THEN similarity END) as stddev_between,
|
|
MIN(CASE WHEN similarity_type = 'between' THEN similarity END) as min_between,
|
|
MAX(CASE WHEN similarity_type = 'between' THEN similarity END) as max_between,
|
|
COUNT(CASE WHEN similarity_type = 'between' THEN 1 END) as count_between
|
|
FROM affinity_similarities",
|
|
&[],
|
|
)
|
|
.await?;
|
|
|
|
// Print cohesion statistics
|
|
info!(
|
|
"Within Cluster Similarity: avg={:.3}, stddev={:.3}, min={:.3}, max={:.3}, pairs={}",
|
|
cohesion_stats.get::<_, f64>("avg_within"),
|
|
cohesion_stats.get::<_, f64>("stddev_within"),
|
|
cohesion_stats.get::<_, f64>("min_within"),
|
|
cohesion_stats.get::<_, f64>("max_within"),
|
|
cohesion_stats.get::<_, i64>("count_within")
|
|
);
|
|
|
|
info!(
|
|
"Between Clusters Similarity: avg={:.3}, stddev={:.3}, min={:.3}, max={:.3}, pairs={}",
|
|
cohesion_stats.get::<_, f64>("avg_between"),
|
|
cohesion_stats.get::<_, f64>("stddev_between"),
|
|
cohesion_stats.get::<_, f64>("min_between"),
|
|
cohesion_stats.get::<_, f64>("max_between"),
|
|
cohesion_stats.get::<_, i64>("count_between")
|
|
);
|
|
|
|
// Print per-cluster statistics
|
|
info!("\nPer-cluster cohesion:");
|
|
let cluster_stats = client
|
|
.query(
|
|
"WITH cluster_similarities AS (
|
|
SELECT
|
|
a.cluster_id,
|
|
SUM(a1 * b1) / (SQRT(SUM(a1 * a1)) * SQRT(SUM(b1 * b1))) as similarity
|
|
FROM item_clusters a
|
|
JOIN item_clusters b ON a.cluster_id = b.cluster_id AND a.item_id < b.item_id
|
|
CROSS JOIN UNNEST(a.cluster_affinities, b.cluster_affinities) AS t(a1, b1)
|
|
GROUP BY a.cluster_id, a.item_id, b.item_id
|
|
)
|
|
SELECT
|
|
cluster_id,
|
|
AVG(similarity) as avg_similarity,
|
|
STDDEV(similarity) as stddev_similarity,
|
|
COUNT(*) as num_pairs
|
|
FROM cluster_similarities
|
|
GROUP BY cluster_id
|
|
ORDER BY cluster_id",
|
|
&[],
|
|
)
|
|
.await?;
|
|
|
|
for row in cluster_stats {
|
|
let cluster_id: i32 = row.get("cluster_id");
|
|
let avg_similarity: f64 = row.get("avg_similarity");
|
|
let stddev_similarity: f64 = row.get("stddev_similarity");
|
|
let num_pairs: i64 = row.get("num_pairs");
|
|
info!(
|
|
"Cluster {}: avg_similarity={:.3}, stddev={:.3}, pairs={}",
|
|
cluster_id, avg_similarity, stddev_similarity, num_pairs
|
|
);
|
|
}
|
|
|
|
// Calculate separation between specific cluster pairs
|
|
let query = format!(
|
|
"WITH similarities AS (
|
|
SELECT
|
|
a.cluster_id as cluster1,
|
|
b.cluster_id as cluster2,
|
|
SUM(a1 * b1) / (SQRT(SUM(a1 * a1)) * SQRT(SUM(b1 * b1))) as similarity
|
|
FROM item_clusters a
|
|
JOIN item_clusters b ON a.cluster_id < b.cluster_id
|
|
CROSS JOIN UNNEST(a.cluster_affinities, b.cluster_affinities) AS t(a1, b1)
|
|
GROUP BY a.cluster_id, b.cluster_id, a.item_id, b.item_id
|
|
)
|
|
SELECT
|
|
cluster1,
|
|
cluster2,
|
|
AVG(similarity) as avg_similarity,
|
|
STDDEV(similarity) as stddev_similarity,
|
|
COUNT(*) as num_pairs
|
|
FROM similarities
|
|
GROUP BY cluster1, cluster2
|
|
ORDER BY cluster1, cluster2",
|
|
);
|
|
|
|
info!("\nBetween-cluster separation:");
|
|
let rows = client.query(&query, &[]).await?;
|
|
for row in rows {
|
|
let cluster1: i32 = row.get(0);
|
|
let cluster2: i32 = row.get(1);
|
|
let avg: f64 = row.get(2);
|
|
let stddev: f64 = row.get(3);
|
|
let count: i64 = row.get(4);
|
|
|
|
info!(
|
|
"Clusters {} <-> {}: avg_similarity={:.3}±{:.3} ({} pairs)",
|
|
cluster1, cluster2, avg, stddev, count
|
|
);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn analyze_embedding_stats(client: &tokio_postgres::Client, args: &Args) -> Result<()> {
|
|
info!("Analyzing embedding statistics...");
|
|
|
|
// Calculate embedding norms and component statistics
|
|
let query = format!(
|
|
"WITH stats AS (
|
|
SELECT
|
|
ie.item_id,
|
|
c.cluster_id,
|
|
SQRT(SUM(x * x))::float8 as norm,
|
|
AVG(x)::float8 as avg_component,
|
|
STDDEV(x)::float8 as stddev_component,
|
|
MIN(x)::float8 as min_component,
|
|
MAX(x)::float8 as max_component
|
|
FROM {} ie
|
|
JOIN {} c ON ie.item_id = c.item_id,
|
|
UNNEST(embedding) x
|
|
GROUP BY ie.item_id, c.cluster_id
|
|
)
|
|
SELECT
|
|
cluster_id,
|
|
AVG(norm)::float8 as avg_norm,
|
|
STDDEV(norm)::float8 as stddev_norm,
|
|
AVG(avg_component)::float8 as avg_component,
|
|
AVG(stddev_component)::float8 as avg_component_spread,
|
|
COUNT(*) as num_items
|
|
FROM stats
|
|
GROUP BY cluster_id
|
|
ORDER BY cluster_id",
|
|
args.embeddings_table, args.clusters_table
|
|
);
|
|
|
|
let rows = client.query(&query, &[]).await?;
|
|
info!("Per-cluster embedding statistics:");
|
|
for row in rows {
|
|
let cluster_id: i32 = row.get(0);
|
|
let avg_norm: f64 = row.get(1);
|
|
let stddev_norm: f64 = row.get(2);
|
|
let avg_component: f64 = row.get(3);
|
|
let avg_spread: f64 = row.get(4);
|
|
let count: i64 = row.get(5);
|
|
|
|
info!(
|
|
"Cluster {}: {} items, norm={:.3}±{:.3}, components={:.3}±{:.3}",
|
|
cluster_id, count, avg_norm, stddev_norm, avg_component, avg_spread
|
|
);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn analyze_cluster_correlation(client: &tokio_postgres::Client, args: &Args) -> Result<()> {
|
|
info!("Analyzing correlation between cluster affinities and embedding similarities...");
|
|
|
|
// Calculate correlation between affinity similarities and embedding similarities
|
|
let correlation = client
|
|
.query_one(
|
|
"WITH distances AS (
|
|
SELECT
|
|
a.cluster_id as cluster1,
|
|
b.cluster_id as cluster2,
|
|
-- Compute affinity similarity
|
|
SUM(a1 * b1) / (SQRT(SUM(a1 * a1)) * SQRT(SUM(b1 * b1))) as affinity_similarity,
|
|
-- Compute embedding similarity
|
|
SUM(e1 * e2) / (SQRT(SUM(e1 * e1)) * SQRT(SUM(e2 * e2))) as embedding_similarity
|
|
FROM item_clusters a
|
|
JOIN item_clusters b ON a.cluster_id < b.cluster_id
|
|
JOIN item_embeddings ae ON a.item_id = ae.item_id
|
|
JOIN item_embeddings be ON b.item_id = be.item_id
|
|
CROSS JOIN UNNEST(a.cluster_affinities, b.cluster_affinities) AS t1(a1, b1)
|
|
CROSS JOIN UNNEST(ae.embedding, be.embedding) AS t2(e1, e2)
|
|
GROUP BY a.cluster_id, b.cluster_id, a.item_id, b.item_id
|
|
)
|
|
SELECT
|
|
corr(affinity_similarity, embedding_similarity) as correlation,
|
|
COUNT(*) as num_pairs
|
|
FROM distances",
|
|
&[],
|
|
)
|
|
.await?;
|
|
|
|
let correlation_value: f64 = correlation.get("correlation");
|
|
let num_pairs: i64 = correlation.get("num_pairs");
|
|
info!(
|
|
"Correlation between affinity similarities and embedding similarities: {:.3} ({} pairs)",
|
|
correlation_value, num_pairs
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<()> {
|
|
dotenv().ok();
|
|
pretty_env_logger::init();
|
|
let args = Args::parse();
|
|
|
|
let pool = create_pool().await?;
|
|
let client = pool.get().await?;
|
|
|
|
analyze_cluster_cohesion(&client, &args).await?;
|
|
analyze_embedding_stats(&client, &args).await?;
|
|
analyze_cluster_correlation(&client, &args).await?;
|
|
|
|
Ok(())
|
|
}
|