From 34e46ed020b9d7f3f3b6a7526c89445e19241f90 Mon Sep 17 00:00:00 2001 From: Dylan Knutson Date: Fri, 26 Apr 2024 11:42:42 -0700 Subject: [PATCH] add zstd compression, tests for store/get --- src/handlers/store_handler.rs | 89 +++++++++++++++++--- src/main.rs | 16 +++- src/sha256.rs | 2 +- src/shard/fn_get.rs | 72 ++++++++++------ src/shard/fn_store.rs | 86 ++++++++++++++----- src/shard/mod.rs | 153 +++++++++++++++++++++++++++++----- src/shards.rs | 8 +- 7 files changed, 342 insertions(+), 84 deletions(-) diff --git a/src/handlers/store_handler.rs b/src/handlers/store_handler.rs index 123a500..3048409 100644 --- a/src/handlers/store_handler.rs +++ b/src/handlers/store_handler.rs @@ -1,4 +1,8 @@ -use crate::{sha256::Sha256, shard::StoreResult, shards::Shards}; +use crate::{ + sha256::Sha256, + shard::{StoreArgs, StoreResult}, + shards::Shards, +}; use axum::{body::Bytes, response::IntoResponse, Extension, Json}; use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; @@ -19,10 +23,12 @@ pub enum StoreResponse { Created { stored_size: usize, data_size: usize, + created_at: String, }, Exists { stored_size: usize, data_size: usize, + created_at: String, }, Sha256Mismatch { expected_sha256: String, @@ -49,16 +55,20 @@ impl From for StoreResponse { StoreResult::Created { stored_size, data_size, + created_at, } => StoreResponse::Created { stored_size, data_size, + created_at: created_at.to_rfc3339(), }, StoreResult::Exists { stored_size, data_size, + created_at, } => StoreResponse::Exists { stored_size, data_size, + created_at: created_at.to_rfc3339(), }, } } @@ -92,7 +102,11 @@ pub async fn store_handler( } } match shard - .store(sha256, request.content_type, request.data.contents) + .store(StoreArgs { + sha256, + content_type: request.content_type, + data: request.data.contents, + }) .await { Ok(store_result) => store_result.into(), @@ -105,19 +119,25 @@ pub async fn store_handler( #[cfg(test)] pub mod test { use super::*; - use crate::shards::test::make_shards; + use crate::{shards::test::make_shards_with_compression, UseCompression}; use axum::body::Bytes; use axum_typed_multipart::FieldData; + use rstest::rstest; - async fn send_request(sha256: Option, data: Bytes) -> StoreResponse { + async fn send_request>( + sha256: Option, + content_type: &str, + use_compression: UseCompression, + data: D, + ) -> StoreResponse { store_handler( - Extension(make_shards().await), + Extension(make_shards_with_compression(use_compression).await), TypedMultipart(StoreRequest { sha256: sha256.map(|s| s.hex_string()), - content_type: "text/plain".to_string(), + content_type: content_type.to_string(), data: FieldData { metadata: Default::default(), - contents: data, + contents: data.into(), }, }), ) @@ -126,7 +146,7 @@ pub mod test { #[tokio::test] async fn test_store_handler() { - let result = send_request(None, "hello, world!".as_bytes().into()).await; + let result = send_request(None, "text/plain", UseCompression::Auto, "hello, world!").await; assert_eq!(result.status_code(), StatusCode::CREATED); assert!(matches!(result, StoreResponse::Created { .. })); } @@ -135,7 +155,13 @@ pub mod test { async fn test_store_handler_mismatched_sha256() { let not_hello_world = Sha256::from_bytes("goodbye, planet!".as_bytes()); let hello_world = Sha256::from_bytes("hello, world!".as_bytes()); - let result = send_request(Some(not_hello_world), "hello, world!".as_bytes().into()).await; + let result = send_request( + Some(not_hello_world), + "text/plain", + UseCompression::Auto, + "hello, world!", + ) + .await; assert_eq!(result.status_code(), StatusCode::BAD_REQUEST); assert_eq!( result, @@ -148,8 +174,51 @@ pub mod test { #[tokio::test] async fn test_store_handler_matching_sha256() { let hello_world = Sha256::from_bytes("hello, world!".as_bytes()); - let result = send_request(Some(hello_world), "hello, world!".as_bytes().into()).await; + let result = send_request( + Some(hello_world), + "text/plain", + UseCompression::Auto, + "hello, world!", + ) + .await; assert_eq!(result.status_code(), StatusCode::CREATED); assert!(matches!(result, StoreResponse::Created { .. })); } + + fn make_assert_eq(value: T) -> impl Fn(T) { + move |actual| assert_eq!(actual, value) + } + fn make_assert_lt(value: T) -> impl Fn(T) { + move |actual| assert!(actual < value) + } + + #[rstest] + // textual should be compressed by default + #[case("text/plain", UseCompression::Auto, make_assert_lt(1024))] + #[case("text/plain", UseCompression::Zstd, make_assert_lt(1024))] + #[case("text/plain", UseCompression::None, make_assert_eq(1024))] + // images, etc should not be compressed by default + #[case("image/jpg", UseCompression::Auto, make_assert_eq(1024))] + #[case("image/jpg", UseCompression::Zstd, make_assert_lt(1024))] + #[case("image/jpg", UseCompression::None, make_assert_eq(1024))] + #[tokio::test] + async fn test_compressible_data( + #[case] content_type: &str, + #[case] use_compression: UseCompression, + #[case] assert_stored_size: F, + ) { + let result = send_request(None, content_type, use_compression, vec![0; 1024]).await; + assert_eq!(result.status_code(), StatusCode::CREATED); + match result { + StoreResponse::Created { + stored_size, + data_size, + .. + } => { + assert_stored_size(stored_size); + assert_eq!(data_size, 1024); + } + _ => panic!("expected StoreResponse::Created"), + }; + } } diff --git a/src/main.rs b/src/main.rs index 616a81f..b1a9871 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,7 @@ use axum::{ routing::{get, post}, Extension, Router, }; -use clap::Parser; +use clap::{Parser, ValueEnum}; use shard::Shard; use std::{error::Error, path::PathBuf}; use tokio::net::TcpListener; @@ -34,6 +34,18 @@ struct Args { /// Number of shards #[arg(short, long)] shards: Option, + + /// How to compress stored data + #[arg(short, long, default_value = "auto")] + compression: UseCompression, +} + +#[derive(Default, PartialEq, Debug, Copy, Clone, ValueEnum)] +pub enum UseCompression { + #[default] + Auto, + None, + Zstd, } #[derive(Debug, serde::Deserialize, serde::Serialize)] @@ -67,7 +79,7 @@ fn main() -> Result<(), Box> { for shard_id in 0..num_shards { let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id)); let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?; - let shard = Shard::open(shard_id, shard_sqlite_conn).await?; + let shard = Shard::open(shard_id, UseCompression::Auto, shard_sqlite_conn).await?; info!( "shard {} has {} entries", shard.id(), diff --git a/src/sha256.rs b/src/sha256.rs index 0ca39ec..74771f1 100644 --- a/src/sha256.rs +++ b/src/sha256.rs @@ -17,7 +17,7 @@ impl Display for Sha256Error { } impl Error for Sha256Error {} -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, PartialEq, Eq, Default)] pub struct Sha256([u8; 32]); impl Sha256 { pub fn from_hex_string(hex: &str) -> Result> { diff --git a/src/shard/fn_get.rs b/src/shard/fn_get.rs index a020c91..48f4e16 100644 --- a/src/shard/fn_get.rs +++ b/src/shard/fn_get.rs @@ -1,3 +1,5 @@ +use std::io::Read; + use super::*; pub struct GetResult { @@ -11,7 +13,7 @@ pub struct GetResult { impl Shard { pub async fn get(&self, sha256: Sha256) -> Result, Box> { self.conn - .call(move |conn| get_impl(conn, sha256).map_err(|e| e.into())) + .call(move |conn| get_impl(conn, sha256)) .await .map_err(|e| { error!("get failed: {}", e); @@ -23,29 +25,49 @@ impl Shard { fn get_impl( conn: &mut rusqlite::Connection, sha256: Sha256, -) -> Result, rusqlite::Error> { - conn.query_row( - "SELECT content_type, compressed_size, created_at, data FROM entries WHERE sha256 = ?", - params![sha256.hex_string()], - |row| { - let content_type = row.get(0)?; - let stored_size = row.get(1)?; - let created_at = parse_created_at_str(row.get(2)?)?; - let data = row.get(3)?; - Ok(GetResult { - sha256, - content_type, - stored_size, - created_at, - data, - }) - }, - ) - .optional() -} +) -> Result, tokio_rusqlite::Error> { + let maybe_row = conn + .query_row( + "SELECT content_type, compressed_size, created_at, compression, data + FROM entries + WHERE sha256 = ?", + params![sha256.hex_string()], + |row| { + let content_type = row.get(0)?; + let stored_size = row.get(1)?; + let created_at = parse_created_at_str(row.get(2)?)?; + let compression = row.get(3)?; + let data: Vec = row.get(4)?; + Ok((content_type, stored_size, created_at, compression, data)) + }, + ) + .optional() + .map_err(into_tokio_rusqlite_err)?; -fn parse_created_at_str(created_at_str: String) -> Result { - let parsed = chrono::DateTime::parse_from_rfc3339(&created_at_str) - .map_err(|e| rusqlite::Error::ToSqlConversionFailure(e.into()))?; - Ok(parsed.to_utc()) + let row = match maybe_row { + Some(row) => row, + None => return Ok(None), + }; + + let (content_type, stored_size, created_at, compression, data) = row; + let data = match compression { + Compression::None => data, + Compression::Zstd => { + let mut decoder = + zstd::Decoder::new(data.as_slice()).map_err(into_tokio_rusqlite_err)?; + let mut decompressed = vec![]; + decoder + .read_to_end(&mut decompressed) + .map_err(into_tokio_rusqlite_err)?; + decompressed + } + }; + + Ok(Some(GetResult { + sha256, + content_type, + stored_size, + created_at, + data, + })) } diff --git a/src/shard/fn_store.rs b/src/shard/fn_store.rs index 59fe488..cb15409 100644 --- a/src/shard/fn_store.rs +++ b/src/shard/fn_store.rs @@ -5,22 +5,27 @@ pub enum StoreResult { Created { stored_size: usize, data_size: usize, + created_at: UtcDateTime, }, Exists { stored_size: usize, data_size: usize, + created_at: UtcDateTime, }, } +#[derive(Default)] +pub struct StoreArgs { + pub sha256: Sha256, + pub content_type: String, + pub data: Bytes, +} + impl Shard { - pub async fn store( - &self, - sha256: Sha256, - content_type: String, - data: Bytes, - ) -> Result> { + pub async fn store(&self, store_args: StoreArgs) -> Result> { + let use_compression = self.use_compression; self.conn - .call(move |conn| store(conn, sha256, content_type, data).map_err(|e| e.into())) + .call(move |conn| store(conn, use_compression, store_args)) .await .map_err(|e| { error!("store failed: {}", e); @@ -31,10 +36,13 @@ impl Shard { fn store( conn: &mut rusqlite::Connection, - sha256: Sha256, - content_type: String, - data: Bytes, -) -> Result { + use_compression: UseCompression, + StoreArgs { + sha256, + content_type, + data, + }: StoreArgs, +) -> Result { let sha256 = sha256.hex_string(); // check for existing entry @@ -46,6 +54,7 @@ fn store( Ok(StoreResult::Exists { stored_size: row.get(0)?, data_size: row.get(1)?, + created_at: parse_created_at_str(row.get(2)?)?, }) }, ) @@ -55,24 +64,59 @@ fn store( return Ok(existing); } - let created_at = chrono::Utc::now().to_rfc3339(); - let data_size = data.len(); + let created_at = chrono::Utc::now(); + let uncompressed_size = data.len(); + let tmp_data_holder; + + let use_compression = match use_compression { + UseCompression::None => false, + UseCompression::Auto => auto_compressible_content_type(&content_type), + UseCompression::Zstd => true, + }; + + let (compression, data) = if use_compression { + tmp_data_holder = zstd::encode_all(&data[..], 0).map_err(into_tokio_rusqlite_err)?; + if tmp_data_holder.len() < data.len() { + (Compression::Zstd, &tmp_data_holder[..]) + } else { + (Compression::None, &data[..]) + } + } else { + (Compression::None, &data[..]) + }; + let compressed_size = data.len(); conn.execute( - "INSERT INTO entries (sha256, content_type, compression, uncompressed_size, compressed_size, data, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)", + "INSERT INTO entries + (sha256, content_type, compression, uncompressed_size, compressed_size, data, created_at) + VALUES + (?, ?, ?, ?, ?, ?, ?) + ", params![ sha256, content_type, - 0, - data_size, - data_size, - &data[..], - created_at, + compression, + uncompressed_size, + compressed_size, + data, + created_at.to_rfc3339(), ], )?; Ok(StoreResult::Created { - stored_size: data_size, - data_size, + stored_size: compressed_size, + data_size: uncompressed_size, + created_at, }) } + +fn auto_compressible_content_type(content_type: &str) -> bool { + [ + "text/", + "application/xml", + "application/json", + "application/javascript", + ] + .iter() + .any(|ct| content_type.starts_with(ct)) +} diff --git a/src/shard/mod.rs b/src/shard/mod.rs index 2b4fd6b..0fd46a6 100644 --- a/src/shard/mod.rs +++ b/src/shard/mod.rs @@ -4,11 +4,11 @@ mod fn_store; pub mod shard_error; pub use fn_get::GetResult; -pub use fn_store::StoreResult; +pub use fn_store::{StoreArgs, StoreResult}; -use crate::{sha256::Sha256, shard::shard_error::ShardError}; +use crate::{sha256::Sha256, shard::shard_error::ShardError, UseCompression}; use axum::body::Bytes; -use rusqlite::{params, types::FromSql, OptionalExtension}; +use rusqlite::{params, types::FromSql, OptionalExtension, ToSql}; use std::error::Error; use tokio_rusqlite::Connection; use tracing::{debug, error}; @@ -19,11 +19,43 @@ pub type UtcDateTime = chrono::DateTime; pub struct Shard { id: usize, conn: Connection, + use_compression: UseCompression, +} + +#[derive(Debug, PartialEq, Eq)] +enum Compression { + None, + Zstd, +} +impl ToSql for Compression { + fn to_sql(&self) -> rusqlite::Result> { + match self { + Compression::None => 0.to_sql(), + Compression::Zstd => 1.to_sql(), + } + } +} +impl FromSql for Compression { + fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult { + match value.as_i64()? { + 0 => Ok(Compression::None), + 1 => Ok(Compression::Zstd), + _ => Err(rusqlite::types::FromSqlError::InvalidType), + } + } } impl Shard { - pub async fn open(id: usize, conn: Connection) -> Result> { - let shard = Self { id, conn }; + pub async fn open( + id: usize, + use_compression: UseCompression, + conn: Connection, + ) -> Result> { + let shard = Self { + id, + use_compression, + conn, + }; shard.migrate().await?; Ok(shard) } @@ -69,14 +101,32 @@ async fn get_num_entries(conn: &Connection) -> Result Result { + let parsed = chrono::DateTime::parse_from_rfc3339(&created_at_str) + .map_err(|e| rusqlite::Error::ToSqlConversionFailure(e.into()))?; + Ok(parsed.to_utc()) +} + +fn into_tokio_rusqlite_err>>( + e: E, +) -> tokio_rusqlite::Error { + tokio_rusqlite::Error::Other(e.into()) +} + #[cfg(test)] pub mod test { - use super::StoreResult; - use crate::sha256::Sha256; + use rstest::rstest; + + use super::{StoreResult, UseCompression}; + use crate::{sha256::Sha256, shard::StoreArgs}; + + pub async fn make_shard_with_compression(use_compression: UseCompression) -> super::Shard { + let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap(); + super::Shard::open(0, use_compression, conn).await.unwrap() + } pub async fn make_shard() -> super::Shard { - let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap(); - super::Shard::open(0, conn).await.unwrap() + make_shard_with_compression(UseCompression::Auto).await } #[tokio::test] @@ -107,16 +157,25 @@ pub mod test { let data = "hello, world!".as_bytes(); let sha256 = Sha256::from_bytes(data); let store_result = shard - .store(sha256, "text/plain".to_string(), data.into()) + .store(StoreArgs { + sha256, + content_type: "text/plain".to_string(), + data: data.into(), + }) .await .unwrap(); - assert_eq!( - store_result, + match store_result { StoreResult::Created { - data_size: data.len(), - stored_size: data.len() + stored_size, + data_size, + created_at, + } => { + assert_eq!(stored_size, data.len()); + assert_eq!(data_size, data.len()); + assert!(created_at > chrono::Utc::now() - chrono::Duration::seconds(1)); } - ); + _ => panic!("expected StoreResult::Created"), + } assert_eq!(shard.num_entries().await.unwrap(), 1); let get_result = shard.get(sha256).await.unwrap().unwrap(); @@ -132,28 +191,76 @@ pub mod test { let sha256 = Sha256::from_bytes(data); let store_result = shard - .store(sha256, "text/plain".to_string(), data.into()) + .store(StoreArgs { + sha256, + content_type: "text/plain".to_string(), + data: data.into(), + }) .await .unwrap(); - assert_eq!( - store_result, + assert!(matches!(store_result, StoreResult::Created { .. })); + + let created_at = match store_result { StoreResult::Created { - data_size: data.len(), - stored_size: data.len() + stored_size, + data_size, + created_at, + } => { + assert_eq!(stored_size, data.len()); + assert_eq!(data_size, data.len()); + created_at } - ); + _ => panic!("expected StoreResult::Created"), + }; let store_result = shard - .store(sha256, "text/plain".to_string(), data.into()) + .store(StoreArgs { + sha256, + content_type: "text/plain".to_string(), + data: data.into(), + }) .await .unwrap(); assert_eq!( store_result, StoreResult::Exists { data_size: data.len(), - stored_size: data.len() + stored_size: data.len(), + created_at } ); assert_eq!(shard.num_entries().await.unwrap(), 1); } + + #[rstest] + #[tokio::test] + async fn test_compression_store_get( + #[values(UseCompression::Auto, UseCompression::None, UseCompression::Zstd)] + use_compression: UseCompression, + #[values(true, false)] incompressible_data: bool, + #[values("text/string", "image/jpg", "application/octet-stream")] content_type: String, + ) { + let shard = make_shard_with_compression(use_compression).await; + let mut data = vec![b'.'; 1024]; + if incompressible_data { + for byte in data.iter_mut() { + *byte = rand::random(); + } + } + + let sha256 = Sha256::from_bytes(&data); + let store_result = shard + .store(StoreArgs { + sha256, + content_type: content_type.clone(), + data: data.clone().into(), + }) + .await + .unwrap(); + assert!(matches!(store_result, StoreResult::Created { .. })); + + let get_result = shard.get(sha256).await.unwrap().unwrap(); + assert_eq!(get_result.content_type, content_type); + assert_eq!(get_result.data, data); + } } diff --git a/src/shards.rs b/src/shards.rs index cb53e9b..21cba22 100644 --- a/src/shards.rs +++ b/src/shards.rs @@ -34,11 +34,15 @@ impl Shards { #[cfg(test)] pub mod test { - use crate::shard::test::make_shard; + use crate::{shard::test::make_shard_with_compression, UseCompression}; use super::Shards; + pub async fn make_shards_with_compression(use_compression: UseCompression) -> Shards { + Shards::new(vec![make_shard_with_compression(use_compression).await]).unwrap() + } + pub async fn make_shards() -> Shards { - Shards::new(vec![make_shard().await]).unwrap() + make_shards_with_compression(UseCompression::Auto).await } }