diff --git a/src/compressible_data.rs b/src/compressible_data.rs new file mode 100644 index 0000000..3b14f55 --- /dev/null +++ b/src/compressible_data.rs @@ -0,0 +1,67 @@ +use axum::{body::Bytes, response::IntoResponse}; + +#[derive(Debug, Eq, PartialEq)] +pub enum CompressibleData { + Bytes(Bytes), + Vec(Vec), +} + +impl CompressibleData { + pub fn len(&self) -> usize { + match self { + CompressibleData::Bytes(b) => b.len(), + CompressibleData::Vec(v) => v.len(), + } + } +} + +impl AsRef<[u8]> for CompressibleData { + fn as_ref(&self) -> &[u8] { + match self { + CompressibleData::Bytes(b) => b.as_ref(), + CompressibleData::Vec(v) => v.as_ref(), + } + } +} + +impl From for CompressibleData { + fn from(b: Bytes) -> Self { + CompressibleData::Bytes(b) + } +} + +impl From> for CompressibleData { + fn from(v: Vec) -> Self { + CompressibleData::Vec(v) + } +} + +impl IntoResponse for CompressibleData { + fn into_response(self) -> axum::response::Response { + match self { + CompressibleData::Bytes(b) => b.into_response(), + CompressibleData::Vec(v) => v.into_response(), + } + } +} + +impl PartialEq<&[u8]> for CompressibleData { + fn eq(&self, other: &&[u8]) -> bool { + let as_ref = self.as_ref(); + as_ref == *other + } +} + +impl PartialEq> for CompressibleData { + fn eq(&self, other: &Vec) -> bool { + let as_ref = self.as_ref(); + as_ref == other.as_slice() + } +} + +impl PartialEq for CompressibleData { + fn eq(&self, other: &Bytes) -> bool { + let as_ref = self.as_ref(); + as_ref == other.as_ref() + } +} diff --git a/src/compressor.rs b/src/compressor.rs new file mode 100644 index 0000000..9a815fd --- /dev/null +++ b/src/compressor.rs @@ -0,0 +1,218 @@ +use std::{collections::HashMap, sync::Arc}; + +use tokio::sync::RwLock; + +use crate::{ + compressible_data::CompressibleData, + sql_types::{CompressionId, ZstdDictId}, + zstd_dict::{ZstdDict, ZstdDictArc}, + AsyncBoxError, CompressionPolicy, +}; + +pub type CompressorArc = Arc>; + +pub struct Compressor { + zstd_dict_by_id: HashMap, + zstd_dict_by_name: HashMap, + compression_policy: CompressionPolicy, +} +impl Compressor { + pub fn new(compression_policy: CompressionPolicy) -> Self { + Self { + zstd_dict_by_id: HashMap::new(), + zstd_dict_by_name: HashMap::new(), + compression_policy, + } + } +} + +impl Default for Compressor { + fn default() -> Self { + Self::new(CompressionPolicy::Auto) + } +} + +type Result = std::result::Result; + +impl Compressor { + pub fn into_arc(self) -> CompressorArc { + Arc::new(RwLock::new(self)) + } + + fn _add_from_samples>( + &mut self, + id: ZstdDictId, + name: Str, + samples: Vec<&[u8]>, + ) -> Result { + let name = name.into(); + self.check(id, &name)?; + let zstd_dict = ZstdDict::from_samples(id, name, 3, samples)?; + Ok(self.add(zstd_dict)) + } + + pub fn add_from_bytes>( + &mut self, + id: ZstdDictId, + name: Str, + level: i32, + dict_bytes: Vec, + ) -> Result { + let name = name.into(); + self.check(id, &name)?; + let zstd_dict = ZstdDict::from_dict_bytes(id, name, level, dict_bytes); + Ok(self.add(zstd_dict)) + } + + pub fn by_id(&self, id: ZstdDictId) -> Option<&ZstdDict> { + self.zstd_dict_by_id.get(&id).map(|d| &**d) + } + pub fn by_name(&self, name: &str) -> Option<&ZstdDict> { + self.zstd_dict_by_name.get(name).map(|d| &**d) + } + + pub fn compress>( + &self, + name: &str, + content_type: &str, + data: Data, + ) -> Result<(CompressionId, CompressibleData)> { + let data = data.into(); + let should_compress = match self.compression_policy { + CompressionPolicy::None => false, + CompressionPolicy::Auto => auto_compressible_content_type(content_type), + CompressionPolicy::ForceZstd => true, + }; + + if should_compress { + let (id, compressed): (_, CompressibleData) = if let Some(dict) = self.by_name(name) { + ( + CompressionId::ZstdDictId(dict.id()), + dict.compress(&data)?.into(), + ) + } else { + ( + CompressionId::ZstdGeneric, + zstd::stream::encode_all(data.as_ref(), 3)?.into(), + ) + }; + + if compressed.len() < data.len() { + return Ok((id, compressed)); + } + } + + Ok((CompressionId::None, data)) + } + + pub fn decompress>( + &self, + compression_id: CompressionId, + data: Data, + ) -> Result { + let data = data.into(); + match compression_id { + CompressionId::None => Ok(data), + CompressionId::ZstdDictId(id) => { + if let Some(dict) = self.by_id(id) { + Ok(CompressibleData::Vec(dict.decompress(data.as_ref())?)) + } else { + Err(format!("zstd dictionary {:?} not found", id).into()) + } + } + CompressionId::ZstdGeneric => Ok(CompressibleData::Vec(zstd::stream::decode_all( + data.as_ref(), + )?)), + } + } + + fn check(&self, id: ZstdDictId, name: &str) -> Result<()> { + if self.zstd_dict_by_id.contains_key(&id) { + return Err(format!("zstd dictionary {:?} already exists", id).into()); + } + if self.zstd_dict_by_name.contains_key(name) { + return Err(format!("zstd dictionary {} already exists", name).into()); + } + Ok(()) + } + + fn add(&mut self, zstd_dict: ZstdDict) -> ZstdDictArc { + let zstd_dict = Arc::new(zstd_dict); + self.zstd_dict_by_id + .insert(zstd_dict.id(), zstd_dict.clone()); + self.zstd_dict_by_name + .insert(zstd_dict.name().to_string(), zstd_dict.clone()); + zstd_dict + } +} + +fn auto_compressible_content_type(content_type: &str) -> bool { + [ + "text/", + "application/xml", + "application/json", + "application/javascript", + ] + .iter() + .any(|ct| content_type.starts_with(ct)) +} + +#[cfg(test)] +pub mod test { + use rstest::rstest; + + use super::*; + use crate::zstd_dict::test::make_zstd_dict; + + pub fn make_compressor() -> Compressor { + make_compressor_with_policy(CompressionPolicy::Auto) + } + + pub fn make_compressor_with_policy(compression_policy: CompressionPolicy) -> Compressor { + let mut compressor = Compressor::new(compression_policy); + let zstd_dict = make_zstd_dict(1.into(), "dict1"); + compressor.add(zstd_dict); + compressor + } + + #[test] + fn test_auto_compressible_content_type() { + assert!(auto_compressible_content_type("text/plain")); + assert!(auto_compressible_content_type("application/xml")); + assert!(auto_compressible_content_type("application/json")); + assert!(auto_compressible_content_type("application/javascript")); + assert!(!auto_compressible_content_type("image/png")); + } + + #[rstest] + #[test] + fn test_compression_policies( + #[values( + CompressionPolicy::Auto, + CompressionPolicy::None, + CompressionPolicy::ForceZstd + )] + compression_policy: CompressionPolicy, + #[values("text/plain", "application/json", "image/png")] content_type: &str, + ) { + let compressor = make_compressor_with_policy(compression_policy); + let data = b"hello, world!".to_vec(); + let (compression_id, compressed) = compressor + .compress("dict1", content_type, data.clone()) + .unwrap(); + + let data_uncompressed = compressor.decompress(compression_id, compressed).unwrap(); + assert_eq!(data_uncompressed, data); + } + + #[test] + fn test_skip_compressing_small_data() { + let compressor = make_compressor(); + let data = b"hello, world".to_vec(); + let (compression_id, compressed) = compressor + .compress("dict1", "text/plain", data.clone()) + .unwrap(); + assert_eq!(compression_id, CompressionId::None); + assert_eq!(compressed, data); + } +} diff --git a/src/handlers/get_handler.rs b/src/handlers/get_handler.rs index 47e3caa..f04927a 100644 --- a/src/handlers/get_handler.rs +++ b/src/handlers/get_handler.rs @@ -1,16 +1,22 @@ -use crate::{sha256::Sha256, shard::GetResult, shards::Shards}; +use crate::{ + compressor::CompressorArc, + sha256::Sha256, + shard::{GetArgs, GetResult}, + shards::Shards, + AsyncBoxError, +}; use axum::{ extract::Path, http::{header, HeaderMap, HeaderValue, StatusCode}, response::IntoResponse, Extension, Json, }; -use std::{collections::HashMap, error::Error}; +use std::{collections::HashMap, sync::Arc}; pub enum GetResponse { MissingSha256, InvalidSha256 { message: String }, - InternalError { error: Box }, + InternalError { error: AsyncBoxError }, NotFound, Found { get_result: GetResult }, } @@ -69,7 +75,7 @@ fn make_found_response( Err(e) => return GetResponse::from(e).into_response(), }; - let created_at = match HeaderValue::from_str(&created_at.to_rfc3339()) { + let created_at = match HeaderValue::from_str(&created_at.to_string()) { Ok(created_at) => created_at, Err(e) => return GetResponse::from(e).into_response(), }; @@ -98,7 +104,7 @@ fn make_found_response( (StatusCode::OK, headers, data).into_response() } -impl>> From for GetResponse { +impl> From for GetResponse { fn from(error: E) -> Self { GetResponse::InternalError { error: error.into(), @@ -109,7 +115,8 @@ impl>> From for GetResponse { #[axum::debug_handler] pub async fn get_handler( Path(params): Path>, - Extension(shards): Extension, + Extension(shards): Extension>, + Extension(compressor): Extension, ) -> GetResponse { let sha256_str = match params.get("sha256") { Some(sha256_str) => sha256_str.clone(), @@ -128,7 +135,7 @@ pub async fn get_handler( }; let shard = shards.shard_for(&sha256); - let get_result = match shard.get(sha256).await { + let get_result = match shard.get(GetArgs { sha256, compressor }).await { Ok(get_result) => get_result, Err(e) => return e.into(), }; @@ -141,7 +148,10 @@ pub async fn get_handler( #[cfg(test)] mod test { - use crate::{sha256::Sha256, shard::GetResult, shards::test::make_shards}; + use crate::{ + compressor::test::make_compressor, sha256::Sha256, shard::GetResult, + shards::test::make_shards, sql_types::UtcDateTime, + }; use axum::{extract::Path, response::IntoResponse, Extension}; use std::collections::HashMap; @@ -150,13 +160,16 @@ mod test { #[tokio::test] async fn test_get_invalid_sha256() { let shards = Extension(make_shards().await); - let response = super::get_handler(Path(HashMap::new()), shards.clone()).await; + let compressor = Extension(make_compressor().into_arc()); + + let response = + super::get_handler(Path(HashMap::new()), shards.clone(), compressor.clone()).await; assert!(matches!(response, super::GetResponse::MissingSha256 { .. })); - let shards = Extension(make_shards().await); let response = super::get_handler( Path(HashMap::from([(String::from("sha256"), String::from(""))])), shards.clone(), + compressor.clone(), ) .await; assert!(matches!(response, super::GetResponse::InvalidSha256 { .. })); @@ -167,6 +180,7 @@ mod test { String::from("invalid"), )])), shards.clone(), + compressor.clone(), ) .await; assert!(matches!(response, super::GetResponse::InvalidSha256 { .. })); @@ -174,8 +188,8 @@ mod test { #[test] fn test_get_response_found_into_response() { - let data = "hello, world!"; - let sha256 = Sha256::from_bytes(data.as_bytes()); + let data = "hello, world!".as_bytes().to_owned(); + let sha256 = Sha256::from_bytes(&data); let sha256_str = sha256.hex_string(); let created_at = "2022-03-04T08:12:34+00:00"; let response = GetResponse::Found { @@ -183,9 +197,7 @@ mod test { sha256, content_type: "text/plain".to_string(), stored_size: 12345, - created_at: chrono::DateTime::parse_from_rfc3339(created_at) - .unwrap() - .to_utc(), + created_at: UtcDateTime::from_string(created_at).unwrap(), data: data.into(), }, } diff --git a/src/handlers/info_handler.rs b/src/handlers/info_handler.rs index aadeb9d..d334e9c 100644 --- a/src/handlers/info_handler.rs +++ b/src/handlers/info_handler.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::shards::Shards; use axum::{http::StatusCode, Extension, Json}; @@ -19,7 +21,7 @@ pub struct ShardInfo { #[axum::debug_handler] pub async fn info_handler( - Extension(shards): Extension, + Extension(shards): Extension>, ) -> Result<(StatusCode, Json), StatusCode> { let mut shard_infos = vec![]; let mut total_db_size_bytes = 0; diff --git a/src/handlers/store_handler.rs b/src/handlers/store_handler.rs index 3048409..4d6ee1f 100644 --- a/src/handlers/store_handler.rs +++ b/src/handlers/store_handler.rs @@ -1,7 +1,8 @@ use crate::{ + compressor::CompressorArc, sha256::Sha256, shard::{StoreArgs, StoreResult}, - shards::Shards, + shards::ShardsArc, }; use axum::{body::Bytes, response::IntoResponse, Extension, Json}; use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; @@ -59,7 +60,7 @@ impl From for StoreResponse { } => StoreResponse::Created { stored_size, data_size, - created_at: created_at.to_rfc3339(), + created_at: created_at.to_string(), }, StoreResult::Exists { stored_size, @@ -68,7 +69,7 @@ impl From for StoreResponse { } => StoreResponse::Exists { stored_size, data_size, - created_at: created_at.to_rfc3339(), + created_at: created_at.to_string(), }, } } @@ -82,7 +83,8 @@ impl IntoResponse for StoreResponse { #[axum::debug_handler] pub async fn store_handler( - Extension(shards): Extension, + Extension(shards): Extension, + Extension(compressor): Extension, TypedMultipart(request): TypedMultipart, ) -> StoreResponse { let sha256 = Sha256::from_bytes(&request.data.contents); @@ -106,6 +108,7 @@ pub async fn store_handler( sha256, content_type: request.content_type, data: request.data.contents, + compressor, }) .await { @@ -118,20 +121,23 @@ pub async fn store_handler( #[cfg(test)] pub mod test { + use crate::{compressor::Compressor, shards::test::make_shards}; + use super::*; - use crate::{shards::test::make_shards_with_compression, UseCompression}; + use crate::CompressionPolicy; use axum::body::Bytes; use axum_typed_multipart::FieldData; use rstest::rstest; async fn send_request>( + compression_policy: CompressionPolicy, sha256: Option, content_type: &str, - use_compression: UseCompression, data: D, ) -> StoreResponse { store_handler( - Extension(make_shards_with_compression(use_compression).await), + Extension(make_shards().await), + Extension(Compressor::new(compression_policy).into_arc()), TypedMultipart(StoreRequest { sha256: sha256.map(|s| s.hex_string()), content_type: content_type.to_string(), @@ -146,8 +152,9 @@ pub mod test { #[tokio::test] async fn test_store_handler() { - let result = send_request(None, "text/plain", UseCompression::Auto, "hello, world!").await; - assert_eq!(result.status_code(), StatusCode::CREATED); + let result = + send_request(CompressionPolicy::Auto, None, "text/plain", "hello, world!").await; + assert_eq!(result.status_code(), StatusCode::CREATED, "{:?}", result); assert!(matches!(result, StoreResponse::Created { .. })); } @@ -156,9 +163,9 @@ pub mod test { let not_hello_world = Sha256::from_bytes("goodbye, planet!".as_bytes()); let hello_world = Sha256::from_bytes("hello, world!".as_bytes()); let result = send_request( + CompressionPolicy::Auto, Some(not_hello_world), "text/plain", - UseCompression::Auto, "hello, world!", ) .await; @@ -175,9 +182,9 @@ pub mod test { async fn test_store_handler_matching_sha256() { let hello_world = Sha256::from_bytes("hello, world!".as_bytes()); let result = send_request( + CompressionPolicy::Auto, Some(hello_world), "text/plain", - UseCompression::Auto, "hello, world!", ) .await; @@ -194,20 +201,20 @@ pub mod test { #[rstest] // textual should be compressed by default - #[case("text/plain", UseCompression::Auto, make_assert_lt(1024))] - #[case("text/plain", UseCompression::Zstd, make_assert_lt(1024))] - #[case("text/plain", UseCompression::None, make_assert_eq(1024))] + #[case("text/plain", CompressionPolicy::Auto, make_assert_lt(1024))] + #[case("text/plain", CompressionPolicy::ForceZstd, make_assert_lt(1024))] + #[case("text/plain", CompressionPolicy::None, make_assert_eq(1024))] // images, etc should not be compressed by default - #[case("image/jpg", UseCompression::Auto, make_assert_eq(1024))] - #[case("image/jpg", UseCompression::Zstd, make_assert_lt(1024))] - #[case("image/jpg", UseCompression::None, make_assert_eq(1024))] + #[case("image/jpg", CompressionPolicy::Auto, make_assert_eq(1024))] + #[case("image/jpg", CompressionPolicy::ForceZstd, make_assert_lt(1024))] + #[case("image/jpg", CompressionPolicy::None, make_assert_eq(1024))] #[tokio::test] async fn test_compressible_data( #[case] content_type: &str, - #[case] use_compression: UseCompression, + #[case] compression_policy: CompressionPolicy, #[case] assert_stored_size: F, ) { - let result = send_request(None, content_type, use_compression, vec![0; 1024]).await; + let result = send_request(compression_policy, None, content_type, vec![0; 1024]).await; assert_eq!(result.status_code(), StatusCode::CREATED); match result { StoreResponse::Created { diff --git a/src/main.rs b/src/main.rs index ce7ae0a..27ab61d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,13 @@ +mod compressible_data; +mod compressor; mod handlers; mod manifest; mod sha256; mod shard; mod shards; mod shutdown_signal; +mod sql_types; +mod zstd_dict; use crate::{manifest::Manifest, shards::Shards}; use axum::{ @@ -11,10 +15,12 @@ use axum::{ Extension, Router, }; use clap::{Parser, ValueEnum}; +use compressor::CompressorArc; use futures::executor::block_on; use shard::Shard; -use std::{error::Error, path::PathBuf}; -use tokio::net::TcpListener; +use shards::ShardsArc; +use std::{error::Error, path::PathBuf, sync::Arc}; +use tokio::{net::TcpListener, select, spawn}; use tokio_rusqlite::Connection; use tracing::info; @@ -39,18 +45,24 @@ struct Args { /// How to compress stored data #[arg(short, long, default_value = "auto")] - compression: UseCompression, + compression: CompressionPolicy, } #[derive(Default, PartialEq, Debug, Copy, Clone, ValueEnum)] -pub enum UseCompression { +pub enum CompressionPolicy { #[default] Auto, None, - Zstd, + ForceZstd, } -fn main() -> Result<(), Box> { +pub type AsyncBoxError = Box; + +pub fn into_tokio_rusqlite_err>(e: E) -> tokio_rusqlite::Error { + tokio_rusqlite::Error::Other(e.into()) +} + +fn main() -> Result<(), AsyncBoxError> { tracing_subscriber::fmt() .with_max_level(tracing::Level::DEBUG) .init(); @@ -85,7 +97,7 @@ fn main() -> Result<(), Box> { for shard_id in 0..manifest.num_shards() { let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id)); let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?; - let shard = Shard::open(shard_id, UseCompression::Auto, shard_sqlite_conn).await?; + let shard = Shard::open(shard_id, shard_sqlite_conn).await?; info!( "shard {} has {} entries", shard.id(), @@ -94,23 +106,46 @@ fn main() -> Result<(), Box> { shards_vec.push(shard); } - let shards = Shards::new(shards_vec).ok_or("num shards must be > 0")?; - server_loop(server, shards.clone()).await?; - info!("shutting down server..."); - shards.close_all().await?; + let shards = Arc::new(Shards::new(shards_vec).ok_or("num shards must be > 0")?); + let compressor = manifest.compressor(); + let dict_loop_handle = spawn(dict_loop(manifest, shards.clone())); + let server_handle = spawn(server_loop(server, shards, compressor)); + dict_loop_handle.await?; + server_handle.await??; info!("server closed sqlite connections. bye!"); - Ok::<_, Box>(()) + Ok::<_, AsyncBoxError>(()) })?; Ok(()) } -async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box> { +async fn dict_loop(manifest: Manifest, shards: ShardsArc) { + loop { + select! { + _ = tokio::time::sleep(tokio::time::Duration::from_secs(1)) => {} + _ = crate::shutdown_signal::shutdown_signal() => { + info!("dict loop: shutdown signal received"); + break; + } + } + info!("dict loop: running"); + let compressor = manifest.compressor(); + let _compressor = compressor.read().await; + for _shard in shards.iter() {} + } +} + +async fn server_loop( + server: TcpListener, + shards: ShardsArc, + compressor: CompressorArc, +) -> Result<(), AsyncBoxError> { let app = Router::new() .route("/store", post(handlers::store_handler::store_handler)) .route("/get/:sha256", get(handlers::get_handler::get_handler)) .route("/info", get(handlers::info_handler::info_handler)) - .layer(Extension(shards)); + .layer(Extension(shards)) + .layer(Extension(compressor)); axum::serve(server, app.into_make_service()) .with_graceful_shutdown(crate::shutdown_signal::shutdown_signal()) diff --git a/src/manifest/dict_id.rs b/src/manifest/dict_id.rs deleted file mode 100644 index 0042e54..0000000 --- a/src/manifest/dict_id.rs +++ /dev/null @@ -1,9 +0,0 @@ -use rusqlite::types::FromSql; - -#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)] -pub struct DictId(i64); -impl FromSql for DictId { - fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult { - Ok(DictId(value.as_i64()?)) - } -} diff --git a/src/manifest/mod.rs b/src/manifest/mod.rs index 61aa09f..a2eeca5 100644 --- a/src/manifest/mod.rs +++ b/src/manifest/mod.rs @@ -1,99 +1,81 @@ -mod dict_id; mod manifest_key; -mod zstd_dict; -use std::{collections::HashMap, error::Error, sync::Arc}; +use std::{error::Error, sync::Arc}; use rusqlite::params; +use tokio::sync::RwLock; use tokio_rusqlite::Connection; -use crate::shards::Shards; - -use self::{ - dict_id::DictId, - manifest_key::{get_manifest_key, set_manifest_key, NumShards}, - zstd_dict::ZstdDict, +use crate::{ + compressor::{Compressor, CompressorArc}, + into_tokio_rusqlite_err, + zstd_dict::ZstdDictArc, + AsyncBoxError, }; -pub type ZstdDictArc = Arc; +use self::manifest_key::{get_manifest_key, set_manifest_key, NumShards}; pub struct Manifest { conn: Connection, num_shards: usize, - zstd_dict_by_id: HashMap, - zstd_dict_by_name: HashMap, + compressor: Arc>, } +pub type ManifestArc = Arc; + impl Manifest { - pub async fn open(conn: Connection, num_shards: Option) -> Result> { + pub async fn open(conn: Connection, num_shards: Option) -> Result { initialize(conn, num_shards).await } + + pub fn into_arc(self) -> ManifestArc { + Arc::new(self) + } + pub fn num_shards(&self) -> usize { self.num_shards } - async fn train_zstd_dict_with_tag( - &mut self, - _name: &str, - _shards: Shards, - ) -> Result> { - // let mut queries = vec![]; - // for shard in shards.iter() { - // queries.push(shard.entries_for_tag(name)); - // } - todo!(); + pub fn compressor(&self) -> CompressorArc { + self.compressor.clone() } - async fn create_zstd_dict_from_samples( - &mut self, - name: &str, + pub async fn insert_zstd_dict_from_samples>( + &self, + name: Str, samples: Vec<&[u8]>, ) -> Result> { - if self.zstd_dict_by_name.contains_key(name) { - return Err(format!("dictionary {} already exists", name).into()); - } - - let level = 3; + let name = name.into(); let dict_bytes = zstd::dict::from_samples( &samples, 1024 * 1024, // 1MB max dictionary size - ) - .unwrap(); - - let name_copy = name.to_string(); - let (dict_id, zstd_dict) = self + )?; + let compressor = self.compressor.clone(); + let zstd_dict = self .conn .call(move |conn| { + let level = 3; let mut stmt = conn.prepare( "INSERT INTO dictionaries (name, level, dict) - VALUES (?, ?, ?) - RETURNING id", + VALUES (?, ?, ?) + RETURNING id", )?; - let dict_id = - stmt.query_row(params![name_copy, level, dict_bytes], |row| row.get(0))?; - let zstd_dict = Arc::new(ZstdDict::create(dict_id, name_copy, level, dict_bytes)); - Ok((dict_id, zstd_dict)) + let dict_id = stmt.query_row(params![name, level, dict_bytes], |row| row.get(0))?; + let mut compressor = compressor.blocking_write(); + let zstd_dict = compressor + .add_from_bytes(dict_id, name, level, dict_bytes) + .map_err(into_tokio_rusqlite_err)?; + Ok(zstd_dict) }) .await?; - - self.zstd_dict_by_id.insert(dict_id, zstd_dict.clone()); - self.zstd_dict_by_name - .insert(name.to_string(), zstd_dict.clone()); Ok(zstd_dict) } - - fn get_dictionary_by_id(&self, id: DictId) -> Option<&ZstdDict> { - self.zstd_dict_by_id.get(&id).map(|d| &**d) - } - fn get_dictionary_by_name(&self, name: &str) -> Option<&ZstdDict> { - self.zstd_dict_by_name.get(name).map(|d| &**d) - } } async fn initialize( conn: Connection, num_shards: Option, -) -> Result> { +) -> Result> { let stored_num_shards: Option = conn .call(|conn| { conn.execute( @@ -161,19 +143,15 @@ async fn initialize( }) .await?; - let mut zstd_dicts_by_id = HashMap::new(); - let mut zstd_dicts_by_name = HashMap::new(); + let mut compressor = Compressor::default(); for (id, name, level, dict_bytes) in rows { - let zstd_dict = Arc::new(ZstdDict::create(id, name.clone(), level, dict_bytes)); - zstd_dicts_by_id.insert(id, zstd_dict.clone()); - zstd_dicts_by_name.insert(name, zstd_dict); + compressor.add_from_bytes(id, name, level, dict_bytes)?; } - + let compressor = compressor.into_arc(); Ok(Manifest { conn, num_shards, - zstd_dict_by_id: zstd_dicts_by_id, - zstd_dict_by_name: zstd_dicts_by_name, + compressor, }) } @@ -184,22 +162,32 @@ mod tests { #[tokio::test] async fn test_manifest() { let conn = Connection::open_in_memory().await.unwrap(); - let mut manifest = initialize(conn, Some(3)).await.unwrap(); + let manifest = initialize(conn, Some(3)).await.unwrap(); let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100]; let zstd_dict = manifest - .create_zstd_dict_from_samples("test", samples) + .insert_zstd_dict_from_samples("test", samples) .await .unwrap(); // test that indexes are created correctly assert_eq!( zstd_dict.as_ref(), - manifest.get_dictionary_by_id(zstd_dict.id()).unwrap() + manifest + .compressor() + .read() + .await + .by_id(zstd_dict.id()) + .unwrap() ); assert_eq!( zstd_dict.as_ref(), - manifest.get_dictionary_by_name(zstd_dict.name()).unwrap() + manifest + .compressor() + .read() + .await + .by_name(zstd_dict.name()) + .unwrap() ); let data = b"hello world, this is a test of a sort of long string"; diff --git a/src/shard/fn_get.rs b/src/shard/fn_get.rs index 48f4e16..f455ba0 100644 --- a/src/shard/fn_get.rs +++ b/src/shard/fn_get.rs @@ -1,73 +1,71 @@ -use std::io::Read; +use crate::{ + compressible_data::CompressibleData, + compressor::CompressorArc, + into_tokio_rusqlite_err, + sql_types::{CompressionId, UtcDateTime}, + AsyncBoxError, +}; use super::*; +pub struct GetArgs { + pub sha256: Sha256, + pub compressor: CompressorArc, +} + pub struct GetResult { pub sha256: Sha256, pub content_type: String, pub stored_size: usize, pub created_at: UtcDateTime, - pub data: Vec, + pub data: CompressibleData, } impl Shard { - pub async fn get(&self, sha256: Sha256) -> Result, Box> { - self.conn - .call(move |conn| get_impl(conn, sha256)) + pub async fn get(&self, args: GetArgs) -> Result, AsyncBoxError> { + let sha256 = args.sha256; + let maybe_row = self + .conn + .call(move |conn| get_compressed_row(conn, sha256).map_err(into_tokio_rusqlite_err)) .await .map_err(|e| { error!("get failed: {}", e); - e.into() - }) + Box::new(e) + })?; + + if let Some((content_type, stored_size, created_at, compression_id, data)) = maybe_row { + let compressor = args.compressor.read().await; + let data = compressor.decompress(compression_id, data)?; + Ok(Some(GetResult { + sha256: args.sha256, + content_type, + stored_size, + created_at, + data, + })) + } else { + Ok(None) + } } } -fn get_impl( +fn get_compressed_row( conn: &mut rusqlite::Connection, sha256: Sha256, -) -> Result, tokio_rusqlite::Error> { - let maybe_row = conn - .query_row( - "SELECT content_type, compressed_size, created_at, compression, data +) -> Result)>, rusqlite::Error> { + conn.query_row( + "SELECT content_type, compressed_size, created_at, compression_id, data FROM entries WHERE sha256 = ?", - params![sha256.hex_string()], - |row| { - let content_type = row.get(0)?; - let stored_size = row.get(1)?; - let created_at = parse_created_at_str(row.get(2)?)?; - let compression = row.get(3)?; - let data: Vec = row.get(4)?; - Ok((content_type, stored_size, created_at, compression, data)) - }, - ) - .optional() - .map_err(into_tokio_rusqlite_err)?; - - let row = match maybe_row { - Some(row) => row, - None => return Ok(None), - }; - - let (content_type, stored_size, created_at, compression, data) = row; - let data = match compression { - Compression::None => data, - Compression::Zstd => { - let mut decoder = - zstd::Decoder::new(data.as_slice()).map_err(into_tokio_rusqlite_err)?; - let mut decompressed = vec![]; - decoder - .read_to_end(&mut decompressed) - .map_err(into_tokio_rusqlite_err)?; - decompressed - } - }; - - Ok(Some(GetResult { - sha256, - content_type, - stored_size, - created_at, - data, - })) + params![sha256.hex_string()], + |row| { + let content_type = row.get(0)?; + let stored_size = row.get(1)?; + let created_at = row.get(2)?; + let compression_id = row.get(3)?; + let data: Vec = row.get(4)?; + Ok((content_type, stored_size, created_at, compression_id, data)) + }, + ) + .optional() } diff --git a/src/shard/fn_migrate.rs b/src/shard/fn_migrate.rs index 20db125..c937924 100644 --- a/src/shard/fn_migrate.rs +++ b/src/shard/fn_migrate.rs @@ -1,7 +1,9 @@ +use crate::AsyncBoxError; + use super::*; impl Shard { - pub(super) async fn migrate(&self) -> Result<(), Box> { + pub(super) async fn migrate(&self) -> Result<(), AsyncBoxError> { let shard_id = self.id(); // create tables, indexes, etc self.conn @@ -60,7 +62,7 @@ fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Err "CREATE TABLE IF NOT EXISTS entries ( sha256 BLOB PRIMARY KEY, content_type TEXT NOT NULL, - compression INTEGER NOT NULL, + compression_id INTEGER NOT NULL, uncompressed_size INTEGER NOT NULL, compressed_size INTEGER NOT NULL, data BLOB NOT NULL, diff --git a/src/shard/fn_store.rs b/src/shard/fn_store.rs index cb15409..d4aa36b 100644 --- a/src/shard/fn_store.rs +++ b/src/shard/fn_store.rs @@ -1,3 +1,10 @@ +use crate::{ + compressor::CompressorArc, + into_tokio_rusqlite_err, + sql_types::{CompressionId, UtcDateTime}, + AsyncBoxError, +}; + use super::*; #[derive(PartialEq, Debug)] @@ -19,89 +26,100 @@ pub struct StoreArgs { pub sha256: Sha256, pub content_type: String, pub data: Bytes, + pub compressor: CompressorArc, } impl Shard { - pub async fn store(&self, store_args: StoreArgs) -> Result> { - let use_compression = self.use_compression; - self.conn - .call(move |conn| store(conn, use_compression, store_args)) - .await - .map_err(|e| { - error!("store failed: {}", e); - e.into() + pub async fn store( + &self, + StoreArgs { + sha256, + data, + content_type, + compressor, + }: StoreArgs, + ) -> Result { + let sha256 = sha256.hex_string(); + + // check for existing entry + let sha256_clone = sha256.clone(); + let maybe_existing_entry = self + .conn + .call(move |conn| { + find_with_sha256(conn, sha256_clone.as_str()).map_err(into_tokio_rusqlite_err) }) + .await?; + + if let Some(entry) = maybe_existing_entry { + return Ok(entry); + } + + let uncompressed_size = data.len(); + + let compressor = compressor.read().await; + let (compression_id, data) = compressor.compress("foobar", &content_type, data)?; + + self.conn + .call(move |conn| { + insert( + conn, + sha256, + content_type, + compression_id, + uncompressed_size, + data.as_ref(), + ) + .map_err(into_tokio_rusqlite_err) + }) + .await + .map_err(|e| e.into()) } } -fn store( +fn find_with_sha256( conn: &mut rusqlite::Connection, - use_compression: UseCompression, - StoreArgs { - sha256, - content_type, - data, - }: StoreArgs, -) -> Result { - let sha256 = sha256.hex_string(); + sha256: &str, +) -> Result, rusqlite::Error> { + conn.query_row( + "SELECT uncompressed_size, compressed_size, created_at FROM entries WHERE sha256 = ?", + params![sha256], + |row| { + Ok(StoreResult::Exists { + stored_size: row.get(0)?, + data_size: row.get(1)?, + created_at: row.get(2)?, + }) + }, + ) + .optional() +} - // check for existing entry - let maybe_existing: Option = conn - .query_row( - "SELECT uncompressed_size, compressed_size, created_at FROM entries WHERE sha256 = ?", - params![sha256], - |row| { - Ok(StoreResult::Exists { - stored_size: row.get(0)?, - data_size: row.get(1)?, - created_at: parse_created_at_str(row.get(2)?)?, - }) - }, - ) - .optional()?; - - if let Some(existing) = maybe_existing { - return Ok(existing); - } - - let created_at = chrono::Utc::now(); - let uncompressed_size = data.len(); - let tmp_data_holder; - - let use_compression = match use_compression { - UseCompression::None => false, - UseCompression::Auto => auto_compressible_content_type(&content_type), - UseCompression::Zstd => true, - }; - - let (compression, data) = if use_compression { - tmp_data_holder = zstd::encode_all(&data[..], 0).map_err(into_tokio_rusqlite_err)?; - if tmp_data_holder.len() < data.len() { - (Compression::Zstd, &tmp_data_holder[..]) - } else { - (Compression::None, &data[..]) - } - } else { - (Compression::None, &data[..]) - }; +fn insert( + conn: &mut rusqlite::Connection, + sha256: String, + content_type: String, + compression_id: CompressionId, + uncompressed_size: usize, + data: &[u8], +) -> Result { + let created_at = UtcDateTime::now(); let compressed_size = data.len(); - conn.execute( - "INSERT INTO entries - (sha256, content_type, compression, uncompressed_size, compressed_size, data, created_at) - VALUES - (?, ?, ?, ?, ?, ?, ?) - ", - params![ - sha256, - content_type, - compression, - uncompressed_size, - compressed_size, - data, - created_at.to_rfc3339(), - ], - )?; + conn.execute("INSERT INTO entries + (sha256, content_type, compression_id, uncompressed_size, compressed_size, data, created_at) + VALUES + (?, ?, ?, ?, ?, ?, ?) + ", + params![ + sha256, + content_type, + compression_id, + uncompressed_size, + compressed_size, + data, + created_at, + ], + )?; Ok(StoreResult::Created { stored_size: compressed_size, @@ -109,14 +127,3 @@ fn store( created_at, }) } - -fn auto_compressible_content_type(content_type: &str) -> bool { - [ - "text/", - "application/xml", - "application/json", - "application/javascript", - ] - .iter() - .any(|ct| content_type.starts_with(ct)) -} diff --git a/src/shard/mod.rs b/src/shard/mod.rs index ec20757..5607c47 100644 --- a/src/shard/mod.rs +++ b/src/shard/mod.rs @@ -4,7 +4,7 @@ mod fn_store; mod shard; pub mod shard_error; -pub use fn_get::GetResult; +pub use fn_get::{GetArgs, GetResult}; pub use fn_store::{StoreArgs, StoreResult}; pub use shard::Shard; @@ -13,46 +13,9 @@ pub mod test { pub use super::shard::test::*; } -use crate::{sha256::Sha256, shard::shard_error::ShardError, UseCompression}; +use crate::{sha256::Sha256, shard::shard_error::ShardError}; use axum::body::Bytes; -use rusqlite::{params, types::FromSql, OptionalExtension, ToSql}; -use std::error::Error; +use rusqlite::{params, types::FromSql, OptionalExtension}; + use tokio_rusqlite::Connection; use tracing::{debug, error}; - -pub type UtcDateTime = chrono::DateTime; - -#[derive(Debug, PartialEq, Eq)] -enum Compression { - None, - Zstd, -} -impl ToSql for Compression { - fn to_sql(&self) -> rusqlite::Result> { - match self { - Compression::None => 0.to_sql(), - Compression::Zstd => 1.to_sql(), - } - } -} -impl FromSql for Compression { - fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult { - match value.as_i64()? { - 0 => Ok(Compression::None), - 1 => Ok(Compression::Zstd), - _ => Err(rusqlite::types::FromSqlError::InvalidType), - } - } -} - -fn parse_created_at_str(created_at_str: String) -> Result { - let parsed = chrono::DateTime::parse_from_rfc3339(&created_at_str) - .map_err(|e| rusqlite::Error::ToSqlConversionFailure(e.into()))?; - Ok(parsed.to_utc()) -} - -fn into_tokio_rusqlite_err>>( - e: E, -) -> tokio_rusqlite::Error { - tokio_rusqlite::Error::Other(e.into()) -} diff --git a/src/shard/shard.rs b/src/shard/shard.rs index 16b9395..d8f2673 100644 --- a/src/shard/shard.rs +++ b/src/shard/shard.rs @@ -1,36 +1,25 @@ +use crate::AsyncBoxError; + use super::*; #[derive(Clone)] pub struct Shard { pub(super) id: usize, pub(super) conn: Connection, - pub(super) use_compression: UseCompression, } impl Shard { - pub async fn open( - id: usize, - use_compression: UseCompression, - conn: Connection, - ) -> Result> { - let shard = Self { - id, - use_compression, - conn, - }; + pub async fn open(id: usize, conn: Connection) -> Result { + let shard = Self { id, conn }; shard.migrate().await?; Ok(shard) } - pub async fn close(self) -> Result<(), Box> { - self.conn.close().await.map_err(|e| e.into()) - } - pub fn id(&self) -> usize { self.id } - pub async fn db_size_bytes(&self) -> Result> { + pub async fn db_size_bytes(&self) -> Result { self.query_single_row( "SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()", ) @@ -40,7 +29,7 @@ impl Shard { async fn query_single_row( &self, query: &'static str, - ) -> Result> { + ) -> Result { self.conn .call(move |conn| { let value: T = conn.query_row(query, [], |row| row.get(0))?; @@ -50,7 +39,7 @@ impl Shard { .map_err(|e| e.into()) } - pub async fn num_entries(&self) -> Result> { + pub async fn num_entries(&self) -> Result { get_num_entries(&self.conn).await.map_err(|e| e.into()) } } @@ -65,18 +54,17 @@ async fn get_num_entries(conn: &Connection) -> Result super::Shard { - let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap(); - super::Shard::open(0, use_compression, conn).await.unwrap() - } - pub async fn make_shard() -> super::Shard { - make_shard_with_compression(UseCompression::Auto).await + let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap(); + super::Shard::open(0, conn).await.unwrap() } #[tokio::test] @@ -96,8 +84,9 @@ pub mod test { #[tokio::test] async fn test_not_found_get() { let shard = make_shard().await; + let compressor = make_compressor().into_arc(); let sha256 = Sha256::from_bytes("hello, world!".as_bytes()); - let get_result = shard.get(sha256).await.unwrap(); + let get_result = shard.get(GetArgs { sha256, compressor }).await.unwrap(); assert!(get_result.is_none()); } @@ -106,11 +95,13 @@ pub mod test { let shard = make_shard().await; let data = "hello, world!".as_bytes(); let sha256 = Sha256::from_bytes(data); + let compressor = make_compressor().into_arc(); let store_result = shard .store(StoreArgs { sha256, content_type: "text/plain".to_string(), data: data.into(), + compressor: compressor.clone(), }) .await .unwrap(); @@ -128,7 +119,11 @@ pub mod test { } assert_eq!(shard.num_entries().await.unwrap(), 1); - let get_result = shard.get(sha256).await.unwrap().unwrap(); + let get_result = shard + .get(GetArgs { sha256, compressor }) + .await + .unwrap() + .unwrap(); assert_eq!(get_result.content_type, "text/plain"); assert_eq!(get_result.data, data); assert_eq!(get_result.stored_size, data.len()); @@ -145,6 +140,7 @@ pub mod test { sha256, content_type: "text/plain".to_string(), data: data.into(), + ..Default::default() }) .await .unwrap(); @@ -168,6 +164,7 @@ pub mod test { sha256, content_type: "text/plain".to_string(), data: data.into(), + ..Default::default() }) .await .unwrap(); @@ -185,12 +182,19 @@ pub mod test { #[rstest] #[tokio::test] async fn test_compression_store_get( - #[values(UseCompression::Auto, UseCompression::None, UseCompression::Zstd)] - use_compression: UseCompression, + #[values( + CompressionPolicy::Auto, + CompressionPolicy::None, + CompressionPolicy::ForceZstd + )] + compression_policy: CompressionPolicy, #[values(true, false)] incompressible_data: bool, #[values("text/string", "image/jpg", "application/octet-stream")] content_type: String, ) { - let shard = make_shard_with_compression(use_compression).await; + use crate::compressor::Compressor; + + let shard = make_shard().await; + let compressor = Compressor::new(compression_policy).into_arc(); let mut data = vec![b'.'; 1024]; if incompressible_data { for byte in data.iter_mut() { @@ -204,12 +208,17 @@ pub mod test { sha256, content_type: content_type.clone(), data: data.clone().into(), + ..Default::default() }) .await .unwrap(); assert!(matches!(store_result, StoreResult::Created { .. })); - let get_result = shard.get(sha256).await.unwrap().unwrap(); + let get_result = shard + .get(GetArgs { sha256, compressor }) + .await + .unwrap() + .unwrap(); assert_eq!(get_result.content_type, content_type); assert_eq!(get_result.data, data); } diff --git a/src/shards.rs b/src/shards.rs index 21cba22..160368e 100644 --- a/src/shards.rs +++ b/src/shards.rs @@ -1,5 +1,8 @@ +use std::sync::Arc; + use crate::{sha256::Sha256, shard::Shard}; -use std::error::Error; + +pub type ShardsArc = Arc; #[derive(Clone)] pub struct Shards(Vec); @@ -16,13 +19,6 @@ impl Shards { &self.0[shard_id] } - pub async fn close_all(self) -> Result<(), Box> { - for shard in self.0 { - shard.close().await?; - } - Ok(()) - } - pub fn iter(&self) -> std::slice::Iter<'_, Shard> { self.0.iter() } @@ -34,15 +30,13 @@ impl Shards { #[cfg(test)] pub mod test { - use crate::{shard::test::make_shard_with_compression, UseCompression}; + use std::sync::Arc; - use super::Shards; + use crate::shard::test::make_shard; - pub async fn make_shards_with_compression(use_compression: UseCompression) -> Shards { - Shards::new(vec![make_shard_with_compression(use_compression).await]).unwrap() - } + use super::{Shards, ShardsArc}; - pub async fn make_shards() -> Shards { - make_shards_with_compression(UseCompression::Auto).await + pub async fn make_shards() -> ShardsArc { + Arc::new(Shards::new(vec![make_shard().await]).unwrap()) } } diff --git a/src/sql_types/compression_id.rs b/src/sql_types/compression_id.rs new file mode 100644 index 0000000..0d7a341 --- /dev/null +++ b/src/sql_types/compression_id.rs @@ -0,0 +1,61 @@ +use rusqlite::{ + types::{FromSql, FromSqlError, FromSqlResult, ToSqlOutput, Value::Integer, ValueRef}, + Error::ToSqlConversionFailure, + ToSql, +}; + +use crate::AsyncBoxError; + +use super::ZstdDictId; + +#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)] +pub enum CompressionId { + None, + ZstdGeneric, + ZstdDictId(ZstdDictId), +} + +impl FromSql for CompressionId { + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + Ok(match value.as_i64()? { + -1 => CompressionId::None, + -2 => CompressionId::ZstdGeneric, + id => CompressionId::ZstdDictId(ZstdDictId(id)), + }) + } +} + +impl ToSql for CompressionId { + fn to_sql(&self) -> rusqlite::Result> { + let value = match self { + CompressionId::None => -1, + CompressionId::ZstdGeneric => -2, + CompressionId::ZstdDictId(ZstdDictId(id)) => *id, + }; + Ok(ToSqlOutput::Owned(Integer(value))) + } +} + +impl FromSql for ZstdDictId { + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + match value.as_i64()? { + id @ (-1 | -2) => Err(FromSqlError::Other(invalid_zstd_dict_id_err(id))), + id => Ok(ZstdDictId(id)), + } + } +} + +impl ToSql for ZstdDictId { + fn to_sql(&self) -> rusqlite::Result> { + let value = match self.0 { + id @ (-1 | -2) => return Err(ToSqlConversionFailure(invalid_zstd_dict_id_err(id))), + id => id, + }; + + Ok(ToSqlOutput::Owned(Integer(value))) + } +} + +fn invalid_zstd_dict_id_err(id: i64) -> AsyncBoxError { + format!("Invalid ZstdDictId: {}", id).into() +} diff --git a/src/sql_types/mod.rs b/src/sql_types/mod.rs new file mode 100644 index 0000000..c53ffa2 --- /dev/null +++ b/src/sql_types/mod.rs @@ -0,0 +1,7 @@ +mod compression_id; +mod utc_date_time; +mod zstd_dict_id; + +pub use compression_id::CompressionId; +pub use utc_date_time::UtcDateTime; +pub use zstd_dict_id::ZstdDictId; diff --git a/src/sql_types/utc_date_time.rs b/src/sql_types/utc_date_time.rs new file mode 100644 index 0000000..650712c --- /dev/null +++ b/src/sql_types/utc_date_time.rs @@ -0,0 +1,46 @@ +use chrono::DateTime; +use rusqlite::{ + types::{FromSql, FromSqlError, ToSqlOutput, ValueRef}, + Result, ToSql, +}; + +#[derive(PartialEq, Debug, PartialOrd)] +pub struct UtcDateTime(DateTime); + +impl UtcDateTime { + pub fn now() -> Self { + Self(chrono::Utc::now()) + } + pub fn to_string(&self) -> String { + self.0.to_rfc3339() + } + pub fn from_string(s: &str) -> Result { + Ok(Self(DateTime::parse_from_rfc3339(s)?.to_utc())) + } +} + +impl PartialEq> for UtcDateTime { + fn eq(&self, other: &DateTime) -> bool { + self.0 == *other + } +} + +impl PartialOrd> for UtcDateTime { + fn partial_cmp(&self, other: &DateTime) -> Option { + self.0.partial_cmp(other) + } +} + +impl ToSql for UtcDateTime { + fn to_sql(&self) -> Result> { + Ok(ToSqlOutput::from(self.0.to_rfc3339())) + } +} + +impl FromSql for UtcDateTime { + fn column_result(value: ValueRef<'_>) -> Result { + let parsed = DateTime::parse_from_rfc3339(value.as_str()?) + .map_err(|e| FromSqlError::Other(e.into()))?; + Ok(UtcDateTime(parsed.to_utc())) + } +} diff --git a/src/sql_types/zstd_dict_id.rs b/src/sql_types/zstd_dict_id.rs new file mode 100644 index 0000000..8f1306a --- /dev/null +++ b/src/sql_types/zstd_dict_id.rs @@ -0,0 +1,7 @@ +#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)] +pub struct ZstdDictId(pub i64); +impl From for ZstdDictId { + fn from(id: i64) -> Self { + Self(id) + } +} diff --git a/src/manifest/zstd_dict.rs b/src/zstd_dict.rs similarity index 50% rename from src/manifest/zstd_dict.rs rename to src/zstd_dict.rs index bb9dac1..6a88401 100644 --- a/src/manifest/zstd_dict.rs +++ b/src/zstd_dict.rs @@ -1,11 +1,14 @@ -use super::dict_id::DictId; use ouroboros::self_referencing; -use std::{error::Error, io}; +use std::{error::Error, io, sync::Arc}; use zstd::dict::{DecoderDictionary, EncoderDictionary}; +use crate::{sql_types::ZstdDictId, AsyncBoxError}; + +pub type ZstdDictArc = Arc; + #[self_referencing] pub struct ZstdDict { - id: DictId, + id: crate::sql_types::ZstdDictId, name: String, level: i32, dict_bytes: Vec, @@ -31,7 +34,6 @@ impl PartialEq for ZstdDict { impl std::fmt::Debug for ZstdDict { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("ZstdDict") - .field("id", &self.id()) .field("name", &self.name()) .field("level", &self.level()) .field("dict_bytes.len", &self.dict_bytes().len()) @@ -40,7 +42,20 @@ impl std::fmt::Debug for ZstdDict { } impl ZstdDict { - pub fn create(id: DictId, name: String, level: i32, dict_bytes: Vec) -> Self { + pub fn from_samples( + id: ZstdDictId, + name: String, + level: i32, + samples: Vec<&[u8]>, + ) -> Result> { + let dict_bytes = zstd::dict::from_samples( + &samples, + 1024 * 1024, // 1MB max dictionary size + )?; + Ok(Self::from_dict_bytes(id, name, level, dict_bytes)) + } + + pub fn from_dict_bytes(id: ZstdDictId, name: String, level: i32, dict_bytes: Vec) -> Self { ZstdDictBuilder { id, name, @@ -52,8 +67,8 @@ impl ZstdDict { .build() } - pub fn id(&self) -> DictId { - self.with_id(|id| *id) + pub fn id(&self) -> ZstdDictId { + *self.borrow_id() } pub fn name(&self) -> &str { self.borrow_name() @@ -65,9 +80,10 @@ impl ZstdDict { self.borrow_dict_bytes() } - pub fn compress(&self, data: &[u8]) -> Result, Box> { - let mut wrapper = io::Cursor::new(data); - let mut out_buffer = Vec::with_capacity(data.len()); + pub fn compress>(&self, data: DataRef) -> Result, AsyncBoxError> { + let as_ref = data.as_ref(); + let mut wrapper = io::Cursor::new(as_ref); + let mut out_buffer = Vec::with_capacity(as_ref.len()); let mut output_wrapper = io::Cursor::new(&mut out_buffer); self.with_encoder_dict(|encoder_dict| { @@ -79,9 +95,13 @@ impl ZstdDict { Ok(out_buffer) } - pub fn decompress(&self, data: &[u8]) -> Result, Box> { - let mut wrapper = io::Cursor::new(data); - let mut out_buffer = Vec::with_capacity(data.len()); + pub fn decompress>( + &self, + data: DataRef, + ) -> Result, AsyncBoxError> { + let as_ref = data.as_ref(); + let mut wrapper = io::Cursor::new(as_ref); + let mut out_buffer = Vec::with_capacity(as_ref.len()); let mut output_wrapper = io::Cursor::new(&mut out_buffer); self.with_decoder_dict(|decoder_dict| { @@ -92,3 +112,35 @@ impl ZstdDict { Ok(out_buffer) } } + +#[cfg(test)] +pub mod test { + use crate::sql_types::ZstdDictId; + + pub fn make_zstd_dict(id: ZstdDictId, name: &str) -> super::ZstdDict { + super::ZstdDict::from_dict_bytes( + id, + name.to_owned(), + 3, + vec![ + "hello, world", + "this is a test", + "of the emergency broadcast system", + ] + .into_iter() + .chain(vec!["foo", "bar", "baz"].repeat(100)) + .map(|s| s.as_bytes().to_owned()) + .flat_map(|s| s.into_iter()) + .collect(), + ) + } + + #[test] + fn test_zstd_dict() { + let dict_bytes = vec![1, 2, 3, 4]; + let zstd_dict = make_zstd_dict(1.into(), "dict1"); + let compressed = zstd_dict.compress(b"hello world").unwrap(); + let decompressed = zstd_dict.decompress(&compressed).unwrap(); + assert_eq!(decompressed, b"hello world"); + } +}