From 857cbf5d1f1e1c7118f7cd36ba48ba3c3254c698 Mon Sep 17 00:00:00 2001 From: Dylan Knutson Date: Sat, 28 Dec 2024 17:41:42 +0000 Subject: [PATCH] add max interactions flag --- src/main.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/main.rs b/src/main.rs index 090645e..f2ee141 100644 --- a/src/main.rs +++ b/src/main.rs @@ -70,6 +70,10 @@ struct Args { /// Number of bins to use for training #[arg(long, default_value = "10")] bins: i32, + + /// Maximum number of interactions to load (optional) + #[arg(long)] + max_interactions: Option, } async fn create_pool() -> Result { @@ -355,6 +359,21 @@ async fn main() -> Result<()> { break; } + // Check if we would exceed max_interactions + if let Some(max) = args.max_interactions { + if total_rows + batch.len() > max { + // Only process up to max_interactions + let remaining = max - total_rows; + for (user_id, item_id) in batch.into_iter().take(remaining) { + matrix.push(user_id, item_id, 1.0f32); + item_ids.insert(item_id); + } + total_rows += remaining; + info!("Reached maximum interactions limit of {max}",); + break; + } + } + total_rows += batch.len(); info!( "Loaded batch of {} rows (total: {})",