From b0955c9c6416f692c510d1e2eebd141d7588c621 Mon Sep 17 00:00:00 2001 From: Dylan Knutson Date: Sun, 5 May 2024 21:03:31 -0700 Subject: [PATCH] move compression into Shard field --- src/compressor.rs | 6 ++--- src/handlers/get_handler.rs | 13 +++------- src/handlers/store_handler.rs | 8 ++---- src/main.rs | 2 +- src/shard/fn_get.rs | 4 +-- src/shard/fn_store.rs | 5 +--- src/shard/shard.rs | 47 ++++++++++++++++++----------------- src/shards.rs | 10 ++++++-- 8 files changed, 43 insertions(+), 52 deletions(-) diff --git a/src/compressor.rs b/src/compressor.rs index 6bc7ae4..e836398 100644 --- a/src/compressor.rs +++ b/src/compressor.rs @@ -173,10 +173,10 @@ pub mod test { use crate::zstd_dict::test::make_zstd_dict; pub fn make_compressor() -> Compressor { - make_compressor_with_policy(CompressionPolicy::Auto) + make_compressor_with(CompressionPolicy::Auto) } - pub fn make_compressor_with_policy(compression_policy: CompressionPolicy) -> Compressor { + pub fn make_compressor_with(compression_policy: CompressionPolicy) -> Compressor { let mut compressor = Compressor::new(compression_policy); let zstd_dict = make_zstd_dict(1.into(), "dict1"); compressor.add(zstd_dict); @@ -203,7 +203,7 @@ pub mod test { compression_policy: CompressionPolicy, #[values("text/plain", "application/json", "image/png")] content_type: &str, ) { - let compressor = make_compressor_with_policy(compression_policy); + let compressor = make_compressor_with(compression_policy); let data = b"hello, world!".to_vec(); let (compression_id, compressed) = compressor .compress(Some("dict1"), content_type, data.clone()) diff --git a/src/handlers/get_handler.rs b/src/handlers/get_handler.rs index f04927a..d74fc21 100644 --- a/src/handlers/get_handler.rs +++ b/src/handlers/get_handler.rs @@ -1,5 +1,4 @@ use crate::{ - compressor::CompressorArc, sha256::Sha256, shard::{GetArgs, GetResult}, shards::Shards, @@ -116,7 +115,6 @@ impl> From for GetResponse { pub async fn get_handler( Path(params): Path>, Extension(shards): Extension>, - Extension(compressor): Extension, ) -> GetResponse { let sha256_str = match params.get("sha256") { Some(sha256_str) => sha256_str.clone(), @@ -135,7 +133,7 @@ pub async fn get_handler( }; let shard = shards.shard_for(&sha256); - let get_result = match shard.get(GetArgs { sha256, compressor }).await { + let get_result = match shard.get(GetArgs { sha256 }).await { Ok(get_result) => get_result, Err(e) => return e.into(), }; @@ -149,8 +147,7 @@ pub async fn get_handler( #[cfg(test)] mod test { use crate::{ - compressor::test::make_compressor, sha256::Sha256, shard::GetResult, - shards::test::make_shards, sql_types::UtcDateTime, + sha256::Sha256, shard::GetResult, shards::test::make_shards, sql_types::UtcDateTime, }; use axum::{extract::Path, response::IntoResponse, Extension}; use std::collections::HashMap; @@ -160,16 +157,13 @@ mod test { #[tokio::test] async fn test_get_invalid_sha256() { let shards = Extension(make_shards().await); - let compressor = Extension(make_compressor().into_arc()); - let response = - super::get_handler(Path(HashMap::new()), shards.clone(), compressor.clone()).await; + let response = super::get_handler(Path(HashMap::new()), shards.clone()).await; assert!(matches!(response, super::GetResponse::MissingSha256 { .. })); let response = super::get_handler( Path(HashMap::from([(String::from("sha256"), String::from(""))])), shards.clone(), - compressor.clone(), ) .await; assert!(matches!(response, super::GetResponse::InvalidSha256 { .. })); @@ -180,7 +174,6 @@ mod test { String::from("invalid"), )])), shards.clone(), - compressor.clone(), ) .await; assert!(matches!(response, super::GetResponse::InvalidSha256 { .. })); diff --git a/src/handlers/store_handler.rs b/src/handlers/store_handler.rs index b42b5f0..62442f2 100644 --- a/src/handlers/store_handler.rs +++ b/src/handlers/store_handler.rs @@ -1,5 +1,4 @@ use crate::{ - compressor::CompressorArc, sha256::Sha256, shard::{StoreArgs, StoreResult}, shards::ShardsArc, @@ -85,7 +84,6 @@ impl IntoResponse for StoreResponse { #[axum::debug_handler] pub async fn store_handler( Extension(shards): Extension, - Extension(compressor): Extension, TypedMultipart(request): TypedMultipart, ) -> StoreResponse { let sha256 = Sha256::from_bytes(&request.data.contents); @@ -110,7 +108,6 @@ pub async fn store_handler( content_type: request.content_type, data: request.data.contents, compression_hint: request.compression_hint, - compressor, }) .await { @@ -123,7 +120,7 @@ pub async fn store_handler( #[cfg(test)] pub mod test { - use crate::{compressor::Compressor, shards::test::make_shards}; + use crate::{compressor::test::make_compressor_with, shards::test::make_shards_with}; use super::*; use crate::CompressionPolicy; @@ -138,8 +135,7 @@ pub mod test { data: D, ) -> StoreResponse { store_handler( - Extension(make_shards().await), - Extension(Compressor::new(compression_policy).into_arc()), + Extension(make_shards_with(make_compressor_with(compression_policy).into_arc()).await), TypedMultipart(StoreRequest { sha256: sha256.map(|s| s.hex_string()), content_type: content_type.to_string(), diff --git a/src/main.rs b/src/main.rs index 66010b8..6e161a2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -104,7 +104,7 @@ fn main() -> Result<(), AsyncBoxError> { for shard_id in 0..manifest.num_shards() { let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id)); let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?; - let shard = Shard::open(shard_id, shard_sqlite_conn).await?; + let shard = Shard::open(shard_id, shard_sqlite_conn, manifest.compressor()).await?; info!( "shard {} has {} entries", shard.id(), diff --git a/src/shard/fn_get.rs b/src/shard/fn_get.rs index 3912dfb..4c5dd43 100644 --- a/src/shard/fn_get.rs +++ b/src/shard/fn_get.rs @@ -1,6 +1,5 @@ use crate::{ compressible_data::CompressibleData, - compressor::CompressorArc, into_tokio_rusqlite_err, sql_types::{CompressionId, UtcDateTime}, AsyncBoxError, @@ -10,7 +9,6 @@ use super::*; pub struct GetArgs { pub sha256: Sha256, - pub compressor: CompressorArc, } pub struct GetResult { @@ -35,7 +33,7 @@ impl Shard { })?; if let Some((content_type, stored_size, created_at, compression_id, data)) = maybe_row { - let compressor = args.compressor.read().await; + let compressor = self.compressor.read().await; let data = compressor.decompress(compression_id, data)?; Ok(Some(GetResult { sha256: args.sha256, diff --git a/src/shard/fn_store.rs b/src/shard/fn_store.rs index 7383d7b..37481cf 100644 --- a/src/shard/fn_store.rs +++ b/src/shard/fn_store.rs @@ -1,6 +1,5 @@ use crate::{ compressible_data::CompressibleData, - compressor::CompressorArc, into_tokio_rusqlite_err, sql_types::{CompressionId, UtcDateTime}, AsyncBoxError, @@ -27,7 +26,6 @@ pub struct StoreArgs { pub sha256: Sha256, pub content_type: String, pub data: Bytes, - pub compressor: CompressorArc, pub compression_hint: Option, } @@ -38,7 +36,6 @@ impl Shard { sha256, data, content_type, - compressor, compression_hint, }: StoreArgs, ) -> Result { @@ -53,7 +50,7 @@ impl Shard { let uncompressed_size = data.len(); - let compressor = compressor.read().await; + let compressor = self.compressor.read().await; let (compression_id, data) = compressor.compress(compression_hint.as_deref(), &content_type, data)?; diff --git a/src/shard/shard.rs b/src/shard/shard.rs index ed31f2b..d71a29a 100644 --- a/src/shard/shard.rs +++ b/src/shard/shard.rs @@ -1,4 +1,4 @@ -use crate::AsyncBoxError; +use crate::{compressor::CompressorArc, AsyncBoxError}; use super::*; @@ -6,11 +6,20 @@ use super::*; pub struct Shard { pub(super) id: usize, pub(super) conn: Connection, + pub(super) compressor: CompressorArc, } impl Shard { - pub async fn open(id: usize, conn: Connection) -> Result { - let shard = Self { id, conn }; + pub async fn open( + id: usize, + conn: Connection, + compressor: CompressorArc, + ) -> Result { + let shard = Self { + id, + conn, + compressor, + }; shard.migrate().await?; Ok(shard) } @@ -54,7 +63,8 @@ async fn get_num_entries(conn: &Connection) -> Result super::Shard { + pub async fn make_shard_with(compressor: CompressorArc) -> super::Shard { let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap(); - super::Shard::open(0, conn).await.unwrap() + super::Shard::open(0, conn, compressor).await.unwrap() + } + + pub async fn make_shard() -> super::Shard { + make_shard_with(make_compressor().into_arc()).await } #[tokio::test] @@ -86,9 +100,8 @@ pub mod test { #[tokio::test] async fn test_not_found_get() { let shard = make_shard().await; - let compressor = make_compressor().into_arc(); let sha256 = Sha256::from_bytes("hello, world!".as_bytes()); - let get_result = shard.get(GetArgs { sha256, compressor }).await.unwrap(); + let get_result = shard.get(GetArgs { sha256 }).await.unwrap(); assert!(get_result.is_none()); } @@ -97,13 +110,11 @@ pub mod test { let shard = make_shard().await; let data = "hello, world!".as_bytes(); let sha256 = Sha256::from_bytes(data); - let compressor = make_compressor().into_arc(); let store_result = shard .store(StoreArgs { sha256, content_type: "text/plain".to_string(), data: data.into(), - compressor: compressor.clone(), ..Default::default() }) .await @@ -122,11 +133,7 @@ pub mod test { } assert_eq!(shard.num_entries().await.unwrap(), 1); - let get_result = shard - .get(GetArgs { sha256, compressor }) - .await - .unwrap() - .unwrap(); + let get_result = shard.get(GetArgs { sha256 }).await.unwrap().unwrap(); assert_eq!(get_result.content_type, "text/plain"); assert_eq!(get_result.data, data); assert_eq!(get_result.stored_size, data.len()); @@ -194,8 +201,7 @@ pub mod test { #[values(true, false)] incompressible_data: bool, #[values("text/string", "image/jpg", "application/octet-stream")] content_type: String, ) { - let shard = make_shard().await; - let compressor = Compressor::new(compression_policy).into_arc(); + let shard = make_shard_with(make_compressor_with(compression_policy).into_arc()).await; let mut data = vec![b'.'; 1024]; if incompressible_data { for byte in data.iter_mut() { @@ -215,11 +221,7 @@ pub mod test { .unwrap(); assert!(matches!(store_result, StoreResult::Created { .. })); - let get_result = shard - .get(GetArgs { sha256, compressor }) - .await - .unwrap() - .unwrap(); + let get_result = shard.get(GetArgs { sha256 }).await.unwrap().unwrap(); assert_eq!(get_result.content_type, content_type); assert_eq!(get_result.data, data); } @@ -235,7 +237,6 @@ pub mod test { sha256, content_type: "text/plain".to_string(), data: data.into(), - compressor: compressor.clone(), compression_hint: Some("hint1".to_string()), }) .await diff --git a/src/shards.rs b/src/shards.rs index 160368e..adb486d 100644 --- a/src/shards.rs +++ b/src/shards.rs @@ -32,11 +32,17 @@ impl Shards { pub mod test { use std::sync::Arc; - use crate::shard::test::make_shard; + use crate::{ + compressor::{test::make_compressor, CompressorArc}, + shard::test::make_shard_with, + }; use super::{Shards, ShardsArc}; pub async fn make_shards() -> ShardsArc { - Arc::new(Shards::new(vec![make_shard().await]).unwrap()) + make_shards_with(make_compressor().into_arc()).await + } + pub async fn make_shards_with(compressor: CompressorArc) -> ShardsArc { + Arc::new(Shards::new(vec![make_shard_with(compressor.clone()).await]).unwrap()) } }