add max interactions flag

This commit is contained in:
Dylan Knutson
2024-12-28 17:41:42 +00:00
parent 350c61c313
commit 857cbf5d1f

View File

@@ -70,6 +70,10 @@ struct Args {
/// Number of bins to use for training /// Number of bins to use for training
#[arg(long, default_value = "10")] #[arg(long, default_value = "10")]
bins: i32, bins: i32,
/// Maximum number of interactions to load (optional)
#[arg(long)]
max_interactions: Option<usize>,
} }
async fn create_pool() -> Result<Pool> { async fn create_pool() -> Result<Pool> {
@@ -355,6 +359,21 @@ async fn main() -> Result<()> {
break; break;
} }
// Check if we would exceed max_interactions
if let Some(max) = args.max_interactions {
if total_rows + batch.len() > max {
// Only process up to max_interactions
let remaining = max - total_rows;
for (user_id, item_id) in batch.into_iter().take(remaining) {
matrix.push(user_id, item_id, 1.0f32);
item_ids.insert(item_id);
}
total_rows += remaining;
info!("Reached maximum interactions limit of {max}",);
break;
}
}
total_rows += batch.len(); total_rows += batch.len();
info!( info!(
"Loaded batch of {} rows (total: {})", "Loaded batch of {} rows (total: {})",