cluster validation
This commit is contained in:
@@ -119,10 +119,46 @@ async fn main() -> Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Create table for cluster assignments
|
||||
client
|
||||
.execute(
|
||||
"CREATE TABLE IF NOT EXISTS item_clusters (
|
||||
item_id INTEGER PRIMARY KEY,
|
||||
cluster_id INTEGER,
|
||||
cluster_center FLOAT
|
||||
)",
|
||||
&[],
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Generate cluster information
|
||||
let user_clusters = generate_cluster_info(args.user_clusters);
|
||||
let item_clusters = generate_cluster_info(args.item_clusters);
|
||||
|
||||
// Store item cluster assignments
|
||||
let mut item_to_cluster = HashMap::new();
|
||||
let mut rng = rand::thread_rng();
|
||||
let user_cluster_dist = Uniform::new(0, args.user_clusters);
|
||||
let item_cluster_dist = Uniform::new(0, args.item_clusters);
|
||||
|
||||
// Assign and store cluster information for each item
|
||||
for item_id in 0..args.num_items {
|
||||
let cluster_id = item_cluster_dist.sample(&mut rng);
|
||||
let cluster = item_clusters[cluster_id as usize];
|
||||
item_to_cluster.insert(item_id, (cluster_id, cluster));
|
||||
|
||||
client
|
||||
.execute(
|
||||
"INSERT INTO item_clusters (item_id, cluster_id, cluster_center)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (item_id) DO UPDATE
|
||||
SET cluster_id = EXCLUDED.cluster_id,
|
||||
cluster_center = EXCLUDED.cluster_center",
|
||||
&[&item_id, &cluster_id, &cluster.center],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let user_cluster_dist = Uniform::new(0, args.user_clusters);
|
||||
let item_cluster_dist = Uniform::new(0, args.item_clusters);
|
||||
|
||||
195
src/bin/validate_embeddings.rs
Normal file
195
src/bin/validate_embeddings.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
use anyhow::{Context, Result};
|
||||
use clap::Parser;
|
||||
use deadpool_postgres::{Config, Pool, Runtime};
|
||||
use dotenv::dotenv;
|
||||
use log::info;
|
||||
use std::env;
|
||||
use tokio_postgres::NoTls;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Table containing user-item interactions
|
||||
#[arg(long)]
|
||||
interactions_table: String,
|
||||
|
||||
/// Table containing item embeddings
|
||||
#[arg(long)]
|
||||
embeddings_table: String,
|
||||
|
||||
/// Table containing item cluster information
|
||||
#[arg(long)]
|
||||
clusters_table: String,
|
||||
}
|
||||
|
||||
async fn create_pool() -> Result<Pool> {
|
||||
let mut config = Config::new();
|
||||
config.host = Some(env::var("POSTGRES_HOST").context("POSTGRES_HOST not set")?);
|
||||
config.port = Some(
|
||||
env::var("POSTGRES_PORT")
|
||||
.context("POSTGRES_PORT not set")?
|
||||
.parse()
|
||||
.context("Invalid POSTGRES_PORT")?,
|
||||
);
|
||||
config.dbname = Some(env::var("POSTGRES_DB").context("POSTGRES_DB not set")?);
|
||||
config.user = Some(env::var("POSTGRES_USER").context("POSTGRES_USER not set")?);
|
||||
config.password = Some(env::var("POSTGRES_PASSWORD").context("POSTGRES_PASSWORD not set")?);
|
||||
|
||||
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
||||
}
|
||||
|
||||
async fn analyze_cluster_cohesion(client: &tokio_postgres::Client, args: &Args) -> Result<()> {
|
||||
info!("Analyzing cluster cohesion...");
|
||||
|
||||
// Calculate within-cluster and between-cluster distances
|
||||
let query = format!(
|
||||
"WITH distances AS (
|
||||
SELECT
|
||||
a.cluster_id = b.cluster_id as same_cluster,
|
||||
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance
|
||||
FROM {} a
|
||||
JOIN {} b ON a.item_id < b.item_id
|
||||
JOIN {} ie1 ON a.item_id = ie1.item_id
|
||||
JOIN {} ie2 ON b.item_id = ie2.item_id,
|
||||
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
|
||||
GROUP BY a.item_id, b.item_id, same_cluster
|
||||
)
|
||||
SELECT
|
||||
CASE WHEN same_cluster THEN 'Within Cluster' ELSE 'Between Clusters' END as comparison,
|
||||
AVG(distance)::float8 as avg_distance,
|
||||
STDDEV(distance)::float8 as stddev_distance,
|
||||
MIN(distance)::float8 as min_distance,
|
||||
MAX(distance)::float8 as max_distance,
|
||||
COUNT(*) as num_pairs
|
||||
FROM distances
|
||||
GROUP BY same_cluster
|
||||
ORDER BY same_cluster DESC",
|
||||
args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table
|
||||
);
|
||||
|
||||
let rows = client.query(&query, &[]).await?;
|
||||
for row in rows {
|
||||
let comparison: &str = row.get(0);
|
||||
let avg: f64 = row.get(1);
|
||||
let stddev: f64 = row.get(2);
|
||||
let min: f64 = row.get(3);
|
||||
let max: f64 = row.get(4);
|
||||
let count: i64 = row.get(5);
|
||||
|
||||
info!(
|
||||
"{}: avg={:.3}, stddev={:.3}, min={:.3}, max={:.3}, pairs={}",
|
||||
comparison, avg, stddev, min, max, count
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn analyze_embedding_stats(client: &tokio_postgres::Client, args: &Args) -> Result<()> {
|
||||
info!("Analyzing embedding statistics...");
|
||||
|
||||
// Calculate embedding norms and component statistics
|
||||
let query = format!(
|
||||
"WITH stats AS (
|
||||
SELECT
|
||||
ie.item_id,
|
||||
c.cluster_id,
|
||||
SQRT(SUM(x * x))::float8 as norm,
|
||||
AVG(x)::float8 as avg_component,
|
||||
STDDEV(x)::float8 as stddev_component,
|
||||
MIN(x)::float8 as min_component,
|
||||
MAX(x)::float8 as max_component
|
||||
FROM {} ie
|
||||
JOIN {} c ON ie.item_id = c.item_id,
|
||||
UNNEST(embedding) x
|
||||
GROUP BY ie.item_id, c.cluster_id
|
||||
)
|
||||
SELECT
|
||||
cluster_id,
|
||||
AVG(norm)::float8 as avg_norm,
|
||||
STDDEV(norm)::float8 as stddev_norm,
|
||||
AVG(avg_component)::float8 as avg_component,
|
||||
AVG(stddev_component)::float8 as avg_component_spread,
|
||||
COUNT(*) as num_items
|
||||
FROM stats
|
||||
GROUP BY cluster_id
|
||||
ORDER BY cluster_id",
|
||||
args.embeddings_table, args.clusters_table
|
||||
);
|
||||
|
||||
let rows = client.query(&query, &[]).await?;
|
||||
info!("Per-cluster embedding statistics:");
|
||||
for row in rows {
|
||||
let cluster_id: i32 = row.get(0);
|
||||
let avg_norm: f64 = row.get(1);
|
||||
let stddev_norm: f64 = row.get(2);
|
||||
let avg_component: f64 = row.get(3);
|
||||
let avg_spread: f64 = row.get(4);
|
||||
let count: i64 = row.get(5);
|
||||
|
||||
info!(
|
||||
"Cluster {}: {} items, norm={:.3}±{:.3}, components={:.3}±{:.3}",
|
||||
cluster_id, count, avg_norm, stddev_norm, avg_component, avg_spread
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn analyze_cluster_correlation(client: &tokio_postgres::Client, args: &Args) -> Result<()> {
|
||||
info!("Analyzing correlation between cluster centers and embedding distances...");
|
||||
|
||||
// Calculate correlation between cluster center distances and embedding distances
|
||||
let query = format!(
|
||||
"WITH distances AS (
|
||||
SELECT
|
||||
a.cluster_id as cluster1,
|
||||
b.cluster_id as cluster2,
|
||||
ABS(a.cluster_center - b.cluster_center)::float8 as center_distance,
|
||||
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as embedding_distance
|
||||
FROM {} a
|
||||
JOIN {} b ON a.item_id < b.item_id
|
||||
JOIN {} ie1 ON a.item_id = ie1.item_id
|
||||
JOIN {} ie2 ON b.item_id = ie2.item_id,
|
||||
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
|
||||
GROUP BY a.item_id, b.item_id, a.cluster_id, b.cluster_id, a.cluster_center, b.cluster_center
|
||||
)
|
||||
SELECT
|
||||
CORR(center_distance, embedding_distance)::float8 as correlation,
|
||||
COUNT(*) as num_pairs,
|
||||
AVG(center_distance)::float8 as avg_center_dist,
|
||||
AVG(embedding_distance)::float8 as avg_emb_dist
|
||||
FROM distances",
|
||||
args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table
|
||||
);
|
||||
|
||||
let row = client.query_one(&query, &[]).await?;
|
||||
info!(
|
||||
"Correlation between cluster center distances and embedding distances: {:.3} (from {} pairs)",
|
||||
row.get::<_, f64>(0),
|
||||
row.get::<_, i64>(1)
|
||||
);
|
||||
info!(
|
||||
"Average distances: cluster centers={:.3}, embeddings={:.3}",
|
||||
row.get::<_, f64>(2),
|
||||
row.get::<_, f64>(3)
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
dotenv().ok();
|
||||
pretty_env_logger::init();
|
||||
let args = Args::parse();
|
||||
|
||||
let pool = create_pool().await?;
|
||||
let client = pool.get().await?;
|
||||
|
||||
analyze_cluster_cohesion(&client, &args).await?;
|
||||
analyze_embedding_stats(&client, &args).await?;
|
||||
analyze_cluster_correlation(&client, &args).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -188,7 +188,8 @@ async fn main() -> Result<()> {
|
||||
.lambda_q2(args.lambda)
|
||||
.learning_rate(0.01)
|
||||
.iterations(100)
|
||||
.loss(libmf::Loss::BinaryL2)
|
||||
.loss(libmf::Loss::OneClassL2)
|
||||
.c(0.00001)
|
||||
.quiet(false)
|
||||
.fit(&matrix)?;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user