clippy, refactor
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,2 +1,3 @@
|
|||||||
.DS_Store
|
.DS_Store
|
||||||
target
|
target
|
||||||
|
/testdb
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
use std::borrow::Borrow;
|
|
||||||
use std::borrow::BorrowMut;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::Mutex;
|
use std::sync::Mutex;
|
||||||
|
|
||||||
@@ -13,7 +13,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
let mut handles = vec![];
|
let mut handles = vec![];
|
||||||
let num_shards = 8;
|
let num_shards = 8;
|
||||||
for _ in (0..num_shards).into_iter() {
|
for _ in 0..num_shards {
|
||||||
let pb = pb.clone();
|
let pb = pb.clone();
|
||||||
handles.push(std::thread::spawn(move || {
|
handles.push(std::thread::spawn(move || {
|
||||||
run_loop(pb).unwrap();
|
run_loop(pb).unwrap();
|
||||||
@@ -27,7 +27,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_loop(mut pb: Arc<Mutex<Bar>>) -> Result<(), Box<dyn std::error::Error>> {
|
fn run_loop(pb: Arc<Mutex<Bar>>) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let client = reqwest::blocking::Client::new();
|
let client = reqwest::blocking::Client::new();
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::thread_rng();
|
||||||
let mut rand_data = vec![0u8; 1024 * 1024];
|
let mut rand_data = vec![0u8; 1024 * 1024];
|
||||||
@@ -55,5 +55,4 @@ fn run_loop(mut pb: Arc<Mutex<Bar>>) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
pb.update(1)?;
|
pb.update(1)?;
|
||||||
if resp.status() != 200 && resp.status() != 201 {}
|
if resp.status() != 200 && resp.status() != 201 {}
|
||||||
}
|
}
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|||||||
69
src/main.rs
69
src/main.rs
@@ -1,3 +1,5 @@
|
|||||||
|
mod shutdown_signal;
|
||||||
|
|
||||||
use axum::Json;
|
use axum::Json;
|
||||||
use axum::{body::Bytes, http::StatusCode, routing::post, Extension, Router};
|
use axum::{body::Bytes, http::StatusCode, routing::post, Extension, Router};
|
||||||
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||||
@@ -6,10 +8,10 @@ use rusqlite::ffi;
|
|||||||
use rusqlite::params;
|
use rusqlite::params;
|
||||||
use rusqlite::Error::SqliteFailure;
|
use rusqlite::Error::SqliteFailure;
|
||||||
use sha2::Digest;
|
use sha2::Digest;
|
||||||
use tokio::signal;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::{borrow::Borrow, error::Error, path::PathBuf, sync::Arc};
|
use std::{borrow::Borrow, error::Error, path::PathBuf, sync::Arc};
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
use tokio_rusqlite::Connection;
|
use tokio_rusqlite::Connection;
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
|
|
||||||
@@ -77,7 +79,7 @@ fn main() -> Result<(), Box<dyn Error>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (shard, conn) in shards.iter().enumerate() {
|
for (shard, conn) in shards.iter().enumerate() {
|
||||||
let count = num_entries_in(&conn).await?;
|
let count = num_entries_in(conn).await?;
|
||||||
info!("shard {} has {} entries", shard, count);
|
info!("shard {} has {} entries", shard, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,37 +99,15 @@ async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box<dyn
|
|||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/store", post(store_request_handler))
|
.route("/store", post(store_request_handler))
|
||||||
.layer(Extension(shards));
|
.layer(Extension(shards));
|
||||||
axum::serve(server, app.into_make_service()).with_graceful_shutdown(shutdown_signal()).await?;
|
axum::serve(server, app.into_make_service())
|
||||||
|
.with_graceful_shutdown(crate::shutdown_signal::shutdown_signal())
|
||||||
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn shutdown_signal() {
|
|
||||||
let ctrl_c = async {
|
|
||||||
signal::ctrl_c()
|
|
||||||
.await
|
|
||||||
.expect("failed to install Ctrl+C handler");
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(unix)]
|
|
||||||
let terminate = async {
|
|
||||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
|
||||||
.expect("failed to install signal handler")
|
|
||||||
.recv()
|
|
||||||
.await;
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(not(unix))]
|
|
||||||
let terminate = std::future::pending::<()>();
|
|
||||||
|
|
||||||
tokio::select! {
|
|
||||||
_ = ctrl_c => {},
|
|
||||||
_ = terminate => {},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ResultType = Result<
|
type ResultType = Result<
|
||||||
(StatusCode, Json<HashMap<&'static str, String>>),
|
(StatusCode, Json<HashMap<&'static str, String>>),
|
||||||
(StatusCode, Json<HashMap<&'static str, String>>)
|
(StatusCode, Json<HashMap<&'static str, String>>),
|
||||||
>;
|
>;
|
||||||
|
|
||||||
#[axum::debug_handler]
|
#[axum::debug_handler]
|
||||||
@@ -137,7 +117,7 @@ async fn store_request_handler(
|
|||||||
) -> ResultType {
|
) -> ResultType {
|
||||||
// compute sha256 of data
|
// compute sha256 of data
|
||||||
let data_bytes = &request.data.contents;
|
let data_bytes = &request.data.contents;
|
||||||
let sha256 = sha2::Sha256::digest(&data_bytes);
|
let sha256 = sha2::Sha256::digest(data_bytes);
|
||||||
let sha256_str = format!("{:x}", sha256);
|
let sha256_str = format!("{:x}", sha256);
|
||||||
let num_shards = shards.0.len();
|
let num_shards = shards.0.len();
|
||||||
// select shard
|
// select shard
|
||||||
@@ -163,13 +143,10 @@ async fn store_request_handler(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let conn = conn.borrow();
|
let conn = conn.borrow();
|
||||||
perform_store(&conn, request_parsed).await
|
perform_store(conn, request_parsed).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn perform_store(
|
async fn perform_store(conn: &Connection, store_request: StoreRequestParsed) -> ResultType {
|
||||||
conn: &Connection,
|
|
||||||
store_request: StoreRequestParsed,
|
|
||||||
) -> ResultType {
|
|
||||||
conn.call(move |conn| {
|
conn.call(move |conn| {
|
||||||
let created_at = chrono::Utc::now().to_rfc3339();
|
let created_at = chrono::Utc::now().to_rfc3339();
|
||||||
let maybe_error = conn.execute(
|
let maybe_error = conn.execute(
|
||||||
@@ -198,7 +175,7 @@ async fn perform_store(
|
|||||||
let mut response = HashMap::new();
|
let mut response = HashMap::new();
|
||||||
response.insert("status", "ok".to_owned());
|
response.insert("status", "ok".to_owned());
|
||||||
response.insert("message", "created".to_owned());
|
response.insert("message", "created".to_owned());
|
||||||
return Ok((StatusCode::CREATED, Json(response)));
|
Ok((StatusCode::CREATED, Json(response)))
|
||||||
})
|
})
|
||||||
.await.map_err(|e| {
|
.await.map_err(|e| {
|
||||||
error!("store failed: {}", e);
|
error!("store failed: {}", e);
|
||||||
@@ -248,7 +225,9 @@ async fn num_entries_in(conn: &Connection) -> Result<i64, Box<dyn Error>> {
|
|||||||
conn.call(|conn| {
|
conn.call(|conn| {
|
||||||
let count: i64 = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?;
|
let count: i64 = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?;
|
||||||
Ok(count)
|
Ok(count)
|
||||||
}).await.map_err(|e| e.into())
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|e| e.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn validate_manifest(args: Args) -> Result<usize, Box<dyn Error>> {
|
fn validate_manifest(args: Args) -> Result<usize, Box<dyn Error>> {
|
||||||
@@ -266,15 +245,13 @@ fn validate_manifest(args: Args) -> Result<usize, Box<dyn Error>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(manifest.shards)
|
Ok(manifest.shards)
|
||||||
|
} else if let Some(shards) = args.shards {
|
||||||
|
std::fs::create_dir_all(&args.db_path)?;
|
||||||
|
let manifest = ManifestData { shards };
|
||||||
|
let manifest_json = serde_json::to_string(&manifest)?;
|
||||||
|
std::fs::write(manifest_path, manifest_json)?;
|
||||||
|
Ok(shards)
|
||||||
} else {
|
} else {
|
||||||
if let Some(shards) = args.shards {
|
Err("new database needs --shards argument".into())
|
||||||
std::fs::create_dir_all(&args.db_path)?;
|
|
||||||
let manifest = ManifestData { shards };
|
|
||||||
let manifest_json = serde_json::to_string(&manifest)?;
|
|
||||||
std::fs::write(manifest_path, manifest_json)?;
|
|
||||||
return Ok(shards);
|
|
||||||
} else {
|
|
||||||
return Err("new database needs --shards argument".into());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
25
src/shutdown_signal.rs
Normal file
25
src/shutdown_signal.rs
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
use tokio::signal;
|
||||||
|
|
||||||
|
pub async fn shutdown_signal() {
|
||||||
|
let ctrl_c = async {
|
||||||
|
signal::ctrl_c()
|
||||||
|
.await
|
||||||
|
.expect("failed to install Ctrl+C handler");
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
let terminate = async {
|
||||||
|
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||||
|
.expect("failed to install signal handler")
|
||||||
|
.recv()
|
||||||
|
.await;
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(not(unix))]
|
||||||
|
let terminate = std::future::pending::<()>();
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
_ = ctrl_c => {},
|
||||||
|
_ = terminate => {},
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user