add max interactions flag
This commit is contained in:
19
src/main.rs
19
src/main.rs
@@ -70,6 +70,10 @@ struct Args {
|
|||||||
/// Number of bins to use for training
|
/// Number of bins to use for training
|
||||||
#[arg(long, default_value = "10")]
|
#[arg(long, default_value = "10")]
|
||||||
bins: i32,
|
bins: i32,
|
||||||
|
|
||||||
|
/// Maximum number of interactions to load (optional)
|
||||||
|
#[arg(long)]
|
||||||
|
max_interactions: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create_pool() -> Result<Pool> {
|
async fn create_pool() -> Result<Pool> {
|
||||||
@@ -355,6 +359,21 @@ async fn main() -> Result<()> {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if we would exceed max_interactions
|
||||||
|
if let Some(max) = args.max_interactions {
|
||||||
|
if total_rows + batch.len() > max {
|
||||||
|
// Only process up to max_interactions
|
||||||
|
let remaining = max - total_rows;
|
||||||
|
for (user_id, item_id) in batch.into_iter().take(remaining) {
|
||||||
|
matrix.push(user_id, item_id, 1.0f32);
|
||||||
|
item_ids.insert(item_id);
|
||||||
|
}
|
||||||
|
total_rows += remaining;
|
||||||
|
info!("Reached maximum interactions limit of {max}",);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
total_rows += batch.len();
|
total_rows += batch.len();
|
||||||
info!(
|
info!(
|
||||||
"Loaded batch of {} rows (total: {})",
|
"Loaded batch of {} rows (total: {})",
|
||||||
|
|||||||
Reference in New Issue
Block a user