more fixes
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -1183,6 +1183,7 @@ dependencies = [
|
||||
"log",
|
||||
"ndarray",
|
||||
"ndarray-linalg",
|
||||
"num_cpus",
|
||||
"plotly",
|
||||
"pretty_env_logger",
|
||||
"rand",
|
||||
|
||||
@@ -13,6 +13,7 @@ libmf = "0.3"
|
||||
log = "0.4"
|
||||
ndarray = { version = "0.15", features = ["blas"] }
|
||||
ndarray-linalg = { version = "0.16", features = ["openblas-system"] }
|
||||
num_cpus = "1.16"
|
||||
plotly = { version = "0.8", features = ["kaleido"] }
|
||||
pretty_env_logger = "0.5"
|
||||
rand = "0.8"
|
||||
|
||||
@@ -2,11 +2,9 @@ use anyhow::{Context, Result};
|
||||
use clap::Parser;
|
||||
use deadpool_postgres::{Config, Pool, Runtime};
|
||||
use dotenv::dotenv;
|
||||
use log::{info, warn};
|
||||
use rand::distributions::{Distribution, Uniform};
|
||||
use log::info;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::Rng;
|
||||
use rand_distr::Normal;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::env;
|
||||
use tokio_postgres::NoTls;
|
||||
@@ -101,26 +99,33 @@ async fn main() -> Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Create table for cluster assignments
|
||||
client
|
||||
.execute(
|
||||
"CREATE TABLE IF NOT EXISTS item_clusters (
|
||||
item_id INTEGER PRIMARY KEY,
|
||||
cluster_id INTEGER,
|
||||
cluster_center FLOAT
|
||||
)",
|
||||
&[],
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Generate cluster centers that are well-separated
|
||||
// Generate cluster centers that are well-separated in 3D space
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut cluster_centers = Vec::new();
|
||||
for i in 0..args.item_clusters {
|
||||
// Place centers evenly around a circle in 2D space, then project to 1D
|
||||
let angle = 2.0 * std::f64::consts::PI * (i as f64) / (args.item_clusters as f64);
|
||||
let center = angle.cos(); // Project to 1D by taking cosine
|
||||
cluster_centers.push(center);
|
||||
|
||||
// Base vertices of a regular octahedron
|
||||
let base_centers = vec![
|
||||
vec![1.0, 0.0, 0.0], // +x
|
||||
vec![-1.0, 0.0, 0.0], // -x
|
||||
vec![0.0, 1.0, 0.0], // +y
|
||||
vec![0.0, -1.0, 0.0], // -y
|
||||
vec![0.0, 0.0, 1.0], // +z
|
||||
vec![0.0, 0.0, -1.0], // -z
|
||||
];
|
||||
|
||||
// Scale factor to control separation
|
||||
let scale = 2.0;
|
||||
|
||||
// Add jittered versions of the base centers
|
||||
for i in 0..args.item_clusters as usize {
|
||||
let base = &base_centers[i % base_centers.len()];
|
||||
// Add controlled random jitter (up to 10% of the scale)
|
||||
let jitter = 0.1;
|
||||
let x = base[0] * scale + (rng.gen::<f64>() - 0.5) * jitter * scale;
|
||||
let y = base[1] * scale + (rng.gen::<f64>() - 0.5) * jitter * scale;
|
||||
let z = base[2] * scale + (rng.gen::<f64>() - 0.5) * jitter * scale;
|
||||
|
||||
cluster_centers.push(vec![x, y, z]);
|
||||
}
|
||||
|
||||
// Create shuffled list of items
|
||||
@@ -130,35 +135,54 @@ async fn main() -> Result<()> {
|
||||
// Assign items to clusters in contiguous blocks
|
||||
let mut items_per_cluster = vec![Vec::new(); args.item_clusters as usize];
|
||||
let items_per_cluster_count = args.num_items / args.item_clusters;
|
||||
|
||||
// Drop existing item_clusters table and recreate with 3D centers
|
||||
client
|
||||
.execute("DROP TABLE IF EXISTS item_clusters", &[])
|
||||
.await?;
|
||||
client
|
||||
.execute(
|
||||
"CREATE TABLE item_clusters (
|
||||
item_id INTEGER PRIMARY KEY,
|
||||
cluster_id INTEGER,
|
||||
center_x FLOAT,
|
||||
center_y FLOAT,
|
||||
center_z FLOAT
|
||||
)",
|
||||
&[],
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Clear existing interactions
|
||||
client
|
||||
.execute(&format!("TRUNCATE TABLE {}", args.interactions_table), &[])
|
||||
.await?;
|
||||
|
||||
for (i, &item_id) in all_items.iter().enumerate() {
|
||||
let cluster_id = (i as i32) / items_per_cluster_count;
|
||||
if cluster_id < args.item_clusters {
|
||||
items_per_cluster[cluster_id as usize].push(item_id);
|
||||
|
||||
// Store cluster assignment with meaningful center
|
||||
// Store cluster assignment with 3D center
|
||||
let center = &cluster_centers[cluster_id as usize];
|
||||
client
|
||||
.execute(
|
||||
"INSERT INTO item_clusters (item_id, cluster_id, cluster_center)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (item_id) DO UPDATE
|
||||
SET cluster_id = EXCLUDED.cluster_id,
|
||||
cluster_center = EXCLUDED.cluster_center",
|
||||
&[&item_id, &cluster_id, &cluster_centers[cluster_id as usize]],
|
||||
"INSERT INTO item_clusters (item_id, cluster_id, center_x, center_y, center_z)
|
||||
VALUES ($1, $2, $3, $4, $5)",
|
||||
&[&item_id, &cluster_id, ¢er[0], ¢er[1], ¢er[2]],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
// Generate interactions with strong cluster affinity
|
||||
// Generate interactions with very strong cluster affinity
|
||||
info!("Generating interactions...");
|
||||
let mut item_interactions = HashMap::new();
|
||||
|
||||
// For each user cluster
|
||||
for user_cluster in 0..args.user_clusters {
|
||||
// Each user cluster strongly prefers 1-2 item clusters
|
||||
let preferred_item_clusters: HashSet<i32> = (0..args.item_clusters)
|
||||
.filter(|&i| i % args.user_clusters == user_cluster)
|
||||
.collect();
|
||||
// Each user cluster strongly prefers exactly one item cluster with wraparound
|
||||
let preferred_item_cluster = user_cluster % args.item_clusters;
|
||||
|
||||
// For each user in this cluster
|
||||
let cluster_start = user_cluster * (args.num_users / args.user_clusters);
|
||||
@@ -169,13 +193,13 @@ async fn main() -> Result<()> {
|
||||
};
|
||||
|
||||
for user_id in cluster_start..cluster_end {
|
||||
// Get all items from preferred clusters
|
||||
let mut preferred_items = HashSet::new();
|
||||
for &item_cluster in &preferred_item_clusters {
|
||||
preferred_items.extend(items_per_cluster[item_cluster as usize].iter());
|
||||
}
|
||||
// Get all items from preferred cluster
|
||||
let preferred_items: HashSet<_> = items_per_cluster[preferred_item_cluster as usize]
|
||||
.iter()
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
// Interact with most preferred items (1 - noise_level of items in preferred clusters)
|
||||
// Interact with most preferred items (1 - noise_level of items in preferred cluster)
|
||||
let num_interactions =
|
||||
(preferred_items.len() as f64 * (1.0 - args.noise_level)) as usize;
|
||||
let selected_items: Vec<_> = preferred_items.into_iter().collect();
|
||||
@@ -184,7 +208,7 @@ async fn main() -> Result<()> {
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
// Add random items from non-preferred clusters (noise)
|
||||
// Add very few random items from non-preferred clusters (noise)
|
||||
let num_noise = (args.avg_interactions as f64 * args.noise_level) as i32;
|
||||
while interactions.len() < num_interactions + num_noise as usize {
|
||||
let item_id = rng.gen_range(0..args.num_items);
|
||||
|
||||
@@ -41,85 +41,96 @@ async fn create_pool() -> Result<Pool> {
|
||||
async fn analyze_cluster_cohesion(client: &tokio_postgres::Client, args: &Args) -> Result<()> {
|
||||
info!("Analyzing cluster cohesion...");
|
||||
|
||||
// Calculate within-cluster and between-cluster distances
|
||||
let query = format!(
|
||||
"WITH distances AS (
|
||||
// Analyze cluster cohesion
|
||||
info!("Analyzing cluster cohesion...");
|
||||
let cohesion_stats = client
|
||||
.query_one(
|
||||
"WITH embedding_distances AS (
|
||||
SELECT
|
||||
a.cluster_id = b.cluster_id as same_cluster,
|
||||
a.item_id as item1,
|
||||
b.item_id as item2,
|
||||
a.cluster_id as cluster1,
|
||||
b.cluster_id as cluster2,
|
||||
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance
|
||||
FROM {} a
|
||||
JOIN {} b ON a.item_id < b.item_id
|
||||
JOIN {} ie1 ON a.item_id = ie1.item_id
|
||||
JOIN {} ie2 ON b.item_id = ie2.item_id,
|
||||
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
|
||||
GROUP BY a.item_id, b.item_id, same_cluster, a.cluster_id, b.cluster_id
|
||||
SQRT(
|
||||
POW(a.center_x - b.center_x, 2) +
|
||||
POW(a.center_y - b.center_y, 2) +
|
||||
POW(a.center_z - b.center_z, 2)
|
||||
) as distance,
|
||||
CASE WHEN a.cluster_id = b.cluster_id THEN 'within' ELSE 'between' END as distance_type
|
||||
FROM item_clusters a
|
||||
CROSS JOIN item_clusters b
|
||||
WHERE a.item_id < b.item_id
|
||||
)
|
||||
SELECT
|
||||
CASE WHEN same_cluster THEN 'Within Cluster' ELSE 'Between Clusters' END as comparison,
|
||||
AVG(distance)::float8 as avg_distance,
|
||||
STDDEV(distance)::float8 as stddev_distance,
|
||||
MIN(distance)::float8 as min_distance,
|
||||
MAX(distance)::float8 as max_distance,
|
||||
COUNT(*) as num_pairs
|
||||
FROM distances
|
||||
GROUP BY same_cluster
|
||||
ORDER BY same_cluster DESC",
|
||||
args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table
|
||||
);
|
||||
AVG(CASE WHEN distance_type = 'within' THEN distance END) as avg_within,
|
||||
STDDEV(CASE WHEN distance_type = 'within' THEN distance END) as stddev_within,
|
||||
MIN(CASE WHEN distance_type = 'within' THEN distance END) as min_within,
|
||||
MAX(CASE WHEN distance_type = 'within' THEN distance END) as max_within,
|
||||
COUNT(CASE WHEN distance_type = 'within' THEN 1 END) as count_within,
|
||||
AVG(CASE WHEN distance_type = 'between' THEN distance END) as avg_between,
|
||||
STDDEV(CASE WHEN distance_type = 'between' THEN distance END) as stddev_between,
|
||||
MIN(CASE WHEN distance_type = 'between' THEN distance END) as min_between,
|
||||
MAX(CASE WHEN distance_type = 'between' THEN distance END) as max_between,
|
||||
COUNT(CASE WHEN distance_type = 'between' THEN 1 END) as count_between
|
||||
FROM embedding_distances",
|
||||
&[],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let rows = client.query(&query, &[]).await?;
|
||||
for row in rows {
|
||||
let comparison: &str = row.get(0);
|
||||
let avg: f64 = row.get(1);
|
||||
let stddev: f64 = row.get(2);
|
||||
let min: f64 = row.get(3);
|
||||
let max: f64 = row.get(4);
|
||||
let count: i64 = row.get(5);
|
||||
// Print cohesion statistics
|
||||
info!(
|
||||
"Within Cluster: avg={:.3}, stddev={:.3}, min={:.3}, max={:.3}, pairs={}",
|
||||
cohesion_stats.get::<_, f64>("avg_within"),
|
||||
cohesion_stats.get::<_, f64>("stddev_within"),
|
||||
cohesion_stats.get::<_, f64>("min_within"),
|
||||
cohesion_stats.get::<_, f64>("max_within"),
|
||||
cohesion_stats.get::<_, i64>("count_within")
|
||||
);
|
||||
|
||||
info!(
|
||||
"{}: avg={:.3}, stddev={:.3}, min={:.3}, max={:.3}, pairs={}",
|
||||
comparison, avg, stddev, min, max, count
|
||||
"Between Clusters: avg={:.3}, stddev={:.3}, min={:.3}, max={:.3}, pairs={}",
|
||||
cohesion_stats.get::<_, f64>("avg_between"),
|
||||
cohesion_stats.get::<_, f64>("stddev_between"),
|
||||
cohesion_stats.get::<_, f64>("min_between"),
|
||||
cohesion_stats.get::<_, f64>("max_between"),
|
||||
cohesion_stats.get::<_, i64>("count_between")
|
||||
);
|
||||
}
|
||||
|
||||
// Calculate per-cluster statistics
|
||||
let query = format!(
|
||||
"WITH distances AS (
|
||||
// Print per-cluster statistics
|
||||
info!("\nPer-cluster cohesion:");
|
||||
let cluster_stats = client
|
||||
.query(
|
||||
"WITH cluster_distances AS (
|
||||
SELECT
|
||||
a.cluster_id,
|
||||
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as distance
|
||||
FROM {} a
|
||||
JOIN {} b ON a.item_id < b.item_id
|
||||
JOIN {} ie1 ON a.item_id = ie1.item_id
|
||||
JOIN {} ie2 ON b.item_id = ie2.item_id,
|
||||
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
|
||||
WHERE a.cluster_id = b.cluster_id
|
||||
GROUP BY a.item_id, b.item_id, a.cluster_id
|
||||
SQRT(
|
||||
POW(a.center_x - b.center_x, 2) +
|
||||
POW(a.center_y - b.center_y, 2) +
|
||||
POW(a.center_z - b.center_z, 2)
|
||||
) as distance
|
||||
FROM item_clusters a
|
||||
JOIN item_clusters b ON a.cluster_id = b.cluster_id AND a.item_id < b.item_id
|
||||
)
|
||||
SELECT
|
||||
cluster_id,
|
||||
AVG(distance)::float8 as avg_distance,
|
||||
STDDEV(distance)::float8 as stddev_distance,
|
||||
AVG(distance) as avg_distance,
|
||||
STDDEV(distance) as stddev_distance,
|
||||
COUNT(*) as num_pairs
|
||||
FROM distances
|
||||
FROM cluster_distances
|
||||
GROUP BY cluster_id
|
||||
ORDER BY cluster_id",
|
||||
args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table
|
||||
);
|
||||
|
||||
info!("\nPer-cluster cohesion:");
|
||||
let rows = client.query(&query, &[]).await?;
|
||||
for row in rows {
|
||||
let cluster: i32 = row.get(0);
|
||||
let avg: f64 = row.get(1);
|
||||
let stddev: f64 = row.get(2);
|
||||
let count: i64 = row.get(3);
|
||||
&[],
|
||||
)
|
||||
.await?;
|
||||
|
||||
for row in cluster_stats {
|
||||
let cluster_id: i32 = row.get("cluster_id");
|
||||
let avg_distance: f64 = row.get("avg_distance");
|
||||
let stddev_distance: f64 = row.get("stddev_distance");
|
||||
let num_pairs: i64 = row.get("num_pairs");
|
||||
info!(
|
||||
"Cluster {}: avg_distance={:.3}±{:.3} ({} pairs)",
|
||||
cluster, avg, stddev, count
|
||||
"Cluster {}: avg={:.3}, stddev={:.3}, pairs={}",
|
||||
cluster_id, avg_distance, stddev_distance, num_pairs
|
||||
);
|
||||
}
|
||||
|
||||
@@ -223,39 +234,40 @@ async fn analyze_cluster_correlation(client: &tokio_postgres::Client, args: &Arg
|
||||
info!("Analyzing correlation between cluster centers and embedding distances...");
|
||||
|
||||
// Calculate correlation between cluster center distances and embedding distances
|
||||
let query = format!(
|
||||
let correlation = client
|
||||
.query_one(
|
||||
"WITH distances AS (
|
||||
SELECT
|
||||
a.cluster_id as cluster1,
|
||||
b.cluster_id as cluster2,
|
||||
ABS(a.cluster_center - b.cluster_center)::float8 as center_distance,
|
||||
SQRT(SUM((e1 - e2) * (e1 - e2)))::float8 as embedding_distance
|
||||
FROM {} a
|
||||
JOIN {} b ON a.item_id < b.item_id
|
||||
JOIN {} ie1 ON a.item_id = ie1.item_id
|
||||
JOIN {} ie2 ON b.item_id = ie2.item_id,
|
||||
UNNEST(ie1.embedding, ie2.embedding) AS t(e1, e2)
|
||||
GROUP BY a.item_id, b.item_id, a.cluster_id, b.cluster_id, a.cluster_center, b.cluster_center
|
||||
SQRT(
|
||||
POW(a.center_x - b.center_x, 2) +
|
||||
POW(a.center_y - b.center_y, 2) +
|
||||
POW(a.center_z - b.center_z, 2)
|
||||
) as center_distance,
|
||||
SQRT(
|
||||
POW(ae.embedding[1] - be.embedding[1], 2) +
|
||||
POW(ae.embedding[2] - be.embedding[2], 2) +
|
||||
POW(ae.embedding[3] - be.embedding[3], 2)
|
||||
) as embedding_distance
|
||||
FROM item_clusters a
|
||||
JOIN item_clusters b ON a.cluster_id < b.cluster_id
|
||||
JOIN item_embeddings ae ON a.item_id = ae.item_id
|
||||
JOIN item_embeddings be ON b.item_id = be.item_id
|
||||
)
|
||||
SELECT
|
||||
CORR(center_distance, embedding_distance)::float8 as correlation,
|
||||
COUNT(*) as num_pairs,
|
||||
AVG(center_distance)::float8 as avg_center_dist,
|
||||
AVG(embedding_distance)::float8 as avg_emb_dist
|
||||
corr(center_distance, embedding_distance) as correlation,
|
||||
COUNT(*) as num_pairs
|
||||
FROM distances",
|
||||
args.clusters_table, args.clusters_table, args.embeddings_table, args.embeddings_table
|
||||
);
|
||||
&[],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let row = client.query_one(&query, &[]).await?;
|
||||
let correlation_value: f64 = correlation.get("correlation");
|
||||
let num_pairs: i64 = correlation.get("num_pairs");
|
||||
info!(
|
||||
"Correlation between cluster center distances and embedding distances: {:.3} (from {} pairs)",
|
||||
row.get::<_, f64>(0),
|
||||
row.get::<_, i64>(1)
|
||||
);
|
||||
info!(
|
||||
"Average distances: cluster centers={:.3}, embeddings={:.3}",
|
||||
row.get::<_, f64>(2),
|
||||
row.get::<_, f64>(3)
|
||||
"Correlation between cluster center distances and embedding distances: {:.3} ({} pairs)",
|
||||
correlation_value, num_pairs
|
||||
);
|
||||
|
||||
Ok(())
|
||||
|
||||
152
src/main.rs
152
src/main.rs
@@ -2,8 +2,9 @@ use anyhow::{Context, Result};
|
||||
use clap::Parser;
|
||||
use deadpool_postgres::{Config, Pool, Runtime};
|
||||
use dotenv::dotenv;
|
||||
use libmf::Matrix;
|
||||
use log::{info, warn};
|
||||
use libmf::{Loss, Matrix, Model};
|
||||
use log::info;
|
||||
use num_cpus;
|
||||
use std::env;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
@@ -39,6 +40,10 @@ struct Args {
|
||||
/// Lambda for regularization
|
||||
#[arg(long, default_value = "0.1")]
|
||||
lambda: f32,
|
||||
|
||||
/// Number of threads for matrix factorization (defaults to number of CPU cores)
|
||||
#[arg(long, default_value_t = num_cpus::get() as i32)]
|
||||
threads: i32,
|
||||
}
|
||||
|
||||
async fn create_pool() -> Result<Pool> {
|
||||
@@ -57,33 +62,51 @@ async fn create_pool() -> Result<Pool> {
|
||||
Ok(config.create_pool(Some(Runtime::Tokio1), NoTls)?)
|
||||
}
|
||||
|
||||
async fn load_data_batch(pool: &Pool, args: &Args, offset: i64) -> Result<Vec<(i32, i32, f32)>> {
|
||||
let client = pool.get().await?;
|
||||
async fn load_data_batch(
|
||||
client: &deadpool_postgres::Client,
|
||||
source_table: &str,
|
||||
user_id_column: &str,
|
||||
item_id_column: &str,
|
||||
batch_size: usize,
|
||||
last_user_id: Option<i32>,
|
||||
last_item_id: Option<i32>,
|
||||
) -> Result<Vec<(i32, i32)>> {
|
||||
let rows = if let (Some(last_user), Some(last_item)) = (last_user_id, last_item_id) {
|
||||
let query = format!(
|
||||
"SELECT {}, {} FROM {} OFFSET $1 LIMIT $2",
|
||||
args.user_id_column, args.item_id_column, args.source_table
|
||||
"SELECT {user}, {item} FROM {table} \
|
||||
WHERE ({user}, {item}) > ($1, $2) \
|
||||
ORDER BY {user}, {item} \
|
||||
LIMIT $3",
|
||||
user = user_id_column,
|
||||
item = item_id_column,
|
||||
table = source_table,
|
||||
);
|
||||
client
|
||||
.query(&query, &[&last_user, &last_item, &(batch_size as i64)])
|
||||
.await?
|
||||
} else {
|
||||
let query = format!(
|
||||
"SELECT {user}, {item} FROM {table} \
|
||||
ORDER BY {user}, {item} \
|
||||
LIMIT $1",
|
||||
user = user_id_column,
|
||||
item = item_id_column,
|
||||
table = source_table,
|
||||
);
|
||||
client.query(&query, &[&(batch_size as i64)]).await?
|
||||
};
|
||||
|
||||
let rows = client
|
||||
.query(&query, &[&offset, &(args.batch_size as i64)])
|
||||
.await?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
let mut batch = Vec::with_capacity(rows.len());
|
||||
for row in rows {
|
||||
let user_id: i32 = row.get(0);
|
||||
let item_id: i32 = row.get(1);
|
||||
(user_id, item_id, 1.0) // Using 1.0 as interaction strength
|
||||
})
|
||||
.collect())
|
||||
batch.push((user_id, item_id));
|
||||
}
|
||||
|
||||
async fn save_embeddings(
|
||||
pool: &Pool,
|
||||
args: &Args,
|
||||
model: &libmf::Model,
|
||||
item_ids: &[i32],
|
||||
) -> Result<()> {
|
||||
Ok(batch)
|
||||
}
|
||||
|
||||
async fn save_embeddings(pool: &Pool, args: &Args, model: &Model, item_ids: &[i32]) -> Result<()> {
|
||||
let client = pool.get().await?;
|
||||
|
||||
// Create the target table if it doesn't exist
|
||||
@@ -128,6 +151,19 @@ async fn save_embeddings(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_unique_item_ids(
|
||||
client: &deadpool_postgres::Client,
|
||||
source_table: &str,
|
||||
item_id_column: &str,
|
||||
) -> Result<Vec<i32>> {
|
||||
let query = format!(
|
||||
"SELECT DISTINCT {} FROM {} ORDER BY {}",
|
||||
item_id_column, source_table, item_id_column
|
||||
);
|
||||
let rows = client.query(&query, &[]).await?;
|
||||
Ok(rows.iter().map(|row| row.get(0)).collect())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
dotenv().ok();
|
||||
@@ -143,59 +179,75 @@ async fn main() -> Result<()> {
|
||||
})?;
|
||||
|
||||
let pool = create_pool().await?;
|
||||
let mut offset = 0i64;
|
||||
let mut all_data = Vec::new();
|
||||
let mut unique_item_ids = Vec::new();
|
||||
let mut matrix = Matrix::new();
|
||||
let mut last_user_id = None;
|
||||
let mut last_item_id = None;
|
||||
let mut total_rows = 0;
|
||||
|
||||
info!("Starting data loading...");
|
||||
while running.load(Ordering::SeqCst) {
|
||||
let batch = load_data_batch(&pool, &args, offset).await?;
|
||||
let batch = load_data_batch(
|
||||
&pool.get().await?,
|
||||
&args.source_table,
|
||||
&args.user_id_column,
|
||||
&args.item_id_column,
|
||||
args.batch_size as usize,
|
||||
last_user_id,
|
||||
last_item_id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if batch.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Track unique item IDs
|
||||
for &(_, item_id, _) in &batch {
|
||||
if !unique_item_ids.contains(&item_id) {
|
||||
unique_item_ids.push(item_id);
|
||||
total_rows += batch.len();
|
||||
info!(
|
||||
"Loaded batch of {} rows (total: {})",
|
||||
batch.len(),
|
||||
total_rows
|
||||
);
|
||||
|
||||
// Update last seen IDs for next batch
|
||||
if let Some((user_id, item_id)) = batch.last() {
|
||||
last_user_id = Some(*user_id);
|
||||
last_item_id = Some(*item_id);
|
||||
}
|
||||
|
||||
// Process batch
|
||||
for (user_id, item_id) in batch {
|
||||
matrix.push(user_id as i32, item_id as i32, 1.0f32);
|
||||
}
|
||||
}
|
||||
|
||||
all_data.extend(batch);
|
||||
offset += args.batch_size as i64;
|
||||
info!("Loaded {} rows so far", all_data.len());
|
||||
}
|
||||
info!("Loaded {} total rows", total_rows);
|
||||
|
||||
if all_data.is_empty() {
|
||||
warn!("No data was loaded. Exiting.");
|
||||
if total_rows == 0 {
|
||||
info!("No data found in source table");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!("Creating libmf problem...");
|
||||
let mut matrix = Matrix::new();
|
||||
for (user_id, item_id, value) in all_data {
|
||||
matrix.push(
|
||||
user_id.try_into().unwrap(),
|
||||
item_id.try_into().unwrap(),
|
||||
value,
|
||||
);
|
||||
}
|
||||
// Get unique item IDs from database
|
||||
let client = pool.get().await?;
|
||||
let unique_item_ids =
|
||||
get_unique_item_ids(&client, &args.source_table, &args.item_id_column).await?;
|
||||
info!("Found {} unique items", unique_item_ids.len());
|
||||
|
||||
info!("Training model with {} factors...", args.factors);
|
||||
let model = libmf::Model::params()
|
||||
.factors(args.factors.try_into().unwrap())
|
||||
// Set up training parameters
|
||||
let model = Model::params()
|
||||
.factors(args.factors as i32)
|
||||
.lambda_p2(args.lambda)
|
||||
.lambda_q2(args.lambda)
|
||||
.learning_rate(0.01)
|
||||
.iterations(100)
|
||||
.loss(libmf::Loss::OneClassL2)
|
||||
.loss(Loss::OneClassL2)
|
||||
.c(0.00001)
|
||||
.quiet(false)
|
||||
.threads(args.threads)
|
||||
.fit(&matrix)?;
|
||||
|
||||
info!("Saving embeddings...");
|
||||
save_embeddings(&pool, &args, &model, &unique_item_ids).await?;
|
||||
|
||||
info!("Done!");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user