remove zstd dicts & hints

This commit is contained in:
Dylan Knutson
2024-06-12 13:39:49 -07:00
parent f44d884761
commit eef6811625
19 changed files with 19 additions and 802 deletions

View File

@@ -1,20 +1,17 @@
use crate::{
compressible_data::CompressibleData,
compression_stats::{CompressionStat, CompressionStats},
compressor::{BrotliGenericCompressor, Compressor, ZstdDictCompressor, ZstdGenericCompressor},
sql_types::{CompressionId, ZstdDictId},
zstd_dict::{ZstdDict, ZstdDictArc},
compressor::{BrotliGenericCompressor, Compressor, ZstdGenericCompressor},
sql_types::CompressionId,
AsyncBoxError, CompressionPolicy,
};
use std::{collections::HashMap, sync::Arc};
use std::sync::Arc;
use tokio::sync::{RwLock, RwLockReadGuard};
pub type CompressionManagerArc = Arc<RwLock<CompressionManager>>;
pub type CompressionStatsArc = Arc<RwLock<CompressionStats>>;
pub struct CompressionManager {
zstd_dict_by_id: HashMap<ZstdDictId, ZstdDictArc>,
zstd_dict_by_name: HashMap<String, Vec<ZstdDictArc>>,
compression_policy: CompressionPolicy,
compression_stats: CompressionStatsArc,
}
@@ -22,8 +19,6 @@ pub struct CompressionManager {
impl CompressionManager {
pub fn new(compression_policy: CompressionPolicy) -> Self {
Self {
zstd_dict_by_id: HashMap::new(),
zstd_dict_by_name: HashMap::new(),
compression_policy,
compression_stats: Arc::new(RwLock::new(CompressionStats::default())),
}
@@ -52,43 +47,6 @@ impl CompressionManager {
*compression_stats.for_id_mut(id) = stat;
}
pub fn add_zstd_dict(&mut self, zstd_dict: ZstdDict) -> Result<ZstdDictArc> {
self.check(zstd_dict.id())?;
let zstd_dict = Arc::new(zstd_dict);
self.zstd_dict_by_id
.insert(zstd_dict.id(), zstd_dict.clone());
let name = zstd_dict.name();
if let Some(by_name_list) = self.zstd_dict_by_name.get_mut(name) {
by_name_list.push(zstd_dict.clone());
} else {
self.zstd_dict_by_name
.insert(name.to_string(), vec![zstd_dict.clone()]);
}
Ok(zstd_dict)
}
pub fn by_id(&self, id: ZstdDictId) -> Option<&ZstdDict> {
self.zstd_dict_by_id.get(&id).map(|d| &**d)
}
#[cfg(test)]
pub fn by_name(&self, name: &str) -> Option<&ZstdDict> {
if let Some(by_name_list) = self.zstd_dict_by_name.get(name) {
// take a random element
Some(
&by_name_list[rand::Rng::gen_range(&mut rand::thread_rng(), 0..by_name_list.len())],
)
} else {
None
}
}
pub fn names(&self) -> impl Iterator<Item = &String> {
self.zstd_dict_by_name.keys()
}
pub async fn compress<Data: Into<CompressibleData>>(
&self,
content_type: &str,
@@ -152,10 +110,6 @@ impl CompressionManager {
fn get_compressor(&self, compression_id: CompressionId) -> Result<Box<dyn Compressor>> {
Ok(match compression_id {
CompressionId::Zstd => Box::new(ZstdGenericCompressor),
CompressionId::ZstdDict(zstd_dict_id) => {
let zstd_dict = self.zstd_dict_by_id.get(&zstd_dict_id).unwrap().clone();
Box::new(ZstdDictCompressor::new(zstd_dict))
}
CompressionId::Brotli => Box::new(BrotliGenericCompressor),
CompressionId::None => return Err("no compressor for CompressionId::None".into()),
})
@@ -169,13 +123,6 @@ impl CompressionManager {
let data = data.into();
match compression_id {
CompressionId::None => Ok(data),
CompressionId::ZstdDict(id) => {
if let Some(dict) = self.by_id(id) {
Ok(CompressibleData::Vec(dict.decompress(data.as_ref())?))
} else {
Err(format!("zstd dictionary {:?} not found", id).into())
}
}
CompressionId::Zstd => Ok(CompressibleData::Vec(zstd::stream::decode_all(
data.as_ref(),
)?)),
@@ -191,13 +138,6 @@ impl CompressionManager {
pub async fn stats(&self) -> RwLockReadGuard<CompressionStats> {
self.compression_stats.read().await
}
fn check(&self, id: ZstdDictId) -> Result<()> {
if self.zstd_dict_by_id.contains_key(&id) {
return Err(format!("zstd dictionary id {} already exists", id.0).into());
}
Ok(())
}
}
fn auto_compressible_content_type(content_type: &str) -> bool {
@@ -213,22 +153,15 @@ fn auto_compressible_content_type(content_type: &str) -> bool {
#[cfg(test)]
pub mod test {
use std::collections::HashSet;
use rstest::rstest;
use super::*;
use crate::zstd_dict::test::make_zstd_dict;
use rstest::rstest;
pub fn make_compressor() -> CompressionManager {
make_compressor_with(CompressionPolicy::Auto)
}
pub fn make_compressor_with(compression_policy: CompressionPolicy) -> CompressionManager {
let mut compressor = CompressionManager::new(compression_policy);
let zstd_dict = make_zstd_dict(1.into(), "dict1");
compressor.add_zstd_dict(zstd_dict).unwrap();
compressor
CompressionManager::new(compression_policy)
}
#[test]
@@ -286,18 +219,4 @@ pub mod test {
assert_ne!(compressed, data);
assert!(compressed.len() < data.len());
}
#[tokio::test]
async fn test_multiple_dicts_same_name() {
let mut compressor = make_compressor();
let zstd_dict = make_zstd_dict(2.into(), "dict1");
compressor.add_zstd_dict(zstd_dict).unwrap();
let mut seen_ids = HashSet::new();
for _ in 0..1000 {
let zstd_dict = compressor.by_name("dict1").unwrap();
seen_ids.insert(zstd_dict.id());
}
assert_eq!(seen_ids, [1, 2].iter().copied().map(ZstdDictId).collect());
}
}

View File

@@ -7,12 +7,9 @@ pub trait Compressor: Send {
#[cfg(test)]
pub(self) mod test {
use rstest::*;
use super::*;
use crate::{
compressor::*, into_arc::IntoArc, sql_types::ZstdDictId, zstd_dict::test::make_zstd_dict,
};
use crate::compressor::*;
use rstest::*;
fn random_bytes(len: usize) -> Vec<u8> {
(0..len).map(|_| rand::random::<u8>()).collect()
@@ -21,7 +18,6 @@ pub(self) mod test {
#[rstest]
#[case(BrotliGenericCompressor)]
#[case(ZstdGenericCompressor)]
#[case(ZstdDictCompressor::new(make_zstd_dict(ZstdDictId(1), "test_dict").into_arc()))]
fn test_compressor(
#[case] compressor: impl Compressor,
#[values(

View File

@@ -1,9 +1,7 @@
mod brotli_generic_compressor;
mod compressor_trait;
mod zstd_dict_compressor;
mod zstd_generic_compressor;
pub use brotli_generic_compressor::BrotliGenericCompressor;
pub use compressor_trait::Compressor;
pub use zstd_dict_compressor::ZstdDictCompressor;
pub use zstd_generic_compressor::ZstdGenericCompressor;

View File

@@ -1,20 +0,0 @@
use super::compressor_trait::Compressor;
use crate::{compressible_data::CompressibleData, zstd_dict::ZstdDictArc, AsyncBoxError};
pub struct ZstdDictCompressor(ZstdDictArc);
impl ZstdDictCompressor {
pub fn new(zstd_dict: ZstdDictArc) -> Self {
Self(zstd_dict)
}
}
impl Compressor for ZstdDictCompressor {
fn compress(&self, data: &CompressibleData) -> Result<CompressibleData, AsyncBoxError> {
Ok(self.0.compress(data)?.into())
}
fn decompress(&self, data: &CompressibleData) -> Result<CompressibleData, AsyncBoxError> {
Ok(self.0.decompress(data)?.into())
}
}

View File

@@ -13,7 +13,6 @@ use tracing::error;
pub struct StoreRequest {
pub sha256: Option<String>,
pub content_type: String,
pub compression_hint: Option<String>,
pub data: FieldData<Bytes>,
}
@@ -106,7 +105,6 @@ pub async fn store_handler(
sha256,
content_type: request.content_type,
data: request.data.contents,
compression_hint: request.compression_hint,
})
.await
{
@@ -138,7 +136,6 @@ pub mod test {
TypedMultipart(StoreRequest {
sha256: sha256.map(|s| s.hex_string()),
content_type: content_type.to_string(),
compression_hint: None,
data: FieldData {
metadata: Default::default(),
contents: data.into(),

View File

@@ -1,68 +0,0 @@
use std::sync::Arc;
use tokio::select;
use tracing::{debug, info};
use crate::{manifest::Manifest, shards::ShardsArc, AsyncBoxError};
pub async fn create_new_dicts_loop(
manifest: Arc<Manifest>,
shards: ShardsArc,
) -> Result<(), AsyncBoxError> {
return Ok(());
loop {
let hint_names = shards.hint_names().await?;
let mut new_hint_names = hint_names.clone();
{
// find what hint names don't have a corresponding dictionary
let compressor = manifest.compression_manager();
let compressor = compressor.read().await;
compressor.names().for_each(|name| {
new_hint_names.remove(name);
});
}
for hint_name in new_hint_names {
let samples = shards.samples_for_hint_name(&hint_name, 100).await?;
let num_samples = samples.len();
let bytes_samples = samples.iter().map(|s| s.data.len()).sum::<usize>();
let bytes_samples_human = humansize::format_size(bytes_samples, humansize::BINARY);
if num_samples < 10 || bytes_samples < 1024 {
debug!(
"skipping dictionary for {} - not enough samples: ({} samples, {})",
hint_name, num_samples, bytes_samples_human
);
continue;
}
debug!("building dictionary for {}", hint_name);
let now = chrono::Utc::now();
let ref_samples = samples.iter().map(|s| s.data.as_ref()).collect::<Vec<_>>();
let dict = manifest
.insert_zstd_dict_from_samples(&hint_name, ref_samples)
.await?;
let duration = chrono::Utc::now() - now;
let bytes_dict_human =
humansize::format_size(dict.dict_bytes().len(), humansize::BINARY);
debug!(
"built dictionary {} ({}) in {}ms: {} samples / {} sample bytes, {} dict size",
dict.id(),
hint_name,
duration.num_milliseconds(),
num_samples,
bytes_samples_human,
bytes_dict_human,
);
}
select! {
_ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => {}
_ = crate::shutdown_signal::shutdown_signal() => {
info!("new_dict_loop: shutdown signal received");
return Ok(());
}
}
}
}

View File

@@ -1,9 +1,7 @@
mod axum_server_loop;
mod create_new_dicts_loop;
mod dict_stats_printer_loop;
mod save_compression_stats_loop;
pub use axum_server_loop::axum_server_loop;
pub use create_new_dicts_loop::create_new_dicts_loop;
pub use dict_stats_printer_loop::dict_stats_printer_loop;
pub use save_compression_stats_loop::save_compression_stats_loop;

View File

@@ -13,7 +13,6 @@ mod shard;
mod shards;
mod shutdown_signal;
mod sql_types;
mod zstd_dict;
use crate::{manifest::Manifest, shards::Shards};
use app_time_formatter::AppTimeFormatter;
@@ -129,24 +128,12 @@ fn main() -> Result<(), AsyncBoxError> {
);
let compression_manager = manifest.compression_manager();
{
let compressor = compression_manager.read().await;
debug!(
"loaded compression dictionaries: {:?}",
compressor.names().collect::<Vec<_>>()
);
}
let shards = Arc::new(Shards::new(shards_vec).ok_or("num shards must be > 0")?);
let join_handles = vec![
spawn(loops::save_compression_stats_loop(manifest.clone())),
spawn(loops::dict_stats_printer_loop(
manifest.compression_manager(),
)),
spawn(loops::create_new_dicts_loop(
manifest.clone(),
shards.clone(),
)),
spawn(loops::axum_server_loop(server, shards, compression_manager)),
];
for handle in join_handles {

View File

@@ -4,8 +4,7 @@ use crate::{
compression_manager::{CompressionManager, CompressionManagerArc},
compression_stats::{CompressionStat, CompressionStats},
concat_lines, into_tokio_rusqlite_err,
sql_types::{CompressionId, ZstdDictId},
zstd_dict::{ZstdDict, ZstdDictArc, ZstdEncoder, ZSTD_LEVEL},
sql_types::CompressionId,
AsyncBoxError,
};
use rusqlite::params;
@@ -46,35 +45,6 @@ impl Manifest {
.await?;
Ok(())
}
pub async fn insert_zstd_dict_from_samples<Str: Into<String>>(
&self,
name: Str,
samples: Vec<&[u8]>,
) -> Result<ZstdDictArc, AsyncBoxError> {
let name = name.into();
let encoder = ZstdEncoder::from_samples(ZSTD_LEVEL, samples);
let zstd_dict = self
.conn
.call(move |conn| {
let mut stmt = conn.prepare(concat_lines!(
"INSERT INTO zstd_dictionaries (name, encoder_json)",
"VALUES (?, ?)",
"RETURNING zstd_dict_id"
))?;
let zstd_dict_id = stmt.query_row(params![name, encoder], |row| row.get(0))?;
Ok(ZstdDict::new(zstd_dict_id, name, encoder))
})
.await?;
let mut compressor = self.compressor.write().await;
compressor
.set_stat(
CompressionId::ZstdDict(zstd_dict.id()),
CompressionStat::default(),
)
.await;
compressor.add_zstd_dict(zstd_dict)
}
}
async fn migrate_manifest(conn: &Connection) -> Result<(), tokio_rusqlite::Error> {
@@ -167,33 +137,6 @@ async fn load_and_store_num_shards(
})
}
async fn load_zstd_dicts(
conn: &Connection,
compressor: &mut CompressionManager,
) -> Result<(), AsyncBoxError> {
type ZstdDictRow = (ZstdDictId, String, ZstdEncoder);
let rows: Vec<ZstdDictRow> = conn
.call(|conn| {
let rows = conn
.prepare(concat_lines!(
"SELECT",
" zstd_dict_id, name, encoder_json",
"FROM zstd_dictionaries",
))?
.query_map([], |row| ZstdDictRow::try_from(row))?
.collect::<Result<_, _>>()?;
Ok(rows)
})
.await?;
debug!("loaded {} zstd dictionaries from manifest", rows.len());
for (zstd_dict_id, name, encoder) in rows {
compressor.add_zstd_dict(ZstdDict::new(zstd_dict_id, name, encoder))?;
}
Ok(())
}
async fn load_compression_stats(
conn: &Connection,
compressor: &mut CompressionManager,
@@ -264,7 +207,6 @@ fn save_compression_stats(
async fn load_compressor(conn: &Connection) -> Result<CompressionManagerArc, AsyncBoxError> {
let mut compressor = CompressionManager::default();
load_zstd_dicts(conn, &mut compressor).await?;
load_compression_stats(conn, &mut compressor).await?;
Ok(compressor.into_arc())
}
@@ -294,45 +236,7 @@ mod tests {
let compressor = manifest.compression_manager();
let compressor = compressor.read().await;
let stats = compressor.stats().await;
assert_eq!(stats.len(), 2);
assert_eq!(stats.iter().count(), 2);
}
#[tokio::test]
async fn test_manifest() {
let conn = Connection::open_in_memory().await.unwrap();
let manifest = initialize(conn, Some(4)).await.unwrap();
let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100];
let zstd_dict = manifest
.insert_zstd_dict_from_samples("test", samples)
.await
.unwrap();
// test that indexes are created correctly
assert_eq!(
zstd_dict.as_ref(),
manifest
.compression_manager()
.read()
.await
.by_id(zstd_dict.id())
.unwrap()
);
assert_eq!(
zstd_dict.as_ref(),
manifest
.compression_manager()
.read()
.await
.by_name(zstd_dict.name())
.unwrap()
);
let data = b"hello world, this is a test of a sort of long string";
let compressed = zstd_dict.compress(data).unwrap();
let decompressed = zstd_dict.decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
assert!(data.len() > compressed.len());
assert_eq!(stats.len(), 3);
assert_eq!(stats.iter().count(), 3);
}
}

View File

@@ -1,15 +0,0 @@
use super::Shard;
use crate::AsyncBoxError;
impl Shard {
pub async fn hint_names(&self) -> Result<Vec<String>, AsyncBoxError> {
self.conn
.call(|conn| {
let mut stmt = conn.prepare("SELECT DISTINCT name FROM compression_hints")?;
let rows = stmt.query_map([], |row| row.get(0))?;
Ok(rows.collect::<Result<Vec<_>, _>>()?)
})
.await
.map_err(Into::into)
}
}

View File

@@ -1,37 +0,0 @@
use super::Shard;
use crate::{into_tokio_rusqlite_err, sha256::Sha256, AsyncBoxError};
use rusqlite::params;
pub struct SampleForHintResult {
pub sha256: Sha256,
pub data: Vec<u8>,
}
impl Shard {
pub async fn samples_for_hint_name(
&self,
hint_name: &str,
limit: usize,
) -> Result<Vec<SampleForHintResult>, AsyncBoxError> {
let hint_name = hint_name.to_owned();
let result = self
.conn
.call(move |conn| {
let mut stmt = conn.prepare(
"SELECT sha256, data FROM entries WHERE id IN (
SELECT entry_id FROM compression_hints WHERE name = ?1 ORDER BY ordering
) LIMIT ?2",
)?;
let rows = stmt.query_map(params![hint_name, limit], |row| {
let sha256: Sha256 = row.get(0)?;
let data: Vec<u8> = row.get(1)?;
Ok(SampleForHintResult { sha256, data })
})?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(into_tokio_rusqlite_err)
})
.await?;
Ok(result)
}
}

View File

@@ -2,7 +2,7 @@ use super::*;
use crate::{
compressible_data::CompressibleData,
concat_lines, into_tokio_rusqlite_err,
sql_types::{CompressionId, EntryId, UtcDateTime},
sql_types::{CompressionId, UtcDateTime},
AsyncBoxError,
};
@@ -25,7 +25,6 @@ pub struct StoreArgs {
pub sha256: Sha256,
pub content_type: String,
pub data: Bytes,
pub compression_hint: Option<String>,
}
impl Shard {
@@ -35,7 +34,6 @@ impl Shard {
sha256,
data,
content_type,
compression_hint,
}: StoreArgs,
) -> Result<StoreResult, AsyncBoxError> {
let existing_entry = self
@@ -63,7 +61,6 @@ impl Shard {
dict_id,
uncompressed_size,
data,
compression_hint,
)
.map_err(into_tokio_rusqlite_err)
})
@@ -101,18 +98,15 @@ fn insert(
dict_id: CompressionId,
uncompressed_size: usize,
data: CompressibleData,
compression_hint: Option<String>,
) -> Result<StoreResult, rusqlite::Error> {
let created_at = UtcDateTime::now();
let compressed_size = data.len();
let tx = conn.transaction()?;
let entry_id: EntryId = tx.query_row(
conn.execute(
concat_lines!(
"INSERT INTO entries",
" (sha256, content_type, dict_id, uncompressed_size, compressed_size, data, created_at)",
"VALUES (?, ?, ?, ?, ?, ?, ?)",
"RETURNING id"
),
params![
sha256,
@@ -123,21 +117,8 @@ fn insert(
data.as_ref(),
created_at,
],
|row| row.get(0),
)?;
if let Some(compression_hint) = compression_hint {
let rand_ordering = rand::random::<i64>();
tx.execute(
concat_lines!(
"INSERT INTO compression_hints (name, ordering, entry_id)",
"VALUES (?, ?, ?)"
),
params![compression_hint, rand_ordering, entry_id],
)?;
}
tx.commit()?;
Ok(StoreResult::Created {
stored_size: compressed_size,
data_size: uncompressed_size,

View File

@@ -1,13 +1,10 @@
mod fn_get;
mod fn_hint_names;
mod fn_migrate;
mod fn_samples_for_hint;
mod fn_store;
pub mod shard_error;
mod shard_struct;
pub use fn_get::{GetArgs, GetResult};
pub use fn_samples_for_hint::SampleForHintResult;
pub use fn_store::{StoreArgs, StoreResult};
pub use shard_struct::Shard;

View File

@@ -62,8 +62,6 @@ async fn get_num_entries(conn: &Connection) -> Result<usize, tokio_rusqlite::Err
#[cfg(test)]
pub mod test {
use std::collections::{HashMap, HashSet};
use crate::compression_manager::test::make_compressor_with;
use crate::compression_manager::CompressionManagerArc;
use crate::{
@@ -72,11 +70,8 @@ pub mod test {
shard::{GetArgs, StoreArgs, StoreResult},
CompressionPolicy,
};
use rstest::rstest;
use super::Shard;
pub async fn make_shard_with(compressor: CompressionManagerArc) -> super::Shard {
let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap();
super::Shard::open(0, conn, compressor).await.unwrap()
@@ -228,109 +223,4 @@ pub mod test {
assert_eq!(get_result.content_type, content_type);
assert_eq!(get_result.data, data);
}
#[tokio::test]
async fn test_compression_hint() {
let shard = make_shard().await;
let data = "hello, world!".as_bytes();
let sha256 = Sha256::from_bytes(data);
let store_result = shard
.store(StoreArgs {
sha256,
content_type: "text/plain".to_string(),
data: data.into(),
compression_hint: Some("hint1".to_string()),
})
.await
.unwrap();
assert!(matches!(store_result, StoreResult::Created { .. }));
let results = shard.samples_for_hint_name("hint1", 10).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].sha256, sha256);
assert_eq!(results[0].data, data);
}
async fn store_random(shard: &Shard, hint: &str) -> (Sha256, Vec<u8>) {
let data = (0..1024).map(|_| rand::random::<u8>()).collect::<Vec<_>>();
let sha256 = Sha256::from_bytes(&data);
let store_result = shard
.store(StoreArgs {
sha256,
content_type: "text/plain".to_string(),
data: data.clone().into(),
compression_hint: Some(hint.to_string()),
})
.await
.unwrap();
assert!(matches!(store_result, StoreResult::Created { .. }));
(sha256, data)
}
// TODO - remove this test
// #[tokio::test]
async fn test_compression_hint_limits() {
let get_keys_set = |hash_map: &HashMap<Sha256, Vec<u8>>| {
hash_map
.keys()
.into_iter()
.map(|k| *k)
.collect::<HashSet<_>>()
};
let insert_num = 50;
let sample_num = 10;
let shard = make_shard().await;
let mut hint1 = HashMap::new();
let mut hint2 = HashMap::new();
for _ in 0..insert_num {
let (a, b) = store_random(&shard, "hint1").await;
hint1.insert(a, b);
}
for _ in 0..insert_num {
let (a, b) = store_random(&shard, "hint2").await;
hint2.insert(a, b);
}
assert_eq!(hint1.len(), insert_num);
assert_eq!(hint2.len(), insert_num);
let hint1samples = shard
.samples_for_hint_name("hint1", sample_num)
.await
.unwrap()
.into_iter()
.map(|r| (r.sha256, r.data))
.collect::<HashMap<_, _>>();
let hint2samples = shard
.samples_for_hint_name("hint2", sample_num)
.await
.unwrap()
.into_iter()
.map(|r| (r.sha256, r.data))
.collect::<HashMap<_, _>>();
let hint1_keys = get_keys_set(&hint1);
let hint2_keys = get_keys_set(&hint2);
let hint1samples_keys = get_keys_set(&hint1samples);
let hint2samples_keys = get_keys_set(&hint2samples);
assert_eq!(hint1_keys.len(), insert_num);
assert_eq!(hint2_keys.len(), insert_num);
assert!(hint1_keys.is_disjoint(&hint2_keys));
assert!(
hint1samples_keys.is_disjoint(&hint2samples_keys),
"hint1: {:?}, hint2: {:?}",
hint1samples_keys,
hint2samples_keys
);
assert_eq!(hint1samples.len(), sample_num);
assert_eq!(hint2samples.len(), sample_num);
assert_eq!(hint2samples_keys.len(), sample_num);
assert!(hint1_keys.is_superset(&hint1samples_keys));
assert!(hint2_keys.is_superset(&hint2samples_keys));
}
}

View File

@@ -1,10 +1,5 @@
use crate::{
sha256::Sha256,
shard::{SampleForHintResult, Shard},
AsyncBoxError,
};
use std::{collections::HashSet, sync::Arc};
use tokio::task::JoinSet;
use crate::{sha256::Sha256, shard::Shard};
use std::sync::Arc;
pub type ShardsArc = Arc<Shards>;
@@ -30,36 +25,6 @@ impl Shards {
pub fn len(&self) -> usize {
self.0.len()
}
pub async fn hint_names(&self) -> Result<HashSet<String>, AsyncBoxError> {
let mut hint_names = HashSet::new();
for shard in self.iter() {
hint_names.extend(shard.hint_names().await?);
}
Ok(hint_names)
}
pub async fn samples_for_hint_name(
&self,
hint_name: &str,
limit: usize,
) -> Result<Vec<SampleForHintResult>, AsyncBoxError> {
let mut tasks = JoinSet::new();
for shard in self.0.iter() {
let shard = shard.clone();
let hint_name = hint_name.to_owned();
tasks.spawn(async move { shard.samples_for_hint_name(&hint_name, limit).await });
}
let mut hints: Vec<SampleForHintResult> = Vec::new();
while let Some(result) = tasks.join_next().await {
let result = result??;
hints.extend(result);
}
Ok(hints)
}
}
#[cfg(test)]

View File

@@ -1,4 +1,3 @@
use super::ZstdDictId;
use crate::AsyncBoxError;
use rusqlite::{
types::{FromSql, FromSqlError, FromSqlResult, ToSqlOutput, ValueRef},
@@ -9,13 +8,11 @@ use rusqlite::{
pub enum CompressionId {
None,
Zstd,
ZstdDict(ZstdDictId),
Brotli,
}
const NONE_PREFIX: &str = "none";
const ZSTD_PREFIX: &str = "zstd";
const ZSTD_DICT_PREFIX: &str = "zstd_dict";
const BROTLI_PREFIX: &str = "brotli";
impl CompressionId {
@@ -23,31 +20,17 @@ impl CompressionId {
match self {
CompressionId::None => NONE_PREFIX,
CompressionId::Zstd => ZSTD_PREFIX,
CompressionId::ZstdDict(_) => ZSTD_DICT_PREFIX,
CompressionId::Brotli => BROTLI_PREFIX,
}
}
fn from_str(id_as_str: &str) -> Result<Self, AsyncBoxError> {
match id_as_str {
NONE_PREFIX => return Ok(CompressionId::None),
ZSTD_PREFIX => return Ok(CompressionId::Zstd),
BROTLI_PREFIX => return Ok(CompressionId::Brotli),
_ => {}
};
let parse_int = |s: &str| -> Result<i64, AsyncBoxError> {
Ok(s.parse()
.map_err(|e| format!("invalid i64 ({}): {}", id_as_str, e))?)
};
if let Some(int_str) = id_as_str.strip_prefix(ZSTD_DICT_PREFIX) {
return Ok(CompressionId::ZstdDict(ZstdDictId(parse_int(
&int_str[1..],
)?)));
NONE_PREFIX => Ok(CompressionId::None),
ZSTD_PREFIX => Ok(CompressionId::Zstd),
BROTLI_PREFIX => Ok(CompressionId::Brotli),
_ => Err(format!("invalid DictId: {}", id_as_str).into()),
}
Err(format!("invalid DictId: {}", id_as_str).into())
}
}
@@ -56,7 +39,6 @@ impl ToString for CompressionId {
match self {
CompressionId::None => NONE_PREFIX.to_string(),
CompressionId::Zstd => ZSTD_PREFIX.to_string(),
CompressionId::ZstdDict(id) => format!("{}:{}", ZSTD_DICT_PREFIX, id.0),
CompressionId::Brotli => BROTLI_PREFIX.to_string(),
}
}
@@ -64,12 +46,7 @@ impl ToString for CompressionId {
impl ToSql for CompressionId {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
let prefix = self.prefix();
if let CompressionId::ZstdDict(id) = self {
Ok(ToSqlOutput::from(format!("{}:{}", prefix, id.0)))
} else {
prefix.to_sql()
}
self.prefix().to_sql()
}
}
@@ -88,7 +65,6 @@ mod tests {
#[rstest(
case(CompressionId::None, "none"),
case(CompressionId::Zstd, "zstd"),
case(CompressionId::ZstdDict(ZstdDictId(42)), "zstd_dict:42"),
case(CompressionId::Brotli, "brotli")
)]
#[test]

View File

@@ -1,9 +1,5 @@
mod compression_id;
mod entry_id;
mod utc_date_time;
mod zstd_dict_id;
pub use compression_id::CompressionId;
pub use entry_id::EntryId;
pub use utc_date_time::UtcDateTime;
pub use zstd_dict_id::ZstdDictId;

View File

@@ -1,31 +0,0 @@
use rusqlite::{
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
ToSql,
};
use std::fmt::Display;
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, PartialOrd, Ord)]
pub struct ZstdDictId(pub i64);
impl From<i64> for ZstdDictId {
fn from(id: i64) -> Self {
Self(id)
}
}
impl Display for ZstdDictId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "zstd_dict:{}", self.0)
}
}
impl FromSql for ZstdDictId {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
Ok(ZstdDictId(value.as_i64()?))
}
}
impl ToSql for ZstdDictId {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
self.0.to_sql()
}
}

View File

@@ -1,216 +0,0 @@
use crate::{sql_types::ZstdDictId, AsyncBoxError};
use ouroboros::self_referencing;
use rusqlite::{
types::{FromSql, FromSqlError, FromSqlResult, ToSqlOutput, Value, ValueRef},
Error::ToSqlConversionFailure,
ToSql,
};
use serde::{Deserialize, Serialize};
use std::{io, sync::Arc};
use zstd::dict::{DecoderDictionary, EncoderDictionary};
const ENCODER_DICT_SIZE: usize = 2 * 1024 * 1024;
pub const ZSTD_LEVEL: i32 = 9;
pub type ZstdDictArc = Arc<ZstdDict>;
#[self_referencing]
pub struct ZstdEncoder {
level: i32,
dict_bytes: Vec<u8>,
#[borrows(dict_bytes)]
#[not_covariant]
encoder_dict: EncoderDictionary<'this>,
#[borrows(dict_bytes)]
#[not_covariant]
decoder_dict: DecoderDictionary<'this>,
}
impl Serialize for ZstdEncoder {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
use serde::ser::SerializeTuple;
// serialize level and dict_bytes as a tuple
let mut state = serializer.serialize_tuple(2)?;
state.serialize_element(&self.level())?;
state.serialize_element(self.dict_bytes())?;
state.end()
}
}
impl<'de> Deserialize<'de> for ZstdEncoder {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
use serde::de::SeqAccess;
struct ZstdEncoderVisitor;
impl<'de> serde::de::Visitor<'de> for ZstdEncoderVisitor {
type Value = ZstdEncoder;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a tuple of (i32, Vec<u8>)")
}
fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
let level = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
let dict_bytes = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
Ok(ZstdEncoder::from_dict_bytes(level, dict_bytes))
}
}
deserializer.deserialize_tuple(2, ZstdEncoderVisitor)
}
}
impl ToSql for ZstdEncoder {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
let json = serde_json::to_string(self).map_err(|e| ToSqlConversionFailure(Box::new(e)))?;
Ok(ToSqlOutput::Owned(Value::Text(json)))
}
}
impl FromSql for ZstdEncoder {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
let json = value.as_str()?;
serde_json::from_str(json).map_err(|e| FromSqlError::Other(e.into()))
}
}
impl ZstdEncoder {
pub fn from_samples(level: i32, samples: Vec<&[u8]>) -> Self {
let dict_bytes = zstd::dict::from_samples(&samples, ENCODER_DICT_SIZE).unwrap();
Self::from_dict_bytes(level, dict_bytes)
}
pub fn from_dict_bytes(level: i32, dict_bytes: Vec<u8>) -> Self {
ZstdEncoderBuilder {
level,
dict_bytes,
encoder_dict_builder: |dict_bytes| EncoderDictionary::new(dict_bytes, level),
decoder_dict_builder: |dict_bytes| DecoderDictionary::new(dict_bytes),
}
.build()
}
pub fn level(&self) -> i32 {
*self.borrow_level()
}
pub fn dict_bytes(&self) -> &[u8] {
self.borrow_dict_bytes()
}
}
pub struct ZstdDict {
id: ZstdDictId,
name: String,
encoder: ZstdEncoder,
}
impl PartialEq for ZstdDict {
fn eq(&self, other: &Self) -> bool {
self.id() == other.id()
&& self.name() == other.name()
&& self.level() == other.level()
&& self.dict_bytes() == other.dict_bytes()
}
}
impl std::fmt::Debug for ZstdDict {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ZstdDict")
.field("name", &self.name())
.field("level", &self.level())
.field("dict_bytes.len", &self.dict_bytes().len())
.finish()
}
}
impl ZstdDict {
pub fn new(id: ZstdDictId, name: String, encoder: ZstdEncoder) -> Self {
Self { id, name, encoder }
}
pub fn id(&self) -> ZstdDictId {
self.id
}
pub fn name(&self) -> &str {
&self.name
}
pub fn level(&self) -> i32 {
*self.encoder.borrow_level()
}
pub fn dict_bytes(&self) -> &[u8] {
self.encoder.borrow_dict_bytes()
}
pub fn compress<DataRef: AsRef<[u8]>>(&self, data: DataRef) -> Result<Vec<u8>, AsyncBoxError> {
let as_ref = data.as_ref();
let mut wrapper = io::Cursor::new(as_ref);
let mut out_buffer = Vec::with_capacity(as_ref.len());
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
self.encoder.with_encoder_dict(|encoder_dict| {
let mut encoder =
zstd::stream::Encoder::with_prepared_dictionary(&mut output_wrapper, encoder_dict)?;
io::copy(&mut wrapper, &mut encoder)?;
encoder.finish()
})?;
Ok(out_buffer)
}
pub fn decompress<DataRef: AsRef<[u8]>>(
&self,
data: DataRef,
) -> Result<Vec<u8>, AsyncBoxError> {
let as_ref = data.as_ref();
let mut wrapper = io::Cursor::new(as_ref);
let mut out_buffer = Vec::with_capacity(as_ref.len());
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
self.encoder.with_decoder_dict(|decoder_dict| {
let mut decoder =
zstd::stream::Decoder::with_prepared_dictionary(&mut wrapper, decoder_dict)?;
io::copy(&mut decoder, &mut output_wrapper)
})?;
Ok(out_buffer)
}
}
#[cfg(test)]
pub mod test {
use crate::sql_types::ZstdDictId;
use super::ZstdEncoder;
pub fn make_zstd_dict(id: ZstdDictId, name: &str) -> super::ZstdDict {
super::ZstdDict::new(
id,
name.to_owned(),
ZstdEncoder::from_samples(
5,
vec![
"hello, world",
"this is a test",
"of the emergency broadcast system",
]
.into_iter()
.chain(vec!["foo", "bar", "baz"].repeat(100))
.map(|s| s.as_bytes())
.collect(),
),
)
}
#[test]
fn test_zstd_dict_basics() {
let zstd_dict = make_zstd_dict(1.into(), "dict1");
let compressed = zstd_dict.compress(b"hello world").unwrap();
let decompressed = zstd_dict.decompress(&compressed).unwrap();
assert_eq!(decompressed, b"hello world");
}
}