diff --git a/src/handlers/get_handler.rs b/src/handlers/get_handler.rs index 77bf0ec..47e3caa 100644 --- a/src/handlers/get_handler.rs +++ b/src/handlers/get_handler.rs @@ -1,89 +1,202 @@ -use std::collections::HashMap; - +use crate::{sha256::Sha256, shard::GetResult, shards::Shards}; use axum::{ extract::Path, - http::{header, HeaderMap, HeaderName, HeaderValue, StatusCode}, + http::{header, HeaderMap, HeaderValue, StatusCode}, + response::IntoResponse, Extension, Json, }; +use std::{collections::HashMap, error::Error}; -use crate::shards::Shards; +pub enum GetResponse { + MissingSha256, + InvalidSha256 { message: String }, + InternalError { error: Box }, + NotFound, + Found { get_result: GetResult }, +} -#[derive(Debug, serde::Serialize)] -pub struct GetError { - sha256: Option, - message: String, +impl From for GetResponse { + fn from(get_result: GetResult) -> Self { + GetResponse::Found { get_result } + } +} + +impl IntoResponse for GetResponse { + fn into_response(self) -> axum::response::Response { + match self { + GetResponse::MissingSha256 => ( + StatusCode::BAD_REQUEST, + Json(HashMap::from([("status", "missing_sha256")])), + ) + .into_response(), + GetResponse::InvalidSha256 { message } => ( + StatusCode::BAD_REQUEST, + Json(HashMap::from([ + ("status", "invalid_sha256".to_owned()), + ("message", message), + ])), + ) + .into_response(), + GetResponse::InternalError { error } => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(HashMap::from([ + ("status", "internal_error".to_owned()), + ("message", error.to_string()), + ])), + ) + .into_response(), + GetResponse::NotFound => ( + StatusCode::NOT_FOUND, + Json(HashMap::from([("status", "not_found")])), + ) + .into_response(), + GetResponse::Found { get_result } => make_found_response(get_result).into_response(), + } + } +} + +fn make_found_response( + GetResult { + sha256, + content_type, + created_at, + stored_size, + data, + }: GetResult, +) -> impl IntoResponse { + let content_type = match HeaderValue::from_str(&content_type) { + Ok(content_type) => content_type, + Err(e) => return GetResponse::from(e).into_response(), + }; + + let created_at = match HeaderValue::from_str(&created_at.to_rfc3339()) { + Ok(created_at) => created_at, + Err(e) => return GetResponse::from(e).into_response(), + }; + + let stored_size = match HeaderValue::from_str(&stored_size.to_string()) { + Ok(stored_size) => stored_size, + Err(e) => return GetResponse::from(e).into_response(), + }; + + let mut headers = HeaderMap::new(); + headers.insert(header::CONTENT_TYPE, content_type); + headers.insert( + header::CACHE_CONTROL, + HeaderValue::from_static("public, max-age=31536000"), + ); + headers.insert( + header::ETAG, + HeaderValue::from_str(&sha256.hex_string()).unwrap(), + ); + headers.insert(header::HeaderName::from_static("x-stored-at"), created_at); + headers.insert( + header::HeaderName::from_static("x-stored-size"), + stored_size, + ); + + (StatusCode::OK, headers, data).into_response() +} + +impl>> From for GetResponse { + fn from(error: E) -> Self { + GetResponse::InternalError { + error: error.into(), + } + } } #[axum::debug_handler] pub async fn get_handler( Path(params): Path>, Extension(shards): Extension, -) -> Result<(StatusCode, HeaderMap, Vec), (StatusCode, Json)> { +) -> GetResponse { let sha256_str = match params.get("sha256") { Some(sha256_str) => sha256_str.clone(), None => { - return Err(( - StatusCode::BAD_REQUEST, - Json(GetError { - sha256: None, - message: "missing sha256 parameter".to_owned(), - }), - )); + return GetResponse::MissingSha256; } }; - let sha256 = crate::sha256::Sha256::from_hex_string(&sha256_str).map_err(|e| { - ( - StatusCode::BAD_REQUEST, - Json(GetError { - sha256: Some(sha256_str), + let sha256 = match Sha256::from_hex_string(&sha256_str) { + Ok(sha256) => sha256, + Err(e) => { + return GetResponse::InvalidSha256 { message: e.to_string(), - }), - ) - })?; - - let internal_error = |message| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(GetError { - sha256: Some(sha256.hex_string()), - message, - }), - ) + }; + } }; let shard = shards.shard_for(&sha256); - let response = shard - .get(sha256) - .await - .map_err(|e| internal_error(e.to_string()))?; + let get_result = match shard.get(sha256).await { + Ok(get_result) => get_result, + Err(e) => return e.into(), + }; - let sha256_str = sha256.hex_string(); - match response { - Some(response) => { - let content_type = HeaderValue::from_str(&response.content_type) - .map_err(|e| internal_error(e.to_string()))?; - let created_at = HeaderValue::from_str(&response.created_at.to_rfc3339()) - .map_err(|e| internal_error(e.to_string()))?; - - let mut headers = HeaderMap::new(); - headers.insert(header::CONTENT_TYPE, content_type); - headers.insert( - header::CACHE_CONTROL, - HeaderValue::from_static("public, max-age=31536000"), - ); - headers.insert(header::ETAG, HeaderValue::from_str(&sha256_str).unwrap()); - headers.insert(HeaderName::from_static("x-stored-at"), created_at); - - Ok((StatusCode::OK, headers, response.data)) - } - None => Err(( - StatusCode::NOT_FOUND, - GetError { - sha256: Some(sha256_str), - message: "not found".to_owned(), - } - .into(), - )), + match get_result { + None => GetResponse::NotFound, + Some(result) => result.into(), + } +} + +#[cfg(test)] +mod test { + use crate::{sha256::Sha256, shard::GetResult, shards::test::make_shards}; + use axum::{extract::Path, response::IntoResponse, Extension}; + use std::collections::HashMap; + + use super::GetResponse; + + #[tokio::test] + async fn test_get_invalid_sha256() { + let shards = Extension(make_shards().await); + let response = super::get_handler(Path(HashMap::new()), shards.clone()).await; + assert!(matches!(response, super::GetResponse::MissingSha256 { .. })); + + let shards = Extension(make_shards().await); + let response = super::get_handler( + Path(HashMap::from([(String::from("sha256"), String::from(""))])), + shards.clone(), + ) + .await; + assert!(matches!(response, super::GetResponse::InvalidSha256 { .. })); + + let response = super::get_handler( + Path(HashMap::from([( + String::from("sha256"), + String::from("invalid"), + )])), + shards.clone(), + ) + .await; + assert!(matches!(response, super::GetResponse::InvalidSha256 { .. })); + } + + #[test] + fn test_get_response_found_into_response() { + let data = "hello, world!"; + let sha256 = Sha256::from_bytes(data.as_bytes()); + let sha256_str = sha256.hex_string(); + let created_at = "2022-03-04T08:12:34+00:00"; + let response = GetResponse::Found { + get_result: GetResult { + sha256, + content_type: "text/plain".to_string(), + stored_size: 12345, + created_at: chrono::DateTime::parse_from_rfc3339(created_at) + .unwrap() + .to_utc(), + data: data.into(), + }, + } + .into_response(); + assert_eq!(response.status(), 200); + assert_eq!( + response.headers().get("content-type").unwrap(), + "text/plain" + ); + assert_eq!(response.headers().get("etag").unwrap(), &sha256_str); + assert_eq!(response.headers().get("x-stored-size").unwrap(), "12345"); + assert_eq!(response.headers().get("x-stored-at").unwrap(), created_at); } } diff --git a/src/handlers/store_handler.rs b/src/handlers/store_handler.rs index 89be049..123a500 100644 --- a/src/handlers/store_handler.rs +++ b/src/handlers/store_handler.rs @@ -103,16 +103,12 @@ pub async fn store_handler( } #[cfg(test)] -mod test { +pub mod test { use super::*; - use crate::shard::test::make_shard; + use crate::shards::test::make_shards; use axum::body::Bytes; use axum_typed_multipart::FieldData; - async fn make_shards() -> Shards { - Shards::new(vec![make_shard().await]).unwrap() - } - async fn send_request(sha256: Option, data: Bytes) -> StoreResponse { store_handler( Extension(make_shards().await), @@ -137,7 +133,7 @@ mod test { #[tokio::test] async fn test_store_handler_mismatched_sha256() { - let not_hello_world = Sha256::from_bytes("not hello, world!".as_bytes()); + let not_hello_world = Sha256::from_bytes("goodbye, planet!".as_bytes()); let hello_world = Sha256::from_bytes("hello, world!".as_bytes()); let result = send_request(Some(not_hello_world), "hello, world!".as_bytes().into()).await; assert_eq!(result.status_code(), StatusCode::BAD_REQUEST); diff --git a/src/main.rs b/src/main.rs index cb00039..616a81f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,8 @@ mod sha256; mod shard; mod shards; mod shutdown_signal; + +use crate::shards::Shards; use axum::{ routing::{get, post}, Extension, Router, @@ -14,8 +16,6 @@ use tokio::net::TcpListener; use tokio_rusqlite::Connection; use tracing::info; -use crate::shards::Shards; - #[derive(Parser, Debug)] #[command(version, about, long_about = None)] struct Args { diff --git a/src/shard/fn_get.rs b/src/shard/fn_get.rs index c69a9eb..a020c91 100644 --- a/src/shard/fn_get.rs +++ b/src/shard/fn_get.rs @@ -1,5 +1,13 @@ use super::*; +pub struct GetResult { + pub sha256: Sha256, + pub content_type: String, + pub stored_size: usize, + pub created_at: UtcDateTime, + pub data: Vec, +} + impl Shard { pub async fn get(&self, sha256: Sha256) -> Result, Box> { self.conn @@ -25,6 +33,7 @@ fn get_impl( let created_at = parse_created_at_str(row.get(2)?)?; let data = row.get(3)?; Ok(GetResult { + sha256, content_type, stored_size, created_at, @@ -34,3 +43,9 @@ fn get_impl( ) .optional() } + +fn parse_created_at_str(created_at_str: String) -> Result { + let parsed = chrono::DateTime::parse_from_rfc3339(&created_at_str) + .map_err(|e| rusqlite::Error::ToSqlConversionFailure(e.into()))?; + Ok(parsed.to_utc()) +} diff --git a/src/shard/mod.rs b/src/shard/mod.rs index be3c7c4..2b4fd6b 100644 --- a/src/shard/mod.rs +++ b/src/shard/mod.rs @@ -3,10 +3,11 @@ mod fn_migrate; mod fn_store; pub mod shard_error; -use crate::{sha256::Sha256, shard::shard_error::ShardError}; -use axum::body::Bytes; +pub use fn_get::GetResult; pub use fn_store::StoreResult; +use crate::{sha256::Sha256, shard::shard_error::ShardError}; +use axum::body::Bytes; use rusqlite::{params, types::FromSql, OptionalExtension}; use std::error::Error; use tokio_rusqlite::Connection; @@ -20,13 +21,6 @@ pub struct Shard { conn: Connection, } -pub struct GetResult { - pub content_type: String, - pub stored_size: usize, - pub created_at: UtcDateTime, - pub data: Vec, -} - impl Shard { pub async fn open(id: usize, conn: Connection) -> Result> { let shard = Self { id, conn }; @@ -67,12 +61,6 @@ impl Shard { } } -fn parse_created_at_str(created_at_str: String) -> Result { - let parsed = chrono::DateTime::parse_from_rfc3339(&created_at_str) - .map_err(|e| rusqlite::Error::ToSqlConversionFailure(e.into()))?; - Ok(parsed.to_utc()) -} - async fn get_num_entries(conn: &Connection) -> Result { conn.call(|conn| { let count: usize = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?; @@ -83,6 +71,7 @@ async fn get_num_entries(conn: &Connection) -> Result super::Shard { @@ -123,7 +112,7 @@ pub mod test { .unwrap(); assert_eq!( store_result, - super::StoreResult::Created { + StoreResult::Created { data_size: data.len(), stored_size: data.len() } @@ -148,7 +137,7 @@ pub mod test { .unwrap(); assert_eq!( store_result, - super::StoreResult::Created { + StoreResult::Created { data_size: data.len(), stored_size: data.len() } @@ -160,7 +149,7 @@ pub mod test { .unwrap(); assert_eq!( store_result, - super::StoreResult::Exists { + StoreResult::Exists { data_size: data.len(), stored_size: data.len() } diff --git a/src/shards.rs b/src/shards.rs index 1d0ad8e..cb53e9b 100644 --- a/src/shards.rs +++ b/src/shards.rs @@ -1,6 +1,5 @@ -use std::error::Error; - use crate::{sha256::Sha256, shard::Shard}; +use std::error::Error; #[derive(Clone)] pub struct Shards(Vec); @@ -32,3 +31,14 @@ impl Shards { self.0.len() } } + +#[cfg(test)] +pub mod test { + use crate::shard::test::make_shard; + + use super::Shards; + + pub async fn make_shards() -> Shards { + Shards::new(vec![make_shard().await]).unwrap() + } +}