clippy, concat_lines
This commit is contained in:
@@ -35,3 +35,7 @@ ouroboros = "0.18.3"
|
||||
|
||||
[dev-dependencies]
|
||||
rstest = "0.19.0"
|
||||
|
||||
[lints.rust]
|
||||
unsafe_code = "forbid"
|
||||
unused_must_use = "forbid"
|
||||
|
||||
@@ -39,29 +39,14 @@ impl Compressor {
|
||||
Arc::new(RwLock::new(self))
|
||||
}
|
||||
|
||||
fn _add_from_samples<Str: Into<String>>(
|
||||
&mut self,
|
||||
id: ZstdDictId,
|
||||
name: Str,
|
||||
samples: Vec<&[u8]>,
|
||||
) -> Result<ZstdDictArc> {
|
||||
let name = name.into();
|
||||
self.check(id, &name)?;
|
||||
let zstd_dict = ZstdDict::from_samples(id, name, 3, samples)?;
|
||||
Ok(self.add(zstd_dict))
|
||||
}
|
||||
|
||||
pub fn add_from_bytes<Str: Into<String>>(
|
||||
&mut self,
|
||||
id: ZstdDictId,
|
||||
name: Str,
|
||||
level: i32,
|
||||
dict_bytes: Vec<u8>,
|
||||
) -> Result<ZstdDictArc> {
|
||||
let name = name.into();
|
||||
self.check(id, &name)?;
|
||||
let zstd_dict = ZstdDict::from_dict_bytes(id, name, level, dict_bytes);
|
||||
Ok(self.add(zstd_dict))
|
||||
pub fn add(&mut self, zstd_dict: ZstdDict) -> Result<ZstdDictArc> {
|
||||
self.check(zstd_dict.id(), zstd_dict.name())?;
|
||||
let zstd_dict = Arc::new(zstd_dict);
|
||||
self.zstd_dict_by_id
|
||||
.insert(zstd_dict.id(), zstd_dict.clone());
|
||||
self.zstd_dict_by_name
|
||||
.insert(zstd_dict.name().to_string(), zstd_dict.clone());
|
||||
Ok(zstd_dict)
|
||||
}
|
||||
|
||||
pub fn by_id(&self, id: ZstdDictId) -> Option<&ZstdDict> {
|
||||
@@ -70,6 +55,9 @@ impl Compressor {
|
||||
pub fn by_name(&self, name: &str) -> Option<&ZstdDict> {
|
||||
self.zstd_dict_by_name.get(name).map(|d| &**d)
|
||||
}
|
||||
pub fn names(&self) -> impl Iterator<Item = &String> {
|
||||
self.zstd_dict_by_name.keys()
|
||||
}
|
||||
|
||||
pub fn compress<Data: Into<CompressibleData>>(
|
||||
&self,
|
||||
@@ -143,15 +131,6 @@ impl Compressor {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn add(&mut self, zstd_dict: ZstdDict) -> ZstdDictArc {
|
||||
let zstd_dict = Arc::new(zstd_dict);
|
||||
self.zstd_dict_by_id
|
||||
.insert(zstd_dict.id(), zstd_dict.clone());
|
||||
self.zstd_dict_by_name
|
||||
.insert(zstd_dict.name().to_string(), zstd_dict.clone());
|
||||
zstd_dict
|
||||
}
|
||||
}
|
||||
|
||||
fn auto_compressible_content_type(content_type: &str) -> bool {
|
||||
@@ -179,7 +158,7 @@ pub mod test {
|
||||
pub fn make_compressor_with(compression_policy: CompressionPolicy) -> Compressor {
|
||||
let mut compressor = Compressor::new(compression_policy);
|
||||
let zstd_dict = make_zstd_dict(1.into(), "dict1");
|
||||
compressor.add(zstd_dict);
|
||||
compressor.add(zstd_dict).unwrap();
|
||||
compressor
|
||||
}
|
||||
|
||||
|
||||
7
src/concat_lines.rs
Normal file
7
src/concat_lines.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
// like concat, but adds a newline between each expression
|
||||
#[macro_export]
|
||||
macro_rules! concat_lines {
|
||||
($($e:expr),* $(,)?) => {
|
||||
concat!($($e, "\n"),*)
|
||||
};
|
||||
}
|
||||
34
src/main.rs
34
src/main.rs
@@ -1,5 +1,6 @@
|
||||
mod compressible_data;
|
||||
mod compressor;
|
||||
mod concat_lines;
|
||||
mod handlers;
|
||||
mod manifest;
|
||||
mod sha256;
|
||||
@@ -22,7 +23,7 @@ use shards::ShardsArc;
|
||||
use std::{error::Error, path::PathBuf, sync::Arc};
|
||||
use tokio::{net::TcpListener, select, spawn};
|
||||
use tokio_rusqlite::Connection;
|
||||
use tracing::info;
|
||||
use tracing::{debug, info};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(version, about, long_about = None)]
|
||||
@@ -117,7 +118,7 @@ fn main() -> Result<(), AsyncBoxError> {
|
||||
let compressor = manifest.compressor();
|
||||
let dict_loop_handle = spawn(dict_loop(manifest, shards.clone()));
|
||||
let server_handle = spawn(server_loop(server, shards, compressor));
|
||||
dict_loop_handle.await?;
|
||||
dict_loop_handle.await??;
|
||||
server_handle.await??;
|
||||
info!("server closed sqlite connections. bye!");
|
||||
Ok::<_, AsyncBoxError>(())
|
||||
@@ -126,18 +127,35 @@ fn main() -> Result<(), AsyncBoxError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn dict_loop(manifest: Manifest, shards: ShardsArc) {
|
||||
async fn dict_loop(manifest: Manifest, shards: ShardsArc) -> Result<(), AsyncBoxError> {
|
||||
loop {
|
||||
info!("dict loop: running");
|
||||
let compressor = manifest.compressor();
|
||||
let _compressor = compressor.read().await;
|
||||
for _shard in shards.iter() {}
|
||||
let mut hint_names = shards.hint_names().await?;
|
||||
{
|
||||
// find what hint names don't have a corresponding dictionary
|
||||
let compressor = manifest.compressor();
|
||||
let compressor = compressor.read().await;
|
||||
compressor.names().for_each(|name| {
|
||||
hint_names.remove(name);
|
||||
});
|
||||
}
|
||||
|
||||
for hint_name in hint_names {
|
||||
debug!("creating dictionary for {}", hint_name);
|
||||
let samples = shards.samples_for_hint_name(&hint_name, 10).await?;
|
||||
let samples_bytes = samples
|
||||
.iter()
|
||||
.map(|s| s.data.as_ref())
|
||||
.collect::<Vec<&[u8]>>();
|
||||
manifest
|
||||
.insert_dict_from_samples(hint_name, samples_bytes)
|
||||
.await?;
|
||||
}
|
||||
|
||||
select! {
|
||||
_ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => {}
|
||||
_ = crate::shutdown_signal::shutdown_signal() => {
|
||||
info!("dict loop: shutdown signal received");
|
||||
break;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,8 +8,9 @@ use tokio_rusqlite::Connection;
|
||||
|
||||
use crate::{
|
||||
compressor::{Compressor, CompressorArc},
|
||||
into_tokio_rusqlite_err,
|
||||
zstd_dict::ZstdDictArc,
|
||||
concat_lines,
|
||||
sql_types::ZstdDictId,
|
||||
zstd_dict::{ZstdDict, ZstdDictArc, ZstdEncoder},
|
||||
AsyncBoxError,
|
||||
};
|
||||
|
||||
@@ -21,17 +22,11 @@ pub struct Manifest {
|
||||
compressor: Arc<RwLock<Compressor>>,
|
||||
}
|
||||
|
||||
pub type ManifestArc = Arc<Manifest>;
|
||||
|
||||
impl Manifest {
|
||||
pub async fn open(conn: Connection, num_shards: Option<usize>) -> Result<Self, AsyncBoxError> {
|
||||
initialize(conn, num_shards).await
|
||||
}
|
||||
|
||||
pub fn into_arc(self) -> ManifestArc {
|
||||
Arc::new(self)
|
||||
}
|
||||
|
||||
pub fn num_shards(&self) -> usize {
|
||||
self.num_shards
|
||||
}
|
||||
@@ -40,35 +35,29 @@ impl Manifest {
|
||||
self.compressor.clone()
|
||||
}
|
||||
|
||||
pub async fn insert_zstd_dict_from_samples<Str: Into<String>>(
|
||||
pub async fn insert_dict_from_samples<Str: Into<String>>(
|
||||
&self,
|
||||
name: Str,
|
||||
samples: Vec<&[u8]>,
|
||||
) -> Result<ZstdDictArc, Box<dyn Error>> {
|
||||
) -> Result<ZstdDictArc, AsyncBoxError> {
|
||||
let name = name.into();
|
||||
let dict_bytes = zstd::dict::from_samples(
|
||||
&samples,
|
||||
1024 * 1024, // 1MB max dictionary size
|
||||
)?;
|
||||
let compressor = self.compressor.clone();
|
||||
let encoder = ZstdEncoder::from_samples(3, samples);
|
||||
let zstd_dict = self
|
||||
.conn
|
||||
.call(move |conn| {
|
||||
let level = 3;
|
||||
let mut stmt = conn.prepare(
|
||||
"INSERT INTO dictionaries (name, level, dict)
|
||||
VALUES (?, ?, ?)
|
||||
RETURNING id",
|
||||
)?;
|
||||
let dict_id = stmt.query_row(params![name, level, dict_bytes], |row| row.get(0))?;
|
||||
let mut compressor = compressor.blocking_write();
|
||||
let zstd_dict = compressor
|
||||
.add_from_bytes(dict_id, name, level, dict_bytes)
|
||||
.map_err(into_tokio_rusqlite_err)?;
|
||||
Ok(zstd_dict)
|
||||
let mut stmt = conn.prepare(concat_lines!(
|
||||
"INSERT INTO dictionaries (name, level, dict)",
|
||||
"VALUES (?, ?, ?)",
|
||||
"RETURNING id"
|
||||
))?;
|
||||
let dict_id =
|
||||
stmt.query_row(params![name, level, encoder.dict_bytes()], |row| row.get(0))?;
|
||||
Ok(ZstdDict::new(dict_id, name, encoder))
|
||||
})
|
||||
.await?;
|
||||
Ok(zstd_dict)
|
||||
let mut compressor = self.compressor.write().await;
|
||||
compressor.add(zstd_dict)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,12 +73,14 @@ async fn initialize(
|
||||
)?;
|
||||
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS dictionaries (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
level INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
dict BLOB NOT NULL
|
||||
)",
|
||||
concat_lines!(
|
||||
"CREATE TABLE IF NOT EXISTS dictionaries (",
|
||||
" id INTEGER PRIMARY KEY AUTOINCREMENT,",
|
||||
" level INTEGER NOT NULL,",
|
||||
" name TEXT NOT NULL,",
|
||||
" dict BLOB NOT NULL",
|
||||
")"
|
||||
),
|
||||
[],
|
||||
)?;
|
||||
|
||||
@@ -126,17 +117,12 @@ async fn initialize(
|
||||
}
|
||||
};
|
||||
|
||||
let rows = conn
|
||||
type DictRow = (ZstdDictId, String, i32, Vec<u8>);
|
||||
let rows: Vec<DictRow> = conn
|
||||
.call(|conn| {
|
||||
let mut stmt = conn.prepare("SELECT id, name, level, dict FROM dictionaries")?;
|
||||
let mut rows = vec![];
|
||||
for r in stmt.query_map([], |row| {
|
||||
let id = row.get(0)?;
|
||||
let name: String = row.get(1)?;
|
||||
let level: i32 = row.get(2)?;
|
||||
let dict: Vec<u8> = row.get(3)?;
|
||||
Ok((id, name, level, dict))
|
||||
})? {
|
||||
for r in stmt.query_map([], |row| DictRow::try_from(row))? {
|
||||
rows.push(r?);
|
||||
}
|
||||
Ok(rows)
|
||||
@@ -144,8 +130,12 @@ async fn initialize(
|
||||
.await?;
|
||||
|
||||
let mut compressor = Compressor::default();
|
||||
for (id, name, level, dict_bytes) in rows {
|
||||
compressor.add_from_bytes(id, name, level, dict_bytes)?;
|
||||
for (dict_id, name, level, dict_bytes) in rows {
|
||||
compressor.add(ZstdDict::new(
|
||||
dict_id,
|
||||
name,
|
||||
ZstdEncoder::from_dict_bytes(level, dict_bytes),
|
||||
))?;
|
||||
}
|
||||
let compressor = compressor.into_arc();
|
||||
Ok(Manifest {
|
||||
@@ -166,7 +156,7 @@ mod tests {
|
||||
|
||||
let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100];
|
||||
let zstd_dict = manifest
|
||||
.insert_zstd_dict_from_samples("test", samples)
|
||||
.insert_dict_from_samples("test", samples)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
16
src/shard/fn_hint_names.rs
Normal file
16
src/shard/fn_hint_names.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
use crate::AsyncBoxError;
|
||||
|
||||
use super::Shard;
|
||||
|
||||
impl Shard {
|
||||
pub async fn hint_names(&self) -> Result<Vec<String>, AsyncBoxError> {
|
||||
self.conn
|
||||
.call(|conn| {
|
||||
let mut stmt = conn.prepare("SELECT DISTINCT name FROM compression_hints")?;
|
||||
let rows = stmt.query_map([], |row| row.get(0))?;
|
||||
Ok(rows.collect::<Result<Vec<_>, _>>()?)
|
||||
})
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::AsyncBoxError;
|
||||
use crate::{concat_lines, AsyncBoxError};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -37,17 +37,23 @@ impl Shard {
|
||||
|
||||
fn ensure_schema_versions_table(conn: &rusqlite::Connection) -> Result<usize, rusqlite::Error> {
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS schema_version (
|
||||
version INTEGER PRIMARY KEY,
|
||||
created_at TEXT NOT NULL
|
||||
)",
|
||||
concat_lines!(
|
||||
"CREATE TABLE IF NOT EXISTS schema_version (",
|
||||
" version INTEGER PRIMARY KEY,",
|
||||
" created_at TEXT NOT NULL",
|
||||
")"
|
||||
),
|
||||
[],
|
||||
)
|
||||
}
|
||||
|
||||
fn load_schema_rows(conn: &rusqlite::Connection) -> Result<Vec<(i64, String)>, rusqlite::Error> {
|
||||
let mut stmt = conn
|
||||
.prepare("SELECT version, created_at FROM schema_version ORDER BY version DESC LIMIT 1")?;
|
||||
let mut stmt = conn.prepare(concat_lines!(
|
||||
"SELECT version, created_at",
|
||||
"FROM schema_version",
|
||||
"ORDER BY version",
|
||||
"DESC LIMIT 1"
|
||||
))?;
|
||||
let rows = stmt.query_map([], |row| {
|
||||
let version = row.get(0)?;
|
||||
let created_at = row.get(1)?;
|
||||
@@ -59,35 +65,45 @@ fn load_schema_rows(conn: &rusqlite::Connection) -> Result<Vec<(i64, String)>, r
|
||||
fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
|
||||
debug!("migrating to version 1");
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS entries (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
sha256 BLOB NOT NULL,
|
||||
content_type TEXT NOT NULL,
|
||||
compression_id INTEGER NOT NULL,
|
||||
uncompressed_size INTEGER NOT NULL,
|
||||
compressed_size INTEGER NOT NULL,
|
||||
data BLOB NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
)",
|
||||
concat_lines!(
|
||||
"CREATE TABLE IF NOT EXISTS entries (",
|
||||
" id INTEGER PRIMARY KEY AUTOINCREMENT,",
|
||||
" sha256 BLOB NOT NULL,",
|
||||
" content_type TEXT NOT NULL,",
|
||||
" compression_id INTEGER NOT NULL,",
|
||||
" uncompressed_size INTEGER NOT NULL,",
|
||||
" compressed_size INTEGER NOT NULL,",
|
||||
" data BLOB NOT NULL,",
|
||||
" created_at TEXT NOT NULL",
|
||||
")"
|
||||
),
|
||||
[],
|
||||
)?;
|
||||
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS entries_sha256_idx ON entries (sha256)",
|
||||
concat_lines!(
|
||||
"CREATE INDEX IF NOT EXISTS entries_sha256_idx",
|
||||
"ON entries (sha256)"
|
||||
),
|
||||
[],
|
||||
)?;
|
||||
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS compression_hints (
|
||||
name TEXT NOT NULL,
|
||||
ordering INTEGER NOT NULL,
|
||||
entry_id INTEGER NOT NULL
|
||||
)",
|
||||
concat_lines!(
|
||||
"CREATE TABLE IF NOT EXISTS compression_hints (",
|
||||
" name TEXT NOT NULL,",
|
||||
" ordering INTEGER NOT NULL,",
|
||||
" entry_id INTEGER NOT NULL",
|
||||
")"
|
||||
),
|
||||
[],
|
||||
)?;
|
||||
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS compression_hints_name_idx ON compression_hints (name, ordering)",
|
||||
concat_lines!(
|
||||
"CREATE INDEX IF NOT EXISTS compression_hints_name_idx",
|
||||
"ON compression_hints (name, ordering)",
|
||||
),
|
||||
[],
|
||||
)?;
|
||||
|
||||
|
||||
@@ -10,12 +10,12 @@ pub struct SampleForHintResult {
|
||||
}
|
||||
|
||||
impl Shard {
|
||||
pub async fn samples_for_hint(
|
||||
pub async fn samples_for_hint_name(
|
||||
&self,
|
||||
compression_hint: &str,
|
||||
hint_name: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<SampleForHintResult>, AsyncBoxError> {
|
||||
let compression_hint = compression_hint.to_owned();
|
||||
let hint_name = hint_name.to_owned();
|
||||
let result = self
|
||||
.conn
|
||||
.call(move |conn| {
|
||||
@@ -24,7 +24,7 @@ impl Shard {
|
||||
SELECT entry_id FROM compression_hints WHERE name = ?1 ORDER BY ordering
|
||||
) LIMIT ?2",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![compression_hint, limit], |row| {
|
||||
let rows = stmt.query_map(params![hint_name, limit], |row| {
|
||||
let sha256: Sha256 = row.get(0)?;
|
||||
let data: Vec<u8> = row.get(1)?;
|
||||
Ok(SampleForHintResult { sha256, data })
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
mod fn_get;
|
||||
mod fn_hint_names;
|
||||
mod fn_migrate;
|
||||
mod fn_samples_for_hint;
|
||||
mod fn_store;
|
||||
@@ -6,6 +7,7 @@ pub mod shard_error;
|
||||
mod shard_struct;
|
||||
|
||||
pub use fn_get::{GetArgs, GetResult};
|
||||
pub use fn_samples_for_hint::SampleForHintResult;
|
||||
pub use fn_store::{StoreArgs, StoreResult};
|
||||
pub use shard_struct::Shard;
|
||||
|
||||
|
||||
@@ -246,7 +246,7 @@ pub mod test {
|
||||
.unwrap();
|
||||
assert!(matches!(store_result, StoreResult::Created { .. }));
|
||||
|
||||
let results = shard.samples_for_hint("hint1", 10).await.unwrap();
|
||||
let results = shard.samples_for_hint_name("hint1", 10).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].sha256, sha256);
|
||||
assert_eq!(results[0].data, data);
|
||||
@@ -298,7 +298,7 @@ pub mod test {
|
||||
assert_eq!(hint2.len(), insert_num);
|
||||
|
||||
let hint1samples = shard
|
||||
.samples_for_hint("hint1", sample_num)
|
||||
.samples_for_hint_name("hint1", sample_num)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
@@ -306,7 +306,7 @@ pub mod test {
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let hint2samples = shard
|
||||
.samples_for_hint("hint2", sample_num)
|
||||
.samples_for_hint_name("hint2", sample_num)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
use std::sync::Arc;
|
||||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
use crate::{sha256::Sha256, shard::Shard};
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
use crate::{
|
||||
sha256::Sha256,
|
||||
shard::{SampleForHintResult, Shard},
|
||||
AsyncBoxError,
|
||||
};
|
||||
|
||||
pub type ShardsArc = Arc<Shards>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Shards(Vec<Shard>);
|
||||
pub struct Shards(Vec<Arc<Shard>>);
|
||||
impl Shards {
|
||||
pub fn new(shards: Vec<Shard>) -> Option<Self> {
|
||||
if shards.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some(Self(shards))
|
||||
Some(Self(shards.into_iter().map(Arc::new).collect()))
|
||||
}
|
||||
|
||||
pub fn shard_for(&self, sha256: &Sha256) -> &Shard {
|
||||
@@ -19,13 +25,43 @@ impl Shards {
|
||||
&self.0[shard_id]
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> std::slice::Iter<'_, Shard> {
|
||||
self.0.iter()
|
||||
pub fn iter(&self) -> impl Iterator<Item = &Shard> {
|
||||
self.0.iter().map(|shard| shard.as_ref())
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
pub async fn hint_names(&self) -> Result<HashSet<String>, AsyncBoxError> {
|
||||
let mut hint_names = HashSet::new();
|
||||
for shard in self.iter() {
|
||||
hint_names.extend(shard.hint_names().await?);
|
||||
}
|
||||
Ok(hint_names)
|
||||
}
|
||||
|
||||
pub async fn samples_for_hint_name(
|
||||
&self,
|
||||
hint_name: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<SampleForHintResult>, AsyncBoxError> {
|
||||
let mut tasks = JoinSet::new();
|
||||
|
||||
for shard in self.0.iter() {
|
||||
let shard = shard.clone();
|
||||
let hint_name = hint_name.to_owned();
|
||||
tasks.spawn(async move { shard.samples_for_hint_name(&hint_name, limit).await });
|
||||
}
|
||||
|
||||
let mut hints: Vec<SampleForHintResult> = Vec::new();
|
||||
while let Some(result) = tasks.join_next().await {
|
||||
let result = result??;
|
||||
hints.extend(result);
|
||||
}
|
||||
|
||||
Ok(hints)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
100
src/zstd_dict.rs
100
src/zstd_dict.rs
@@ -1,5 +1,5 @@
|
||||
use ouroboros::self_referencing;
|
||||
use std::{error::Error, io, sync::Arc};
|
||||
use std::{io, sync::Arc};
|
||||
use zstd::dict::{DecoderDictionary, EncoderDictionary};
|
||||
|
||||
use crate::{sql_types::ZstdDictId, AsyncBoxError};
|
||||
@@ -7,9 +7,7 @@ use crate::{sql_types::ZstdDictId, AsyncBoxError};
|
||||
pub type ZstdDictArc = Arc<ZstdDict>;
|
||||
|
||||
#[self_referencing]
|
||||
pub struct ZstdDict {
|
||||
id: crate::sql_types::ZstdDictId,
|
||||
name: String,
|
||||
pub struct ZstdEncoder {
|
||||
level: i32,
|
||||
dict_bytes: Vec<u8>,
|
||||
|
||||
@@ -22,6 +20,33 @@ pub struct ZstdDict {
|
||||
decoder_dict: DecoderDictionary<'this>,
|
||||
}
|
||||
|
||||
impl ZstdEncoder {
|
||||
pub fn from_samples(level: i32, samples: Vec<&[u8]>) -> Self {
|
||||
let dict_bytes = zstd::dict::from_samples(&samples, 1024 * 1024).unwrap();
|
||||
Self::from_dict_bytes(level, dict_bytes)
|
||||
}
|
||||
|
||||
pub fn from_dict_bytes(level: i32, dict_bytes: Vec<u8>) -> Self {
|
||||
ZstdEncoderBuilder {
|
||||
level,
|
||||
dict_bytes,
|
||||
encoder_dict_builder: |dict_bytes| EncoderDictionary::new(dict_bytes, level),
|
||||
decoder_dict_builder: |dict_bytes| DecoderDictionary::new(dict_bytes),
|
||||
}
|
||||
.build()
|
||||
}
|
||||
|
||||
pub fn dict_bytes(&self) -> &[u8] {
|
||||
self.borrow_dict_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ZstdDict {
|
||||
id: ZstdDictId,
|
||||
name: String,
|
||||
encoder: ZstdEncoder,
|
||||
}
|
||||
|
||||
impl PartialEq for ZstdDict {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.id() == other.id()
|
||||
@@ -42,42 +67,21 @@ impl std::fmt::Debug for ZstdDict {
|
||||
}
|
||||
|
||||
impl ZstdDict {
|
||||
pub fn from_samples(
|
||||
id: ZstdDictId,
|
||||
name: String,
|
||||
level: i32,
|
||||
samples: Vec<&[u8]>,
|
||||
) -> Result<Self, Box<dyn Error + Send + Sync>> {
|
||||
let dict_bytes = zstd::dict::from_samples(
|
||||
&samples,
|
||||
1024 * 1024, // 1MB max dictionary size
|
||||
)?;
|
||||
Ok(Self::from_dict_bytes(id, name, level, dict_bytes))
|
||||
}
|
||||
|
||||
pub fn from_dict_bytes(id: ZstdDictId, name: String, level: i32, dict_bytes: Vec<u8>) -> Self {
|
||||
ZstdDictBuilder {
|
||||
id,
|
||||
name,
|
||||
level,
|
||||
dict_bytes,
|
||||
encoder_dict_builder: |dict_bytes| EncoderDictionary::new(dict_bytes, level),
|
||||
decoder_dict_builder: |dict_bytes| DecoderDictionary::new(dict_bytes),
|
||||
}
|
||||
.build()
|
||||
pub fn new(id: ZstdDictId, name: String, encoder: ZstdEncoder) -> Self {
|
||||
Self { id, name, encoder }
|
||||
}
|
||||
|
||||
pub fn id(&self) -> ZstdDictId {
|
||||
*self.borrow_id()
|
||||
self.id
|
||||
}
|
||||
pub fn name(&self) -> &str {
|
||||
self.borrow_name()
|
||||
&self.name
|
||||
}
|
||||
pub fn level(&self) -> i32 {
|
||||
*self.borrow_level()
|
||||
*self.encoder.borrow_level()
|
||||
}
|
||||
pub fn dict_bytes(&self) -> &[u8] {
|
||||
self.borrow_dict_bytes()
|
||||
self.encoder.borrow_dict_bytes()
|
||||
}
|
||||
|
||||
pub fn compress<DataRef: AsRef<[u8]>>(&self, data: DataRef) -> Result<Vec<u8>, AsyncBoxError> {
|
||||
@@ -86,7 +90,7 @@ impl ZstdDict {
|
||||
let mut out_buffer = Vec::with_capacity(as_ref.len());
|
||||
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
|
||||
|
||||
self.with_encoder_dict(|encoder_dict| {
|
||||
self.encoder.with_encoder_dict(|encoder_dict| {
|
||||
let mut encoder =
|
||||
zstd::stream::Encoder::with_prepared_dictionary(&mut output_wrapper, encoder_dict)?;
|
||||
io::copy(&mut wrapper, &mut encoder)?;
|
||||
@@ -104,7 +108,7 @@ impl ZstdDict {
|
||||
let mut out_buffer = Vec::with_capacity(as_ref.len());
|
||||
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
|
||||
|
||||
self.with_decoder_dict(|decoder_dict| {
|
||||
self.encoder.with_decoder_dict(|decoder_dict| {
|
||||
let mut decoder =
|
||||
zstd::stream::Decoder::with_prepared_dictionary(&mut wrapper, decoder_dict)?;
|
||||
io::copy(&mut decoder, &mut output_wrapper)
|
||||
@@ -117,27 +121,29 @@ impl ZstdDict {
|
||||
pub mod test {
|
||||
use crate::sql_types::ZstdDictId;
|
||||
|
||||
use super::ZstdEncoder;
|
||||
|
||||
pub fn make_zstd_dict(id: ZstdDictId, name: &str) -> super::ZstdDict {
|
||||
super::ZstdDict::from_dict_bytes(
|
||||
super::ZstdDict::new(
|
||||
id,
|
||||
name.to_owned(),
|
||||
3,
|
||||
vec![
|
||||
"hello, world",
|
||||
"this is a test",
|
||||
"of the emergency broadcast system",
|
||||
]
|
||||
.into_iter()
|
||||
.chain(vec!["foo", "bar", "baz"].repeat(100))
|
||||
.map(|s| s.as_bytes().to_owned())
|
||||
.flat_map(|s| s.into_iter())
|
||||
.collect(),
|
||||
ZstdEncoder::from_samples(
|
||||
3,
|
||||
vec![
|
||||
"hello, world",
|
||||
"this is a test",
|
||||
"of the emergency broadcast system",
|
||||
]
|
||||
.into_iter()
|
||||
.chain(vec!["foo", "bar", "baz"].repeat(100))
|
||||
.map(|s| s.as_bytes())
|
||||
.collect(),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zstd_dict() {
|
||||
let dict_bytes = vec![1, 2, 3, 4];
|
||||
fn test_zstd_dict_basics() {
|
||||
let zstd_dict = make_zstd_dict(1.into(), "dict1");
|
||||
let compressed = zstd_dict.compress(b"hello world").unwrap();
|
||||
let decompressed = zstd_dict.decompress(&compressed).unwrap();
|
||||
|
||||
Reference in New Issue
Block a user