tests for get_handler

This commit is contained in:
Dylan Knutson
2024-04-25 18:24:09 -07:00
parent e63403b4ee
commit 12f791af75
6 changed files with 216 additions and 93 deletions

View File

@@ -1,89 +1,202 @@
use std::collections::HashMap; use crate::{sha256::Sha256, shard::GetResult, shards::Shards};
use axum::{ use axum::{
extract::Path, extract::Path,
http::{header, HeaderMap, HeaderName, HeaderValue, StatusCode}, http::{header, HeaderMap, HeaderValue, StatusCode},
response::IntoResponse,
Extension, Json, Extension, Json,
}; };
use std::{collections::HashMap, error::Error};
use crate::shards::Shards; pub enum GetResponse {
MissingSha256,
InvalidSha256 { message: String },
InternalError { error: Box<dyn Error> },
NotFound,
Found { get_result: GetResult },
}
#[derive(Debug, serde::Serialize)] impl From<GetResult> for GetResponse {
pub struct GetError { fn from(get_result: GetResult) -> Self {
sha256: Option<String>, GetResponse::Found { get_result }
message: String, }
}
impl IntoResponse for GetResponse {
fn into_response(self) -> axum::response::Response {
match self {
GetResponse::MissingSha256 => (
StatusCode::BAD_REQUEST,
Json(HashMap::from([("status", "missing_sha256")])),
)
.into_response(),
GetResponse::InvalidSha256 { message } => (
StatusCode::BAD_REQUEST,
Json(HashMap::from([
("status", "invalid_sha256".to_owned()),
("message", message),
])),
)
.into_response(),
GetResponse::InternalError { error } => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(HashMap::from([
("status", "internal_error".to_owned()),
("message", error.to_string()),
])),
)
.into_response(),
GetResponse::NotFound => (
StatusCode::NOT_FOUND,
Json(HashMap::from([("status", "not_found")])),
)
.into_response(),
GetResponse::Found { get_result } => make_found_response(get_result).into_response(),
}
}
}
fn make_found_response(
GetResult {
sha256,
content_type,
created_at,
stored_size,
data,
}: GetResult,
) -> impl IntoResponse {
let content_type = match HeaderValue::from_str(&content_type) {
Ok(content_type) => content_type,
Err(e) => return GetResponse::from(e).into_response(),
};
let created_at = match HeaderValue::from_str(&created_at.to_rfc3339()) {
Ok(created_at) => created_at,
Err(e) => return GetResponse::from(e).into_response(),
};
let stored_size = match HeaderValue::from_str(&stored_size.to_string()) {
Ok(stored_size) => stored_size,
Err(e) => return GetResponse::from(e).into_response(),
};
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, content_type);
headers.insert(
header::CACHE_CONTROL,
HeaderValue::from_static("public, max-age=31536000"),
);
headers.insert(
header::ETAG,
HeaderValue::from_str(&sha256.hex_string()).unwrap(),
);
headers.insert(header::HeaderName::from_static("x-stored-at"), created_at);
headers.insert(
header::HeaderName::from_static("x-stored-size"),
stored_size,
);
(StatusCode::OK, headers, data).into_response()
}
impl<E: Into<Box<dyn Error>>> From<E> for GetResponse {
fn from(error: E) -> Self {
GetResponse::InternalError {
error: error.into(),
}
}
} }
#[axum::debug_handler] #[axum::debug_handler]
pub async fn get_handler( pub async fn get_handler(
Path(params): Path<HashMap<String, String>>, Path(params): Path<HashMap<String, String>>,
Extension(shards): Extension<Shards>, Extension(shards): Extension<Shards>,
) -> Result<(StatusCode, HeaderMap, Vec<u8>), (StatusCode, Json<GetError>)> { ) -> GetResponse {
let sha256_str = match params.get("sha256") { let sha256_str = match params.get("sha256") {
Some(sha256_str) => sha256_str.clone(), Some(sha256_str) => sha256_str.clone(),
None => { None => {
return Err(( return GetResponse::MissingSha256;
StatusCode::BAD_REQUEST,
Json(GetError {
sha256: None,
message: "missing sha256 parameter".to_owned(),
}),
));
} }
}; };
let sha256 = crate::sha256::Sha256::from_hex_string(&sha256_str).map_err(|e| { let sha256 = match Sha256::from_hex_string(&sha256_str) {
( Ok(sha256) => sha256,
StatusCode::BAD_REQUEST, Err(e) => {
Json(GetError { return GetResponse::InvalidSha256 {
sha256: Some(sha256_str),
message: e.to_string(), message: e.to_string(),
}), };
) }
})?;
let internal_error = |message| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(GetError {
sha256: Some(sha256.hex_string()),
message,
}),
)
}; };
let shard = shards.shard_for(&sha256); let shard = shards.shard_for(&sha256);
let response = shard let get_result = match shard.get(sha256).await {
.get(sha256) Ok(get_result) => get_result,
.await Err(e) => return e.into(),
.map_err(|e| internal_error(e.to_string()))?; };
let sha256_str = sha256.hex_string(); match get_result {
match response { None => GetResponse::NotFound,
Some(response) => { Some(result) => result.into(),
let content_type = HeaderValue::from_str(&response.content_type) }
.map_err(|e| internal_error(e.to_string()))?; }
let created_at = HeaderValue::from_str(&response.created_at.to_rfc3339())
.map_err(|e| internal_error(e.to_string()))?; #[cfg(test)]
mod test {
let mut headers = HeaderMap::new(); use crate::{sha256::Sha256, shard::GetResult, shards::test::make_shards};
headers.insert(header::CONTENT_TYPE, content_type); use axum::{extract::Path, response::IntoResponse, Extension};
headers.insert( use std::collections::HashMap;
header::CACHE_CONTROL,
HeaderValue::from_static("public, max-age=31536000"), use super::GetResponse;
);
headers.insert(header::ETAG, HeaderValue::from_str(&sha256_str).unwrap()); #[tokio::test]
headers.insert(HeaderName::from_static("x-stored-at"), created_at); async fn test_get_invalid_sha256() {
let shards = Extension(make_shards().await);
Ok((StatusCode::OK, headers, response.data)) let response = super::get_handler(Path(HashMap::new()), shards.clone()).await;
} assert!(matches!(response, super::GetResponse::MissingSha256 { .. }));
None => Err((
StatusCode::NOT_FOUND, let shards = Extension(make_shards().await);
GetError { let response = super::get_handler(
sha256: Some(sha256_str), Path(HashMap::from([(String::from("sha256"), String::from(""))])),
message: "not found".to_owned(), shards.clone(),
} )
.into(), .await;
)), assert!(matches!(response, super::GetResponse::InvalidSha256 { .. }));
let response = super::get_handler(
Path(HashMap::from([(
String::from("sha256"),
String::from("invalid"),
)])),
shards.clone(),
)
.await;
assert!(matches!(response, super::GetResponse::InvalidSha256 { .. }));
}
#[test]
fn test_get_response_found_into_response() {
let data = "hello, world!";
let sha256 = Sha256::from_bytes(data.as_bytes());
let sha256_str = sha256.hex_string();
let created_at = "2022-03-04T08:12:34+00:00";
let response = GetResponse::Found {
get_result: GetResult {
sha256,
content_type: "text/plain".to_string(),
stored_size: 12345,
created_at: chrono::DateTime::parse_from_rfc3339(created_at)
.unwrap()
.to_utc(),
data: data.into(),
},
}
.into_response();
assert_eq!(response.status(), 200);
assert_eq!(
response.headers().get("content-type").unwrap(),
"text/plain"
);
assert_eq!(response.headers().get("etag").unwrap(), &sha256_str);
assert_eq!(response.headers().get("x-stored-size").unwrap(), "12345");
assert_eq!(response.headers().get("x-stored-at").unwrap(), created_at);
} }
} }

