add max interactions flag
This commit is contained in:
19
src/main.rs
19
src/main.rs
@@ -70,6 +70,10 @@ struct Args {
|
||||
/// Number of bins to use for training
|
||||
#[arg(long, default_value = "10")]
|
||||
bins: i32,
|
||||
|
||||
/// Maximum number of interactions to load (optional)
|
||||
#[arg(long)]
|
||||
max_interactions: Option<usize>,
|
||||
}
|
||||
|
||||
async fn create_pool() -> Result<Pool> {
|
||||
@@ -355,6 +359,21 @@ async fn main() -> Result<()> {
|
||||
break;
|
||||
}
|
||||
|
||||
// Check if we would exceed max_interactions
|
||||
if let Some(max) = args.max_interactions {
|
||||
if total_rows + batch.len() > max {
|
||||
// Only process up to max_interactions
|
||||
let remaining = max - total_rows;
|
||||
for (user_id, item_id) in batch.into_iter().take(remaining) {
|
||||
matrix.push(user_id, item_id, 1.0f32);
|
||||
item_ids.insert(item_id);
|
||||
}
|
||||
total_rows += remaining;
|
||||
info!("Reached maximum interactions limit of {max}",);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
total_rows += batch.len();
|
||||
info!(
|
||||
"Loaded batch of {} rows (total: {})",
|
||||
|
||||
Reference in New Issue
Block a user