clippy, refactor
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
.DS_Store
|
||||
target
|
||||
target
|
||||
/testdb
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::borrow::Borrow;
|
||||
use std::borrow::BorrowMut;
|
||||
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
@@ -13,7 +13,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
||||
let mut handles = vec![];
|
||||
let num_shards = 8;
|
||||
for _ in (0..num_shards).into_iter() {
|
||||
for _ in 0..num_shards {
|
||||
let pb = pb.clone();
|
||||
handles.push(std::thread::spawn(move || {
|
||||
run_loop(pb).unwrap();
|
||||
@@ -27,7 +27,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_loop(mut pb: Arc<Mutex<Bar>>) -> Result<(), Box<dyn std::error::Error>> {
|
||||
fn run_loop(pb: Arc<Mutex<Bar>>) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut rand_data = vec![0u8; 1024 * 1024];
|
||||
@@ -55,5 +55,4 @@ fn run_loop(mut pb: Arc<Mutex<Bar>>) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pb.update(1)?;
|
||||
if resp.status() != 200 && resp.status() != 201 {}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
69
src/main.rs
69
src/main.rs
@@ -1,3 +1,5 @@
|
||||
mod shutdown_signal;
|
||||
|
||||
use axum::Json;
|
||||
use axum::{body::Bytes, http::StatusCode, routing::post, Extension, Router};
|
||||
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||
@@ -6,10 +8,10 @@ use rusqlite::ffi;
|
||||
use rusqlite::params;
|
||||
use rusqlite::Error::SqliteFailure;
|
||||
use sha2::Digest;
|
||||
use tokio::signal;
|
||||
use std::collections::HashMap;
|
||||
use std::{borrow::Borrow, error::Error, path::PathBuf, sync::Arc};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use tokio_rusqlite::Connection;
|
||||
use tracing::{error, info};
|
||||
|
||||
@@ -77,7 +79,7 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
}
|
||||
|
||||
for (shard, conn) in shards.iter().enumerate() {
|
||||
let count = num_entries_in(&conn).await?;
|
||||
let count = num_entries_in(conn).await?;
|
||||
info!("shard {} has {} entries", shard, count);
|
||||
}
|
||||
|
||||
@@ -97,37 +99,15 @@ async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box<dyn
|
||||
let app = Router::new()
|
||||
.route("/store", post(store_request_handler))
|
||||
.layer(Extension(shards));
|
||||
axum::serve(server, app.into_make_service()).with_graceful_shutdown(shutdown_signal()).await?;
|
||||
axum::serve(server, app.into_make_service())
|
||||
.with_graceful_shutdown(crate::shutdown_signal::shutdown_signal())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install Ctrl+C handler");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||
.expect("failed to install signal handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
tokio::select! {
|
||||
_ = ctrl_c => {},
|
||||
_ = terminate => {},
|
||||
}
|
||||
}
|
||||
|
||||
type ResultType = Result<
|
||||
(StatusCode, Json<HashMap<&'static str, String>>),
|
||||
(StatusCode, Json<HashMap<&'static str, String>>)
|
||||
(StatusCode, Json<HashMap<&'static str, String>>),
|
||||
(StatusCode, Json<HashMap<&'static str, String>>),
|
||||
>;
|
||||
|
||||
#[axum::debug_handler]
|
||||
@@ -137,7 +117,7 @@ async fn store_request_handler(
|
||||
) -> ResultType {
|
||||
// compute sha256 of data
|
||||
let data_bytes = &request.data.contents;
|
||||
let sha256 = sha2::Sha256::digest(&data_bytes);
|
||||
let sha256 = sha2::Sha256::digest(data_bytes);
|
||||
let sha256_str = format!("{:x}", sha256);
|
||||
let num_shards = shards.0.len();
|
||||
// select shard
|
||||
@@ -163,13 +143,10 @@ async fn store_request_handler(
|
||||
};
|
||||
|
||||
let conn = conn.borrow();
|
||||
perform_store(&conn, request_parsed).await
|
||||
perform_store(conn, request_parsed).await
|
||||
}
|
||||
|
||||
async fn perform_store(
|
||||
conn: &Connection,
|
||||
store_request: StoreRequestParsed,
|
||||
) -> ResultType {
|
||||
async fn perform_store(conn: &Connection, store_request: StoreRequestParsed) -> ResultType {
|
||||
conn.call(move |conn| {
|
||||
let created_at = chrono::Utc::now().to_rfc3339();
|
||||
let maybe_error = conn.execute(
|
||||
@@ -198,7 +175,7 @@ async fn perform_store(
|
||||
let mut response = HashMap::new();
|
||||
response.insert("status", "ok".to_owned());
|
||||
response.insert("message", "created".to_owned());
|
||||
return Ok((StatusCode::CREATED, Json(response)));
|
||||
Ok((StatusCode::CREATED, Json(response)))
|
||||
})
|
||||
.await.map_err(|e| {
|
||||
error!("store failed: {}", e);
|
||||
@@ -248,7 +225,9 @@ async fn num_entries_in(conn: &Connection) -> Result<i64, Box<dyn Error>> {
|
||||
conn.call(|conn| {
|
||||
let count: i64 = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?;
|
||||
Ok(count)
|
||||
}).await.map_err(|e| e.into())
|
||||
})
|
||||
.await
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
fn validate_manifest(args: Args) -> Result<usize, Box<dyn Error>> {
|
||||
@@ -266,15 +245,13 @@ fn validate_manifest(args: Args) -> Result<usize, Box<dyn Error>> {
|
||||
}
|
||||
}
|
||||
Ok(manifest.shards)
|
||||
} else if let Some(shards) = args.shards {
|
||||
std::fs::create_dir_all(&args.db_path)?;
|
||||
let manifest = ManifestData { shards };
|
||||
let manifest_json = serde_json::to_string(&manifest)?;
|
||||
std::fs::write(manifest_path, manifest_json)?;
|
||||
Ok(shards)
|
||||
} else {
|
||||
if let Some(shards) = args.shards {
|
||||
std::fs::create_dir_all(&args.db_path)?;
|
||||
let manifest = ManifestData { shards };
|
||||
let manifest_json = serde_json::to_string(&manifest)?;
|
||||
std::fs::write(manifest_path, manifest_json)?;
|
||||
return Ok(shards);
|
||||
} else {
|
||||
return Err("new database needs --shards argument".into());
|
||||
}
|
||||
Err("new database needs --shards argument".into())
|
||||
}
|
||||
}
|
||||
|
||||
25
src/shutdown_signal.rs
Normal file
25
src/shutdown_signal.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use tokio::signal;
|
||||
|
||||
pub async fn shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install Ctrl+C handler");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||
.expect("failed to install signal handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
tokio::select! {
|
||||
_ = ctrl_c => {},
|
||||
_ = terminate => {},
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user