diff --git a/.cargo/config.toml b/.cargo/config.toml index f4d28a3..a883339 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,2 +1,17 @@ -[build] -rustflags = ["-C", "link-arg=-fuse-ld=lld"] +[target.x86_64-unknown-linux-gnu] +rustflags = [ + "-C", + "link-arg=-fuse-ld=lld", + "-C", + "link-arg=-L/usr/lib/x86_64-linux-gnu", + "-C", + "link-arg=-lgomp", + "-C", + "link-arg=-fopenmp", +] + +[env] +CXXFLAGS = "-fopenmp -pthread -DUSEOMP=1" +LDFLAGS = "-fopenmp -pthread -DUSEOMP=1" +CC = "gcc" +CXX = "g++" diff --git a/src/main.rs b/src/main.rs index 19c7e76..6b7a56c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -28,6 +28,10 @@ struct Args { #[arg(long)] target_table: String, + /// Number of iterations for matrix factorization + #[arg(long, short = 'i', default_value = "100")] + iterations: i32, + /// Batch size for loading data #[arg(long, default_value = "10000")] batch_size: i32, @@ -51,6 +55,10 @@ struct Args { /// Number of threads for matrix factorization (defaults to number of CPU cores) #[arg(long, default_value_t = num_cpus::get() as i32)] threads: i32, + + /// Number of bins to use for training + #[arg(long, default_value = "10")] + bins: i32, } async fn create_pool() -> Result { @@ -113,6 +121,7 @@ async fn load_data_batch( Ok(batch) } +// TODO - don't load all item IDs at once async fn save_embeddings(pool: &Pool, args: &Args, model: &Model, item_ids: &[i32]) -> Result<()> { let client = pool.get().await?; @@ -248,11 +257,12 @@ async fn main() -> Result<()> { .lambda_p2(args.lambda2) .lambda_q2(args.lambda2) .learning_rate(args.learning_rate) - .iterations(100) + .iterations(args.iterations) .loss(Loss::OneClassL2) .c(0.00001) .quiet(false) .threads(args.threads) + .bins(args.bins) .fit(&matrix)?; info!("Saving embeddings...");