clippy, concat_lines

This commit is contained in:
Dylan Knutson
2024-05-06 09:05:36 -07:00
parent 83b4dacede
commit 6deb909c43
12 changed files with 243 additions and 169 deletions

View File

@@ -35,3 +35,7 @@ ouroboros = "0.18.3"
[dev-dependencies] [dev-dependencies]
rstest = "0.19.0" rstest = "0.19.0"
[lints.rust]
unsafe_code = "forbid"
unused_must_use = "forbid"

View File

@@ -39,29 +39,14 @@ impl Compressor {
Arc::new(RwLock::new(self)) Arc::new(RwLock::new(self))
} }
fn _add_from_samples<Str: Into<String>>( pub fn add(&mut self, zstd_dict: ZstdDict) -> Result<ZstdDictArc> {
&mut self, self.check(zstd_dict.id(), zstd_dict.name())?;
id: ZstdDictId, let zstd_dict = Arc::new(zstd_dict);
name: Str, self.zstd_dict_by_id
samples: Vec<&[u8]>, .insert(zstd_dict.id(), zstd_dict.clone());
) -> Result<ZstdDictArc> { self.zstd_dict_by_name
let name = name.into(); .insert(zstd_dict.name().to_string(), zstd_dict.clone());
self.check(id, &name)?; Ok(zstd_dict)
let zstd_dict = ZstdDict::from_samples(id, name, 3, samples)?;
Ok(self.add(zstd_dict))
}
pub fn add_from_bytes<Str: Into<String>>(
&mut self,
id: ZstdDictId,
name: Str,
level: i32,
dict_bytes: Vec<u8>,
) -> Result<ZstdDictArc> {
let name = name.into();
self.check(id, &name)?;
let zstd_dict = ZstdDict::from_dict_bytes(id, name, level, dict_bytes);
Ok(self.add(zstd_dict))
} }
pub fn by_id(&self, id: ZstdDictId) -> Option<&ZstdDict> { pub fn by_id(&self, id: ZstdDictId) -> Option<&ZstdDict> {
@@ -70,6 +55,9 @@ impl Compressor {
pub fn by_name(&self, name: &str) -> Option<&ZstdDict> { pub fn by_name(&self, name: &str) -> Option<&ZstdDict> {
self.zstd_dict_by_name.get(name).map(|d| &**d) self.zstd_dict_by_name.get(name).map(|d| &**d)
} }
pub fn names(&self) -> impl Iterator<Item = &String> {
self.zstd_dict_by_name.keys()
}
pub fn compress<Data: Into<CompressibleData>>( pub fn compress<Data: Into<CompressibleData>>(
&self, &self,
@@ -143,15 +131,6 @@ impl Compressor {
} }
Ok(()) Ok(())
} }
fn add(&mut self, zstd_dict: ZstdDict) -> ZstdDictArc {
let zstd_dict = Arc::new(zstd_dict);
self.zstd_dict_by_id
.insert(zstd_dict.id(), zstd_dict.clone());
self.zstd_dict_by_name
.insert(zstd_dict.name().to_string(), zstd_dict.clone());
zstd_dict
}
} }
fn auto_compressible_content_type(content_type: &str) -> bool { fn auto_compressible_content_type(content_type: &str) -> bool {
@@ -179,7 +158,7 @@ pub mod test {
pub fn make_compressor_with(compression_policy: CompressionPolicy) -> Compressor { pub fn make_compressor_with(compression_policy: CompressionPolicy) -> Compressor {
let mut compressor = Compressor::new(compression_policy); let mut compressor = Compressor::new(compression_policy);
let zstd_dict = make_zstd_dict(1.into(), "dict1"); let zstd_dict = make_zstd_dict(1.into(), "dict1");
compressor.add(zstd_dict); compressor.add(zstd_dict).unwrap();
compressor compressor
} }

7
src/concat_lines.rs Normal file
View File

@@ -0,0 +1,7 @@
// like concat, but adds a newline between each expression
#[macro_export]
macro_rules! concat_lines {
($($e:expr),* $(,)?) => {
concat!($($e, "\n"),*)
};
}

View File

@@ -1,5 +1,6 @@
mod compressible_data; mod compressible_data;
mod compressor; mod compressor;
mod concat_lines;
mod handlers; mod handlers;
mod manifest; mod manifest;
mod sha256; mod sha256;
@@ -22,7 +23,7 @@ use shards::ShardsArc;
use std::{error::Error, path::PathBuf, sync::Arc}; use std::{error::Error, path::PathBuf, sync::Arc};
use tokio::{net::TcpListener, select, spawn}; use tokio::{net::TcpListener, select, spawn};
use tokio_rusqlite::Connection; use tokio_rusqlite::Connection;
use tracing::info; use tracing::{debug, info};
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(version, about, long_about = None)] #[command(version, about, long_about = None)]
@@ -117,7 +118,7 @@ fn main() -> Result<(), AsyncBoxError> {
let compressor = manifest.compressor(); let compressor = manifest.compressor();
let dict_loop_handle = spawn(dict_loop(manifest, shards.clone())); let dict_loop_handle = spawn(dict_loop(manifest, shards.clone()));
let server_handle = spawn(server_loop(server, shards, compressor)); let server_handle = spawn(server_loop(server, shards, compressor));
dict_loop_handle.await?; dict_loop_handle.await??;
server_handle.await??; server_handle.await??;
info!("server closed sqlite connections. bye!"); info!("server closed sqlite connections. bye!");
Ok::<_, AsyncBoxError>(()) Ok::<_, AsyncBoxError>(())
@@ -126,18 +127,35 @@ fn main() -> Result<(), AsyncBoxError> {
Ok(()) Ok(())
} }
async fn dict_loop(manifest: Manifest, shards: ShardsArc) { async fn dict_loop(manifest: Manifest, shards: ShardsArc) -> Result<(), AsyncBoxError> {
loop { loop {
info!("dict loop: running"); let mut hint_names = shards.hint_names().await?;
let compressor = manifest.compressor(); {
let _compressor = compressor.read().await; // find what hint names don't have a corresponding dictionary
for _shard in shards.iter() {} let compressor = manifest.compressor();
let compressor = compressor.read().await;
compressor.names().for_each(|name| {
hint_names.remove(name);
});
}
for hint_name in hint_names {
debug!("creating dictionary for {}", hint_name);
let samples = shards.samples_for_hint_name(&hint_name, 10).await?;
let samples_bytes = samples
.iter()
.map(|s| s.data.as_ref())
.collect::<Vec<&[u8]>>();
manifest
.insert_dict_from_samples(hint_name, samples_bytes)
.await?;
}
select! { select! {
_ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => {} _ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => {}
_ = crate::shutdown_signal::shutdown_signal() => { _ = crate::shutdown_signal::shutdown_signal() => {
info!("dict loop: shutdown signal received"); info!("dict loop: shutdown signal received");
break; return Ok(());
} }
} }
} }

View File

@@ -8,8 +8,9 @@ use tokio_rusqlite::Connection;
use crate::{ use crate::{
compressor::{Compressor, CompressorArc}, compressor::{Compressor, CompressorArc},
into_tokio_rusqlite_err, concat_lines,
zstd_dict::ZstdDictArc, sql_types::ZstdDictId,
zstd_dict::{ZstdDict, ZstdDictArc, ZstdEncoder},
AsyncBoxError, AsyncBoxError,
}; };
@@ -21,17 +22,11 @@ pub struct Manifest {
compressor: Arc<RwLock<Compressor>>, compressor: Arc<RwLock<Compressor>>,
} }
pub type ManifestArc = Arc<Manifest>;
impl Manifest { impl Manifest {
pub async fn open(conn: Connection, num_shards: Option<usize>) -> Result<Self, AsyncBoxError> { pub async fn open(conn: Connection, num_shards: Option<usize>) -> Result<Self, AsyncBoxError> {
initialize(conn, num_shards).await initialize(conn, num_shards).await
} }
pub fn into_arc(self) -> ManifestArc {
Arc::new(self)
}
pub fn num_shards(&self) -> usize { pub fn num_shards(&self) -> usize {
self.num_shards self.num_shards
} }
@@ -40,35 +35,29 @@ impl Manifest {
self.compressor.clone() self.compressor.clone()
} }
pub async fn insert_zstd_dict_from_samples<Str: Into<String>>( pub async fn insert_dict_from_samples<Str: Into<String>>(
&self, &self,
name: Str, name: Str,
samples: Vec<&[u8]>, samples: Vec<&[u8]>,
) -> Result<ZstdDictArc, Box<dyn Error>> { ) -> Result<ZstdDictArc, AsyncBoxError> {
let name = name.into(); let name = name.into();
let dict_bytes = zstd::dict::from_samples( let encoder = ZstdEncoder::from_samples(3, samples);
&samples,
1024 * 1024, // 1MB max dictionary size
)?;
let compressor = self.compressor.clone();
let zstd_dict = self let zstd_dict = self
.conn .conn
.call(move |conn| { .call(move |conn| {
let level = 3; let level = 3;
let mut stmt = conn.prepare( let mut stmt = conn.prepare(concat_lines!(
"INSERT INTO dictionaries (name, level, dict) "INSERT INTO dictionaries (name, level, dict)",
VALUES (?, ?, ?) "VALUES (?, ?, ?)",
RETURNING id", "RETURNING id"
)?; ))?;
let dict_id = stmt.query_row(params![name, level, dict_bytes], |row| row.get(0))?; let dict_id =
let mut compressor = compressor.blocking_write(); stmt.query_row(params![name, level, encoder.dict_bytes()], |row| row.get(0))?;
let zstd_dict = compressor Ok(ZstdDict::new(dict_id, name, encoder))
.add_from_bytes(dict_id, name, level, dict_bytes)
.map_err(into_tokio_rusqlite_err)?;
Ok(zstd_dict)
}) })
.await?; .await?;
Ok(zstd_dict) let mut compressor = self.compressor.write().await;
compressor.add(zstd_dict)
} }
} }
@@ -84,12 +73,14 @@ async fn initialize(
)?; )?;
conn.execute( conn.execute(
"CREATE TABLE IF NOT EXISTS dictionaries ( concat_lines!(
id INTEGER PRIMARY KEY AUTOINCREMENT, "CREATE TABLE IF NOT EXISTS dictionaries (",
level INTEGER NOT NULL, " id INTEGER PRIMARY KEY AUTOINCREMENT,",
name TEXT NOT NULL, " level INTEGER NOT NULL,",
dict BLOB NOT NULL " name TEXT NOT NULL,",
)", " dict BLOB NOT NULL",
")"
),
[], [],
)?; )?;
@@ -126,17 +117,12 @@ async fn initialize(
} }
}; };
let rows = conn type DictRow = (ZstdDictId, String, i32, Vec<u8>);
let rows: Vec<DictRow> = conn
.call(|conn| { .call(|conn| {
let mut stmt = conn.prepare("SELECT id, name, level, dict FROM dictionaries")?; let mut stmt = conn.prepare("SELECT id, name, level, dict FROM dictionaries")?;
let mut rows = vec![]; let mut rows = vec![];
for r in stmt.query_map([], |row| { for r in stmt.query_map([], |row| DictRow::try_from(row))? {
let id = row.get(0)?;
let name: String = row.get(1)?;
let level: i32 = row.get(2)?;
let dict: Vec<u8> = row.get(3)?;
Ok((id, name, level, dict))
})? {
rows.push(r?); rows.push(r?);
} }
Ok(rows) Ok(rows)
@@ -144,8 +130,12 @@ async fn initialize(
.await?; .await?;
let mut compressor = Compressor::default(); let mut compressor = Compressor::default();
for (id, name, level, dict_bytes) in rows { for (dict_id, name, level, dict_bytes) in rows {
compressor.add_from_bytes(id, name, level, dict_bytes)?; compressor.add(ZstdDict::new(
dict_id,
name,
ZstdEncoder::from_dict_bytes(level, dict_bytes),
))?;
} }
let compressor = compressor.into_arc(); let compressor = compressor.into_arc();
Ok(Manifest { Ok(Manifest {
@@ -166,7 +156,7 @@ mod tests {
let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100]; let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100];
let zstd_dict = manifest let zstd_dict = manifest
.insert_zstd_dict_from_samples("test", samples) .insert_dict_from_samples("test", samples)
.await .await
.unwrap(); .unwrap();

View File

@@ -0,0 +1,16 @@
use crate::AsyncBoxError;
use super::Shard;
impl Shard {
pub async fn hint_names(&self) -> Result<Vec<String>, AsyncBoxError> {
self.conn
.call(|conn| {
let mut stmt = conn.prepare("SELECT DISTINCT name FROM compression_hints")?;
let rows = stmt.query_map([], |row| row.get(0))?;
Ok(rows.collect::<Result<Vec<_>, _>>()?)
})
.await
.map_err(Into::into)
}
}

View File

@@ -1,4 +1,4 @@
use crate::AsyncBoxError; use crate::{concat_lines, AsyncBoxError};
use super::*; use super::*;
@@ -37,17 +37,23 @@ impl Shard {
fn ensure_schema_versions_table(conn: &rusqlite::Connection) -> Result<usize, rusqlite::Error> { fn ensure_schema_versions_table(conn: &rusqlite::Connection) -> Result<usize, rusqlite::Error> {
conn.execute( conn.execute(
"CREATE TABLE IF NOT EXISTS schema_version ( concat_lines!(
version INTEGER PRIMARY KEY, "CREATE TABLE IF NOT EXISTS schema_version (",
created_at TEXT NOT NULL " version INTEGER PRIMARY KEY,",
)", " created_at TEXT NOT NULL",
")"
),
[], [],
) )
} }
fn load_schema_rows(conn: &rusqlite::Connection) -> Result<Vec<(i64, String)>, rusqlite::Error> { fn load_schema_rows(conn: &rusqlite::Connection) -> Result<Vec<(i64, String)>, rusqlite::Error> {
let mut stmt = conn let mut stmt = conn.prepare(concat_lines!(
.prepare("SELECT version, created_at FROM schema_version ORDER BY version DESC LIMIT 1")?; "SELECT version, created_at",
"FROM schema_version",
"ORDER BY version",
"DESC LIMIT 1"
))?;
let rows = stmt.query_map([], |row| { let rows = stmt.query_map([], |row| {
let version = row.get(0)?; let version = row.get(0)?;
let created_at = row.get(1)?; let created_at = row.get(1)?;
@@ -59,35 +65,45 @@ fn load_schema_rows(conn: &rusqlite::Connection) -> Result<Vec<(i64, String)>, r
fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> { fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
debug!("migrating to version 1"); debug!("migrating to version 1");
conn.execute( conn.execute(
"CREATE TABLE IF NOT EXISTS entries ( concat_lines!(
id INTEGER PRIMARY KEY AUTOINCREMENT, "CREATE TABLE IF NOT EXISTS entries (",
sha256 BLOB NOT NULL, " id INTEGER PRIMARY KEY AUTOINCREMENT,",
content_type TEXT NOT NULL, " sha256 BLOB NOT NULL,",
compression_id INTEGER NOT NULL, " content_type TEXT NOT NULL,",
uncompressed_size INTEGER NOT NULL, " compression_id INTEGER NOT NULL,",
compressed_size INTEGER NOT NULL, " uncompressed_size INTEGER NOT NULL,",
data BLOB NOT NULL, " compressed_size INTEGER NOT NULL,",
created_at TEXT NOT NULL " data BLOB NOT NULL,",
)", " created_at TEXT NOT NULL",
")"
),
[], [],
)?; )?;
conn.execute( conn.execute(
"CREATE INDEX IF NOT EXISTS entries_sha256_idx ON entries (sha256)", concat_lines!(
"CREATE INDEX IF NOT EXISTS entries_sha256_idx",
"ON entries (sha256)"
),
[], [],
)?; )?;
conn.execute( conn.execute(
"CREATE TABLE IF NOT EXISTS compression_hints ( concat_lines!(
name TEXT NOT NULL, "CREATE TABLE IF NOT EXISTS compression_hints (",
ordering INTEGER NOT NULL, " name TEXT NOT NULL,",
entry_id INTEGER NOT NULL " ordering INTEGER NOT NULL,",
)", " entry_id INTEGER NOT NULL",
")"
),
[], [],
)?; )?;
conn.execute( conn.execute(
"CREATE INDEX IF NOT EXISTS compression_hints_name_idx ON compression_hints (name, ordering)", concat_lines!(
"CREATE INDEX IF NOT EXISTS compression_hints_name_idx",
"ON compression_hints (name, ordering)",
),
[], [],
)?; )?;

View File

@@ -10,12 +10,12 @@ pub struct SampleForHintResult {
} }
impl Shard { impl Shard {
pub async fn samples_for_hint( pub async fn samples_for_hint_name(
&self, &self,
compression_hint: &str, hint_name: &str,
limit: usize, limit: usize,
) -> Result<Vec<SampleForHintResult>, AsyncBoxError> { ) -> Result<Vec<SampleForHintResult>, AsyncBoxError> {
let compression_hint = compression_hint.to_owned(); let hint_name = hint_name.to_owned();
let result = self let result = self
.conn .conn
.call(move |conn| { .call(move |conn| {
@@ -24,7 +24,7 @@ impl Shard {
SELECT entry_id FROM compression_hints WHERE name = ?1 ORDER BY ordering SELECT entry_id FROM compression_hints WHERE name = ?1 ORDER BY ordering
) LIMIT ?2", ) LIMIT ?2",
)?; )?;
let rows = stmt.query_map(params![compression_hint, limit], |row| { let rows = stmt.query_map(params![hint_name, limit], |row| {
let sha256: Sha256 = row.get(0)?; let sha256: Sha256 = row.get(0)?;
let data: Vec<u8> = row.get(1)?; let data: Vec<u8> = row.get(1)?;
Ok(SampleForHintResult { sha256, data }) Ok(SampleForHintResult { sha256, data })

View File

@@ -1,4 +1,5 @@
mod fn_get; mod fn_get;
mod fn_hint_names;
mod fn_migrate; mod fn_migrate;
mod fn_samples_for_hint; mod fn_samples_for_hint;
mod fn_store; mod fn_store;
@@ -6,6 +7,7 @@ pub mod shard_error;
mod shard_struct; mod shard_struct;
pub use fn_get::{GetArgs, GetResult}; pub use fn_get::{GetArgs, GetResult};
pub use fn_samples_for_hint::SampleForHintResult;
pub use fn_store::{StoreArgs, StoreResult}; pub use fn_store::{StoreArgs, StoreResult};
pub use shard_struct::Shard; pub use shard_struct::Shard;

View File

@@ -246,7 +246,7 @@ pub mod test {
.unwrap(); .unwrap();
assert!(matches!(store_result, StoreResult::Created { .. })); assert!(matches!(store_result, StoreResult::Created { .. }));
let results = shard.samples_for_hint("hint1", 10).await.unwrap(); let results = shard.samples_for_hint_name("hint1", 10).await.unwrap();
assert_eq!(results.len(), 1); assert_eq!(results.len(), 1);
assert_eq!(results[0].sha256, sha256); assert_eq!(results[0].sha256, sha256);
assert_eq!(results[0].data, data); assert_eq!(results[0].data, data);
@@ -298,7 +298,7 @@ pub mod test {
assert_eq!(hint2.len(), insert_num); assert_eq!(hint2.len(), insert_num);
let hint1samples = shard let hint1samples = shard
.samples_for_hint("hint1", sample_num) .samples_for_hint_name("hint1", sample_num)
.await .await
.unwrap() .unwrap()
.into_iter() .into_iter()
@@ -306,7 +306,7 @@ pub mod test {
.collect::<HashMap<_, _>>(); .collect::<HashMap<_, _>>();
let hint2samples = shard let hint2samples = shard
.samples_for_hint("hint2", sample_num) .samples_for_hint_name("hint2", sample_num)
.await .await
.unwrap() .unwrap()
.into_iter() .into_iter()

View File

@@ -1,17 +1,23 @@
use std::sync::Arc; use std::{collections::HashSet, sync::Arc};
use crate::{sha256::Sha256, shard::Shard}; use tokio::task::JoinSet;
use crate::{
sha256::Sha256,
shard::{SampleForHintResult, Shard},
AsyncBoxError,
};
pub type ShardsArc = Arc<Shards>; pub type ShardsArc = Arc<Shards>;
#[derive(Clone)] #[derive(Clone)]
pub struct Shards(Vec<Shard>); pub struct Shards(Vec<Arc<Shard>>);
impl Shards { impl Shards {
pub fn new(shards: Vec<Shard>) -> Option<Self> { pub fn new(shards: Vec<Shard>) -> Option<Self> {
if shards.is_empty() { if shards.is_empty() {
return None; return None;
} }
Some(Self(shards)) Some(Self(shards.into_iter().map(Arc::new).collect()))
} }
pub fn shard_for(&self, sha256: &Sha256) -> &Shard { pub fn shard_for(&self, sha256: &Sha256) -> &Shard {
@@ -19,13 +25,43 @@ impl Shards {
&self.0[shard_id] &self.0[shard_id]
} }
pub fn iter(&self) -> std::slice::Iter<'_, Shard> { pub fn iter(&self) -> impl Iterator<Item = &Shard> {
self.0.iter() self.0.iter().map(|shard| shard.as_ref())
} }
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
self.0.len() self.0.len()
} }
pub async fn hint_names(&self) -> Result<HashSet<String>, AsyncBoxError> {
let mut hint_names = HashSet::new();
for shard in self.iter() {
hint_names.extend(shard.hint_names().await?);
}
Ok(hint_names)
}
pub async fn samples_for_hint_name(
&self,
hint_name: &str,
limit: usize,
) -> Result<Vec<SampleForHintResult>, AsyncBoxError> {
let mut tasks = JoinSet::new();
for shard in self.0.iter() {
let shard = shard.clone();
let hint_name = hint_name.to_owned();
tasks.spawn(async move { shard.samples_for_hint_name(&hint_name, limit).await });
}
let mut hints: Vec<SampleForHintResult> = Vec::new();
while let Some(result) = tasks.join_next().await {
let result = result??;
hints.extend(result);
}
Ok(hints)
}
} }
#[cfg(test)] #[cfg(test)]

View File

@@ -1,5 +1,5 @@
use ouroboros::self_referencing; use ouroboros::self_referencing;
use std::{error::Error, io, sync::Arc}; use std::{io, sync::Arc};
use zstd::dict::{DecoderDictionary, EncoderDictionary}; use zstd::dict::{DecoderDictionary, EncoderDictionary};
use crate::{sql_types::ZstdDictId, AsyncBoxError}; use crate::{sql_types::ZstdDictId, AsyncBoxError};
@@ -7,9 +7,7 @@ use crate::{sql_types::ZstdDictId, AsyncBoxError};
pub type ZstdDictArc = Arc<ZstdDict>; pub type ZstdDictArc = Arc<ZstdDict>;
#[self_referencing] #[self_referencing]
pub struct ZstdDict { pub struct ZstdEncoder {
id: crate::sql_types::ZstdDictId,
name: String,
level: i32, level: i32,
dict_bytes: Vec<u8>, dict_bytes: Vec<u8>,
@@ -22,6 +20,33 @@ pub struct ZstdDict {
decoder_dict: DecoderDictionary<'this>, decoder_dict: DecoderDictionary<'this>,
} }
impl ZstdEncoder {
pub fn from_samples(level: i32, samples: Vec<&[u8]>) -> Self {
let dict_bytes = zstd::dict::from_samples(&samples, 1024 * 1024).unwrap();
Self::from_dict_bytes(level, dict_bytes)
}
pub fn from_dict_bytes(level: i32, dict_bytes: Vec<u8>) -> Self {
ZstdEncoderBuilder {
level,
dict_bytes,
encoder_dict_builder: |dict_bytes| EncoderDictionary::new(dict_bytes, level),
decoder_dict_builder: |dict_bytes| DecoderDictionary::new(dict_bytes),
}
.build()
}
pub fn dict_bytes(&self) -> &[u8] {
self.borrow_dict_bytes()
}
}
pub struct ZstdDict {
id: ZstdDictId,
name: String,
encoder: ZstdEncoder,
}
impl PartialEq for ZstdDict { impl PartialEq for ZstdDict {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.id() == other.id() self.id() == other.id()
@@ -42,42 +67,21 @@ impl std::fmt::Debug for ZstdDict {
} }
impl ZstdDict { impl ZstdDict {
pub fn from_samples( pub fn new(id: ZstdDictId, name: String, encoder: ZstdEncoder) -> Self {
id: ZstdDictId, Self { id, name, encoder }
name: String,
level: i32,
samples: Vec<&[u8]>,
) -> Result<Self, Box<dyn Error + Send + Sync>> {
let dict_bytes = zstd::dict::from_samples(
&samples,
1024 * 1024, // 1MB max dictionary size
)?;
Ok(Self::from_dict_bytes(id, name, level, dict_bytes))
}
pub fn from_dict_bytes(id: ZstdDictId, name: String, level: i32, dict_bytes: Vec<u8>) -> Self {
ZstdDictBuilder {
id,
name,
level,
dict_bytes,
encoder_dict_builder: |dict_bytes| EncoderDictionary::new(dict_bytes, level),
decoder_dict_builder: |dict_bytes| DecoderDictionary::new(dict_bytes),
}
.build()
} }
pub fn id(&self) -> ZstdDictId { pub fn id(&self) -> ZstdDictId {
*self.borrow_id() self.id
} }
pub fn name(&self) -> &str { pub fn name(&self) -> &str {
self.borrow_name() &self.name
} }
pub fn level(&self) -> i32 { pub fn level(&self) -> i32 {
*self.borrow_level() *self.encoder.borrow_level()
} }
pub fn dict_bytes(&self) -> &[u8] { pub fn dict_bytes(&self) -> &[u8] {
self.borrow_dict_bytes() self.encoder.borrow_dict_bytes()
} }
pub fn compress<DataRef: AsRef<[u8]>>(&self, data: DataRef) -> Result<Vec<u8>, AsyncBoxError> { pub fn compress<DataRef: AsRef<[u8]>>(&self, data: DataRef) -> Result<Vec<u8>, AsyncBoxError> {
@@ -86,7 +90,7 @@ impl ZstdDict {
let mut out_buffer = Vec::with_capacity(as_ref.len()); let mut out_buffer = Vec::with_capacity(as_ref.len());
let mut output_wrapper = io::Cursor::new(&mut out_buffer); let mut output_wrapper = io::Cursor::new(&mut out_buffer);
self.with_encoder_dict(|encoder_dict| { self.encoder.with_encoder_dict(|encoder_dict| {
let mut encoder = let mut encoder =
zstd::stream::Encoder::with_prepared_dictionary(&mut output_wrapper, encoder_dict)?; zstd::stream::Encoder::with_prepared_dictionary(&mut output_wrapper, encoder_dict)?;
io::copy(&mut wrapper, &mut encoder)?; io::copy(&mut wrapper, &mut encoder)?;
@@ -104,7 +108,7 @@ impl ZstdDict {
let mut out_buffer = Vec::with_capacity(as_ref.len()); let mut out_buffer = Vec::with_capacity(as_ref.len());
let mut output_wrapper = io::Cursor::new(&mut out_buffer); let mut output_wrapper = io::Cursor::new(&mut out_buffer);
self.with_decoder_dict(|decoder_dict| { self.encoder.with_decoder_dict(|decoder_dict| {
let mut decoder = let mut decoder =
zstd::stream::Decoder::with_prepared_dictionary(&mut wrapper, decoder_dict)?; zstd::stream::Decoder::with_prepared_dictionary(&mut wrapper, decoder_dict)?;
io::copy(&mut decoder, &mut output_wrapper) io::copy(&mut decoder, &mut output_wrapper)
@@ -117,27 +121,29 @@ impl ZstdDict {
pub mod test { pub mod test {
use crate::sql_types::ZstdDictId; use crate::sql_types::ZstdDictId;
use super::ZstdEncoder;
pub fn make_zstd_dict(id: ZstdDictId, name: &str) -> super::ZstdDict { pub fn make_zstd_dict(id: ZstdDictId, name: &str) -> super::ZstdDict {
super::ZstdDict::from_dict_bytes( super::ZstdDict::new(
id, id,
name.to_owned(), name.to_owned(),
3, ZstdEncoder::from_samples(
vec![ 3,
"hello, world", vec![
"this is a test", "hello, world",
"of the emergency broadcast system", "this is a test",
] "of the emergency broadcast system",
.into_iter() ]
.chain(vec!["foo", "bar", "baz"].repeat(100)) .into_iter()
.map(|s| s.as_bytes().to_owned()) .chain(vec!["foo", "bar", "baz"].repeat(100))
.flat_map(|s| s.into_iter()) .map(|s| s.as_bytes())
.collect(), .collect(),
),
) )
} }
#[test] #[test]
fn test_zstd_dict() { fn test_zstd_dict_basics() {
let dict_bytes = vec![1, 2, 3, 4];
let zstd_dict = make_zstd_dict(1.into(), "dict1"); let zstd_dict = make_zstd_dict(1.into(), "dict1");
let compressed = zstd_dict.compress(b"hello world").unwrap(); let compressed = zstd_dict.compress(b"hello world").unwrap();
let decompressed = zstd_dict.decompress(&compressed).unwrap(); let decompressed = zstd_dict.decompress(&compressed).unwrap();