zstd dict refactorings

This commit is contained in:
Dylan Knutson
2024-05-11 14:36:50 -04:00
parent 6deb909c43
commit 3f44344ac0
18 changed files with 933 additions and 226 deletions

94
Cargo.lock generated
View File

@@ -281,6 +281,7 @@ dependencies = [
"clap",
"futures",
"hex",
"humansize",
"kdam",
"ouroboros",
"rand",
@@ -290,10 +291,12 @@ dependencies = [
"serde",
"serde_json",
"sha2",
"tabled",
"tokio",
"tokio-rusqlite",
"tracing",
"tracing-subscriber",
"walkdir",
"zstd",
]
@@ -312,6 +315,12 @@ version = "3.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
[[package]]
name = "bytecount"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce"
[[package]]
name = "bytes"
version = "1.6.0"
@@ -805,6 +814,15 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "humansize"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6cb51c9a029ddc91b07a787f1d86b53ccfa49b0e86688c946ebe8d3555685dd7"
dependencies = [
"libm",
]
[[package]]
name = "hyper"
version = "1.3.1"
@@ -972,6 +990,12 @@ version = "0.2.153"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd"
[[package]]
name = "libm"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "libsqlite3-sys"
version = "0.28.0"
@@ -1207,6 +1231,17 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "papergrid"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ad43c07024ef767f9160710b3a6773976194758c7919b17e63b863db0bdf7fb"
dependencies = [
"bytecount",
"fnv",
"unicode-width",
]
[[package]]
name = "parking_lot"
version = "0.12.1"
@@ -1552,6 +1587,15 @@ version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1"
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]]
name = "schannel"
version = "0.1.23"
@@ -1734,6 +1778,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
@@ -1781,6 +1826,30 @@ dependencies = [
"libc",
]
[[package]]
name = "tabled"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c998b0c8b921495196a48aabaf1901ff28be0760136e31604f7967b0792050e"
dependencies = [
"papergrid",
"tabled_derive",
"unicode-width",
]
[[package]]
name = "tabled_derive"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c138f99377e5d653a371cdad263615634cfc8467685dfe8e73e2b8e98f44b17"
dependencies = [
"heck 0.4.1",
"proc-macro-error",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "tempfile"
version = "3.10.1"
@@ -2047,6 +2116,12 @@ dependencies = [
"tinyvec",
]
[[package]]
name = "unicode-width"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6"
[[package]]
name = "url"
version = "2.5.0"
@@ -2088,6 +2163,16 @@ version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "walkdir"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
dependencies = [
"same-file",
"winapi-util",
]
[[package]]
name = "want"
version = "0.3.1"
@@ -2195,6 +2280,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b"
dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"

View File

@@ -12,6 +12,10 @@ path = "src/main.rs"
name = "load-test"
path = "load_test/main.rs"
[[bin]]
name = "fixture-inserter"
path = "fixture_inserter/main.rs"
[dependencies]
axum = { version = "0.7.5", features = ["macros"] }
axum_typed_multipart = "0.11.1"
@@ -32,6 +36,9 @@ reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] }
hex = "0.4.3"
zstd = { version = "0.13.1", features = ["experimental"] }
ouroboros = "0.18.3"
humansize = "2.1.3"
walkdir = "2.5.0"
tabled = "0.15.0"
[dev-dependencies]
rstest = "0.19.0"

107
fixture_inserter/main.rs Normal file
View File

@@ -0,0 +1,107 @@
use std::{
error::Error,
sync::{Arc, Mutex},
};
use clap::{arg, Parser};
use kdam::BarExt;
use kdam::{tqdm, Bar};
use rand::prelude::SliceRandom;
use reqwest::blocking::{multipart, Client};
use walkdir::WalkDir;
#[derive(Parser, Debug, Clone)]
struct Args {
#[arg(long)]
hint_name: String,
#[arg(long)]
fixture_dir: String,
#[arg(long)]
limit: Option<usize>,
#[arg(long)]
num_threads: Option<usize>,
}
fn main() -> Result<(), Box<dyn Error>> {
let args = Args::parse();
let hint_name = args.hint_name;
let num_threads = args.num_threads.unwrap_or(1);
let pb = Arc::new(Mutex::new(tqdm!()));
let mut entries = WalkDir::new(&args.fixture_dir)
.into_iter()
.filter_map(Result::ok)
.filter(|entry| entry.file_type().is_file())
.filter_map(|entry| {
let path = entry.path();
let name = path.to_str()?.to_string();
Some(name)
})
.collect::<Vec<_>>();
// shuffle the entries
entries.shuffle(&mut rand::thread_rng());
// if there's a limit, drop the rest
if let Some(limit) = args.limit {
entries.truncate(limit);
}
let entry_slices = entries.chunks(entries.len() / num_threads);
let join_handles = entry_slices.map(|entry_slice| {
let hint_name = hint_name.clone();
let pb = pb.clone();
let entry_slice = entry_slice.to_vec();
std::thread::spawn(move || {
let client = Client::new();
for entry in entry_slice.into_iter() {
store_file(&client, &hint_name, &entry, pb.clone()).unwrap();
}
})
});
for join_handle in join_handles {
join_handle.join().unwrap();
}
Ok(())
}
fn store_file(
client: &Client,
hint_name: &str,
file_name: &str,
pb: Arc<Mutex<Bar>>,
) -> Result<(), Box<dyn Error>> {
let file_bytes = std::fs::read(file_name)?;
let content_type = if file_name.ends_with(".html") {
"text/html"
} else if file_name.ends_with(".json") {
"application/json"
} else if file_name.ends_with(".jpg") || file_name.ends_with(".jpeg") {
"image/jpeg"
} else if file_name.ends_with(".png") {
"image/png"
} else if file_name.ends_with(".txt") {
"text/plain"
} else if file_name.ends_with(".kindle.images") {
"application/octet-stream"
} else {
"text/html"
};
let form = multipart::Form::new()
.text("content_type", content_type)
.text("compression_hint", hint_name.to_string())
.part("data", multipart::Part::bytes(file_bytes));
let _ = client
.post("http://localhost:7692/store")
.multipart(form)
.send();
let mut pb = pb.lock().unwrap();
pb.update(1)?;
Ok(())
}

View File

