From e9c36cb805d6f1e908637c1efcfe3502786b0ff3 Mon Sep 17 00:00:00 2001 From: Dylan Knutson Date: Tue, 23 Apr 2024 11:13:41 -0700 Subject: [PATCH] more refactors, type wrappers --- Cargo.lock | 4 +- Cargo.toml | 4 +- src/load_test.rs => load_test/main.rs | 2 - src/axum_result_type.rs | 3 + src/main.rs | 203 ++++++++------------------ src/sha256.rs | 27 ++++ src/shard/mod.rs | 142 ++++++++++++++++++ src/store_request.rs | 28 ++++ 8 files changed, 261 insertions(+), 152 deletions(-) rename src/load_test.rs => load_test/main.rs (99%) create mode 100644 src/axum_result_type.rs create mode 100644 src/sha256.rs create mode 100644 src/shard/mod.rs create mode 100644 src/store_request.rs diff --git a/Cargo.lock b/Cargo.lock index ab31409..ef2dde7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1359,9 +1359,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.4.1" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ecd36cc4259e3e4514335c4a138c6b43171a8d61d8f5c9348f9fc7529416f247" +checksum = "beb461507cee2c2ff151784c52762cf4d9ff6a61f3e80968600ed24fa837fa54" [[package]] name = "rustversion" diff --git a/Cargo.toml b/Cargo.toml index 1b1fd94..6d123e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ path = "src/main.rs" [[bin]] name = "load-test" -path = "src/load_test.rs" +path = "load_test/main.rs" [dependencies] axum = { version = "0.7.5", features = ["macros"] } @@ -20,7 +20,6 @@ clap = { version = "4.5.4", features = ["derive"] } futures = "0.3.30" kdam = "0.5.1" rand = "0.8.5" -reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] } rusqlite = "0.31.0" serde = { version = "1.0.198", features = ["serde_derive"] } serde_json = "1.0.116" @@ -29,3 +28,4 @@ tokio = { version = "1.37.0", features = ["full", "rt-multi-thread"] } tokio-rusqlite = "0.5.1" tracing = "0.1.40" tracing-subscriber = "0.3.18" +reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] } diff --git a/src/load_test.rs b/load_test/main.rs similarity index 99% rename from src/load_test.rs rename to load_test/main.rs index 8ef5b09..21f55f7 100644 --- a/src/load_test.rs +++ b/load_test/main.rs @@ -1,5 +1,3 @@ - - use std::sync::Arc; use std::sync::Mutex; diff --git a/src/axum_result_type.rs b/src/axum_result_type.rs new file mode 100644 index 0000000..7918a22 --- /dev/null +++ b/src/axum_result_type.rs @@ -0,0 +1,3 @@ +use axum::{http::StatusCode, Json}; + +pub type AxumJsonResultOf = Result<(StatusCode, Json), (StatusCode, Json)>; diff --git a/src/main.rs b/src/main.rs index 1464655..61ec2e3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,20 +1,27 @@ +mod axum_result_type; +mod sha256; +mod shard; mod shutdown_signal; +mod store_request; use axum::Json; -use axum::{body::Bytes, http::StatusCode, routing::post, Extension, Router}; -use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; +use axum::{http::StatusCode, routing::post, Extension, Router}; +use axum_result_type::AxumJsonResultOf; +use axum_typed_multipart::TypedMultipart; use clap::Parser; -use rusqlite::ffi; -use rusqlite::params; -use rusqlite::Error::SqliteFailure; -use sha2::Digest; + +use shard::Shard; use std::collections::HashMap; -use std::{borrow::Borrow, error::Error, path::PathBuf, sync::Arc}; +use std::{error::Error, path::PathBuf}; +use store_request::StoreResponse; use tokio::net::TcpListener; -use tokio_rusqlite::Connection; use tracing::{error, info}; +use crate::sha256::Sha256; +use crate::shard::Shards; +use crate::store_request::{StoreRequest, StoreRequestParsed}; + #[derive(Parser, Debug)] #[command(version, about, long_about = None)] struct Args { @@ -22,6 +29,14 @@ struct Args { #[arg(short, long)] db_path: String, + /// Port to listen on + #[arg(short, long, default_value_t = 7692)] + port: u16, + + /// Host to bind to + #[arg(short, long, default_value = "127.0.0.1")] + bind: String, + /// Number of shards #[arg(short, long)] shards: Option, @@ -32,22 +47,6 @@ struct ManifestData { shards: usize, } -#[derive(TryFromMultipart)] -struct StoreRequest { - sha256: Option, - content_type: String, - data: FieldData, -} - -struct StoreRequestParsed { - sha256: String, - content_type: String, - data: Bytes, -} - -#[derive(Clone)] -struct Shards(Vec>); - fn main() -> Result<(), Box> { tracing_subscriber::fmt() .with_max_level(tracing::Level::DEBUG) @@ -55,7 +54,7 @@ fn main() -> Result<(), Box> { let args = Args::parse(); let db_path = PathBuf::from(&args.db_path); - let num_shards = validate_manifest(args)?; + let num_shards = validate_manifest(&args)?; // max num_shards threads let runtime = tokio::runtime::Builder::new_multi_thread() @@ -64,32 +63,30 @@ fn main() -> Result<(), Box> { .build()?; runtime.block_on(async { - let server = TcpListener::bind("127.0.0.1:7692").await?; + let server = TcpListener::bind(format!("{}:{}", args.bind, args.port)).await?; info!( "listening on {} with {} shards", server.local_addr()?, num_shards ); - let mut shards = vec![]; - for shard in 0..num_shards { - let shard_path = db_path.join(format!("shard{}.sqlite", shard)); - let conn = Connection::open(shard_path).await?; - migrate(&conn).await?; - shards.push(Arc::new(conn)); + let mut shards_vec = vec![]; + for shard_id in 0..num_shards { + let shard_path = db_path.join(format!("shard{}.sqlite", shard_id)); + let shard = Shard::open(shard_id, &shard_path).await?; + info!( + "shard {} has {} entries", + shard.id(), + shard.num_entries().await? + ); + shards_vec.push(shard); } - for (shard, conn) in shards.iter().enumerate() { - let count = num_entries_in(conn).await?; - info!("shard {} has {} entries", shard, count); - } - - server_loop(server, Shards(shards.clone())).await?; + let shards = Shards::new(shards_vec); + server_loop(server, shards.clone()).await?; info!("shutting down server..."); - for conn in shards.into_iter() { - (*conn).clone().close().await?; - } + shards.close_all().await?; info!("server closed sqlite connections. bye!"); - Ok::<(), Box>(()) + Ok::<_, Box>(()) })?; Ok(()) @@ -99,138 +96,52 @@ async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box>), - (StatusCode, Json>), ->; - #[axum::debug_handler] async fn store_request_handler( Extension(shards): Extension, TypedMultipart(request): TypedMultipart, -) -> ResultType { +) -> AxumJsonResultOf { // compute sha256 of data - let data_bytes = &request.data.contents; - let sha256 = sha2::Sha256::digest(data_bytes); - let sha256_str = format!("{:x}", sha256); - let num_shards = shards.0.len(); - // select shard - let shard_num = sha256[0] as usize % num_shards; - let conn = &shards.0[shard_num]; + let sha256 = Sha256::from_bytes(&request.data.contents); + let sha256_str = sha256.hex_string(); + let shard = shards.shard_for(&sha256); if let Some(req_sha256) = request.sha256 { if req_sha256 != sha256_str { - error!("sha256 mismatch: {} != {}", req_sha256, sha256_str); + error!( + "client sent mismatched sha256: (client) {} != (computed) {}", + req_sha256, sha256_str + ); let mut response = HashMap::new(); response.insert("status", "error".to_owned()); response.insert("message", "sha256 mismatch".to_owned()); - return Err((StatusCode::BAD_REQUEST, Json(response))); + return Err(( + StatusCode::BAD_REQUEST, + Json(StoreResponse::Error { + sha256: Some(req_sha256), + message: "sha256 mismatch".to_owned(), + }), + )); } } - // info!("storing {} on shard {}", sha256_str, shard_num); - let request_parsed = StoreRequestParsed { sha256: sha256_str, content_type: request.content_type, data: request.data.contents, }; - let conn = conn.borrow(); - perform_store(conn, request_parsed).await + shard.store(request_parsed).await } -async fn perform_store(conn: &Connection, store_request: StoreRequestParsed) -> ResultType { - conn.call(move |conn| { - let created_at = chrono::Utc::now().to_rfc3339(); - let maybe_error = conn.execute( - "INSERT INTO entries (sha256, content_type, size, data, created_at) VALUES (?, ?, ?, ?, ?)", - params![ - store_request.sha256, - store_request.content_type, - store_request.data.len() as i64, - store_request.data.to_vec(), - created_at, - ], - ); - - let mut response = HashMap::new(); - response.insert("sha256", store_request.sha256.clone()); - - if let Err(e) = &maybe_error { - if is_duplicate_entry_err(e) { - info!("entry {} already exists", store_request.sha256); - response.insert("status","ok".to_owned()); - response.insert("message", "already exists".to_owned()); - return Ok((StatusCode::OK, Json(response))); - } - } - maybe_error?; - let mut response = HashMap::new(); - response.insert("status", "ok".to_owned()); - response.insert("message", "created".to_owned()); - Ok((StatusCode::CREATED, Json(response))) - }) - .await.map_err(|e| { - error!("store failed: {}", e); - let mut response = HashMap::new(); - response.insert("status", "error".to_owned()); - response.insert("message", e.to_string()); - (StatusCode::INTERNAL_SERVER_ERROR, Json(response)) - }) -} - -fn is_duplicate_entry_err(error: &rusqlite::Error) -> bool { - if let SqliteFailure( - ffi::Error { - code: ffi::ErrorCode::ConstraintViolation, - .. - }, - Some(err_str), - ) = error - { - if err_str.contains("UNIQUE constraint failed: entries.sha256") { - return true; - } - } - false -} - -async fn migrate(conn: &Connection) -> Result<(), Box> { - // create tables, indexes, etc - conn.call(|conn| { - conn.execute( - "CREATE TABLE IF NOT EXISTS entries ( - sha256 BLOB PRIMARY KEY, - content_type TEXT NOT NULL, - size INTEGER NOT NULL, - data BLOB NOT NULL, - created_at TEXT NOT NULL - )", - [], - )?; - Ok(()) - }) - .await?; - Ok(()) -} - -async fn num_entries_in(conn: &Connection) -> Result> { - conn.call(|conn| { - let count: i64 = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?; - Ok(count) - }) - .await - .map_err(|e| e.into()) -} - -fn validate_manifest(args: Args) -> Result> { +fn validate_manifest(args: &Args) -> Result> { let manifest_path = PathBuf::from(&args.db_path).join("manifest.json"); if manifest_path.exists() { let file_content = std::fs::read_to_string(manifest_path)?; diff --git a/src/sha256.rs b/src/sha256.rs new file mode 100644 index 0000000..86aae24 --- /dev/null +++ b/src/sha256.rs @@ -0,0 +1,27 @@ +use std::fmt::LowerHex; + +use sha2::Digest; + +#[derive(Clone, Copy)] +pub struct Sha256([u8; 32]); +impl Sha256 { + pub fn from_bytes(bytes: &[u8]) -> Self { + let hash = sha2::Sha256::digest(bytes); + Self(hash.into()) + } + pub fn hex_string(&self) -> String { + format!("{:x}", self) + } + pub fn modulo(&self, num: usize) -> usize { + self.0[0] as usize % num + } +} + +impl LowerHex for Sha256 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for byte in self.0.iter() { + write!(f, "{:02x}", byte)?; + } + Ok(()) + } +} diff --git a/src/shard/mod.rs b/src/shard/mod.rs new file mode 100644 index 0000000..799ca97 --- /dev/null +++ b/src/shard/mod.rs @@ -0,0 +1,142 @@ +use std::{error::Error, path::Path}; + +use rusqlite::params; +use tokio_rusqlite::Connection; +use tracing::{error, info}; + +use crate::{ + axum_result_type::AxumJsonResultOf, + sha256::Sha256, + store_request::{StoreRequestParsed, StoreResponse}, +}; +use axum::{http::StatusCode, Json}; + +#[derive(Clone)] +pub struct Shards(Vec); +impl Shards { + pub fn new(shards: Vec) -> Self { + Self(shards) + } + + pub fn shard_for(&self, sha256: &Sha256) -> &Shard { + let shard_id = sha256.modulo(self.0.len()); + &self.0[shard_id] + } + + pub async fn close_all(self) -> Result<(), Box> { + for shard in self.0 { + shard.close().await?; + } + Ok(()) + } +} + +#[derive(Clone)] +pub struct Shard { + id: usize, + sqlite: Connection, +} + +impl Shard { + pub async fn open(id: usize, db_path: &Path) -> Result> { + let sqlite = Connection::open(db_path).await?; + migrate(&sqlite).await?; + Ok(Self { id, sqlite }) + } + + pub async fn close(self) -> Result<(), Box> { + self.sqlite.close().await?; + Ok(()) + } + + pub fn id(&self) -> usize { + self.id + } + + pub async fn store( + &self, + store_request: StoreRequestParsed, + ) -> AxumJsonResultOf { + let sha256 = store_request.sha256.clone(); + + self.sqlite.call(move |conn| { + let created_at = chrono::Utc::now().to_rfc3339(); + let maybe_error = conn.execute( + "INSERT INTO entries (sha256, content_type, size, data, created_at) VALUES (?, ?, ?, ?, ?)", + params![ + store_request.sha256, + store_request.content_type, + store_request.data.len() as i64, + store_request.data.to_vec(), + created_at, + ], + ); + + if let Err(e) = &maybe_error { + if is_duplicate_entry_err(e) { + info!("entry {} already exists", store_request.sha256); + return Ok((StatusCode::OK, Json(StoreResponse::Ok{ + sha256: store_request.sha256, + message: "exists", + }))); + } + } + maybe_error?; + Ok((StatusCode::CREATED, Json(StoreResponse::Ok { sha256: store_request.sha256, message: "created" }))) + }) + .await.map_err(|e| { + error!("store failed: {}", e); + (StatusCode::INTERNAL_SERVER_ERROR, Json(StoreResponse::Error { sha256: Some(sha256), message: e.to_string() })) + }) + } + + pub async fn num_entries(&self) -> Result> { + get_num_entries(&self.sqlite).await + } +} + +fn is_duplicate_entry_err(error: &rusqlite::Error) -> bool { + use rusqlite::*; + + if let Error::SqliteFailure( + ffi::Error { + code: ffi::ErrorCode::ConstraintViolation, + .. + }, + Some(err_str), + ) = error + { + if err_str.contains("UNIQUE constraint failed: entries.sha256") { + return true; + } + } + false +} + +async fn migrate(conn: &Connection) -> Result<(), Box> { + // create tables, indexes, etc + conn.call(|conn| { + conn.execute( + "CREATE TABLE IF NOT EXISTS entries ( + sha256 BLOB PRIMARY KEY, + content_type TEXT NOT NULL, + size INTEGER NOT NULL, + data BLOB NOT NULL, + created_at TEXT NOT NULL + )", + [], + )?; + Ok(()) + }) + .await?; + Ok(()) +} + +async fn get_num_entries(conn: &Connection) -> Result> { + conn.call(|conn| { + let count: usize = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?; + Ok(count) + }) + .await + .map_err(|e| e.into()) +} diff --git a/src/store_request.rs b/src/store_request.rs new file mode 100644 index 0000000..9ca0fe8 --- /dev/null +++ b/src/store_request.rs @@ -0,0 +1,28 @@ +use axum::body::Bytes; +use axum_typed_multipart::{FieldData, TryFromMultipart}; +use serde::Serialize; + +#[derive(TryFromMultipart)] +pub struct StoreRequest { + pub sha256: Option, + pub content_type: String, + pub data: FieldData, +} + +pub struct StoreRequestParsed { + pub sha256: String, + pub content_type: String, + pub data: Bytes, +} + +#[derive(Serialize)] +pub enum StoreResponse { + Ok { + sha256: String, + message: &'static str, + }, + Error { + sha256: Option, + message: String, + }, +}