diff --git a/Cargo.lock b/Cargo.lock index ef2dde7..43ce8a1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -265,6 +265,7 @@ dependencies = [ "chrono", "clap", "futures", + "hex", "kdam", "rand", "reqwest", @@ -711,6 +712,12 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "http" version = "1.1.0" diff --git a/Cargo.toml b/Cargo.toml index 6d123e0..b650165 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,3 +29,4 @@ tokio-rusqlite = "0.5.1" tracing = "0.1.40" tracing-subscriber = "0.3.18" reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] } +hex = "0.4.3" diff --git a/load_test/main.rs b/load_test/main.rs index 21f55f7..66b76b5 100644 --- a/load_test/main.rs +++ b/load_test/main.rs @@ -1,20 +1,33 @@ use std::sync::Arc; use std::sync::Mutex; +use clap::Parser; use kdam::tqdm; use kdam::Bar; use kdam::BarExt; use rand::Rng; +use reqwest::StatusCode; + +#[derive(Parser, Debug, Clone)] +#[command(version, about, long_about = None)] +struct Args { + #[arg(long)] + file_size: usize, + #[arg(long)] + num_threads: usize, +} fn main() -> Result<(), Box> { + let args = Args::parse(); + let pb = Arc::new(Mutex::new(tqdm!())); let mut handles = vec![]; - let num_shards = 8; - for _ in 0..num_shards { + for _ in 0..args.num_threads { let pb = pb.clone(); + let args = args.clone(); handles.push(std::thread::spawn(move || { - run_loop(pb).unwrap(); + run_loop(pb, args).unwrap(); })); } @@ -25,10 +38,10 @@ fn main() -> Result<(), Box> { Ok(()) } -fn run_loop(pb: Arc>) -> Result<(), Box> { +fn run_loop(pb: Arc>, args: Args) -> Result<(), Box> { let client = reqwest::blocking::Client::new(); let mut rng = rand::thread_rng(); - let mut rand_data = vec![0u8; 1024 * 1024]; + let mut rand_data = vec![0u8; args.file_size]; rng.fill(&mut rand_data[..]); loop { @@ -50,7 +63,8 @@ fn run_loop(pb: Arc>) -> Result<(), Box> { // update progress bar let mut pb = pb.lock().unwrap(); - pb.update(1)?; - if resp.status() != 200 && resp.status() != 201 {} + if resp.status() == StatusCode::CREATED { + pb.update(1)?; + } } } diff --git a/src/handlers/get_handler.rs b/src/handlers/get_handler.rs new file mode 100644 index 0000000..d485b32 --- /dev/null +++ b/src/handlers/get_handler.rs @@ -0,0 +1,89 @@ +use std::collections::HashMap; + +use axum::{ + extract::Path, + http::{header, HeaderMap, HeaderName, HeaderValue, StatusCode}, + Extension, Json, +}; + +use crate::shard::Shards; + +#[derive(Debug, serde::Serialize)] +pub struct GetError { + sha256: Option, + message: String, +} + +#[axum::debug_handler] +pub async fn get_handler( + Path(params): Path>, + Extension(shards): Extension, +) -> Result<(StatusCode, HeaderMap, Vec), (StatusCode, Json)> { + let sha256_str = match params.get("sha256") { + Some(sha256_str) => sha256_str.clone(), + None => { + return Err(( + StatusCode::BAD_REQUEST, + Json(GetError { + sha256: None, + message: "missing sha256 parameter".to_owned(), + }), + )); + } + }; + + let sha256 = crate::sha256::Sha256::from_hex_string(&sha256_str).map_err(|e| { + ( + StatusCode::BAD_REQUEST, + Json(GetError { + sha256: Some(sha256_str), + message: e.to_string(), + }), + ) + })?; + + let internal_error = |message| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(GetError { + sha256: Some(sha256.hex_string()), + message, + }), + ) + }; + + let shard = shards.shard_for(&sha256); + let response = shard + .get(sha256) + .await + .map_err(|e| internal_error(e.to_string()))?; + + let sha256_str = sha256.hex_string(); + match response { + Some(response) => { + let content_type = HeaderValue::from_str(&response.content_type) + .map_err(|e| internal_error(e.to_string()))?; + let created_at = HeaderValue::from_str(&response.created_at.to_rfc3339()) + .map_err(|e| internal_error(e.to_string()))?; + + let mut headers = HeaderMap::new(); + headers.insert(header::CONTENT_TYPE, content_type); + headers.insert( + header::CACHE_CONTROL, + HeaderValue::from_static("public, max-age=31536000"), + ); + headers.insert(header::ETAG, HeaderValue::from_str(&sha256_str).unwrap()); + headers.insert(HeaderName::from_static("x-stored-at"), created_at); + + Ok((StatusCode::OK, headers, response.data)) + } + None => Err(( + StatusCode::NOT_FOUND, + GetError { + sha256: Some(sha256_str), + message: "not found".to_owned(), + } + .into(), + )), + } +} diff --git a/src/handlers/info_handler.rs b/src/handlers/info_handler.rs new file mode 100644 index 0000000..f6f727d --- /dev/null +++ b/src/handlers/info_handler.rs @@ -0,0 +1,46 @@ +use crate::shard::Shards; +use axum::{http::StatusCode, Extension, Json}; + +use tracing::error; + +#[derive(serde::Serialize)] +pub struct InfoResponse { + num_shards: usize, + shards: Vec, +} + +#[derive(serde::Serialize)] +pub struct ShardInfo { + id: usize, + num_entries: usize, + size_bytes: u64, +} + +#[axum::debug_handler] +pub async fn info_handler( + Extension(shards): Extension, +) -> Result<(StatusCode, Json), StatusCode> { + let mut shard_infos = vec![]; + for shard in shards.iter() { + let num_entries = shard.num_entries().await.map_err(|e| { + error!("error getting num entries: {}", e); + StatusCode::INTERNAL_SERVER_ERROR + })?; + let size_bytes = shard.size_bytes().await.map_err(|e| { + error!("error getting size bytes: {}", e); + StatusCode::INTERNAL_SERVER_ERROR + })?; + shard_infos.push(ShardInfo { + id: shard.id(), + num_entries, + size_bytes, + }); + } + Ok(( + StatusCode::OK, + Json(InfoResponse { + num_shards: shards.len(), + shards: shard_infos, + }), + )) +} diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index db8f00c..a5e5037 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -1 +1,3 @@ +pub mod get_handler; +pub mod info_handler; pub mod store_handler; diff --git a/src/handlers/store_handler.rs b/src/handlers/store_handler.rs index 85f8688..71d8873 100644 --- a/src/handlers/store_handler.rs +++ b/src/handlers/store_handler.rs @@ -14,7 +14,6 @@ pub async fn store_handler( Extension(shards): Extension, TypedMultipart(request): TypedMultipart, ) -> AxumJsonResultOf { - // compute sha256 of data let sha256 = Sha256::from_bytes(&request.data.contents); let sha256_str = sha256.hex_string(); let shard = shards.shard_for(&sha256); diff --git a/src/main.rs b/src/main.rs index b1930b5..f7a7bfc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,10 @@ mod sha256; mod shard; mod shutdown_signal; use crate::shard::Shards; -use axum::{routing::post, Extension, Router}; +use axum::{ + routing::{get, post}, + Extension, Router, +}; use clap::Parser; use shard::Shard; use std::{error::Error, path::PathBuf}; @@ -85,6 +88,8 @@ fn main() -> Result<(), Box> { async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box> { let app = Router::new() .route("/store", post(handlers::store_handler::store_handler)) + .route("/get/:sha256", get(handlers::get_handler::get_handler)) + .route("/info", get(handlers::info_handler::info_handler)) .layer(Extension(shards)); axum::serve(server, app.into_make_service()) diff --git a/src/sha256.rs b/src/sha256.rs index 86aae24..070b799 100644 --- a/src/sha256.rs +++ b/src/sha256.rs @@ -1,10 +1,39 @@ -use std::fmt::LowerHex; +use std::{ + error::Error, + fmt::{Display, LowerHex}, +}; use sha2::Digest; +#[derive(Debug)] +struct Sha256Error { + message: String, +} + +impl Display for Sha256Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} +impl Error for Sha256Error {} + #[derive(Clone, Copy)] pub struct Sha256([u8; 32]); impl Sha256 { + pub fn from_hex_string(hex: &str) -> Result> { + if hex.len() != 64 { + return Err(Box::new(Sha256Error { + message: "sha256 wrong length".to_owned(), + })); + } + + let mut hash = [0; 32]; + hex::decode_to_slice(hex, &mut hash).map_err(|e| Sha256Error { + message: format!("sha256 decode error: {}", e), + })?; + Ok(Self(hash)) + } + pub fn from_bytes(bytes: &[u8]) -> Self { let hash = sha2::Sha256::digest(bytes); Self(hash.into()) diff --git a/src/shard/mod.rs b/src/shard/mod.rs index 005d1d0..a1cf4eb 100644 --- a/src/shard/mod.rs +++ b/src/shard/mod.rs @@ -6,7 +6,10 @@ use crate::{ use axum::{http::StatusCode, Json}; use rusqlite::{params, OptionalExtension}; -use std::{error::Error, path::Path}; +use std::{ + error::Error, + path::{Path, PathBuf}, +}; use tokio_rusqlite::Connection; use tracing::{debug, error, info}; @@ -30,18 +33,37 @@ impl Shards { } Ok(()) } + + pub fn iter(&self) -> std::slice::Iter<'_, Shard> { + self.0.iter() + } + + pub fn len(&self) -> usize { + self.0.len() + } } #[derive(Clone)] pub struct Shard { id: usize, sqlite: Connection, + file_path: PathBuf, +} + +pub struct GetResult { + pub content_type: String, + pub created_at: UtcDateTime, + pub data: Vec, } impl Shard { - pub async fn open(id: usize, db_path: &Path) -> Result> { - let sqlite = Connection::open(db_path).await?; - let shard = Self { id, sqlite }; + pub async fn open(id: usize, file_path: &Path) -> Result> { + let sqlite = Connection::open(file_path).await?; + let shard = Self { + id, + sqlite, + file_path: file_path.to_owned(), + }; shard.migrate().await?; Ok(shard) } @@ -55,6 +77,12 @@ impl Shard { self.id } + pub async fn size_bytes(&self) -> Result> { + // stat the file to get its size in bytes + let metadata = tokio::fs::metadata(&self.file_path).await?; + Ok(metadata.len()) + } + pub async fn store( &self, store_request: StoreRequestWithSha256, @@ -92,6 +120,34 @@ impl Shard { }) } + pub async fn get(&self, sha256: Sha256) -> Result, Box> { + self.sqlite + .call(move |conn| { + let get_result = conn + .query_row( + "SELECT content_type, created_at, data FROM entries WHERE sha256 = ?", + params![sha256.hex_string()], + |row| { + let content_type = row.get(0)?; + let created_at = parse_created_at_str(row.get(1)?)?; + let data = row.get(2)?; + Ok(GetResult { + content_type, + created_at, + data, + }) + }, + ) + .optional()?; + Ok(get_result) + }) + .await + .map_err(|e| { + error!("get failed: {}", e); + e.into() + }) + } + pub async fn num_entries(&self) -> Result> { get_num_entries(&self.sqlite).await.map_err(|e| e.into()) } @@ -114,10 +170,8 @@ impl Shard { [], |row| { let ver = row.get(0)?; - let created_at_str: String = row.get(1)?; - let created_at = chrono::DateTime::parse_from_rfc3339(&created_at_str).map_err(|e| { - rusqlite::Error::ToSqlConversionFailure(e.into()) - })?.to_utc(); + // let created_at_str: String = row.get(1)?; + let created_at = parse_created_at_str(row.get(1)?)?; Ok((ver, created_at)) } ).optional()?; @@ -125,8 +179,7 @@ impl Shard { if let Some((version, date_time)) = schema_row { debug!( "shard {}: latest schema version: {} @ {}", - shard_id, - version, date_time + shard_id, version, date_time ); if version < 1 { @@ -143,6 +196,12 @@ impl Shard { } } +fn parse_created_at_str(created_at_str: String) -> Result { + let parsed = chrono::DateTime::parse_from_rfc3339(&created_at_str) + .map_err(|e| rusqlite::Error::ToSqlConversionFailure(e.into()))?; + Ok(parsed.to_utc()) +} + fn is_duplicate_entry_err(error: &rusqlite::Error) -> bool { use rusqlite::*; diff --git a/test/test.rb b/test/test.rb index f68a191..4601e34 100644 --- a/test/test.rb +++ b/test/test.rb @@ -1,29 +1,65 @@ require "rest_client" +require "json" FILES = { cat1: -> { File.new("cat1.jpg", mode: "rb") }, } -puts "response without sha256: " -puts RestClient.post("http://localhost:7692/store", { - content_type: "image/jpeg", - data: FILES[:cat1].call, -}) +def run + cat1_sha256 = Digest::SHA256.file("cat1.jpg").hexdigest -puts "response with correct sha256:" -puts RestClient.post("http://localhost:7692/store", { - content_type: "image/jpeg", - sha256: "e3705544cbf2fa93e16107d1821b312a7b825fc177fa28180a9c9a9d3ae8af37", - data: FILES[:cat1].call, -}) - -puts "response with incorrect sha256:" -begin - puts RestClient.post("http://localhost:7692/store", { + puts "store, with sha256" + dump_resp(RestClient.post("http://localhost:7692/store", { content_type: "image/jpeg", - sha256: "123", data: FILES[:cat1].call, - }) -rescue => e - puts e.response + })) + + puts "store, without sha256:" + dump_resp(RestClient.post("http://localhost:7692/store", { + content_type: "image/jpeg", + sha256: cat1_sha256, + data: FILES[:cat1].call, + })) + + puts "store, incorrect sha256:" + begin + RestClient.post("http://localhost:7692/store", { + content_type: "image/jpeg", + sha256: "123", + data: FILES[:cat1].call, + }) + puts "should have thrown!" + rescue => e + dump_resp(e.response) + end + + puts "get, with sha256:" + dump_resp(RestClient.get("http://localhost:7692/get/#{cat1_sha256}")) + + puts "get, 404 sha256:" + begin + RestClient.get("http://localhost:7692/get/e3705544cbf2fa93e16107d1821b312a7b825fc177fa28180a9c9a9d3ae8af3c") + raise "should have thrown!" + rescue => e + dump_resp(e.response) + raise "not 404" if e.response.code != 404 + end end + +def dump_resp(resp) + puts " -> code: #{resp.code}" + headers = resp.headers + content_type = headers[:content_type] + puts " -> headers: #{headers}" + puts " -> content_type: #{content_type}" + puts " -> size: #{resp.size} bytes" + if content_type == "application/json" + puts " -> body: #{JSON.parse(resp.body)}" + else + body_sha256 = Digest::SHA256.hexdigest(resp.body) + puts " -> body sha256: #{body_sha256}" + end + puts "-" * 80 +end + +run