clippy, refactor

This commit is contained in:
Dylan Knutson
2024-04-23 09:52:46 -07:00
parent 37cc74bfd1
commit b5c367d26c
4 changed files with 54 additions and 52 deletions

1
.gitignore vendored
View File

@@ -1,2 +1,3 @@
.DS_Store
target
/testdb

View File

@@ -1,5 +1,5 @@
use std::borrow::Borrow;
use std::borrow::BorrowMut;
use std::sync::Arc;
use std::sync::Mutex;
@@ -13,7 +13,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut handles = vec![];
let num_shards = 8;
for _ in (0..num_shards).into_iter() {
for _ in 0..num_shards {
let pb = pb.clone();
handles.push(std::thread::spawn(move || {
run_loop(pb).unwrap();
@@ -27,7 +27,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
fn run_loop(mut pb: Arc<Mutex<Bar>>) -> Result<(), Box<dyn std::error::Error>> {
fn run_loop(pb: Arc<Mutex<Bar>>) -> Result<(), Box<dyn std::error::Error>> {
let client = reqwest::blocking::Client::new();
let mut rng = rand::thread_rng();
let mut rand_data = vec![0u8; 1024 * 1024];
@@ -55,5 +55,4 @@ fn run_loop(mut pb: Arc<Mutex<Bar>>) -> Result<(), Box<dyn std::error::Error>> {
pb.update(1)?;
if resp.status() != 200 && resp.status() != 201 {}
}
Ok(())
}

View File

@@ -1,3 +1,5 @@
mod shutdown_signal;
use axum::Json;
use axum::{body::Bytes, http::StatusCode, routing::post, Extension, Router};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
@@ -6,10 +8,10 @@ use rusqlite::ffi;
use rusqlite::params;
use rusqlite::Error::SqliteFailure;
use sha2::Digest;
use tokio::signal;
use std::collections::HashMap;
use std::{borrow::Borrow, error::Error, path::PathBuf, sync::Arc};
use tokio::net::TcpListener;
use tokio_rusqlite::Connection;
use tracing::{error, info};
@@ -77,7 +79,7 @@ fn main() -> Result<(), Box<dyn Error>> {
}
for (shard, conn) in shards.iter().enumerate() {
let count = num_entries_in(&conn).await?;
let count = num_entries_in(conn).await?;
info!("shard {} has {} entries", shard, count);
}
@@ -97,37 +99,15 @@ async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box<dyn
let app = Router::new()
.route("/store", post(store_request_handler))
.layer(Extension(shards));
axum::serve(server, app.into_make_service()).with_graceful_shutdown(shutdown_signal()).await?;
axum::serve(server, app.into_make_service())
.with_graceful_shutdown(crate::shutdown_signal::shutdown_signal())
.await?;
Ok(())
}
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
}
type ResultType = Result<
(StatusCode, Json<HashMap<&'static str, String>>),
(StatusCode, Json<HashMap<&'static str, String>>)
(StatusCode, Json<HashMap<&'static str, String>>),
>;
#[axum::debug_handler]
@@ -137,7 +117,7 @@ async fn store_request_handler(
) -> ResultType {
// compute sha256 of data
let data_bytes = &request.data.contents;
let sha256 = sha2::Sha256::digest(&data_bytes);
let sha256 = sha2::Sha256::digest(data_bytes);
let sha256_str = format!("{:x}", sha256);
let num_shards = shards.0.len();
// select shard
@@ -163,13 +143,10 @@ async fn store_request_handler(
};
let conn = conn.borrow();
perform_store(&conn, request_parsed).await
perform_store(conn, request_parsed).await
}
async fn perform_store(
conn: &Connection,
store_request: StoreRequestParsed,
) -> ResultType {
async fn perform_store(conn: &Connection, store_request: StoreRequestParsed) -> ResultType {
conn.call(move |conn| {
let created_at = chrono::Utc::now().to_rfc3339();
let maybe_error = conn.execute(
@@ -198,7 +175,7 @@ async fn perform_store(
let mut response = HashMap::new();
response.insert("status", "ok".to_owned());
response.insert("message", "created".to_owned());
return Ok((StatusCode::CREATED, Json(response)));
Ok((StatusCode::CREATED, Json(response)))
})
.await.map_err(|e| {
error!("store failed: {}", e);
@@ -248,7 +225,9 @@ async fn num_entries_in(conn: &Connection) -> Result<i64, Box<dyn Error>> {
conn.call(|conn| {
let count: i64 = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?;
Ok(count)
}).await.map_err(|e| e.into())
})
.await
.map_err(|e| e.into())
}
fn validate_manifest(args: Args) -> Result<usize, Box<dyn Error>> {
@@ -266,15 +245,13 @@ fn validate_manifest(args: Args) -> Result<usize, Box<dyn Error>> {
}
}
Ok(manifest.shards)
} else if let Some(shards) = args.shards {
std::fs::create_dir_all(&args.db_path)?;
let manifest = ManifestData { shards };
let manifest_json = serde_json::to_string(&manifest)?;
std::fs::write(manifest_path, manifest_json)?;
Ok(shards)
} else {
if let Some(shards) = args.shards {
std::fs::create_dir_all(&args.db_path)?;
let manifest = ManifestData { shards };
let manifest_json = serde_json::to_string(&manifest)?;
std::fs::write(manifest_path, manifest_json)?;
return Ok(shards);
} else {
return Err("new database needs --shards argument".into());
}
Err("new database needs --shards argument".into())
}
}

25
src/shutdown_signal.rs Normal file
View File

@@ -0,0 +1,25 @@
use tokio::signal;
pub async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
}