cluster validation
This commit is contained in:
@@ -119,10 +119,46 @@ async fn main() -> Result<()> {
|
|||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// Create table for cluster assignments
|
||||||
|
client
|
||||||
|
.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS item_clusters (
|
||||||
|
item_id INTEGER PRIMARY KEY,
|
||||||
|
cluster_id INTEGER,
|
||||||
|
cluster_center FLOAT
|
||||||
|
)",
|
||||||
|
&[],
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
// Generate cluster information
|
// Generate cluster information
|
||||||
let user_clusters = generate_cluster_info(args.user_clusters);
|
let user_clusters = generate_cluster_info(args.user_clusters);
|
||||||
let item_clusters = generate_cluster_info(args.item_clusters);
|
let item_clusters = generate_cluster_info(args.item_clusters);
|
||||||
|
|
||||||
|
// Store item cluster assignments
|
||||||
|
let mut item_to_cluster = HashMap::new();
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
let user_cluster_dist = Uniform::new(0, args.user_clusters);
|
||||||
|
let item_cluster_dist = Uniform::new(0, args.item_clusters);
|
||||||
|
|
||||||
|
// Assign and store cluster information for each item
|
||||||
|
for item_id in 0..args.num_items {
|
||||||
|
let cluster_id = item_cluster_dist.sample(&mut rng);
|
||||||
|
let cluster = item_clusters[cluster_id as usize];
|
||||||
|
item_to_cluster.insert(item_id, (cluster_id, cluster));
|
||||||
|
|
||||||
|
client
|
||||||
|
.execute(
|
||||||
|
"INSERT INTO item_clusters (item_id, cluster_id, cluster_center)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
ON CONFLICT (item_id) DO UPDATE
|
||||||
|
SET cluster_id = EXCLUDED.cluster_id,
|
||||||
|
cluster_center = EXCLUDED.cluster_center",
|
||||||
|
&[&item_id, &cluster_id, &cluster.center],
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::thread_rng();
|
||||||
let user_cluster_dist = Uniform::new(0, args.user_clusters);
|
let user_cluster_dist = Uniform::new(0, args.user_clusters);
|
||||||
let item_cluster_dist = Uniform::new(0, args.item_clusters);
|
let item_cluster_dist = Uniform::new(0, args.item_clusters);
|
||||||
|
|||||||
195
src/bin/validate_embeddings.rs
Normal file
195
src/bin/validate_embeddings.rs
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
use anyhow::{Context, Result};
|
||||||
|
use clap::Parser;
|
||||||
|
use deadpool_postgres::{Config, Pool, Runtime};
|
||||||
|
use dotenv::dotenv;
|
||||||
|
use log::info;
|
||||||
|
use std::env;
|
||||||
|
use tokio_postgres::NoTls;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Table containing user-item interactions
|
||||||
|
#[arg(long)]
|
||||||
|
interactions_table: String,
|
||||||
|
|
||||||
|
/// Table containing item embeddings
|
||||||
|
#[arg(long)]
|
||||||
|
embeddings_table: String,
|
||||||
|
|
||||||
|
/// Table containing item cluster information
|
||||||
|
#[arg(long)]
|
||||||
|
clusters_table: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_pool() -> Result<Pool> {
|
||||||
|
let mut config = Config::new();
|
||||||
|
config.host = Some(env::var("POSTGRES_HOST").context("POSTGRES_HOST not set")?);
|
||||||
|
config.port = Some(
|
||||||
|
env::var("POSTGRES_PORT")
|
||||||
|
.context("POSTGRES_PORT not set")?
|
||||||
|
.parse()
|
||||||
|
.context("Invalid POSTGRES_PORT")?,
|
||||||
|
);
|
||||||
|
config.dbname = Some(env::var("POSTGRES_DB").context("POSTGRES_DB not set")?);
|
||||||
|
config.user = Some(env::var("POSTGRES_USER").context("POSTGRES_USER not set")?);
|
||||||
|
config.password = Some(env::var("POSTGRES_PASSWORD").context("POSTGRES_PASSWORD not set")?);
|
||||||
|
|
||||||
|
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn analyze_cluster_cohesion(client: &tokio_postgres::Client, args: &Args) -> Result<()> {
|
||||||
|
info!("Analyzing cluster cohesion...");
|
||||||
|
|
||||||
|
// Calculate within-cluster and between-cluster distances
|
||||||
|
let query = format!(
|
||||||
|
"WITH distances AS (
|
||||||
|
SELECT
|
||||||
|
a.cluster_id = b.cluster_id as same_cluster,
|
||||||
|
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance
|
||||||
|
FROM {} a
|
||||||
|
JOIN {} b ON a.item_id < b.item_id
|
||||||
|
JOIN {} ie1 ON a.item_id = ie1.item_id
|
||||||
|
JOIN {} ie2 ON b.item_id = ie2.item_id,
|
||||||
|
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
|
||||||
|
GROUP BY a.item_id, b.item_id, same_cluster
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
CASE WHEN same_cluster THEN 'Within Cluster' ELSE 'Between Clusters' END as comparison,
|
||||||
|
AVG(distance)::float8 as avg_distance,
|
||||||
|
STDDEV(distance)::float8 as stddev_distance,
|
||||||
|
MIN(distance)::float8 as min_distance,
|
||||||
|
MAX(distance)::float8 as max_distance,
|
||||||
|
COUNT(*) as num_pairs
|
||||||
|
FROM distances
|
||||||
|
GROUP BY same_cluster
|
||||||
|
ORDER BY same_cluster DESC",
|
||||||
|
args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table
|
||||||
|
);
|
||||||
|
|
||||||
|
let rows = client.query(&query, &[]).await?;
|
||||||
|
for row in rows {
|
||||||
|
let comparison: &str = row.get(0);
|
||||||
|
let avg: f64 = row.get(1);
|
||||||
|
let stddev: f64 = row.get(2);
|
||||||
|
let min: f64 = row.get(3);
|
||||||
|
let max: f64 = row.get(4);
|
||||||
|
let count: i64 = row.get(5);
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"{}: avg={:.3}, stddev={:.3}, min={:.3}, max={:.3}, pairs={}",
|
||||||
|
comparison, avg, stddev, min, max, count
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn analyze_embedding_stats(client: &tokio_postgres::Client, args: &Args) -> Result<()> {
|
||||||
|
info!("Analyzing embedding statistics...");
|
||||||
|
|
||||||
|
// Calculate embedding norms and component statistics
|
||||||
|
let query = format!(
|
||||||
|
"WITH stats AS (
|
||||||
|
SELECT
|
||||||
|
ie.item_id,
|
||||||
|
c.cluster_id,
|
||||||
|
SQRT(SUM(x * x))::float8 as norm,
|
||||||
|
AVG(x)::float8 as avg_component,
|
||||||
|
STDDEV(x)::float8 as stddev_component,
|
||||||
|
MIN(x)::float8 as min_component,
|
||||||
|
MAX(x)::float8 as max_component
|
||||||
|
FROM {} ie
|
||||||
|
JOIN {} c ON ie.item_id = c.item_id,
|
||||||
|
UNNEST(embedding) x
|
||||||
|
GROUP BY ie.item_id, c.cluster_id
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
cluster_id,
|
||||||
|
AVG(norm)::float8 as avg_norm,
|
||||||
|
STDDEV(norm)::float8 as stddev_norm,
|
||||||
|
AVG(avg_component)::float8 as avg_component,
|
||||||
|
AVG(stddev_component)::float8 as avg_component_spread,
|
||||||
|
COUNT(*) as num_items
|
||||||
|
FROM stats
|
||||||
|
GROUP BY cluster_id
|
||||||
|
ORDER BY cluster_id",
|
||||||
|
args.embeddings_table, args.clusters_table
|
||||||
|
);
|
||||||
|
|
||||||
|
let rows = client.query(&query, &[]).await?;
|
||||||
|
info!("Per-cluster embedding statistics:");
|
||||||
|
for row in rows {
|
||||||
|
let cluster_id: i32 = row.get(0);
|
||||||
|
let avg_norm: f64 = row.get(1);
|
||||||
|
let stddev_norm: f64 = row.get(2);
|
||||||
|
let avg_component: f64 = row.get(3);
|
||||||
|
let avg_spread: f64 = row.get(4);
|
||||||
|
let count: i64 = row.get(5);
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Cluster {}: {} items, norm={:.3}±{:.3}, components={:.3}±{:.3}",
|
||||||
|
cluster_id, count, avg_norm, stddev_norm, avg_component, avg_spread
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn analyze_cluster_correlation(client: &tokio_postgres::Client, args: &Args) -> Result<()> {
|
||||||
|
info!("Analyzing correlation between cluster centers and embedding distances...");
|
||||||
|
|
||||||
|
// Calculate correlation between cluster center distances and embedding distances
|
||||||
|
let query = format!(
|
||||||
|
"WITH distances AS (
|
||||||
|
SELECT
|
||||||
|
a.cluster_id as cluster1,
|
||||||
|
b.cluster_id as cluster2,
|
||||||
|
ABS(a.cluster_center - b.cluster_center)::float8 as center_distance,
|
||||||
|
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as embedding_distance
|
||||||
|
FROM {} a
|
||||||
|
JOIN {} b ON a.item_id < b.item_id
|
||||||
|
JOIN {} ie1 ON a.item_id = ie1.item_id
|
||||||
|
JOIN {} ie2 ON b.item_id = ie2.item_id,
|
||||||
|
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
|
||||||
|
GROUP BY a.item_id, b.item_id, a.cluster_id, b.cluster_id, a.cluster_center, b.cluster_center
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
CORR(center_distance, embedding_distance)::float8 as correlation,
|
||||||
|
COUNT(*) as num_pairs,
|
||||||
|
AVG(center_distance)::float8 as avg_center_dist,
|
||||||
|
AVG(embedding_distance)::float8 as avg_emb_dist
|
||||||
|
FROM distances",
|
||||||
|
args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table
|
||||||
|
);
|
||||||
|
|
||||||
|
let row = client.query_one(&query, &[]).await?;
|
||||||
|
info!(
|
||||||
|
"Correlation between cluster center distances and embedding distances: {:.3} (from {} pairs)",
|
||||||
|
row.get::<_, f64>(0),
|
||||||
|
row.get::<_, i64>(1)
|
||||||
|
);
|
||||||
|
info!(
|
||||||
|
"Average distances: cluster centers={:.3}, embeddings={:.3}",
|
||||||
|
row.get::<_, f64>(2),
|
||||||
|
row.get::<_, f64>(3)
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
dotenv().ok();
|
||||||
|
pretty_env_logger::init();
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let pool = create_pool().await?;
|
||||||
|
let client = pool.get().await?;
|
||||||
|
|
||||||
|
analyze_cluster_cohesion(&client, &args).await?;
|
||||||
|
analyze_embedding_stats(&client, &args).await?;
|
||||||
|
analyze_cluster_correlation(&client, &args).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -188,7 +188,8 @@ async fn main() -> Result<()> {
|
|||||||
.lambda_q2(args.lambda)
|
.lambda_q2(args.lambda)
|
||||||
.learning_rate(0.01)
|
.learning_rate(0.01)
|
||||||
.iterations(100)
|
.iterations(100)
|
||||||
.loss(libmf::Loss::BinaryL2)
|
.loss(libmf::Loss::OneClassL2)
|
||||||
|
.c(0.00001)
|
||||||
.quiet(false)
|
.quiet(false)
|
||||||
.fit(&matrix)?;
|
.fit(&matrix)?;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user