@@ -43,14 +43,23 @@ fn run_loop(pb: Arc<Mutex<Bar>>, args: Args) -> Result<(), Box<dyn std::error::E
let mut rng = rand::thread_rng();
let mut rand_data = vec![0u8; args.file_size];
rng.fill(&mut rand_data[..]);
let hints = vec![
"foo", "bar", "baz", "qux", "quux", "corge", "grault", "garply", "waldo", "fred", "plugh",
"xyzzy", "thud",
];
loop {
// tweak a byte in the data
let idx = rng.gen_range(0..rand_data.len());
rand_data[idx] = rng.gen();
let hint = hints[rng.gen_range(0..hints.len())];
let form = reqwest::blocking::multipart::Form::new()
.text("content_type", "text/plain")
.part(
"compression_hint",
reqwest::blocking::multipart::Part::text(hint),
)
.part(
"data",
reqwest::blocking::multipart::Part::bytes(rand_data.clone()),

45
src/compression_stats.rs Normal file
View File

@@ -0,0 +1,45 @@
use std::collections::HashMap;
use crate::sql_types::DictId;
#[derive(Default, Copy, Clone)]
pub struct CompressionStat {
pub num_entries: usize,
pub uncompressed_size: usize,
pub compressed_size: usize,
}
#[derive(Default)]
pub struct CompressionStats {
id_to_stat_map: HashMap<DictId, CompressionStat>,
id_ordering: Vec<DictId>,
}
impl CompressionStats {
pub fn for_id_mut(&mut self, dict_id: DictId) -> &mut CompressionStat {
self.id_to_stat_map.entry(dict_id).or_insert_with_key(|_| {
self.id_ordering.push(dict_id);
self.id_ordering.sort();
Default::default()
})
}
pub fn add_entry(
&mut self,
dict_id: DictId,
uncompressed_size: usize,
compressed_size: usize,
) -> &CompressionStat {
let stat = self.for_id_mut(dict_id);
stat.num_entries += 1;
stat.uncompressed_size += uncompressed_size;
stat.compressed_size += compressed_size;
stat
}
pub fn iter(&self) -> impl Iterator<Item = (DictId, &CompressionStat)> {
self.id_ordering
.iter()
.flat_map(|id| self.id_to_stat_map.get(id).map(|stat| (*id, stat)))
}
}

View File

@@ -1,11 +1,13 @@
use std::{collections::HashMap, sync::Arc};
use rand::Rng;
use tokio::sync::RwLock;
use crate::{
compressible_data::CompressibleData,
sql_types::{CompressionId, ZstdDictId},
zstd_dict::{ZstdDict, ZstdDictArc},
compression_stats::{CompressionStat, CompressionStats},
sql_types::{DictId, ZstdDictId},
zstd_dict::{ZstdDict, ZstdDictArc, ZSTD_LEVEL},
AsyncBoxError, CompressionPolicy,
};
@@ -13,15 +15,18 @@ pub type CompressorArc = Arc<RwLock<Compressor>>;
pub struct Compressor {
zstd_dict_by_id: HashMap<ZstdDictId, ZstdDictArc>,
zstd_dict_by_name: HashMap<String, ZstdDictArc>,
zstd_dict_by_name: HashMap<String, Vec<ZstdDictArc>>,
compression_policy: CompressionPolicy,
compression_stats: CompressionStats,
}
impl Compressor {
pub fn new(compression_policy: CompressionPolicy) -> Self {
Self {
zstd_dict_by_id: HashMap::new(),
zstd_dict_by_name: HashMap::new(),
compression_policy,
compression_stats: CompressionStats::default(),
}
}
}
@@ -39,22 +44,40 @@ impl Compressor {
Arc::new(RwLock::new(self))
}
pub fn add(&mut self, zstd_dict: ZstdDict) -> Result<ZstdDictArc> {
self.check(zstd_dict.id(), zstd_dict.name())?;
pub fn add_stat(&mut self, id: DictId, stat: CompressionStat) {
*self.compression_stats.for_id_mut(id) = stat;
}
pub fn add_zstd_dict(&mut self, zstd_dict: ZstdDict) -> Result<ZstdDictArc> {
self.check(zstd_dict.id())?;
let zstd_dict = Arc::new(zstd_dict);
self.zstd_dict_by_id
.insert(zstd_dict.id(), zstd_dict.clone());
self.zstd_dict_by_name
.insert(zstd_dict.name().to_string(), zstd_dict.clone());
let name = zstd_dict.name();
if let Some(by_name_list) = self.zstd_dict_by_name.get_mut(name) {
by_name_list.push(zstd_dict.clone());
} else {
self.zstd_dict_by_name
.insert(name.to_string(), vec![zstd_dict.clone()]);
}
Ok(zstd_dict)
}
pub fn by_id(&self, id: ZstdDictId) -> Option<&ZstdDict> {
self.zstd_dict_by_id.get(&id).map(|d| &**d)
}
pub fn by_name(&self, name: &str) -> Option<&ZstdDict> {
self.zstd_dict_by_name.get(name).map(|d| &**d)
if let Some(by_name_list) = self.zstd_dict_by_name.get(name) {
// take a random element
Some(&by_name_list[rand::thread_rng().gen_range(0..by_name_list.len())])
} else {
None
}
}
pub fn names(&self) -> impl Iterator<Item = &String> {
self.zstd_dict_by_name.keys()
}
@@ -64,7 +87,7 @@ impl Compressor {
hint: Option<&str>,
content_type: &str,
data: Data,
) -> Result<(CompressionId, CompressibleData)> {
) -> Result<(DictId, CompressibleData)> {
let data = data.into();
let should_compress = match self.compression_policy {
CompressionPolicy::None => false,
@@ -74,18 +97,15 @@ impl Compressor {
let generic_compress = || -> Result<_> {
Ok((
CompressionId::ZstdGeneric,
zstd::stream::encode_all(data.as_ref(), 3)?.into(),
DictId::ZstdGeneric,
zstd::stream::encode_all(data.as_ref(), ZSTD_LEVEL)?.into(),
))
};
if should_compress {
let (id, compressed): (_, CompressibleData) = if let Some(hint) = hint {
if let Some(dict) = self.by_name(hint) {
(
CompressionId::ZstdDictId(dict.id()),
dict.compress(&data)?.into(),
)
(DictId::Zstd(dict.id()), dict.compress(&data)?.into())
} else {
generic_compress()?
}
@@ -98,37 +118,42 @@ impl Compressor {
}
}
Ok((CompressionId::None, data))
Ok((DictId::None, data))
}
pub fn decompress<Data: Into<CompressibleData>>(
&self,
compression_id: CompressionId,
dict_id: DictId,
data: Data,
) -> Result<CompressibleData> {
let data = data.into();
match compression_id {
CompressionId::None => Ok(data),
CompressionId::ZstdDictId(id) => {
match dict_id {
DictId::None => Ok(data),
DictId::Zstd(id) => {
if let Some(dict) = self.by_id(id) {
Ok(CompressibleData::Vec(dict.decompress(data.as_ref())?))
} else {
Err(format!("zstd dictionary {:?} not found", id).into())
}
}
CompressionId::ZstdGeneric => Ok(CompressibleData::Vec(zstd::stream::decode_all(
DictId::ZstdGeneric => Ok(CompressibleData::Vec(zstd::stream::decode_all(
data.as_ref(),
)?)),
}
}
fn check(&self, id: ZstdDictId, name: &str) -> Result<()> {
pub fn stats_mut(&mut self) -> &mut CompressionStats {
&mut self.compression_stats
}
pub fn stats(&self) -> &CompressionStats {
&self.compression_stats
}
fn check(&self, id: ZstdDictId) -> Result<()> {
if self.zstd_dict_by_id.contains_key(&id) {
return Err(format!("zstd dictionary id {} already exists", id.0).into());
}
if self.zstd_dict_by_name.contains_key(name) {
return Err(format!("zstd dictionary name {} already exists", name).into());
}
Ok(())
}
}
@@ -146,6 +171,8 @@ fn auto_compressible_content_type(content_type: &str) -> bool {
#[cfg(test)]
pub mod test {
use std::collections::HashSet;
use rstest::rstest;
use super::*;
@@ -158,7 +185,7 @@ pub mod test {
pub fn make_compressor_with(compression_policy: CompressionPolicy) -> Compressor {
let mut compressor = Compressor::new(compression_policy);
let zstd_dict = make_zstd_dict(1.into(), "dict1");
compressor.add(zstd_dict).unwrap();
compressor.add_zstd_dict(zstd_dict).unwrap();
compressor
}
@@ -184,22 +211,22 @@ pub mod test {
) {
let compressor = make_compressor_with(compression_policy);
let data = b"hello, world!".to_vec();
let (compression_id, compressed) = compressor
let (dict_id, compressed) = compressor
.compress(Some("dict1"), content_type, data.clone())
.unwrap();
let data_uncompressed = compressor.decompress(compression_id, compressed).unwrap();
let data_uncompressed = compressor.decompress(dict_id, compressed).unwrap();
assert_eq!(data_uncompressed, data);
}
#[test]
fn test_skip_compressing_small_data() {
let compressor = make_compressor();
let data = b"hello, world".to_vec();
let (compression_id, compressed) = compressor
let data = b"hi!".to_vec();
let (dict_id, compressed) = compressor
.compress(Some("dict1"), "text/plain", data.clone())
.unwrap();
assert_eq!(compression_id, CompressionId::None);
assert_eq!(dict_id, DictId::None);
assert_eq!(compressed, data);
}
@@ -207,11 +234,25 @@ pub mod test {
fn test_compresses_longer_data() {
let compressor = make_compressor();
let data = vec![b'.'; 1024];
let (compression_id, compressed) = compressor
let (dict_id, compressed) = compressor
.compress(Some("dict1"), "text/plain", data.clone())
.unwrap();
assert_eq!(compression_id, CompressionId::ZstdDictId(1.into()));
assert_eq!(dict_id, DictId::Zstd(1.into()));
assert_ne!(compressed, data);
assert!(compressed.len() < data.len());
}
#[test]
fn test_multiple_dicts_same_name() {
let mut compressor = make_compressor();
let zstd_dict = make_zstd_dict(2.into(), "dict1");
compressor.add_zstd_dict(zstd_dict).unwrap();
let mut seen_ids = HashSet::new();
for _ in 0..1000 {
let zstd_dict = compressor.by_name("dict1").unwrap();
seen_ids.insert(zstd_dict.id());
}
assert_eq!(seen_ids, [1, 2].iter().copied().map(ZstdDictId).collect());
}
}

View File

@@ -1,4 +1,5 @@
mod compressible_data;
mod compression_stats;
mod compressor;
mod concat_lines;
mod handlers;
@@ -12,18 +13,22 @@ mod zstd_dict;
use crate::{manifest::Manifest, shards::Shards};
use axum::{
extract::DefaultBodyLimit,
routing::{get, post},
Extension, Router,
};
use clap::{Parser, ValueEnum};
use compressor::CompressorArc;
use core::fmt;
use futures::executor::block_on;
use shard::Shard;
use shards::ShardsArc;
use std::{error::Error, path::PathBuf, sync::Arc};
use std::{error::Error, path::PathBuf, sync::Arc, time::Instant};
use tabled::{settings::Style, Table, Tabled};
use tokio::{net::TcpListener, select, spawn};
use tokio_rusqlite::Connection;
use tracing::{debug, info};
use tracing_subscriber::fmt::{format::Writer, time::FormatTime};
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
@@ -63,9 +68,18 @@ pub fn into_tokio_rusqlite_err<E: Into<AsyncBoxError>>(e: E) -> tokio_rusqlite::
tokio_rusqlite::Error::Other(e.into())
}
struct AppTimeFormatter(Instant);
impl FormatTime for AppTimeFormatter {
fn format_time(&self, w: &mut Writer<'_>) -> fmt::Result {
let e = self.0.elapsed();
write!(w, "{:5}.{:03}s", e.as_secs(), e.subsec_millis())
}
}
fn main() -> Result<(), AsyncBoxError> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.with_timer(AppTimeFormatter(Instant::now()))
.init();
let args = Args::parse();
@@ -80,13 +94,13 @@ fn main() -> Result<(), AsyncBoxError> {
}
// block on opening the manifest
let manifest = block_on(async {
let manifest = Arc::new(block_on(async {
Manifest::open(
Connection::open(db_path.join("manifest.sqlite")).await?,
num_shards,
)
.await
})?;
})?);
// max num_shards threads
let runtime = tokio::runtime::Builder::new_multi_thread()
@@ -102,59 +116,183 @@ fn main() -> Result<(), AsyncBoxError> {
manifest.num_shards()
);
let mut shards_vec = vec![];
let mut num_entries = vec![];
for shard_id in 0..manifest.num_shards() {
let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id));
let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?;
let shard = Shard::open(shard_id, shard_sqlite_conn, manifest.compressor()).await?;
info!(
"shard {} has {} entries",
shard.id(),
shard.num_entries().await?
);
num_entries.push(shard.num_entries().await?);
shards_vec.push(shard);
}
debug!(
"loaded shards with {} total entries <- {:?}",
num_entries.iter().sum::<usize>(),
num_entries
);
let compressor = manifest.compressor();
{
let compressor = compressor.read().await;
debug!(
"loaded compression dictionaries: {:?}",
compressor.names().collect::<Vec<_>>()
);
}
let shards = Arc::new(Shards::new(shards_vec).ok_or("num shards must be > 0")?);
let compressor = manifest.compressor();
let dict_loop_handle = spawn(dict_loop(manifest, shards.clone()));
let server_handle = spawn(server_loop(server, shards, compressor));
dict_loop_handle.await??;
server_handle.await??;
info!("server closed sqlite connections. bye!");
let join_handles = vec![
spawn(save_compression_stats_loop(manifest.clone())),
spawn(dict_stats_loop(manifest.compressor())),
spawn(new_dict_loop(manifest.clone(), shards.clone())),
spawn(server_loop(server, shards, compressor)),
];
for handle in join_handles {
handle.await??;
}
info!("saving compressor stats...");
manifest.save_compression_stats().await?;
info!("done. bye!");
Ok::<_, AsyncBoxError>(())
})?;
Ok(())
}
async fn dict_loop(manifest: Manifest, shards: ShardsArc) -> Result<(), AsyncBoxError> {
async fn save_compression_stats_loop(manifest: Arc<Manifest>) -> Result<(), AsyncBoxError> {
loop {
let mut hint_names = shards.hint_names().await?;
manifest.save_compression_stats().await?;
select! {
_ = tokio::time::sleep(tokio::time::Duration::from_secs(60)) => {}
_ = crate::shutdown_signal::shutdown_signal() => {
info!("persist_dict_stats: shutdown signal received");
return Ok(());
}
}
}
}
async fn dict_stats_loop(compressor: CompressorArc) -> Result<(), AsyncBoxError> {
fn humanized_bytes_str(number: &usize) -> String {
humansize::format_size(*number, humansize::BINARY)
}
fn compression_ratio_str(row: &DictStatTableRow) -> String {
if row.uncompressed_size == 0 {
"(n/a)".to_string()
} else {
format!(
"{:.2} x",
row.compressed_size as f64 / row.uncompressed_size as f64
)
}
}
#[derive(Tabled, Default)]
struct DictStatTableRow {
name: String,
#[tabled(rename = "# entries")]
num_entries: usize,
#[tabled(
rename = "compression ratio",
display_with("compression_ratio_str", self)
)]
_ratio: (),
#[tabled(rename = "uncompressed bytes", display_with = "humanized_bytes_str")]
uncompressed_size: usize,
#[tabled(rename = "compressed bytes", display_with = "humanized_bytes_str")]
compressed_size: usize,
}
loop {
{
let compressor = compressor.read().await;
let stats = compressor.stats();
Table::new(stats.iter().map(|(id, stat)| {
let name = match id {
sql_types::DictId::None => "none",
sql_types::DictId::ZstdGeneric => "zstd_generic",
sql_types::DictId::Zstd(id) => compressor
.by_id(id)
.map(|d| d.name())
.unwrap_or("(missing dict)"),
}
.to_string();
DictStatTableRow {
name,
num_entries: stat.num_entries,
uncompressed_size: stat.uncompressed_size,
compressed_size: stat.compressed_size,
..Default::default()
}
}))
.with(Style::rounded())
.to_string()
.lines()
.for_each(|line| debug!("{}", line));
}
select! {
_ = tokio::time::sleep(tokio::time::Duration::from_secs(5)) => {}
_ = crate::shutdown_signal::shutdown_signal() => {
info!("dict_stats: shutdown signal received");
return Ok(());
}
}
}
}
async fn new_dict_loop(manifest: Arc<Manifest>, shards: ShardsArc) -> Result<(), AsyncBoxError> {
loop {
let hint_names = shards.hint_names().await?;
let mut new_hint_names = hint_names.clone();
{
// find what hint names don't have a corresponding dictionary
let compressor = manifest.compressor();
let compressor = compressor.read().await;
compressor.names().for_each(|name| {
hint_names.remove(name);
new_hint_names.remove(name);
});
}
for hint_name in hint_names {
debug!("creating dictionary for {}", hint_name);
let samples = shards.samples_for_hint_name(&hint_name, 10).await?;
let samples_bytes = samples
.iter()
.map(|s| s.data.as_ref())
.collect::<Vec<&[u8]>>();
manifest
.insert_dict_from_samples(hint_name, samples_bytes)
for hint_name in new_hint_names {
let samples = shards.samples_for_hint_name(&hint_name, 100).await?;
let num_samples = samples.len();
let bytes_samples = samples.iter().map(|s| s.data.len()).sum::<usize>();
let bytes_samples_human = humansize::format_size(bytes_samples, humansize::BINARY);
if num_samples < 10 || bytes_samples < 1024 {
debug!(
"skipping dictionary for {} - not enough samples: ({} samples, {})",
hint_name, num_samples, bytes_samples_human
);
continue;
}
debug!("building dictionary for {}", hint_name);
let now = chrono::Utc::now();
let ref_samples = samples.iter().map(|s| s.data.as_ref()).collect::<Vec<_>>();
let dict = manifest
.insert_zstd_dict_from_samples(&hint_name, ref_samples)
.await?;
let duration = chrono::Utc::now() - now;
let bytes_dict_human =
humansize::format_size(dict.dict_bytes().len(), humansize::BINARY);
debug!(
"built dictionary {} ({}) in {}ms: {} samples / {} sample bytes, {} dict size",
dict.id(),
hint_name,
duration.num_milliseconds(),
num_samples,
bytes_samples_human,
bytes_dict_human,
);
}
select! {
_ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => {}
_ = crate::shutdown_signal::shutdown_signal() => {
info!("dict loop: shutdown signal received");
info!("new_dict_loop: shutdown signal received");
return Ok(());
}
}
@@ -171,7 +309,9 @@ async fn server_loop(
.route("/get/:sha256", get(handlers::get_handler::get_handler))
.route("/info", get(handlers::info_handler::info_handler))
.layer(Extension(shards))
.layer(Extension(compressor));
.layer(Extension(compressor))
// limit the max size of a request body to 100mb
.layer(DefaultBodyLimit::max(1024 * 1024 * 100));
axum::serve(server, app.into_make_service())
.with_graceful_shutdown(crate::shutdown_signal::shutdown_signal())

View File

@@ -1,16 +1,18 @@
mod manifest_key;
use std::{error::Error, sync::Arc};
use std::sync::Arc;
use rusqlite::params;
use tokio::sync::RwLock;
use tokio_rusqlite::Connection;
use tracing::{debug, info};
use crate::{
compression_stats::{CompressionStat, CompressionStats},
compressor::{Compressor, CompressorArc},
concat_lines,
sql_types::ZstdDictId,
zstd_dict::{ZstdDict, ZstdDictArc, ZstdEncoder},
concat_lines, into_tokio_rusqlite_err,
sql_types::{DictId, ZstdDictId},
zstd_dict::{ZstdDict, ZstdDictArc, ZstdEncoder, ZSTD_LEVEL},
AsyncBoxError,
};
@@ -35,60 +37,97 @@ impl Manifest {
self.compressor.clone()
}
pub async fn insert_dict_from_samples<Str: Into<String>>(
pub async fn save_compression_stats(&self) -> Result<(), AsyncBoxError> {
let compressor = self.compressor();
self.conn
.call(move |conn| {
let compressor = compressor.blocking_read();
let compression_stats = compressor.stats();
save_compression_stats(conn, compression_stats).map_err(into_tokio_rusqlite_err)?;
Ok(())
})
.await?;
Ok(())
}
pub async fn insert_zstd_dict_from_samples<Str: Into<String>>(
&self,
name: Str,
samples: Vec<&[u8]>,
) -> Result<ZstdDictArc, AsyncBoxError> {
let name = name.into();
let encoder = ZstdEncoder::from_samples(3, samples);
let encoder = ZstdEncoder::from_samples(ZSTD_LEVEL, samples);
let zstd_dict = self
.conn
.call(move |conn| {
let level = 3;
let mut stmt = conn.prepare(concat_lines!(
"INSERT INTO dictionaries (name, level, dict)",
"VALUES (?, ?, ?)",
"RETURNING id"
"INSERT INTO zstd_dictionaries (name, encoder_json)",
"VALUES (?, ?)",
"RETURNING zstd_dict_id"
))?;
let dict_id =
stmt.query_row(params![name, level, encoder.dict_bytes()], |row| row.get(0))?;
Ok(ZstdDict::new(dict_id, name, encoder))
let zstd_dict_id = stmt.query_row(params![name, encoder], |row| row.get(0))?;
Ok(ZstdDict::new(zstd_dict_id, name, encoder))
})
.await?;
let mut compressor = self.compressor.write().await;
compressor.add(zstd_dict)
compressor.add_stat(DictId::Zstd(zstd_dict.id()), CompressionStat::default());
compressor.add_zstd_dict(zstd_dict)
}
}
async fn initialize(
conn: Connection,
num_shards: Option<usize>,
) -> Result<Manifest, Box<dyn Error + Send + Sync>> {
async fn migrate_manifest(conn: &Connection) -> Result<(), tokio_rusqlite::Error> {
conn.call(|conn| {
conn.execute(
"CREATE TABLE IF NOT EXISTS manifest (key TEXT PRIMARY KEY, value)",
[],
)?;
conn.execute(
concat_lines!(
"CREATE TABLE IF NOT EXISTS zstd_dictionaries (",
" zstd_dict_id INTEGER PRIMARY KEY AUTOINCREMENT,",
" name TEXT NOT NULL,",
" encoder_json TEXT NOT NULL",
")"
),
[],
)?;
conn.execute(
concat_lines!(
"CREATE TABLE IF NOT EXISTS compression_stats (",
" dict_id TEXT PRIMARY KEY,",
" num_entries INTEGER NOT NULL DEFAULT 0,",
" uncompressed_size INTEGER NOT NULL DEFAULT 0,",
" compressed_size INTEGER NOT NULL DEFAULT 0",
")"
),
[],
)?;
// insert the default dictionaries (none, zstd_generic)
conn.execute(
concat_lines!(
"INSERT OR IGNORE INTO compression_stats (dict_id)",
"VALUES (?), (?)"
),
params![DictId::None, DictId::ZstdGeneric],
)?;
Ok(())
})
.await
}
async fn load_and_store_num_shards(
conn: &Connection,
requested_num_shards: Option<usize>,
) -> Result<usize, AsyncBoxError> {
let stored_num_shards: Option<usize> = conn
.call(|conn| {
conn.execute(
"CREATE TABLE IF NOT EXISTS manifest (key TEXT PRIMARY KEY, value)",
[],
)?;
conn.execute(
concat_lines!(
"CREATE TABLE IF NOT EXISTS dictionaries (",
" id INTEGER PRIMARY KEY AUTOINCREMENT,",
" level INTEGER NOT NULL,",
" name TEXT NOT NULL,",
" dict BLOB NOT NULL",
")"
),
[],
)?;
Ok(get_manifest_key(conn, NumShards)?)
})
.call(|conn| Ok(get_manifest_key(conn, NumShards)?))
.await?;
let num_shards = match (stored_num_shards, num_shards) {
Ok(match (stored_num_shards, requested_num_shards) {
(Some(stored_num_shards), Some(num_shards)) => {
if stored_num_shards == num_shards {
num_shards
@@ -115,29 +154,113 @@ async fn initialize(
// existing database, use loaded num_shards
num_shards
}
};
})
}
type DictRow = (ZstdDictId, String, i32, Vec<u8>);
let rows: Vec<DictRow> = conn
async fn load_zstd_dicts(
conn: &Connection,
compressor: &mut Compressor,
) -> Result<(), AsyncBoxError> {
type ZstdDictRow = (ZstdDictId, String, ZstdEncoder);
let rows: Vec<ZstdDictRow> = conn
.call(|conn| {
let mut stmt = conn.prepare("SELECT id, name, level, dict FROM dictionaries")?;
let mut rows = vec![];
for r in stmt.query_map([], |row| DictRow::try_from(row))? {
rows.push(r?);
}
let rows = conn
.prepare(concat_lines!(
"SELECT",
" zstd_dict_id, name, encoder_json",
"FROM zstd_dictionaries",
))?
.query_map([], |row| ZstdDictRow::try_from(row))?
.collect::<Result<_, _>>()?;
Ok(rows)
})
.await?;
let mut compressor = Compressor::default();
for (dict_id, name, level, dict_bytes) in rows {
compressor.add(ZstdDict::new(
dict_id,
name,
ZstdEncoder::from_dict_bytes(level, dict_bytes),
))?;
debug!("loaded {} zstd dictionaries from manifest", rows.len());
for (zstd_dict_id, name, encoder) in rows {
compressor.add_zstd_dict(ZstdDict::new(zstd_dict_id, name, encoder))?;
}
let compressor = compressor.into_arc();
Ok(())
}
async fn load_compression_stats(
conn: &Connection,
compressor: &mut Compressor,
) -> Result<(), AsyncBoxError> {
type CompressionStatRow = (DictId, usize, usize, usize);
let rows: Vec<CompressionStatRow> = conn
.call(|conn| {
let rows = conn
.prepare(concat_lines!(
"SELECT",
" dict_id, num_entries, uncompressed_size, compressed_size",
"FROM compression_stats",
))?
.query_map([], |row| CompressionStatRow::try_from(row))?
.collect::<Result<_, _>>()?;
Ok(rows)
})
.await?;
debug!("loaded {} compression stats from manifest", rows.len());
for (dict_id, num_entries, uncompressed_size, compressed_size) in rows {
compressor.add_stat(
dict_id,
CompressionStat {
num_entries,
uncompressed_size,
compressed_size,
},
);
}
Ok(())
}
fn save_compression_stats(
conn: &rusqlite::Connection,
compression_stats: &CompressionStats,
) -> Result<(), AsyncBoxError> {
let mut num_stats = 0;
for (id, stat) in compression_stats.iter() {
num_stats += 1;
conn.execute(
concat_lines!(
"INSERT INTO compression_stats",
" (dict_id, num_entries, uncompressed_size, compressed_size)",
"VALUES (?, ?, ?, ?)",
"ON CONFLICT (dict_id) DO UPDATE SET",
" num_entries = excluded.num_entries,",
" uncompressed_size = excluded.uncompressed_size,",
" compressed_size = excluded.compressed_size"
),
params![
id,
stat.num_entries,
stat.uncompressed_size,
stat.compressed_size
],
)?;
}
info!("saved {} compressor stats", num_stats);
Ok(())
}
async fn load_compressor(conn: &Connection) -> Result<CompressorArc, AsyncBoxError> {
let mut compressor = Compressor::default();
load_zstd_dicts(conn, &mut compressor).await?;
load_compression_stats(conn, &mut compressor).await?;
Ok(compressor.into_arc())
}
async fn initialize(
conn: Connection,
requested_num_shards: Option<usize>,
) -> Result<Manifest, AsyncBoxError> {
migrate_manifest(&conn).await?;
let num_shards = load_and_store_num_shards(&conn, requested_num_shards).await?;
let compressor = load_compressor(&conn).await?;
Ok(Manifest {
conn,
num_shards,
@@ -149,14 +272,23 @@ async fn initialize(
mod tests {
use super::*;
#[tokio::test]
async fn test_manifest_compression_stats_loading() {
let conn = Connection::open_in_memory().await.unwrap();
let manifest = initialize(conn, Some(4)).await.unwrap();
let compressor = manifest.compressor();
let compressor = compressor.read().await;
assert_eq!(compressor.stats().iter().count(), 2);
}
#[tokio::test]
async fn test_manifest() {
let conn = Connection::open_in_memory().await.unwrap();
let manifest = initialize(conn, Some(3)).await.unwrap();
let manifest = initialize(conn, Some(4)).await.unwrap();
let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100];
let zstd_dict = manifest
.insert_dict_from_samples("test", samples)
.insert_zstd_dict_from_samples("test", samples)
.await
.unwrap();

View File

@@ -1,12 +1,10 @@
use crate::{
compressible_data::CompressibleData,
into_tokio_rusqlite_err,
sql_types::{CompressionId, UtcDateTime},
sql_types::{DictId, UtcDateTime},
AsyncBoxError,
};
use super::*;
pub struct GetArgs {
pub sha256: Sha256,
}
@@ -23,18 +21,12 @@ impl Shard {
pub async fn get(&self, args: GetArgs) -> Result<Option<GetResult>, AsyncBoxError> {
let maybe_row = self
.conn
.call(move |conn| {
get_compressed_row(conn, &args.sha256).map_err(into_tokio_rusqlite_err)
})
.await
.map_err(|e| {
error!("get failed: {}", e);
Box::new(e)
})?;
.call(move |conn| Ok(get_compressed_row(conn, &args.sha256)?))
.await?;
if let Some((content_type, stored_size, created_at, compression_id, data)) = maybe_row {
if let Some((content_type, stored_size, created_at, dict_id, data)) = maybe_row {
let compressor = self.compressor.read().await;
let data = compressor.decompress(compression_id, data)?;
let data = compressor.decompress(dict_id, data)?;
Ok(Some(GetResult {
sha256: args.sha256,
content_type,
@@ -48,13 +40,13 @@ impl Shard {
}
}
type CompressedRowResult = (String, usize, UtcDateTime, CompressionId, Vec<u8>);
type CompressedRowResult = (String, usize, UtcDateTime, DictId, Vec<u8>);
fn get_compressed_row(
conn: &mut rusqlite::Connection,
sha256: &Sha256,
) -> Result<Option<CompressedRowResult>, rusqlite::Error> {
conn.query_row(
"SELECT content_type, compressed_size, created_at, compression_id, data
"SELECT content_type, compressed_size, created_at, dict_id, data
FROM entries
WHERE sha256 = ?",
params![sha256],

View File

@@ -11,12 +11,7 @@ impl Shard {
ensure_schema_versions_table(conn)?;
let schema_rows = load_schema_rows(conn)?;
if let Some((version, date_time)) = schema_rows.first() {
debug!(
"shard {}: latest schema version: {} @ {}",
shard_id, version, date_time
);
if let Some((version, _)) = schema_rows.first() {
if *version == 1 {
// no-op
} else {
@@ -70,7 +65,7 @@ fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Err
" id INTEGER PRIMARY KEY AUTOINCREMENT,",
" sha256 BLOB NOT NULL,",
" content_type TEXT NOT NULL,",
" compression_id INTEGER NOT NULL,",
" dict_id TEXT NOT NULL,",
" uncompressed_size INTEGER NOT NULL,",
" compressed_size INTEGER NOT NULL,",
" data BLOB NOT NULL,",
@@ -101,7 +96,7 @@ fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Err
conn.execute(
concat_lines!(
"CREATE INDEX IF NOT EXISTS compression_hints_name_idx",
"CREATE UNIQUE INDEX IF NOT EXISTS compression_hints_name_idx",
"ON compression_hints (name, ordering)",
),
[],

View File

@@ -1,7 +1,7 @@
use crate::{
compressible_data::CompressibleData,
into_tokio_rusqlite_err,
sql_types::{CompressionId, UtcDateTime},
concat_lines, into_tokio_rusqlite_err,
sql_types::{DictId, EntryId, UtcDateTime},
AsyncBoxError,
};
@@ -50,18 +50,25 @@ impl Shard {
let uncompressed_size = data.len();
let (compression_id, data) = {
let (dict_id, data) = {
let compressor = self.compressor.read().await;
compressor.compress(compression_hint.as_deref(), &content_type, data)?
};
{
let mut compressor = self.compressor.write().await;
compressor
.stats_mut()
.add_entry(dict_id, uncompressed_size, data.len());
}
self.conn
.call(move |conn| {
insert(
conn,
&sha256,
content_type,
compression_id,
dict_id,
uncompressed_size,
data,
compression_hint,
@@ -78,7 +85,11 @@ fn find_with_sha256(
sha256: &Sha256,
) -> Result<Option<StoreResult>, rusqlite::Error> {
conn.query_row(
"SELECT uncompressed_size, compressed_size, created_at FROM entries WHERE sha256 = ?",
concat_lines!(
"SELECT uncompressed_size, compressed_size, created_at",
"FROM entries",
"WHERE sha256 = ?"
),
params![sha256],
|row| {
Ok(StoreResult::Exists {
@@ -95,7 +106,7 @@ fn insert(
conn: &mut rusqlite::Connection,
sha256: &Sha256,
content_type: String,
compression_id: CompressionId,
dict_id: DictId,
uncompressed_size: usize,
data: CompressibleData,
compression_hint: Option<String>,
@@ -103,33 +114,37 @@ fn insert(
let created_at = UtcDateTime::now();
let compressed_size = data.len();
let entry_id: i64 = conn.query_row(
"INSERT INTO entries
(sha256, content_type, compression_id, uncompressed_size, compressed_size, data, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
RETURNING id
",
let tx = conn.transaction()?;
let entry_id: EntryId = tx.query_row(
concat_lines!(
"INSERT INTO entries",
" (sha256, content_type, dict_id, uncompressed_size, compressed_size, data, created_at)",
"VALUES (?, ?, ?, ?, ?, ?, ?)",
"RETURNING id"
),
params![
sha256,
content_type,
compression_id,
dict_id,
uncompressed_size,
compressed_size,
data.as_ref(),
created_at,
],
|row| row.get(0)
|row| row.get(0),
)?;
if let Some(compression_hint) = compression_hint {
let rand_ordering = rand::random::<i64>();
conn.execute(
"INSERT INTO compression_hints
(name, ordering, entry_id)
VALUES (?, ?, ?)",
tx.execute(
concat_lines!(
"INSERT INTO compression_hints (name, ordering, entry_id)",
"VALUES (?, ?, ?)"
),
params![compression_hint, rand_ordering, entry_id],
)?;
}
tx.commit()?;
Ok(StoreResult::Created {
stored_size: compressed_size,

View File

@@ -21,4 +21,4 @@ use axum::body::Bytes;
use rusqlite::{params, types::FromSql, OptionalExtension};
use tokio_rusqlite::Connection;
use tracing::{debug, error};
use tracing::debug;

View File

@@ -1,61 +0,0 @@
use rusqlite::{
types::{FromSql, FromSqlError, FromSqlResult, ToSqlOutput, Value::Integer, ValueRef},
Error::ToSqlConversionFailure,
ToSql,
};
use crate::AsyncBoxError;
use super::ZstdDictId;
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
pub enum CompressionId {
None,
ZstdGeneric,
ZstdDictId(ZstdDictId),
}
impl FromSql for CompressionId {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
Ok(match value.as_i64()? {
-1 => CompressionId::None,
-2 => CompressionId::ZstdGeneric,
id => CompressionId::ZstdDictId(ZstdDictId(id)),
})
}
}
impl ToSql for CompressionId {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
let value = match self {
CompressionId::None => -1,
CompressionId::ZstdGeneric => -2,
CompressionId::ZstdDictId(ZstdDictId(id)) => *id,
};
Ok(ToSqlOutput::Owned(Integer(value)))
}
}
impl FromSql for ZstdDictId {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
match value.as_i64()? {
id @ (-1 | -2) => Err(FromSqlError::Other(invalid_zstd_dict_id_err(id))),
id => Ok(ZstdDictId(id)),
}
}
}
impl ToSql for ZstdDictId {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
let value = match self.0 {
id @ (-1 | -2) => return Err(ToSqlConversionFailure(invalid_zstd_dict_id_err(id))),
id => id,
};
Ok(ToSqlOutput::Owned(Integer(value)))
}
}
fn invalid_zstd_dict_id_err(id: i64) -> AsyncBoxError {
format!("Invalid ZstdDictId: {}", id).into()
}

75
src/sql_types/dict_id.rs Normal file
View File

@@ -0,0 +1,75 @@
use rusqlite::{
types::{FromSql, FromSqlError, FromSqlResult, ToSqlOutput, ValueRef},
ToSql,
};
use super::ZstdDictId;
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, PartialOrd, Ord)]
pub enum DictId {
None,
ZstdGeneric,
Zstd(ZstdDictId),
}
impl DictId {
fn name_prefix(&self) -> &str {
match self {
DictId::None => "none",
DictId::ZstdGeneric => "zstd_generic",
DictId::Zstd(_) => "zstd",
}
}
}
impl ToSql for DictId {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
let prefix = self.name_prefix();
if let DictId::Zstd(id) = self {
Ok(ToSqlOutput::from(format!("{}:{}", prefix, id.0)))
} else {
prefix.to_sql()
}
}
}
impl FromSql for DictId {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
let s = value.as_str()?;
let dict_id = if s == "none" {
DictId::None
} else if s == "zstd_generic" {
DictId::ZstdGeneric
} else if let Some(id_str) = s.strip_prefix("zstd:") {
let id = id_str
.parse()
.map_err(|e| FromSqlError::Other(format!("invalid ZstdDictId: {}", e).into()))?;
DictId::Zstd(ZstdDictId::column_result(ValueRef::Integer(id))?)
} else {
return Err(FromSqlError::Other(format!("invalid DictId: {}", s).into()));
};
Ok(dict_id)
}
}
#[cfg(test)]
mod tests {
use rusqlite::types::Value;
use super::*;
#[test]
fn test_dict_id() {
let id = DictId::Zstd(ZstdDictId(42));
let id_str = "zstd:42";
let sql_to_id = DictId::column_result(ValueRef::Text(id_str.as_bytes()));
assert_eq!(id, sql_to_id.unwrap());
let id_to_sql = match id.to_sql() {
Ok(ToSqlOutput::Owned(Value::Text(text))) => text,
_ => panic!("unexpected ToSqlOutput: {:?}", id.to_sql()),
};
assert_eq!(id_str, id_to_sql);
}
}

24
src/sql_types/entry_id.rs Normal file
View File

@@ -0,0 +1,24 @@
use rusqlite::{
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
ToSql,
};
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
pub struct EntryId(pub i64);
impl From<i64> for EntryId {
fn from(id: i64) -> Self {
Self(id)
}
}
impl FromSql for EntryId {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
Ok(value.as_i64()?.into())
}
}
impl ToSql for EntryId {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
self.0.to_sql()
}
}

View File

@@ -1,7 +1,9 @@
mod compression_id;
mod dict_id;
mod entry_id;
mod utc_date_time;
mod zstd_dict_id;
pub use compression_id::CompressionId;
pub use dict_id::DictId;
pub use entry_id::EntryId;
pub use utc_date_time::UtcDateTime;
pub use zstd_dict_id::ZstdDictId;

View File

@@ -1,7 +1,32 @@
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
use std::fmt::Display;
use rusqlite::{
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
ToSql,
};
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, PartialOrd, Ord)]
pub struct ZstdDictId(pub i64);
impl From<i64> for ZstdDictId {
fn from(id: i64) -> Self {
Self(id)
}
}
impl Display for ZstdDictId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "zstd_dict:{}", self.0)
}
}
impl FromSql for ZstdDictId {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
Ok(ZstdDictId(value.as_i64()?))
}
}
impl ToSql for ZstdDictId {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
self.0.to_sql()
}
}

View File

@@ -1,9 +1,18 @@
use ouroboros::self_referencing;
use rusqlite::{
types::{FromSql, FromSqlError, FromSqlResult, ToSqlOutput, Value, ValueRef},
Error::ToSqlConversionFailure,
ToSql,
};
use serde::{Deserialize, Serialize};
use std::{io, sync::Arc};
use zstd::dict::{DecoderDictionary, EncoderDictionary};
use crate::{sql_types::ZstdDictId, AsyncBoxError};
const ENCODER_DICT_SIZE: usize = 2 * 1024 * 1024;
pub const ZSTD_LEVEL: i32 = 9;
pub type ZstdDictArc = Arc<ZstdDict>;
#[self_referencing]
@@ -20,9 +29,62 @@ pub struct ZstdEncoder {
decoder_dict: DecoderDictionary<'this>,
}
impl Serialize for ZstdEncoder {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
use serde::ser::SerializeTuple;
// serialize level and dict_bytes as a tuple
let mut state = serializer.serialize_tuple(2)?;
state.serialize_element(&self.level())?;
state.serialize_element(self.dict_bytes())?;
state.end()
}
}
impl<'de> Deserialize<'de> for ZstdEncoder {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
use serde::de::SeqAccess;
struct ZstdEncoderVisitor;
impl<'de> serde::de::Visitor<'de> for ZstdEncoderVisitor {
type Value = ZstdEncoder;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a tuple of (i32, Vec<u8>)")
}
fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
let level = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
let dict_bytes = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
Ok(ZstdEncoder::from_dict_bytes(level, dict_bytes))
}
}
deserializer.deserialize_tuple(2, ZstdEncoderVisitor)
}
}
impl ToSql for ZstdEncoder {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
let json = serde_json::to_string(self).map_err(|e| ToSqlConversionFailure(Box::new(e)))?;
Ok(ToSqlOutput::Owned(Value::Text(json)))
}
}
impl FromSql for ZstdEncoder {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
let json = value.as_str()?;
serde_json::from_str(json).map_err(|e| FromSqlError::Other(e.into()))
}
}
impl ZstdEncoder {
pub fn from_samples(level: i32, samples: Vec<&[u8]>) -> Self {
let dict_bytes = zstd::dict::from_samples(&samples, 1024 * 1024).unwrap();
let dict_bytes = zstd::dict::from_samples(&samples, ENCODER_DICT_SIZE).unwrap();
Self::from_dict_bytes(level, dict_bytes)
}
@@ -36,6 +98,9 @@ impl ZstdEncoder {
.build()
}
pub fn level(&self) -> i32 {
*self.borrow_level()
}
pub fn dict_bytes(&self) -> &[u8] {
self.borrow_dict_bytes()
}
@@ -128,7 +193,7 @@ pub mod test {
id,
name.to_owned(),
ZstdEncoder::from_samples(
3,
5,
vec![
"hello, world",
"this is a test",