diff --git a/src/compression_manager.rs b/src/compression_manager.rs index 1d1b00d..63857ef 100644 --- a/src/compression_manager.rs +++ b/src/compression_manager.rs @@ -1,20 +1,17 @@ use crate::{ compressible_data::CompressibleData, compression_stats::{CompressionStat, CompressionStats}, - compressor::{BrotliGenericCompressor, Compressor, ZstdDictCompressor, ZstdGenericCompressor}, - sql_types::{CompressionId, ZstdDictId}, - zstd_dict::{ZstdDict, ZstdDictArc}, + compressor::{BrotliGenericCompressor, Compressor, ZstdGenericCompressor}, + sql_types::CompressionId, AsyncBoxError, CompressionPolicy, }; -use std::{collections::HashMap, sync::Arc}; +use std::sync::Arc; use tokio::sync::{RwLock, RwLockReadGuard}; pub type CompressionManagerArc = Arc>; pub type CompressionStatsArc = Arc>; pub struct CompressionManager { - zstd_dict_by_id: HashMap, - zstd_dict_by_name: HashMap>, compression_policy: CompressionPolicy, compression_stats: CompressionStatsArc, } @@ -22,8 +19,6 @@ pub struct CompressionManager { impl CompressionManager { pub fn new(compression_policy: CompressionPolicy) -> Self { Self { - zstd_dict_by_id: HashMap::new(), - zstd_dict_by_name: HashMap::new(), compression_policy, compression_stats: Arc::new(RwLock::new(CompressionStats::default())), } @@ -52,43 +47,6 @@ impl CompressionManager { *compression_stats.for_id_mut(id) = stat; } - pub fn add_zstd_dict(&mut self, zstd_dict: ZstdDict) -> Result { - self.check(zstd_dict.id())?; - let zstd_dict = Arc::new(zstd_dict); - - self.zstd_dict_by_id - .insert(zstd_dict.id(), zstd_dict.clone()); - - let name = zstd_dict.name(); - if let Some(by_name_list) = self.zstd_dict_by_name.get_mut(name) { - by_name_list.push(zstd_dict.clone()); - } else { - self.zstd_dict_by_name - .insert(name.to_string(), vec![zstd_dict.clone()]); - } - Ok(zstd_dict) - } - - pub fn by_id(&self, id: ZstdDictId) -> Option<&ZstdDict> { - self.zstd_dict_by_id.get(&id).map(|d| &**d) - } - - #[cfg(test)] - pub fn by_name(&self, name: &str) -> Option<&ZstdDict> { - if let Some(by_name_list) = self.zstd_dict_by_name.get(name) { - // take a random element - Some( - &by_name_list[rand::Rng::gen_range(&mut rand::thread_rng(), 0..by_name_list.len())], - ) - } else { - None - } - } - - pub fn names(&self) -> impl Iterator { - self.zstd_dict_by_name.keys() - } - pub async fn compress>( &self, content_type: &str, @@ -152,10 +110,6 @@ impl CompressionManager { fn get_compressor(&self, compression_id: CompressionId) -> Result> { Ok(match compression_id { CompressionId::Zstd => Box::new(ZstdGenericCompressor), - CompressionId::ZstdDict(zstd_dict_id) => { - let zstd_dict = self.zstd_dict_by_id.get(&zstd_dict_id).unwrap().clone(); - Box::new(ZstdDictCompressor::new(zstd_dict)) - } CompressionId::Brotli => Box::new(BrotliGenericCompressor), CompressionId::None => return Err("no compressor for CompressionId::None".into()), }) @@ -169,13 +123,6 @@ impl CompressionManager { let data = data.into(); match compression_id { CompressionId::None => Ok(data), - CompressionId::ZstdDict(id) => { - if let Some(dict) = self.by_id(id) { - Ok(CompressibleData::Vec(dict.decompress(data.as_ref())?)) - } else { - Err(format!("zstd dictionary {:?} not found", id).into()) - } - } CompressionId::Zstd => Ok(CompressibleData::Vec(zstd::stream::decode_all( data.as_ref(), )?)), @@ -191,13 +138,6 @@ impl CompressionManager { pub async fn stats(&self) -> RwLockReadGuard { self.compression_stats.read().await } - - fn check(&self, id: ZstdDictId) -> Result<()> { - if self.zstd_dict_by_id.contains_key(&id) { - return Err(format!("zstd dictionary id {} already exists", id.0).into()); - } - Ok(()) - } } fn auto_compressible_content_type(content_type: &str) -> bool { @@ -213,22 +153,15 @@ fn auto_compressible_content_type(content_type: &str) -> bool { #[cfg(test)] pub mod test { - use std::collections::HashSet; - - use rstest::rstest; - use super::*; - use crate::zstd_dict::test::make_zstd_dict; + use rstest::rstest; pub fn make_compressor() -> CompressionManager { make_compressor_with(CompressionPolicy::Auto) } pub fn make_compressor_with(compression_policy: CompressionPolicy) -> CompressionManager { - let mut compressor = CompressionManager::new(compression_policy); - let zstd_dict = make_zstd_dict(1.into(), "dict1"); - compressor.add_zstd_dict(zstd_dict).unwrap(); - compressor + CompressionManager::new(compression_policy) } #[test] @@ -286,18 +219,4 @@ pub mod test { assert_ne!(compressed, data); assert!(compressed.len() < data.len()); } - - #[tokio::test] - async fn test_multiple_dicts_same_name() { - let mut compressor = make_compressor(); - let zstd_dict = make_zstd_dict(2.into(), "dict1"); - compressor.add_zstd_dict(zstd_dict).unwrap(); - - let mut seen_ids = HashSet::new(); - for _ in 0..1000 { - let zstd_dict = compressor.by_name("dict1").unwrap(); - seen_ids.insert(zstd_dict.id()); - } - assert_eq!(seen_ids, [1, 2].iter().copied().map(ZstdDictId).collect()); - } } diff --git a/src/compressor/compressor_trait.rs b/src/compressor/compressor_trait.rs index fb31f73..2a1d4ef 100644 --- a/src/compressor/compressor_trait.rs +++ b/src/compressor/compressor_trait.rs @@ -7,12 +7,9 @@ pub trait Compressor: Send { #[cfg(test)] pub(self) mod test { - use rstest::*; - use super::*; - use crate::{ - compressor::*, into_arc::IntoArc, sql_types::ZstdDictId, zstd_dict::test::make_zstd_dict, - }; + use crate::compressor::*; + use rstest::*; fn random_bytes(len: usize) -> Vec { (0..len).map(|_| rand::random::()).collect() @@ -21,7 +18,6 @@ pub(self) mod test { #[rstest] #[case(BrotliGenericCompressor)] #[case(ZstdGenericCompressor)] - #[case(ZstdDictCompressor::new(make_zstd_dict(ZstdDictId(1), "test_dict").into_arc()))] fn test_compressor( #[case] compressor: impl Compressor, #[values( diff --git a/src/compressor/mod.rs b/src/compressor/mod.rs index bace09f..cb2af94 100644 --- a/src/compressor/mod.rs +++ b/src/compressor/mod.rs @@ -1,9 +1,7 @@ mod brotli_generic_compressor; mod compressor_trait; -mod zstd_dict_compressor; mod zstd_generic_compressor; pub use brotli_generic_compressor::BrotliGenericCompressor; pub use compressor_trait::Compressor; -pub use zstd_dict_compressor::ZstdDictCompressor; pub use zstd_generic_compressor::ZstdGenericCompressor; diff --git a/src/compressor/zstd_dict_compressor.rs b/src/compressor/zstd_dict_compressor.rs deleted file mode 100644 index 9cec95f..0000000 --- a/src/compressor/zstd_dict_compressor.rs +++ /dev/null @@ -1,20 +0,0 @@ -use super::compressor_trait::Compressor; -use crate::{compressible_data::CompressibleData, zstd_dict::ZstdDictArc, AsyncBoxError}; - -pub struct ZstdDictCompressor(ZstdDictArc); - -impl ZstdDictCompressor { - pub fn new(zstd_dict: ZstdDictArc) -> Self { - Self(zstd_dict) - } -} - -impl Compressor for ZstdDictCompressor { - fn compress(&self, data: &CompressibleData) -> Result { - Ok(self.0.compress(data)?.into()) - } - - fn decompress(&self, data: &CompressibleData) -> Result { - Ok(self.0.decompress(data)?.into()) - } -} diff --git a/src/handlers/store_handler.rs b/src/handlers/store_handler.rs index a4d6593..d921d77 100644 --- a/src/handlers/store_handler.rs +++ b/src/handlers/store_handler.rs @@ -13,7 +13,6 @@ use tracing::error; pub struct StoreRequest { pub sha256: Option, pub content_type: String, - pub compression_hint: Option, pub data: FieldData, } @@ -106,7 +105,6 @@ pub async fn store_handler( sha256, content_type: request.content_type, data: request.data.contents, - compression_hint: request.compression_hint, }) .await { @@ -138,7 +136,6 @@ pub mod test { TypedMultipart(StoreRequest { sha256: sha256.map(|s| s.hex_string()), content_type: content_type.to_string(), - compression_hint: None, data: FieldData { metadata: Default::default(), contents: data.into(), diff --git a/src/loops/create_new_dicts_loop.rs b/src/loops/create_new_dicts_loop.rs deleted file mode 100644 index 1749fbb..0000000 --- a/src/loops/create_new_dicts_loop.rs +++ /dev/null @@ -1,68 +0,0 @@ -use std::sync::Arc; - -use tokio::select; -use tracing::{debug, info}; - -use crate::{manifest::Manifest, shards::ShardsArc, AsyncBoxError}; - -pub async fn create_new_dicts_loop( - manifest: Arc, - shards: ShardsArc, -) -> Result<(), AsyncBoxError> { - return Ok(()); - - loop { - let hint_names = shards.hint_names().await?; - let mut new_hint_names = hint_names.clone(); - { - // find what hint names don't have a corresponding dictionary - let compressor = manifest.compression_manager(); - let compressor = compressor.read().await; - compressor.names().for_each(|name| { - new_hint_names.remove(name); - }); - } - - for hint_name in new_hint_names { - let samples = shards.samples_for_hint_name(&hint_name, 100).await?; - let num_samples = samples.len(); - let bytes_samples = samples.iter().map(|s| s.data.len()).sum::(); - let bytes_samples_human = humansize::format_size(bytes_samples, humansize::BINARY); - if num_samples < 10 || bytes_samples < 1024 { - debug!( - "skipping dictionary for {} - not enough samples: ({} samples, {})", - hint_name, num_samples, bytes_samples_human - ); - continue; - } - - debug!("building dictionary for {}", hint_name); - let now = chrono::Utc::now(); - let ref_samples = samples.iter().map(|s| s.data.as_ref()).collect::>(); - let dict = manifest - .insert_zstd_dict_from_samples(&hint_name, ref_samples) - .await?; - - let duration = chrono::Utc::now() - now; - let bytes_dict_human = - humansize::format_size(dict.dict_bytes().len(), humansize::BINARY); - debug!( - "built dictionary {} ({}) in {}ms: {} samples / {} sample bytes, {} dict size", - dict.id(), - hint_name, - duration.num_milliseconds(), - num_samples, - bytes_samples_human, - bytes_dict_human, - ); - } - - select! { - _ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => {} - _ = crate::shutdown_signal::shutdown_signal() => { - info!("new_dict_loop: shutdown signal received"); - return Ok(()); - } - } - } -} diff --git a/src/loops/mod.rs b/src/loops/mod.rs index 31b5f8e..b9d2659 100644 --- a/src/loops/mod.rs +++ b/src/loops/mod.rs @@ -1,9 +1,7 @@ mod axum_server_loop; -mod create_new_dicts_loop; mod dict_stats_printer_loop; mod save_compression_stats_loop; pub use axum_server_loop::axum_server_loop; -pub use create_new_dicts_loop::create_new_dicts_loop; pub use dict_stats_printer_loop::dict_stats_printer_loop; pub use save_compression_stats_loop::save_compression_stats_loop; diff --git a/src/main.rs b/src/main.rs index 2845cbf..c9bc6a4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,7 +13,6 @@ mod shard; mod shards; mod shutdown_signal; mod sql_types; -mod zstd_dict; use crate::{manifest::Manifest, shards::Shards}; use app_time_formatter::AppTimeFormatter; @@ -129,24 +128,12 @@ fn main() -> Result<(), AsyncBoxError> { ); let compression_manager = manifest.compression_manager(); - { - let compressor = compression_manager.read().await; - debug!( - "loaded compression dictionaries: {:?}", - compressor.names().collect::>() - ); - } - let shards = Arc::new(Shards::new(shards_vec).ok_or("num shards must be > 0")?); let join_handles = vec![ spawn(loops::save_compression_stats_loop(manifest.clone())), spawn(loops::dict_stats_printer_loop( manifest.compression_manager(), )), - spawn(loops::create_new_dicts_loop( - manifest.clone(), - shards.clone(), - )), spawn(loops::axum_server_loop(server, shards, compression_manager)), ]; for handle in join_handles { diff --git a/src/manifest/mod.rs b/src/manifest/mod.rs index 2c9f77d..7f50c13 100644 --- a/src/manifest/mod.rs +++ b/src/manifest/mod.rs @@ -4,8 +4,7 @@ use crate::{ compression_manager::{CompressionManager, CompressionManagerArc}, compression_stats::{CompressionStat, CompressionStats}, concat_lines, into_tokio_rusqlite_err, - sql_types::{CompressionId, ZstdDictId}, - zstd_dict::{ZstdDict, ZstdDictArc, ZstdEncoder, ZSTD_LEVEL}, + sql_types::CompressionId, AsyncBoxError, }; use rusqlite::params; @@ -46,35 +45,6 @@ impl Manifest { .await?; Ok(()) } - - pub async fn insert_zstd_dict_from_samples>( - &self, - name: Str, - samples: Vec<&[u8]>, - ) -> Result { - let name = name.into(); - let encoder = ZstdEncoder::from_samples(ZSTD_LEVEL, samples); - let zstd_dict = self - .conn - .call(move |conn| { - let mut stmt = conn.prepare(concat_lines!( - "INSERT INTO zstd_dictionaries (name, encoder_json)", - "VALUES (?, ?)", - "RETURNING zstd_dict_id" - ))?; - let zstd_dict_id = stmt.query_row(params![name, encoder], |row| row.get(0))?; - Ok(ZstdDict::new(zstd_dict_id, name, encoder)) - }) - .await?; - let mut compressor = self.compressor.write().await; - compressor - .set_stat( - CompressionId::ZstdDict(zstd_dict.id()), - CompressionStat::default(), - ) - .await; - compressor.add_zstd_dict(zstd_dict) - } } async fn migrate_manifest(conn: &Connection) -> Result<(), tokio_rusqlite::Error> { @@ -167,33 +137,6 @@ async fn load_and_store_num_shards( }) } -async fn load_zstd_dicts( - conn: &Connection, - compressor: &mut CompressionManager, -) -> Result<(), AsyncBoxError> { - type ZstdDictRow = (ZstdDictId, String, ZstdEncoder); - let rows: Vec = conn - .call(|conn| { - let rows = conn - .prepare(concat_lines!( - "SELECT", - " zstd_dict_id, name, encoder_json", - "FROM zstd_dictionaries", - ))? - .query_map([], |row| ZstdDictRow::try_from(row))? - .collect::>()?; - Ok(rows) - }) - .await?; - - debug!("loaded {} zstd dictionaries from manifest", rows.len()); - for (zstd_dict_id, name, encoder) in rows { - compressor.add_zstd_dict(ZstdDict::new(zstd_dict_id, name, encoder))?; - } - - Ok(()) -} - async fn load_compression_stats( conn: &Connection, compressor: &mut CompressionManager, @@ -264,7 +207,6 @@ fn save_compression_stats( async fn load_compressor(conn: &Connection) -> Result { let mut compressor = CompressionManager::default(); - load_zstd_dicts(conn, &mut compressor).await?; load_compression_stats(conn, &mut compressor).await?; Ok(compressor.into_arc()) } @@ -294,45 +236,7 @@ mod tests { let compressor = manifest.compression_manager(); let compressor = compressor.read().await; let stats = compressor.stats().await; - assert_eq!(stats.len(), 2); - assert_eq!(stats.iter().count(), 2); - } - - #[tokio::test] - async fn test_manifest() { - let conn = Connection::open_in_memory().await.unwrap(); - let manifest = initialize(conn, Some(4)).await.unwrap(); - - let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100]; - let zstd_dict = manifest - .insert_zstd_dict_from_samples("test", samples) - .await - .unwrap(); - - // test that indexes are created correctly - assert_eq!( - zstd_dict.as_ref(), - manifest - .compression_manager() - .read() - .await - .by_id(zstd_dict.id()) - .unwrap() - ); - assert_eq!( - zstd_dict.as_ref(), - manifest - .compression_manager() - .read() - .await - .by_name(zstd_dict.name()) - .unwrap() - ); - - let data = b"hello world, this is a test of a sort of long string"; - let compressed = zstd_dict.compress(data).unwrap(); - let decompressed = zstd_dict.decompress(&compressed).unwrap(); - assert_eq!(decompressed, data); - assert!(data.len() > compressed.len()); + assert_eq!(stats.len(), 3); + assert_eq!(stats.iter().count(), 3); } } diff --git a/src/shard/fn_hint_names.rs b/src/shard/fn_hint_names.rs deleted file mode 100644 index af330c2..0000000 --- a/src/shard/fn_hint_names.rs +++ /dev/null @@ -1,15 +0,0 @@ -use super::Shard; -use crate::AsyncBoxError; - -impl Shard { - pub async fn hint_names(&self) -> Result, AsyncBoxError> { - self.conn - .call(|conn| { - let mut stmt = conn.prepare("SELECT DISTINCT name FROM compression_hints")?; - let rows = stmt.query_map([], |row| row.get(0))?; - Ok(rows.collect::, _>>()?) - }) - .await - .map_err(Into::into) - } -} diff --git a/src/shard/fn_samples_for_hint.rs b/src/shard/fn_samples_for_hint.rs deleted file mode 100644 index 62b91e4..0000000 --- a/src/shard/fn_samples_for_hint.rs +++ /dev/null @@ -1,37 +0,0 @@ -use super::Shard; -use crate::{into_tokio_rusqlite_err, sha256::Sha256, AsyncBoxError}; -use rusqlite::params; - -pub struct SampleForHintResult { - pub sha256: Sha256, - pub data: Vec, -} - -impl Shard { - pub async fn samples_for_hint_name( - &self, - hint_name: &str, - limit: usize, - ) -> Result, AsyncBoxError> { - let hint_name = hint_name.to_owned(); - let result = self - .conn - .call(move |conn| { - let mut stmt = conn.prepare( - "SELECT sha256, data FROM entries WHERE id IN ( - SELECT entry_id FROM compression_hints WHERE name = ?1 ORDER BY ordering - ) LIMIT ?2", - )?; - let rows = stmt.query_map(params![hint_name, limit], |row| { - let sha256: Sha256 = row.get(0)?; - let data: Vec = row.get(1)?; - Ok(SampleForHintResult { sha256, data }) - })?; - rows.collect::, _>>() - .map_err(into_tokio_rusqlite_err) - }) - .await?; - - Ok(result) - } -} diff --git a/src/shard/fn_store.rs b/src/shard/fn_store.rs index 54c0df5..80d685b 100644 --- a/src/shard/fn_store.rs +++ b/src/shard/fn_store.rs @@ -2,7 +2,7 @@ use super::*; use crate::{ compressible_data::CompressibleData, concat_lines, into_tokio_rusqlite_err, - sql_types::{CompressionId, EntryId, UtcDateTime}, + sql_types::{CompressionId, UtcDateTime}, AsyncBoxError, }; @@ -25,7 +25,6 @@ pub struct StoreArgs { pub sha256: Sha256, pub content_type: String, pub data: Bytes, - pub compression_hint: Option, } impl Shard { @@ -35,7 +34,6 @@ impl Shard { sha256, data, content_type, - compression_hint, }: StoreArgs, ) -> Result { let existing_entry = self @@ -63,7 +61,6 @@ impl Shard { dict_id, uncompressed_size, data, - compression_hint, ) .map_err(into_tokio_rusqlite_err) }) @@ -101,18 +98,15 @@ fn insert( dict_id: CompressionId, uncompressed_size: usize, data: CompressibleData, - compression_hint: Option, ) -> Result { let created_at = UtcDateTime::now(); let compressed_size = data.len(); - let tx = conn.transaction()?; - let entry_id: EntryId = tx.query_row( + conn.execute( concat_lines!( "INSERT INTO entries", " (sha256, content_type, dict_id, uncompressed_size, compressed_size, data, created_at)", "VALUES (?, ?, ?, ?, ?, ?, ?)", - "RETURNING id" ), params![ sha256, @@ -123,21 +117,8 @@ fn insert( data.as_ref(), created_at, ], - |row| row.get(0), )?; - if let Some(compression_hint) = compression_hint { - let rand_ordering = rand::random::(); - tx.execute( - concat_lines!( - "INSERT INTO compression_hints (name, ordering, entry_id)", - "VALUES (?, ?, ?)" - ), - params![compression_hint, rand_ordering, entry_id], - )?; - } - tx.commit()?; - Ok(StoreResult::Created { stored_size: compressed_size, data_size: uncompressed_size, diff --git a/src/shard/mod.rs b/src/shard/mod.rs index 7195cad..5b0d634 100644 --- a/src/shard/mod.rs +++ b/src/shard/mod.rs @@ -1,13 +1,10 @@ mod fn_get; -mod fn_hint_names; mod fn_migrate; -mod fn_samples_for_hint; mod fn_store; pub mod shard_error; mod shard_struct; pub use fn_get::{GetArgs, GetResult}; -pub use fn_samples_for_hint::SampleForHintResult; pub use fn_store::{StoreArgs, StoreResult}; pub use shard_struct::Shard; diff --git a/src/shard/shard_struct.rs b/src/shard/shard_struct.rs index f46bc08..99ae756 100644 --- a/src/shard/shard_struct.rs +++ b/src/shard/shard_struct.rs @@ -62,8 +62,6 @@ async fn get_num_entries(conn: &Connection) -> Result super::Shard { let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap(); super::Shard::open(0, conn, compressor).await.unwrap() @@ -228,109 +223,4 @@ pub mod test { assert_eq!(get_result.content_type, content_type); assert_eq!(get_result.data, data); } - - #[tokio::test] - async fn test_compression_hint() { - let shard = make_shard().await; - let data = "hello, world!".as_bytes(); - let sha256 = Sha256::from_bytes(data); - let store_result = shard - .store(StoreArgs { - sha256, - content_type: "text/plain".to_string(), - data: data.into(), - compression_hint: Some("hint1".to_string()), - }) - .await - .unwrap(); - assert!(matches!(store_result, StoreResult::Created { .. })); - - let results = shard.samples_for_hint_name("hint1", 10).await.unwrap(); - assert_eq!(results.len(), 1); - assert_eq!(results[0].sha256, sha256); - assert_eq!(results[0].data, data); - } - - async fn store_random(shard: &Shard, hint: &str) -> (Sha256, Vec) { - let data = (0..1024).map(|_| rand::random::()).collect::>(); - let sha256 = Sha256::from_bytes(&data); - let store_result = shard - .store(StoreArgs { - sha256, - content_type: "text/plain".to_string(), - data: data.clone().into(), - compression_hint: Some(hint.to_string()), - }) - .await - .unwrap(); - assert!(matches!(store_result, StoreResult::Created { .. })); - (sha256, data) - } - - // TODO - remove this test - // #[tokio::test] - async fn test_compression_hint_limits() { - let get_keys_set = |hash_map: &HashMap>| { - hash_map - .keys() - .into_iter() - .map(|k| *k) - .collect::>() - }; - - let insert_num = 50; - let sample_num = 10; - let shard = make_shard().await; - let mut hint1 = HashMap::new(); - let mut hint2 = HashMap::new(); - - for _ in 0..insert_num { - let (a, b) = store_random(&shard, "hint1").await; - hint1.insert(a, b); - } - - for _ in 0..insert_num { - let (a, b) = store_random(&shard, "hint2").await; - hint2.insert(a, b); - } - - assert_eq!(hint1.len(), insert_num); - assert_eq!(hint2.len(), insert_num); - - let hint1samples = shard - .samples_for_hint_name("hint1", sample_num) - .await - .unwrap() - .into_iter() - .map(|r| (r.sha256, r.data)) - .collect::>(); - - let hint2samples = shard - .samples_for_hint_name("hint2", sample_num) - .await - .unwrap() - .into_iter() - .map(|r| (r.sha256, r.data)) - .collect::>(); - - let hint1_keys = get_keys_set(&hint1); - let hint2_keys = get_keys_set(&hint2); - let hint1samples_keys = get_keys_set(&hint1samples); - let hint2samples_keys = get_keys_set(&hint2samples); - - assert_eq!(hint1_keys.len(), insert_num); - assert_eq!(hint2_keys.len(), insert_num); - assert!(hint1_keys.is_disjoint(&hint2_keys)); - assert!( - hint1samples_keys.is_disjoint(&hint2samples_keys), - "hint1: {:?}, hint2: {:?}", - hint1samples_keys, - hint2samples_keys - ); - assert_eq!(hint1samples.len(), sample_num); - assert_eq!(hint2samples.len(), sample_num); - assert_eq!(hint2samples_keys.len(), sample_num); - assert!(hint1_keys.is_superset(&hint1samples_keys)); - assert!(hint2_keys.is_superset(&hint2samples_keys)); - } } diff --git a/src/shards.rs b/src/shards.rs index 355277c..3fa39c8 100644 --- a/src/shards.rs +++ b/src/shards.rs @@ -1,10 +1,5 @@ -use crate::{ - sha256::Sha256, - shard::{SampleForHintResult, Shard}, - AsyncBoxError, -}; -use std::{collections::HashSet, sync::Arc}; -use tokio::task::JoinSet; +use crate::{sha256::Sha256, shard::Shard}; +use std::sync::Arc; pub type ShardsArc = Arc; @@ -30,36 +25,6 @@ impl Shards { pub fn len(&self) -> usize { self.0.len() } - - pub async fn hint_names(&self) -> Result, AsyncBoxError> { - let mut hint_names = HashSet::new(); - for shard in self.iter() { - hint_names.extend(shard.hint_names().await?); - } - Ok(hint_names) - } - - pub async fn samples_for_hint_name( - &self, - hint_name: &str, - limit: usize, - ) -> Result, AsyncBoxError> { - let mut tasks = JoinSet::new(); - - for shard in self.0.iter() { - let shard = shard.clone(); - let hint_name = hint_name.to_owned(); - tasks.spawn(async move { shard.samples_for_hint_name(&hint_name, limit).await }); - } - - let mut hints: Vec = Vec::new(); - while let Some(result) = tasks.join_next().await { - let result = result??; - hints.extend(result); - } - - Ok(hints) - } } #[cfg(test)] diff --git a/src/sql_types/compression_id.rs b/src/sql_types/compression_id.rs index 7ee5ded..b7a3ec2 100644 --- a/src/sql_types/compression_id.rs +++ b/src/sql_types/compression_id.rs @@ -1,4 +1,3 @@ -use super::ZstdDictId; use crate::AsyncBoxError; use rusqlite::{ types::{FromSql, FromSqlError, FromSqlResult, ToSqlOutput, ValueRef}, @@ -9,13 +8,11 @@ use rusqlite::{ pub enum CompressionId { None, Zstd, - ZstdDict(ZstdDictId), Brotli, } const NONE_PREFIX: &str = "none"; const ZSTD_PREFIX: &str = "zstd"; -const ZSTD_DICT_PREFIX: &str = "zstd_dict"; const BROTLI_PREFIX: &str = "brotli"; impl CompressionId { @@ -23,31 +20,17 @@ impl CompressionId { match self { CompressionId::None => NONE_PREFIX, CompressionId::Zstd => ZSTD_PREFIX, - CompressionId::ZstdDict(_) => ZSTD_DICT_PREFIX, CompressionId::Brotli => BROTLI_PREFIX, } } fn from_str(id_as_str: &str) -> Result { match id_as_str { - NONE_PREFIX => return Ok(CompressionId::None), - ZSTD_PREFIX => return Ok(CompressionId::Zstd), - BROTLI_PREFIX => return Ok(CompressionId::Brotli), - _ => {} - }; - - let parse_int = |s: &str| -> Result { - Ok(s.parse() - .map_err(|e| format!("invalid i64 ({}): {}", id_as_str, e))?) - }; - - if let Some(int_str) = id_as_str.strip_prefix(ZSTD_DICT_PREFIX) { - return Ok(CompressionId::ZstdDict(ZstdDictId(parse_int( - &int_str[1..], - )?))); + NONE_PREFIX => Ok(CompressionId::None), + ZSTD_PREFIX => Ok(CompressionId::Zstd), + BROTLI_PREFIX => Ok(CompressionId::Brotli), + _ => Err(format!("invalid DictId: {}", id_as_str).into()), } - - Err(format!("invalid DictId: {}", id_as_str).into()) } } @@ -56,7 +39,6 @@ impl ToString for CompressionId { match self { CompressionId::None => NONE_PREFIX.to_string(), CompressionId::Zstd => ZSTD_PREFIX.to_string(), - CompressionId::ZstdDict(id) => format!("{}:{}", ZSTD_DICT_PREFIX, id.0), CompressionId::Brotli => BROTLI_PREFIX.to_string(), } } @@ -64,12 +46,7 @@ impl ToString for CompressionId { impl ToSql for CompressionId { fn to_sql(&self) -> rusqlite::Result> { - let prefix = self.prefix(); - if let CompressionId::ZstdDict(id) = self { - Ok(ToSqlOutput::from(format!("{}:{}", prefix, id.0))) - } else { - prefix.to_sql() - } + self.prefix().to_sql() } } @@ -88,7 +65,6 @@ mod tests { #[rstest( case(CompressionId::None, "none"), case(CompressionId::Zstd, "zstd"), - case(CompressionId::ZstdDict(ZstdDictId(42)), "zstd_dict:42"), case(CompressionId::Brotli, "brotli") )] #[test] diff --git a/src/sql_types/mod.rs b/src/sql_types/mod.rs index 03fcab1..346bf34 100644 --- a/src/sql_types/mod.rs +++ b/src/sql_types/mod.rs @@ -1,9 +1,5 @@ mod compression_id; -mod entry_id; mod utc_date_time; -mod zstd_dict_id; pub use compression_id::CompressionId; -pub use entry_id::EntryId; pub use utc_date_time::UtcDateTime; -pub use zstd_dict_id::ZstdDictId; diff --git a/src/sql_types/zstd_dict_id.rs b/src/sql_types/zstd_dict_id.rs deleted file mode 100644 index 9d3c9a6..0000000 --- a/src/sql_types/zstd_dict_id.rs +++ /dev/null @@ -1,31 +0,0 @@ -use rusqlite::{ - types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}, - ToSql, -}; -use std::fmt::Display; - -#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, PartialOrd, Ord)] -pub struct ZstdDictId(pub i64); -impl From for ZstdDictId { - fn from(id: i64) -> Self { - Self(id) - } -} - -impl Display for ZstdDictId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "zstd_dict:{}", self.0) - } -} - -impl FromSql for ZstdDictId { - fn column_result(value: ValueRef<'_>) -> FromSqlResult { - Ok(ZstdDictId(value.as_i64()?)) - } -} - -impl ToSql for ZstdDictId { - fn to_sql(&self) -> rusqlite::Result> { - self.0.to_sql() - } -} diff --git a/src/zstd_dict.rs b/src/zstd_dict.rs deleted file mode 100644 index c63615d..0000000 --- a/src/zstd_dict.rs +++ /dev/null @@ -1,216 +0,0 @@ -use crate::{sql_types::ZstdDictId, AsyncBoxError}; -use ouroboros::self_referencing; -use rusqlite::{ - types::{FromSql, FromSqlError, FromSqlResult, ToSqlOutput, Value, ValueRef}, - Error::ToSqlConversionFailure, - ToSql, -}; -use serde::{Deserialize, Serialize}; -use std::{io, sync::Arc}; -use zstd::dict::{DecoderDictionary, EncoderDictionary}; - -const ENCODER_DICT_SIZE: usize = 2 * 1024 * 1024; -pub const ZSTD_LEVEL: i32 = 9; - -pub type ZstdDictArc = Arc; - -#[self_referencing] -pub struct ZstdEncoder { - level: i32, - dict_bytes: Vec, - - #[borrows(dict_bytes)] - #[not_covariant] - encoder_dict: EncoderDictionary<'this>, - - #[borrows(dict_bytes)] - #[not_covariant] - decoder_dict: DecoderDictionary<'this>, -} - -impl Serialize for ZstdEncoder { - fn serialize(&self, serializer: S) -> Result { - use serde::ser::SerializeTuple; - - // serialize level and dict_bytes as a tuple - let mut state = serializer.serialize_tuple(2)?; - state.serialize_element(&self.level())?; - state.serialize_element(self.dict_bytes())?; - state.end() - } -} - -impl<'de> Deserialize<'de> for ZstdEncoder { - fn deserialize>(deserializer: D) -> Result { - use serde::de::SeqAccess; - - struct ZstdEncoderVisitor; - impl<'de> serde::de::Visitor<'de> for ZstdEncoderVisitor { - type Value = ZstdEncoder; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a tuple of (i32, Vec)") - } - - fn visit_seq>(self, mut seq: A) -> Result { - let level = seq - .next_element()? - .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?; - let dict_bytes = seq - .next_element()? - .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?; - Ok(ZstdEncoder::from_dict_bytes(level, dict_bytes)) - } - } - - deserializer.deserialize_tuple(2, ZstdEncoderVisitor) - } -} - -impl ToSql for ZstdEncoder { - fn to_sql(&self) -> rusqlite::Result> { - let json = serde_json::to_string(self).map_err(|e| ToSqlConversionFailure(Box::new(e)))?; - Ok(ToSqlOutput::Owned(Value::Text(json))) - } -} - -impl FromSql for ZstdEncoder { - fn column_result(value: ValueRef<'_>) -> FromSqlResult { - let json = value.as_str()?; - serde_json::from_str(json).map_err(|e| FromSqlError::Other(e.into())) - } -} - -impl ZstdEncoder { - pub fn from_samples(level: i32, samples: Vec<&[u8]>) -> Self { - let dict_bytes = zstd::dict::from_samples(&samples, ENCODER_DICT_SIZE).unwrap(); - Self::from_dict_bytes(level, dict_bytes) - } - - pub fn from_dict_bytes(level: i32, dict_bytes: Vec) -> Self { - ZstdEncoderBuilder { - level, - dict_bytes, - encoder_dict_builder: |dict_bytes| EncoderDictionary::new(dict_bytes, level), - decoder_dict_builder: |dict_bytes| DecoderDictionary::new(dict_bytes), - } - .build() - } - - pub fn level(&self) -> i32 { - *self.borrow_level() - } - pub fn dict_bytes(&self) -> &[u8] { - self.borrow_dict_bytes() - } -} - -pub struct ZstdDict { - id: ZstdDictId, - name: String, - encoder: ZstdEncoder, -} - -impl PartialEq for ZstdDict { - fn eq(&self, other: &Self) -> bool { - self.id() == other.id() - && self.name() == other.name() - && self.level() == other.level() - && self.dict_bytes() == other.dict_bytes() - } -} - -impl std::fmt::Debug for ZstdDict { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ZstdDict") - .field("name", &self.name()) - .field("level", &self.level()) - .field("dict_bytes.len", &self.dict_bytes().len()) - .finish() - } -} - -impl ZstdDict { - pub fn new(id: ZstdDictId, name: String, encoder: ZstdEncoder) -> Self { - Self { id, name, encoder } - } - - pub fn id(&self) -> ZstdDictId { - self.id - } - pub fn name(&self) -> &str { - &self.name - } - pub fn level(&self) -> i32 { - *self.encoder.borrow_level() - } - pub fn dict_bytes(&self) -> &[u8] { - self.encoder.borrow_dict_bytes() - } - - pub fn compress>(&self, data: DataRef) -> Result, AsyncBoxError> { - let as_ref = data.as_ref(); - let mut wrapper = io::Cursor::new(as_ref); - let mut out_buffer = Vec::with_capacity(as_ref.len()); - let mut output_wrapper = io::Cursor::new(&mut out_buffer); - - self.encoder.with_encoder_dict(|encoder_dict| { - let mut encoder = - zstd::stream::Encoder::with_prepared_dictionary(&mut output_wrapper, encoder_dict)?; - io::copy(&mut wrapper, &mut encoder)?; - encoder.finish() - })?; - Ok(out_buffer) - } - - pub fn decompress>( - &self, - data: DataRef, - ) -> Result, AsyncBoxError> { - let as_ref = data.as_ref(); - let mut wrapper = io::Cursor::new(as_ref); - let mut out_buffer = Vec::with_capacity(as_ref.len()); - let mut output_wrapper = io::Cursor::new(&mut out_buffer); - - self.encoder.with_decoder_dict(|decoder_dict| { - let mut decoder = - zstd::stream::Decoder::with_prepared_dictionary(&mut wrapper, decoder_dict)?; - io::copy(&mut decoder, &mut output_wrapper) - })?; - Ok(out_buffer) - } -} - -#[cfg(test)] -pub mod test { - use crate::sql_types::ZstdDictId; - - use super::ZstdEncoder; - - pub fn make_zstd_dict(id: ZstdDictId, name: &str) -> super::ZstdDict { - super::ZstdDict::new( - id, - name.to_owned(), - ZstdEncoder::from_samples( - 5, - vec![ - "hello, world", - "this is a test", - "of the emergency broadcast system", - ] - .into_iter() - .chain(vec!["foo", "bar", "baz"].repeat(100)) - .map(|s| s.as_bytes()) - .collect(), - ), - ) - } - - #[test] - fn test_zstd_dict_basics() { - let zstd_dict = make_zstd_dict(1.into(), "dict1"); - let compressed = zstd_dict.compress(b"hello world").unwrap(); - let decompressed = zstd_dict.decompress(&compressed).unwrap(); - assert_eq!(decompressed, b"hello world"); - } -}