From b5c367d26c82f0db0773f108de031c29e4f3bf53 Mon Sep 17 00:00:00 2001 From: Dylan Knutson Date: Tue, 23 Apr 2024 09:52:46 -0700 Subject: [PATCH] clippy, refactor --- .gitignore | 3 +- src/load_test.rs | 9 +++--- src/main.rs | 69 ++++++++++++++---------------------------- src/shutdown_signal.rs | 25 +++++++++++++++ 4 files changed, 54 insertions(+), 52 deletions(-) create mode 100644 src/shutdown_signal.rs diff --git a/.gitignore b/.gitignore index 4c16076..12718bc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .DS_Store -target \ No newline at end of file +target +/testdb \ No newline at end of file diff --git a/src/load_test.rs b/src/load_test.rs index 4c13ce0..8ef5b09 100644 --- a/src/load_test.rs +++ b/src/load_test.rs @@ -1,5 +1,5 @@ -use std::borrow::Borrow; -use std::borrow::BorrowMut; + + use std::sync::Arc; use std::sync::Mutex; @@ -13,7 +13,7 @@ fn main() -> Result<(), Box> { let mut handles = vec![]; let num_shards = 8; - for _ in (0..num_shards).into_iter() { + for _ in 0..num_shards { let pb = pb.clone(); handles.push(std::thread::spawn(move || { run_loop(pb).unwrap(); @@ -27,7 +27,7 @@ fn main() -> Result<(), Box> { Ok(()) } -fn run_loop(mut pb: Arc>) -> Result<(), Box> { +fn run_loop(pb: Arc>) -> Result<(), Box> { let client = reqwest::blocking::Client::new(); let mut rng = rand::thread_rng(); let mut rand_data = vec![0u8; 1024 * 1024]; @@ -55,5 +55,4 @@ fn run_loop(mut pb: Arc>) -> Result<(), Box> { pb.update(1)?; if resp.status() != 200 && resp.status() != 201 {} } - Ok(()) } diff --git a/src/main.rs b/src/main.rs index 64de86e..1464655 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +mod shutdown_signal; + use axum::Json; use axum::{body::Bytes, http::StatusCode, routing::post, Extension, Router}; use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; @@ -6,10 +8,10 @@ use rusqlite::ffi; use rusqlite::params; use rusqlite::Error::SqliteFailure; use sha2::Digest; -use tokio::signal; use std::collections::HashMap; use std::{borrow::Borrow, error::Error, path::PathBuf, sync::Arc}; use tokio::net::TcpListener; + use tokio_rusqlite::Connection; use tracing::{error, info}; @@ -77,7 +79,7 @@ fn main() -> Result<(), Box> { } for (shard, conn) in shards.iter().enumerate() { - let count = num_entries_in(&conn).await?; + let count = num_entries_in(conn).await?; info!("shard {} has {} entries", shard, count); } @@ -97,37 +99,15 @@ async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box(); - - tokio::select! { - _ = ctrl_c => {}, - _ = terminate => {}, - } -} - type ResultType = Result< - (StatusCode, Json>), - (StatusCode, Json>) + (StatusCode, Json>), + (StatusCode, Json>), >; #[axum::debug_handler] @@ -137,7 +117,7 @@ async fn store_request_handler( ) -> ResultType { // compute sha256 of data let data_bytes = &request.data.contents; - let sha256 = sha2::Sha256::digest(&data_bytes); + let sha256 = sha2::Sha256::digest(data_bytes); let sha256_str = format!("{:x}", sha256); let num_shards = shards.0.len(); // select shard @@ -163,13 +143,10 @@ async fn store_request_handler( }; let conn = conn.borrow(); - perform_store(&conn, request_parsed).await + perform_store(conn, request_parsed).await } -async fn perform_store( - conn: &Connection, - store_request: StoreRequestParsed, -) -> ResultType { +async fn perform_store(conn: &Connection, store_request: StoreRequestParsed) -> ResultType { conn.call(move |conn| { let created_at = chrono::Utc::now().to_rfc3339(); let maybe_error = conn.execute( @@ -198,7 +175,7 @@ async fn perform_store( let mut response = HashMap::new(); response.insert("status", "ok".to_owned()); response.insert("message", "created".to_owned()); - return Ok((StatusCode::CREATED, Json(response))); + Ok((StatusCode::CREATED, Json(response))) }) .await.map_err(|e| { error!("store failed: {}", e); @@ -248,7 +225,9 @@ async fn num_entries_in(conn: &Connection) -> Result> { conn.call(|conn| { let count: i64 = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?; Ok(count) - }).await.map_err(|e| e.into()) + }) + .await + .map_err(|e| e.into()) } fn validate_manifest(args: Args) -> Result> { @@ -266,15 +245,13 @@ fn validate_manifest(args: Args) -> Result> { } } Ok(manifest.shards) + } else if let Some(shards) = args.shards { + std::fs::create_dir_all(&args.db_path)?; + let manifest = ManifestData { shards }; + let manifest_json = serde_json::to_string(&manifest)?; + std::fs::write(manifest_path, manifest_json)?; + Ok(shards) } else { - if let Some(shards) = args.shards { - std::fs::create_dir_all(&args.db_path)?; - let manifest = ManifestData { shards }; - let manifest_json = serde_json::to_string(&manifest)?; - std::fs::write(manifest_path, manifest_json)?; - return Ok(shards); - } else { - return Err("new database needs --shards argument".into()); - } + Err("new database needs --shards argument".into()) } } diff --git a/src/shutdown_signal.rs b/src/shutdown_signal.rs new file mode 100644 index 0000000..a129d52 --- /dev/null +++ b/src/shutdown_signal.rs @@ -0,0 +1,25 @@ +use tokio::signal; + +pub async fn shutdown_signal() { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + } +}