make libmf multithreading work
This commit is contained in:
@@ -1,2 +1,17 @@
|
|||||||
[build]
|
[target.x86_64-unknown-linux-gnu]
|
||||||
rustflags = ["-C", "link-arg=-fuse-ld=lld"]
|
rustflags = [
|
||||||
|
"-C",
|
||||||
|
"link-arg=-fuse-ld=lld",
|
||||||
|
"-C",
|
||||||
|
"link-arg=-L/usr/lib/x86_64-linux-gnu",
|
||||||
|
"-C",
|
||||||
|
"link-arg=-lgomp",
|
||||||
|
"-C",
|
||||||
|
"link-arg=-fopenmp",
|
||||||
|
]
|
||||||
|
|
||||||
|
[env]
|
||||||
|
CXXFLAGS = "-fopenmp -pthread -DUSEOMP=1"
|
||||||
|
LDFLAGS = "-fopenmp -pthread -DUSEOMP=1"
|
||||||
|
CC = "gcc"
|
||||||
|
CXX = "g++"
|
||||||
|
|||||||
12
src/main.rs
12
src/main.rs
@@ -28,6 +28,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
target_table: String,
|
target_table: String,
|
||||||
|
|
||||||
|
/// Number of iterations for matrix factorization
|
||||||
|
#[arg(long, short = 'i', default_value = "100")]
|
||||||
|
iterations: i32,
|
||||||
|
|
||||||
/// Batch size for loading data
|
/// Batch size for loading data
|
||||||
#[arg(long, default_value = "10000")]
|
#[arg(long, default_value = "10000")]
|
||||||
batch_size: i32,
|
batch_size: i32,
|
||||||
@@ -51,6 +55,10 @@ struct Args {
|
|||||||
/// Number of threads for matrix factorization (defaults to number of CPU cores)
|
/// Number of threads for matrix factorization (defaults to number of CPU cores)
|
||||||
#[arg(long, default_value_t = num_cpus::get() as i32)]
|
#[arg(long, default_value_t = num_cpus::get() as i32)]
|
||||||
threads: i32,
|
threads: i32,
|
||||||
|
|
||||||
|
/// Number of bins to use for training
|
||||||
|
#[arg(long, default_value = "10")]
|
||||||
|
bins: i32,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create_pool() -> Result<Pool> {
|
async fn create_pool() -> Result<Pool> {
|
||||||
@@ -113,6 +121,7 @@ async fn load_data_batch(
|
|||||||
Ok(batch)
|
Ok(batch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO - don't load all item IDs at once
|
||||||
async fn save_embeddings(pool: &Pool, args: &Args, model: &Model, item_ids: &[i32]) -> Result<()> {
|
async fn save_embeddings(pool: &Pool, args: &Args, model: &Model, item_ids: &[i32]) -> Result<()> {
|
||||||
let client = pool.get().await?;
|
let client = pool.get().await?;
|
||||||
|
|
||||||
@@ -248,11 +257,12 @@ async fn main() -> Result<()> {
|
|||||||
.lambda_p2(args.lambda2)
|
.lambda_p2(args.lambda2)
|
||||||
.lambda_q2(args.lambda2)
|
.lambda_q2(args.lambda2)
|
||||||
.learning_rate(args.learning_rate)
|
.learning_rate(args.learning_rate)
|
||||||
.iterations(100)
|
.iterations(args.iterations)
|
||||||
.loss(Loss::OneClassL2)
|
.loss(Loss::OneClassL2)
|
||||||
.c(0.00001)
|
.c(0.00001)
|
||||||
.quiet(false)
|
.quiet(false)
|
||||||
.threads(args.threads)
|
.threads(args.threads)
|
||||||
|
.bins(args.bins)
|
||||||
.fit(&matrix)?;
|
.fit(&matrix)?;
|
||||||
|
|
||||||
info!("Saving embeddings...");
|
info!("Saving embeddings...");
|
||||||
|
|||||||
Reference in New Issue
Block a user