View File

@@ -103,16 +103,12 @@ pub async fn store_handler(
} }
#[cfg(test)] #[cfg(test)]
mod test { pub mod test {
use super::*; use super::*;
use crate::shard::test::make_shard; use crate::shards::test::make_shards;
use axum::body::Bytes; use axum::body::Bytes;
use axum_typed_multipart::FieldData; use axum_typed_multipart::FieldData;
async fn make_shards() -> Shards {
Shards::new(vec![make_shard().await]).unwrap()
}
async fn send_request(sha256: Option<Sha256>, data: Bytes) -> StoreResponse { async fn send_request(sha256: Option<Sha256>, data: Bytes) -> StoreResponse {
store_handler( store_handler(
Extension(make_shards().await), Extension(make_shards().await),
@@ -137,7 +133,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn test_store_handler_mismatched_sha256() { async fn test_store_handler_mismatched_sha256() {
let not_hello_world = Sha256::from_bytes("not hello, world!".as_bytes()); let not_hello_world = Sha256::from_bytes("goodbye, planet!".as_bytes());
let hello_world = Sha256::from_bytes("hello, world!".as_bytes()); let hello_world = Sha256::from_bytes("hello, world!".as_bytes());
let result = send_request(Some(not_hello_world), "hello, world!".as_bytes().into()).await; let result = send_request(Some(not_hello_world), "hello, world!".as_bytes().into()).await;
assert_eq!(result.status_code(), StatusCode::BAD_REQUEST); assert_eq!(result.status_code(), StatusCode::BAD_REQUEST);

View File

@@ -3,6 +3,8 @@ mod sha256;
mod shard; mod shard;
mod shards; mod shards;
mod shutdown_signal; mod shutdown_signal;
use crate::shards::Shards;
use axum::{ use axum::{
routing::{get, post}, routing::{get, post},
Extension, Router, Extension, Router,
@@ -14,8 +16,6 @@ use tokio::net::TcpListener;
use tokio_rusqlite::Connection; use tokio_rusqlite::Connection;
use tracing::info; use tracing::info;
use crate::shards::Shards;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(version, about, long_about = None)] #[command(version, about, long_about = None)]
struct Args { struct Args {

View File

@@ -1,5 +1,13 @@
use super::*; use super::*;
pub struct GetResult {
pub sha256: Sha256,
pub content_type: String,
pub stored_size: usize,
pub created_at: UtcDateTime,
pub data: Vec<u8>,
}
impl Shard { impl Shard {
pub async fn get(&self, sha256: Sha256) -> Result<Option<GetResult>, Box<dyn Error>> { pub async fn get(&self, sha256: Sha256) -> Result<Option<GetResult>, Box<dyn Error>> {
self.conn self.conn
@@ -25,6 +33,7 @@ fn get_impl(
let created_at = parse_created_at_str(row.get(2)?)?; let created_at = parse_created_at_str(row.get(2)?)?;
let data = row.get(3)?; let data = row.get(3)?;
Ok(GetResult { Ok(GetResult {
sha256,
content_type, content_type,
stored_size, stored_size,
created_at, created_at,
@@ -34,3 +43,9 @@ fn get_impl(
) )
.optional() .optional()
} }
fn parse_created_at_str(created_at_str: String) -> Result<UtcDateTime, rusqlite::Error> {
let parsed = chrono::DateTime::parse_from_rfc3339(&created_at_str)
.map_err(|e| rusqlite::Error::ToSqlConversionFailure(e.into()))?;
Ok(parsed.to_utc())
}

View File

@@ -3,10 +3,11 @@ mod fn_migrate;
mod fn_store; mod fn_store;
pub mod shard_error; pub mod shard_error;
use crate::{sha256::Sha256, shard::shard_error::ShardError}; pub use fn_get::GetResult;
use axum::body::Bytes;
pub use fn_store::StoreResult; pub use fn_store::StoreResult;
use crate::{sha256::Sha256, shard::shard_error::ShardError};
use axum::body::Bytes;
use rusqlite::{params, types::FromSql, OptionalExtension}; use rusqlite::{params, types::FromSql, OptionalExtension};
use std::error::Error; use std::error::Error;
use tokio_rusqlite::Connection; use tokio_rusqlite::Connection;
@@ -20,13 +21,6 @@ pub struct Shard {
conn: Connection, conn: Connection,
} }
pub struct GetResult {
pub content_type: String,
pub stored_size: usize,
pub created_at: UtcDateTime,
pub data: Vec<u8>,
}
impl Shard { impl Shard {
pub async fn open(id: usize, conn: Connection) -> Result<Self, Box<dyn Error>> { pub async fn open(id: usize, conn: Connection) -> Result<Self, Box<dyn Error>> {
let shard = Self { id, conn }; let shard = Self { id, conn };
@@ -67,12 +61,6 @@ impl Shard {
} }
} }
fn parse_created_at_str(created_at_str: String) -> Result<UtcDateTime, rusqlite::Error> {
let parsed = chrono::DateTime::parse_from_rfc3339(&created_at_str)
.map_err(|e| rusqlite::Error::ToSqlConversionFailure(e.into()))?;
Ok(parsed.to_utc())
}
async fn get_num_entries(conn: &Connection) -> Result<usize, tokio_rusqlite::Error> { async fn get_num_entries(conn: &Connection) -> Result<usize, tokio_rusqlite::Error> {
conn.call(|conn| { conn.call(|conn| {
let count: usize = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?; let count: usize = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?;
@@ -83,6 +71,7 @@ async fn get_num_entries(conn: &Connection) -> Result<usize, tokio_rusqlite::Err
#[cfg(test)] #[cfg(test)]
pub mod test { pub mod test {
use super::StoreResult;
use crate::sha256::Sha256; use crate::sha256::Sha256;
pub async fn make_shard() -> super::Shard { pub async fn make_shard() -> super::Shard {
@@ -123,7 +112,7 @@ pub mod test {
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
store_result, store_result,
super::StoreResult::Created { StoreResult::Created {
data_size: data.len(), data_size: data.len(),
stored_size: data.len() stored_size: data.len()
} }
@@ -148,7 +137,7 @@ pub mod test {
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
store_result, store_result,
super::StoreResult::Created { StoreResult::Created {
data_size: data.len(), data_size: data.len(),
stored_size: data.len() stored_size: data.len()
} }
@@ -160,7 +149,7 @@ pub mod test {
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
store_result, store_result,
super::StoreResult::Exists { StoreResult::Exists {
data_size: data.len(), data_size: data.len(),
stored_size: data.len() stored_size: data.len()
} }

View File

@@ -1,6 +1,5 @@
use std::error::Error;
use crate::{sha256::Sha256, shard::Shard}; use crate::{sha256::Sha256, shard::Shard};
use std::error::Error;
#[derive(Clone)] #[derive(Clone)]
pub struct Shards(Vec<Shard>); pub struct Shards(Vec<Shard>);
@@ -32,3 +31,14 @@ impl Shards {
self.0.len() self.0.len()
} }
} }
#[cfg(test)]
pub mod test {
use crate::shard::test::make_shard;
use super::Shards;
pub async fn make_shards() -> Shards {
Shards::new(vec![make_shard().await]).unwrap()
}
}