move compression into Shard field

This commit is contained in:
Dylan Knutson
2024-05-05 21:03:31 -07:00
parent a3b550526e
commit b0955c9c64
8 changed files with 43 additions and 52 deletions

View File

@@ -173,10 +173,10 @@ pub mod test {
use crate::zstd_dict::test::make_zstd_dict; use crate::zstd_dict::test::make_zstd_dict;
pub fn make_compressor() -> Compressor { pub fn make_compressor() -> Compressor {
make_compressor_with_policy(CompressionPolicy::Auto) make_compressor_with(CompressionPolicy::Auto)
} }
pub fn make_compressor_with_policy(compression_policy: CompressionPolicy) -> Compressor { pub fn make_compressor_with(compression_policy: CompressionPolicy) -> Compressor {
let mut compressor = Compressor::new(compression_policy); let mut compressor = Compressor::new(compression_policy);
let zstd_dict = make_zstd_dict(1.into(), "dict1"); let zstd_dict = make_zstd_dict(1.into(), "dict1");
compressor.add(zstd_dict); compressor.add(zstd_dict);
@@ -203,7 +203,7 @@ pub mod test {
compression_policy: CompressionPolicy, compression_policy: CompressionPolicy,
#[values("text/plain", "application/json", "image/png")] content_type: &str, #[values("text/plain", "application/json", "image/png")] content_type: &str,
) { ) {
let compressor = make_compressor_with_policy(compression_policy); let compressor = make_compressor_with(compression_policy);
let data = b"hello, world!".to_vec(); let data = b"hello, world!".to_vec();
let (compression_id, compressed) = compressor let (compression_id, compressed) = compressor
.compress(Some("dict1"), content_type, data.clone()) .compress(Some("dict1"), content_type, data.clone())

View File

@@ -1,5 +1,4 @@
use crate::{ use crate::{
compressor::CompressorArc,
sha256::Sha256, sha256::Sha256,
shard::{GetArgs, GetResult}, shard::{GetArgs, GetResult},
shards::Shards, shards::Shards,
@@ -116,7 +115,6 @@ impl<E: Into<AsyncBoxError>> From<E> for GetResponse {
pub async fn get_handler( pub async fn get_handler(
Path(params): Path<HashMap<String, String>>, Path(params): Path<HashMap<String, String>>,
Extension(shards): Extension<Arc<Shards>>, Extension(shards): Extension<Arc<Shards>>,
Extension(compressor): Extension<CompressorArc>,
) -> GetResponse { ) -> GetResponse {
let sha256_str = match params.get("sha256") { let sha256_str = match params.get("sha256") {
Some(sha256_str) => sha256_str.clone(), Some(sha256_str) => sha256_str.clone(),
@@ -135,7 +133,7 @@ pub async fn get_handler(
}; };
let shard = shards.shard_for(&sha256); let shard = shards.shard_for(&sha256);
let get_result = match shard.get(GetArgs { sha256, compressor }).await { let get_result = match shard.get(GetArgs { sha256 }).await {
Ok(get_result) => get_result, Ok(get_result) => get_result,
Err(e) => return e.into(), Err(e) => return e.into(),
}; };
@@ -149,8 +147,7 @@ pub async fn get_handler(
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::{ use crate::{
compressor::test::make_compressor, sha256::Sha256, shard::GetResult, sha256::Sha256, shard::GetResult, shards::test::make_shards, sql_types::UtcDateTime,
shards::test::make_shards, sql_types::UtcDateTime,
}; };
use axum::{extract::Path, response::IntoResponse, Extension}; use axum::{extract::Path, response::IntoResponse, Extension};
use std::collections::HashMap; use std::collections::HashMap;
@@ -160,16 +157,13 @@ mod test {
#[tokio::test] #[tokio::test]
async fn test_get_invalid_sha256() { async fn test_get_invalid_sha256() {
let shards = Extension(make_shards().await); let shards = Extension(make_shards().await);
let compressor = Extension(make_compressor().into_arc());
let response = let response = super::get_handler(Path(HashMap::new()), shards.clone()).await;
super::get_handler(Path(HashMap::new()), shards.clone(), compressor.clone()).await;
assert!(matches!(response, super::GetResponse::MissingSha256 { .. })); assert!(matches!(response, super::GetResponse::MissingSha256 { .. }));
let response = super::get_handler( let response = super::get_handler(
Path(HashMap::from([(String::from("sha256"), String::from(""))])), Path(HashMap::from([(String::from("sha256"), String::from(""))])),
shards.clone(), shards.clone(),
compressor.clone(),
) )
.await; .await;
assert!(matches!(response, super::GetResponse::InvalidSha256 { .. })); assert!(matches!(response, super::GetResponse::InvalidSha256 { .. }));
@@ -180,7 +174,6 @@ mod test {
String::from("invalid"), String::from("invalid"),
)])), )])),
shards.clone(), shards.clone(),
compressor.clone(),
) )
.await; .await;
assert!(matches!(response, super::GetResponse::InvalidSha256 { .. })); assert!(matches!(response, super::GetResponse::InvalidSha256 { .. }));

View File

@@ -1,5 +1,4 @@
use crate::{ use crate::{
compressor::CompressorArc,
sha256::Sha256, sha256::Sha256,
shard::{StoreArgs, StoreResult}, shard::{StoreArgs, StoreResult},
shards::ShardsArc, shards::ShardsArc,
@@ -85,7 +84,6 @@ impl IntoResponse for StoreResponse {
#[axum::debug_handler] #[axum::debug_handler]
pub async fn store_handler( pub async fn store_handler(
Extension(shards): Extension<ShardsArc>, Extension(shards): Extension<ShardsArc>,
Extension(compressor): Extension<CompressorArc>,
TypedMultipart(request): TypedMultipart<StoreRequest>, TypedMultipart(request): TypedMultipart<StoreRequest>,
) -> StoreResponse { ) -> StoreResponse {
let sha256 = Sha256::from_bytes(&request.data.contents); let sha256 = Sha256::from_bytes(&request.data.contents);
@@ -110,7 +108,6 @@ pub async fn store_handler(
content_type: request.content_type, content_type: request.content_type,
data: request.data.contents, data: request.data.contents,
compression_hint: request.compression_hint, compression_hint: request.compression_hint,
compressor,
}) })
.await .await
{ {
@@ -123,7 +120,7 @@ pub async fn store_handler(
#[cfg(test)] #[cfg(test)]
pub mod test { pub mod test {
use crate::{compressor::Compressor, shards::test::make_shards}; use crate::{compressor::test::make_compressor_with, shards::test::make_shards_with};
use super::*; use super::*;
use crate::CompressionPolicy; use crate::CompressionPolicy;
@@ -138,8 +135,7 @@ pub mod test {
data: D, data: D,
) -> StoreResponse { ) -> StoreResponse {
store_handler( store_handler(
Extension(make_shards().await), Extension(make_shards_with(make_compressor_with(compression_policy).into_arc()).await),
Extension(Compressor::new(compression_policy).into_arc()),
TypedMultipart(StoreRequest { TypedMultipart(StoreRequest {
sha256: sha256.map(|s| s.hex_string()), sha256: sha256.map(|s| s.hex_string()),
content_type: content_type.to_string(), content_type: content_type.to_string(),

View File

@@ -104,7 +104,7 @@ fn main() -> Result<(), AsyncBoxError> {
for shard_id in 0..manifest.num_shards() { for shard_id in 0..manifest.num_shards() {
let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id)); let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id));
let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?; let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?;
let shard = Shard::open(shard_id, shard_sqlite_conn).await?; let shard = Shard::open(shard_id, shard_sqlite_conn, manifest.compressor()).await?;
info!( info!(
"shard {} has {} entries", "shard {} has {} entries",
shard.id(), shard.id(),

View File

@@ -1,6 +1,5 @@
use crate::{ use crate::{
compressible_data::CompressibleData, compressible_data::CompressibleData,
compressor::CompressorArc,
into_tokio_rusqlite_err, into_tokio_rusqlite_err,
sql_types::{CompressionId, UtcDateTime}, sql_types::{CompressionId, UtcDateTime},
AsyncBoxError, AsyncBoxError,
@@ -10,7 +9,6 @@ use super::*;
pub struct GetArgs { pub struct GetArgs {
pub sha256: Sha256, pub sha256: Sha256,
pub compressor: CompressorArc,
} }
pub struct GetResult { pub struct GetResult {
@@ -35,7 +33,7 @@ impl Shard {
})?; })?;
if let Some((content_type, stored_size, created_at, compression_id, data)) = maybe_row { if let Some((content_type, stored_size, created_at, compression_id, data)) = maybe_row {
let compressor = args.compressor.read().await; let compressor = self.compressor.read().await;
let data = compressor.decompress(compression_id, data)?; let data = compressor.decompress(compression_id, data)?;
Ok(Some(GetResult { Ok(Some(GetResult {
sha256: args.sha256, sha256: args.sha256,

View File

@@ -1,6 +1,5 @@
use crate::{ use crate::{
compressible_data::CompressibleData, compressible_data::CompressibleData,
compressor::CompressorArc,
into_tokio_rusqlite_err, into_tokio_rusqlite_err,
sql_types::{CompressionId, UtcDateTime}, sql_types::{CompressionId, UtcDateTime},
AsyncBoxError, AsyncBoxError,
@@ -27,7 +26,6 @@ pub struct StoreArgs {
pub sha256: Sha256, pub sha256: Sha256,
pub content_type: String, pub content_type: String,
pub data: Bytes, pub data: Bytes,
pub compressor: CompressorArc,
pub compression_hint: Option<String>, pub compression_hint: Option<String>,
} }
@@ -38,7 +36,6 @@ impl Shard {
sha256, sha256,
data, data,
content_type, content_type,
compressor,
compression_hint, compression_hint,
}: StoreArgs, }: StoreArgs,
) -> Result<StoreResult, AsyncBoxError> { ) -> Result<StoreResult, AsyncBoxError> {
@@ -53,7 +50,7 @@ impl Shard {
let uncompressed_size = data.len(); let uncompressed_size = data.len();
let compressor = compressor.read().await; let compressor = self.compressor.read().await;
let (compression_id, data) = let (compression_id, data) =
compressor.compress(compression_hint.as_deref(), &content_type, data)?; compressor.compress(compression_hint.as_deref(), &content_type, data)?;

View File

@@ -1,4 +1,4 @@
use crate::AsyncBoxError; use crate::{compressor::CompressorArc, AsyncBoxError};
use super::*; use super::*;
@@ -6,11 +6,20 @@ use super::*;
pub struct Shard { pub struct Shard {
pub(super) id: usize, pub(super) id: usize,
pub(super) conn: Connection, pub(super) conn: Connection,
pub(super) compressor: CompressorArc,
} }
impl Shard { impl Shard {
pub async fn open(id: usize, conn: Connection) -> Result<Self, AsyncBoxError> { pub async fn open(
let shard = Self { id, conn }; id: usize,
conn: Connection,
compressor: CompressorArc,
) -> Result<Self, AsyncBoxError> {
let shard = Self {
id,
conn,
compressor,
};
shard.migrate().await?; shard.migrate().await?;
Ok(shard) Ok(shard)
} }
@@ -54,7 +63,8 @@ async fn get_num_entries(conn: &Connection) -> Result<usize, tokio_rusqlite::Err
#[cfg(test)] #[cfg(test)]
pub mod test { pub mod test {
use crate::compressor::Compressor; use crate::compressor::test::make_compressor_with;
use crate::compressor::{Compressor, CompressorArc};
use crate::{ use crate::{
compressor::test::make_compressor, compressor::test::make_compressor,
sha256::Sha256, sha256::Sha256,
@@ -64,9 +74,13 @@ pub mod test {
use rstest::rstest; use rstest::rstest;
pub async fn make_shard() -> super::Shard { pub async fn make_shard_with(compressor: CompressorArc) -> super::Shard {
let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap(); let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap();
super::Shard::open(0, conn).await.unwrap() super::Shard::open(0, conn, compressor).await.unwrap()
}
pub async fn make_shard() -> super::Shard {
make_shard_with(make_compressor().into_arc()).await
} }
#[tokio::test] #[tokio::test]
@@ -86,9 +100,8 @@ pub mod test {
#[tokio::test] #[tokio::test]
async fn test_not_found_get() { async fn test_not_found_get() {
let shard = make_shard().await; let shard = make_shard().await;
let compressor = make_compressor().into_arc();
let sha256 = Sha256::from_bytes("hello, world!".as_bytes()); let sha256 = Sha256::from_bytes("hello, world!".as_bytes());
let get_result = shard.get(GetArgs { sha256, compressor }).await.unwrap(); let get_result = shard.get(GetArgs { sha256 }).await.unwrap();
assert!(get_result.is_none()); assert!(get_result.is_none());
} }
@@ -97,13 +110,11 @@ pub mod test {
let shard = make_shard().await; let shard = make_shard().await;
let data = "hello, world!".as_bytes(); let data = "hello, world!".as_bytes();
let sha256 = Sha256::from_bytes(data); let sha256 = Sha256::from_bytes(data);
let compressor = make_compressor().into_arc();
let store_result = shard let store_result = shard
.store(StoreArgs { .store(StoreArgs {
sha256, sha256,
content_type: "text/plain".to_string(), content_type: "text/plain".to_string(),
data: data.into(), data: data.into(),
compressor: compressor.clone(),
..Default::default() ..Default::default()
}) })
.await .await
@@ -122,11 +133,7 @@ pub mod test {
} }
assert_eq!(shard.num_entries().await.unwrap(), 1); assert_eq!(shard.num_entries().await.unwrap(), 1);
let get_result = shard let get_result = shard.get(GetArgs { sha256 }).await.unwrap().unwrap();
.get(GetArgs { sha256, compressor })
.await
.unwrap()
.unwrap();
assert_eq!(get_result.content_type, "text/plain"); assert_eq!(get_result.content_type, "text/plain");
assert_eq!(get_result.data, data); assert_eq!(get_result.data, data);
assert_eq!(get_result.stored_size, data.len()); assert_eq!(get_result.stored_size, data.len());
@@ -194,8 +201,7 @@ pub mod test {
#[values(true, false)] incompressible_data: bool, #[values(true, false)] incompressible_data: bool,
#[values("text/string", "image/jpg", "application/octet-stream")] content_type: String, #[values("text/string", "image/jpg", "application/octet-stream")] content_type: String,
) { ) {
let shard = make_shard().await; let shard = make_shard_with(make_compressor_with(compression_policy).into_arc()).await;
let compressor = Compressor::new(compression_policy).into_arc();
let mut data = vec![b'.'; 1024]; let mut data = vec![b'.'; 1024];
if incompressible_data { if incompressible_data {
for byte in data.iter_mut() { for byte in data.iter_mut() {
@@ -215,11 +221,7 @@ pub mod test {
.unwrap(); .unwrap();
assert!(matches!(store_result, StoreResult::Created { .. })); assert!(matches!(store_result, StoreResult::Created { .. }));
let get_result = shard let get_result = shard.get(GetArgs { sha256 }).await.unwrap().unwrap();
.get(GetArgs { sha256, compressor })
.await
.unwrap()
.unwrap();
assert_eq!(get_result.content_type, content_type); assert_eq!(get_result.content_type, content_type);
assert_eq!(get_result.data, data); assert_eq!(get_result.data, data);
} }
@@ -235,7 +237,6 @@ pub mod test {
sha256, sha256,
content_type: "text/plain".to_string(), content_type: "text/plain".to_string(),
data: data.into(), data: data.into(),
compressor: compressor.clone(),
compression_hint: Some("hint1".to_string()), compression_hint: Some("hint1".to_string()),
}) })
.await .await

View File

@@ -32,11 +32,17 @@ impl Shards {
pub mod test { pub mod test {
use std::sync::Arc; use std::sync::Arc;
use crate::shard::test::make_shard; use crate::{
compressor::{test::make_compressor, CompressorArc},
shard::test::make_shard_with,
};
use super::{Shards, ShardsArc}; use super::{Shards, ShardsArc};
pub async fn make_shards() -> ShardsArc { pub async fn make_shards() -> ShardsArc {
Arc::new(Shards::new(vec![make_shard().await]).unwrap()) make_shards_with(make_compressor().into_arc()).await
}
pub async fn make_shards_with(compressor: CompressorArc) -> ShardsArc {
Arc::new(Shards::new(vec![make_shard_with(compressor.clone()).await]).unwrap())
} }
} }