zstd dict refactorings
This commit is contained in:
94
Cargo.lock
generated
94
Cargo.lock
generated
@@ -281,6 +281,7 @@ dependencies = [
|
||||
"clap",
|
||||
"futures",
|
||||
"hex",
|
||||
"humansize",
|
||||
"kdam",
|
||||
"ouroboros",
|
||||
"rand",
|
||||
@@ -290,10 +291,12 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"tabled",
|
||||
"tokio",
|
||||
"tokio-rusqlite",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"walkdir",
|
||||
"zstd",
|
||||
]
|
||||
|
||||
@@ -312,6 +315,12 @@ version = "3.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
|
||||
|
||||
[[package]]
|
||||
name = "bytecount"
|
||||
version = "0.6.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce"
|
||||
|
||||
[[package]]
|
||||
name = "bytes"
|
||||
version = "1.6.0"
|
||||
@@ -805,6 +814,15 @@ version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
|
||||
|
||||
[[package]]
|
||||
name = "humansize"
|
||||
version = "2.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6cb51c9a029ddc91b07a787f1d86b53ccfa49b0e86688c946ebe8d3555685dd7"
|
||||
dependencies = [
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "1.3.1"
|
||||
@@ -972,6 +990,12 @@ version = "0.2.153"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd"
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
|
||||
|
||||
[[package]]
|
||||
name = "libsqlite3-sys"
|
||||
version = "0.28.0"
|
||||
@@ -1207,6 +1231,17 @@ version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
|
||||
|
||||
[[package]]
|
||||
name = "papergrid"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ad43c07024ef767f9160710b3a6773976194758c7919b17e63b863db0bdf7fb"
|
||||
dependencies = [
|
||||
"bytecount",
|
||||
"fnv",
|
||||
"unicode-width",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.1"
|
||||
@@ -1552,6 +1587,15 @@ version = "1.0.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1"
|
||||
|
||||
[[package]]
|
||||
name = "same-file"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
version = "0.1.23"
|
||||
@@ -1734,6 +1778,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
@@ -1781,6 +1826,30 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tabled"
|
||||
version = "0.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c998b0c8b921495196a48aabaf1901ff28be0760136e31604f7967b0792050e"
|
||||
dependencies = [
|
||||
"papergrid",
|
||||
"tabled_derive",
|
||||
"unicode-width",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tabled_derive"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c138f99377e5d653a371cdad263615634cfc8467685dfe8e73e2b8e98f44b17"
|
||||
dependencies = [
|
||||
"heck 0.4.1",
|
||||
"proc-macro-error",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tempfile"
|
||||
version = "3.10.1"
|
||||
@@ -2047,6 +2116,12 @@ dependencies = [
|
||||
"tinyvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-width"
|
||||
version = "0.1.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6"
|
||||
|
||||
[[package]]
|
||||
name = "url"
|
||||
version = "2.5.0"
|
||||
@@ -2088,6 +2163,16 @@ version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||
|
||||
[[package]]
|
||||
name = "walkdir"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
|
||||
dependencies = [
|
||||
"same-file",
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "want"
|
||||
version = "0.3.1"
|
||||
@@ -2195,6 +2280,15 @@ version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-util"
|
||||
version = "0.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b"
|
||||
dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
|
||||
@@ -12,6 +12,10 @@ path = "src/main.rs"
|
||||
name = "load-test"
|
||||
path = "load_test/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "fixture-inserter"
|
||||
path = "fixture_inserter/main.rs"
|
||||
|
||||
[dependencies]
|
||||
axum = { version = "0.7.5", features = ["macros"] }
|
||||
axum_typed_multipart = "0.11.1"
|
||||
@@ -32,6 +36,9 @@ reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] }
|
||||
hex = "0.4.3"
|
||||
zstd = { version = "0.13.1", features = ["experimental"] }
|
||||
ouroboros = "0.18.3"
|
||||
humansize = "2.1.3"
|
||||
walkdir = "2.5.0"
|
||||
tabled = "0.15.0"
|
||||
|
||||
[dev-dependencies]
|
||||
rstest = "0.19.0"
|
||||
|
||||
107
fixture_inserter/main.rs
Normal file
107
fixture_inserter/main.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
use std::{
|
||||
error::Error,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use clap::{arg, Parser};
|
||||
use kdam::BarExt;
|
||||
use kdam::{tqdm, Bar};
|
||||
use rand::prelude::SliceRandom;
|
||||
use reqwest::blocking::{multipart, Client};
|
||||
use walkdir::WalkDir;
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
hint_name: String,
|
||||
#[arg(long)]
|
||||
fixture_dir: String,
|
||||
#[arg(long)]
|
||||
limit: Option<usize>,
|
||||
#[arg(long)]
|
||||
num_threads: Option<usize>,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
let args = Args::parse();
|
||||
let hint_name = args.hint_name;
|
||||
let num_threads = args.num_threads.unwrap_or(1);
|
||||
let pb = Arc::new(Mutex::new(tqdm!()));
|
||||
|
||||
let mut entries = WalkDir::new(&args.fixture_dir)
|
||||
.into_iter()
|
||||
.filter_map(Result::ok)
|
||||
.filter(|entry| entry.file_type().is_file())
|
||||
.filter_map(|entry| {
|
||||
let path = entry.path();
|
||||
let name = path.to_str()?.to_string();
|
||||
Some(name)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// shuffle the entries
|
||||
entries.shuffle(&mut rand::thread_rng());
|
||||
|
||||
// if there's a limit, drop the rest
|
||||
if let Some(limit) = args.limit {
|
||||
entries.truncate(limit);
|
||||
}
|
||||
|
||||
let entry_slices = entries.chunks(entries.len() / num_threads);
|
||||
|
||||
let join_handles = entry_slices.map(|entry_slice| {
|
||||
let hint_name = hint_name.clone();
|
||||
let pb = pb.clone();
|
||||
let entry_slice = entry_slice.to_vec();
|
||||
std::thread::spawn(move || {
|
||||
let client = Client::new();
|
||||
for entry in entry_slice.into_iter() {
|
||||
store_file(&client, &hint_name, &entry, pb.clone()).unwrap();
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
for join_handle in join_handles {
|
||||
join_handle.join().unwrap();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn store_file(
|
||||
client: &Client,
|
||||
hint_name: &str,
|
||||
file_name: &str,
|
||||
pb: Arc<Mutex<Bar>>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
let file_bytes = std::fs::read(file_name)?;
|
||||
|
||||
let content_type = if file_name.ends_with(".html") {
|
||||
"text/html"
|
||||
} else if file_name.ends_with(".json") {
|
||||
"application/json"
|
||||
} else if file_name.ends_with(".jpg") || file_name.ends_with(".jpeg") {
|
||||
"image/jpeg"
|
||||
} else if file_name.ends_with(".png") {
|
||||
"image/png"
|
||||
} else if file_name.ends_with(".txt") {
|
||||
"text/plain"
|
||||
} else if file_name.ends_with(".kindle.images") {
|
||||
"application/octet-stream"
|
||||
} else {
|
||||
"text/html"
|
||||
};
|
||||
|
||||
let form = multipart::Form::new()
|
||||
.text("content_type", content_type)
|
||||
.text("compression_hint", hint_name.to_string())
|
||||
.part("data", multipart::Part::bytes(file_bytes));
|
||||
|
||||
let _ = client
|
||||
.post("http://localhost:7692/store")
|
||||
.multipart(form)
|
||||
.send();
|
||||
|
||||
let mut pb = pb.lock().unwrap();
|
||||
pb.update(1)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -43,14 +43,23 @@ fn run_loop(pb: Arc<Mutex<Bar>>, args: Args) -> Result<(), Box<dyn std::error::E
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut rand_data = vec![0u8; args.file_size];
|
||||
rng.fill(&mut rand_data[..]);
|
||||
let hints = vec![
|
||||
"foo", "bar", "baz", "qux", "quux", "corge", "grault", "garply", "waldo", "fred", "plugh",
|
||||
"xyzzy", "thud",
|
||||
];
|
||||
|
||||
loop {
|
||||
// tweak a byte in the data
|
||||
let idx = rng.gen_range(0..rand_data.len());
|
||||
rand_data[idx] = rng.gen();
|
||||
let hint = hints[rng.gen_range(0..hints.len())];
|
||||
|
||||
let form = reqwest::blocking::multipart::Form::new()
|
||||
.text("content_type", "text/plain")
|
||||
.part(
|
||||
"compression_hint",
|
||||
reqwest::blocking::multipart::Part::text(hint),
|
||||
)
|
||||
.part(
|
||||
"data",
|
||||
reqwest::blocking::multipart::Part::bytes(rand_data.clone()),
|
||||
|
||||
45
src/compression_stats.rs
Normal file
45
src/compression_stats.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::sql_types::DictId;
|
||||
|
||||
#[derive(Default, Copy, Clone)]
|
||||
pub struct CompressionStat {
|
||||
pub num_entries: usize,
|
||||
pub uncompressed_size: usize,
|
||||
pub compressed_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct CompressionStats {
|
||||
id_to_stat_map: HashMap<DictId, CompressionStat>,
|
||||
id_ordering: Vec<DictId>,
|
||||
}
|
||||
|
||||
impl CompressionStats {
|
||||
pub fn for_id_mut(&mut self, dict_id: DictId) -> &mut CompressionStat {
|
||||
self.id_to_stat_map.entry(dict_id).or_insert_with_key(|_| {
|
||||
self.id_ordering.push(dict_id);
|
||||
self.id_ordering.sort();
|
||||
Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_entry(
|
||||
&mut self,
|
||||
dict_id: DictId,
|
||||
uncompressed_size: usize,
|
||||
compressed_size: usize,
|
||||
) -> &CompressionStat {
|
||||
let stat = self.for_id_mut(dict_id);
|
||||
stat.num_entries += 1;
|
||||
stat.uncompressed_size += uncompressed_size;
|
||||
stat.compressed_size += compressed_size;
|
||||
stat
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = (DictId, &CompressionStat)> {
|
||||
self.id_ordering
|
||||
.iter()
|
||||
.flat_map(|id| self.id_to_stat_map.get(id).map(|stat| (*id, stat)))
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,13 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use rand::Rng;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::{
|
||||
compressible_data::CompressibleData,
|
||||
sql_types::{CompressionId, ZstdDictId},
|
||||
zstd_dict::{ZstdDict, ZstdDictArc},
|
||||
compression_stats::{CompressionStat, CompressionStats},
|
||||
sql_types::{DictId, ZstdDictId},
|
||||
zstd_dict::{ZstdDict, ZstdDictArc, ZSTD_LEVEL},
|
||||
AsyncBoxError, CompressionPolicy,
|
||||
};
|
||||
|
||||
@@ -13,15 +15,18 @@ pub type CompressorArc = Arc<RwLock<Compressor>>;
|
||||
|
||||
pub struct Compressor {
|
||||
zstd_dict_by_id: HashMap<ZstdDictId, ZstdDictArc>,
|
||||
zstd_dict_by_name: HashMap<String, ZstdDictArc>,
|
||||
zstd_dict_by_name: HashMap<String, Vec<ZstdDictArc>>,
|
||||
compression_policy: CompressionPolicy,
|
||||
compression_stats: CompressionStats,
|
||||
}
|
||||
|
||||
impl Compressor {
|
||||
pub fn new(compression_policy: CompressionPolicy) -> Self {
|
||||
Self {
|
||||
zstd_dict_by_id: HashMap::new(),
|
||||
zstd_dict_by_name: HashMap::new(),
|
||||
compression_policy,
|
||||
compression_stats: CompressionStats::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -39,22 +44,40 @@ impl Compressor {
|
||||
Arc::new(RwLock::new(self))
|
||||
}
|
||||
|
||||
pub fn add(&mut self, zstd_dict: ZstdDict) -> Result<ZstdDictArc> {
|
||||
self.check(zstd_dict.id(), zstd_dict.name())?;
|
||||
pub fn add_stat(&mut self, id: DictId, stat: CompressionStat) {
|
||||
*self.compression_stats.for_id_mut(id) = stat;
|
||||
}
|
||||
|
||||
pub fn add_zstd_dict(&mut self, zstd_dict: ZstdDict) -> Result<ZstdDictArc> {
|
||||
self.check(zstd_dict.id())?;
|
||||
let zstd_dict = Arc::new(zstd_dict);
|
||||
|
||||
self.zstd_dict_by_id
|
||||
.insert(zstd_dict.id(), zstd_dict.clone());
|
||||
|
||||
let name = zstd_dict.name();
|
||||
if let Some(by_name_list) = self.zstd_dict_by_name.get_mut(name) {
|
||||
by_name_list.push(zstd_dict.clone());
|
||||
} else {
|
||||
self.zstd_dict_by_name
|
||||
.insert(zstd_dict.name().to_string(), zstd_dict.clone());
|
||||
.insert(name.to_string(), vec![zstd_dict.clone()]);
|
||||
}
|
||||
Ok(zstd_dict)
|
||||
}
|
||||
|
||||
pub fn by_id(&self, id: ZstdDictId) -> Option<&ZstdDict> {
|
||||
self.zstd_dict_by_id.get(&id).map(|d| &**d)
|
||||
}
|
||||
|
||||
pub fn by_name(&self, name: &str) -> Option<&ZstdDict> {
|
||||
self.zstd_dict_by_name.get(name).map(|d| &**d)
|
||||
if let Some(by_name_list) = self.zstd_dict_by_name.get(name) {
|
||||
// take a random element
|
||||
Some(&by_name_list[rand::thread_rng().gen_range(0..by_name_list.len())])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn names(&self) -> impl Iterator<Item = &String> {
|
||||
self.zstd_dict_by_name.keys()
|
||||
}
|
||||
@@ -64,7 +87,7 @@ impl Compressor {
|
||||
hint: Option<&str>,
|
||||
content_type: &str,
|
||||
data: Data,
|
||||
) -> Result<(CompressionId, CompressibleData)> {
|
||||
) -> Result<(DictId, CompressibleData)> {
|
||||
let data = data.into();
|
||||
let should_compress = match self.compression_policy {
|
||||
CompressionPolicy::None => false,
|
||||
@@ -74,18 +97,15 @@ impl Compressor {
|
||||
|
||||
let generic_compress = || -> Result<_> {
|
||||
Ok((
|
||||
CompressionId::ZstdGeneric,
|
||||
zstd::stream::encode_all(data.as_ref(), 3)?.into(),
|
||||
DictId::ZstdGeneric,
|
||||
zstd::stream::encode_all(data.as_ref(), ZSTD_LEVEL)?.into(),
|
||||
))
|
||||
};
|
||||
|
||||
if should_compress {
|
||||
let (id, compressed): (_, CompressibleData) = if let Some(hint) = hint {
|
||||
if let Some(dict) = self.by_name(hint) {
|
||||
(
|
||||
CompressionId::ZstdDictId(dict.id()),
|
||||
dict.compress(&data)?.into(),
|
||||
)
|
||||
(DictId::Zstd(dict.id()), dict.compress(&data)?.into())
|
||||
} else {
|
||||
generic_compress()?
|
||||
}
|
||||
@@ -98,37 +118,42 @@ impl Compressor {
|
||||
}
|
||||
}
|
||||
|
||||
Ok((CompressionId::None, data))
|
||||
Ok((DictId::None, data))
|
||||
}
|
||||
|
||||
pub fn decompress<Data: Into<CompressibleData>>(
|
||||
&self,
|
||||
compression_id: CompressionId,
|
||||
dict_id: DictId,
|
||||
data: Data,
|
||||
) -> Result<CompressibleData> {
|
||||
let data = data.into();
|
||||
match compression_id {
|
||||
CompressionId::None => Ok(data),
|
||||
CompressionId::ZstdDictId(id) => {
|
||||
match dict_id {
|
||||
DictId::None => Ok(data),
|
||||
DictId::Zstd(id) => {
|
||||
if let Some(dict) = self.by_id(id) {
|
||||
Ok(CompressibleData::Vec(dict.decompress(data.as_ref())?))
|
||||
} else {
|
||||
Err(format!("zstd dictionary {:?} not found", id).into())
|
||||
}
|
||||
}
|
||||
CompressionId::ZstdGeneric => Ok(CompressibleData::Vec(zstd::stream::decode_all(
|
||||
DictId::ZstdGeneric => Ok(CompressibleData::Vec(zstd::stream::decode_all(
|
||||
data.as_ref(),
|
||||
)?)),
|
||||
}
|
||||
}
|
||||
|
||||
fn check(&self, id: ZstdDictId, name: &str) -> Result<()> {
|
||||
pub fn stats_mut(&mut self) -> &mut CompressionStats {
|
||||
&mut self.compression_stats
|
||||
}
|
||||
|
||||
pub fn stats(&self) -> &CompressionStats {
|
||||
&self.compression_stats
|
||||
}
|
||||
|
||||
fn check(&self, id: ZstdDictId) -> Result<()> {
|
||||
if self.zstd_dict_by_id.contains_key(&id) {
|
||||
return Err(format!("zstd dictionary id {} already exists", id.0).into());
|
||||
}
|
||||
if self.zstd_dict_by_name.contains_key(name) {
|
||||
return Err(format!("zstd dictionary name {} already exists", name).into());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -146,6 +171,8 @@ fn auto_compressible_content_type(content_type: &str) -> bool {
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
use std::collections::HashSet;
|
||||
|
||||
use rstest::rstest;
|
||||
|
||||
use super::*;
|
||||
@@ -158,7 +185,7 @@ pub mod test {
|
||||
pub fn make_compressor_with(compression_policy: CompressionPolicy) -> Compressor {
|
||||
let mut compressor = Compressor::new(compression_policy);
|
||||
let zstd_dict = make_zstd_dict(1.into(), "dict1");
|
||||
compressor.add(zstd_dict).unwrap();
|
||||
compressor.add_zstd_dict(zstd_dict).unwrap();
|
||||
compressor
|
||||
}
|
||||
|
||||
@@ -184,22 +211,22 @@ pub mod test {
|
||||
) {
|
||||
let compressor = make_compressor_with(compression_policy);
|
||||
let data = b"hello, world!".to_vec();
|
||||
let (compression_id, compressed) = compressor
|
||||
let (dict_id, compressed) = compressor
|
||||
.compress(Some("dict1"), content_type, data.clone())
|
||||
.unwrap();
|
||||
|
||||
let data_uncompressed = compressor.decompress(compression_id, compressed).unwrap();
|
||||
let data_uncompressed = compressor.decompress(dict_id, compressed).unwrap();
|
||||
assert_eq!(data_uncompressed, data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skip_compressing_small_data() {
|
||||
let compressor = make_compressor();
|
||||
let data = b"hello, world".to_vec();
|
||||
let (compression_id, compressed) = compressor
|
||||
let data = b"hi!".to_vec();
|
||||
let (dict_id, compressed) = compressor
|
||||
.compress(Some("dict1"), "text/plain", data.clone())
|
||||
.unwrap();
|
||||
assert_eq!(compression_id, CompressionId::None);
|
||||
assert_eq!(dict_id, DictId::None);
|
||||
assert_eq!(compressed, data);
|
||||
}
|
||||
|
||||
@@ -207,11 +234,25 @@ pub mod test {
|
||||
fn test_compresses_longer_data() {
|
||||
let compressor = make_compressor();
|
||||
let data = vec![b'.'; 1024];
|
||||
let (compression_id, compressed) = compressor
|
||||
let (dict_id, compressed) = compressor
|
||||
.compress(Some("dict1"), "text/plain", data.clone())
|
||||
.unwrap();
|
||||
assert_eq!(compression_id, CompressionId::ZstdDictId(1.into()));
|
||||
assert_eq!(dict_id, DictId::Zstd(1.into()));
|
||||
assert_ne!(compressed, data);
|
||||
assert!(compressed.len() < data.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_dicts_same_name() {
|
||||
let mut compressor = make_compressor();
|
||||
let zstd_dict = make_zstd_dict(2.into(), "dict1");
|
||||
compressor.add_zstd_dict(zstd_dict).unwrap();
|
||||
|
||||
let mut seen_ids = HashSet::new();
|
||||
for _ in 0..1000 {
|
||||
let zstd_dict = compressor.by_name("dict1").unwrap();
|
||||
seen_ids.insert(zstd_dict.id());
|
||||
}
|
||||
assert_eq!(seen_ids, [1, 2].iter().copied().map(ZstdDictId).collect());
|
||||
}
|
||||
}
|
||||
|
||||
196
src/main.rs
196
src/main.rs
@@ -1,4 +1,5 @@
|
||||
mod compressible_data;
|
||||
mod compression_stats;
|
||||
mod compressor;
|
||||
mod concat_lines;
|
||||
mod handlers;
|
||||
@@ -12,18 +13,22 @@ mod zstd_dict;
|
||||
|
||||
use crate::{manifest::Manifest, shards::Shards};
|
||||
use axum::{
|
||||
extract::DefaultBodyLimit,
|
||||
routing::{get, post},
|
||||
Extension, Router,
|
||||
};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use compressor::CompressorArc;
|
||||
use core::fmt;
|
||||
use futures::executor::block_on;
|
||||
use shard::Shard;
|
||||
use shards::ShardsArc;
|
||||
use std::{error::Error, path::PathBuf, sync::Arc};
|
||||
use std::{error::Error, path::PathBuf, sync::Arc, time::Instant};
|
||||
use tabled::{settings::Style, Table, Tabled};
|
||||
use tokio::{net::TcpListener, select, spawn};
|
||||
use tokio_rusqlite::Connection;
|
||||
use tracing::{debug, info};
|
||||
use tracing_subscriber::fmt::{format::Writer, time::FormatTime};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(version, about, long_about = None)]
|
||||
@@ -63,9 +68,18 @@ pub fn into_tokio_rusqlite_err<E: Into<AsyncBoxError>>(e: E) -> tokio_rusqlite::
|
||||
tokio_rusqlite::Error::Other(e.into())
|
||||
}
|
||||
|
||||
struct AppTimeFormatter(Instant);
|
||||
impl FormatTime for AppTimeFormatter {
|
||||
fn format_time(&self, w: &mut Writer<'_>) -> fmt::Result {
|
||||
let e = self.0.elapsed();
|
||||
write!(w, "{:5}.{:03}s", e.as_secs(), e.subsec_millis())
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<(), AsyncBoxError> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::DEBUG)
|
||||
.with_timer(AppTimeFormatter(Instant::now()))
|
||||
.init();
|
||||
|
||||
let args = Args::parse();
|
||||
@@ -80,13 +94,13 @@ fn main() -> Result<(), AsyncBoxError> {
|
||||
}
|
||||
|
||||
// block on opening the manifest
|
||||
let manifest = block_on(async {
|
||||
let manifest = Arc::new(block_on(async {
|
||||
Manifest::open(
|
||||
Connection::open(db_path.join("manifest.sqlite")).await?,
|
||||
num_shards,
|
||||
)
|
||||
.await
|
||||
})?;
|
||||
})?);
|
||||
|
||||
// max num_shards threads
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
@@ -102,59 +116,183 @@ fn main() -> Result<(), AsyncBoxError> {
|
||||
manifest.num_shards()
|
||||
);
|
||||
let mut shards_vec = vec![];
|
||||
let mut num_entries = vec![];
|
||||
for shard_id in 0..manifest.num_shards() {
|
||||
let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id));
|
||||
let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?;
|
||||
let shard = Shard::open(shard_id, shard_sqlite_conn, manifest.compressor()).await?;
|
||||
info!(
|
||||
"shard {} has {} entries",
|
||||
shard.id(),
|
||||
shard.num_entries().await?
|
||||
);
|
||||
num_entries.push(shard.num_entries().await?);
|
||||
shards_vec.push(shard);
|
||||
}
|
||||
debug!(
|
||||
"loaded shards with {} total entries <- {:?}",
|
||||
num_entries.iter().sum::<usize>(),
|
||||
num_entries
|
||||
);
|
||||
|
||||
let compressor = manifest.compressor();
|
||||
{
|
||||
let compressor = compressor.read().await;
|
||||
debug!(
|
||||
"loaded compression dictionaries: {:?}",
|
||||
compressor.names().collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
||||
let shards = Arc::new(Shards::new(shards_vec).ok_or("num shards must be > 0")?);
|
||||
let compressor = manifest.compressor();
|
||||
let dict_loop_handle = spawn(dict_loop(manifest, shards.clone()));
|
||||
let server_handle = spawn(server_loop(server, shards, compressor));
|
||||
dict_loop_handle.await??;
|
||||
server_handle.await??;
|
||||
info!("server closed sqlite connections. bye!");
|
||||
let join_handles = vec![
|
||||
spawn(save_compression_stats_loop(manifest.clone())),
|
||||
spawn(dict_stats_loop(manifest.compressor())),
|
||||
spawn(new_dict_loop(manifest.clone(), shards.clone())),
|
||||
spawn(server_loop(server, shards, compressor)),
|
||||
];
|
||||
for handle in join_handles {
|
||||
handle.await??;
|
||||
}
|
||||
info!("saving compressor stats...");
|
||||
manifest.save_compression_stats().await?;
|
||||
info!("done. bye!");
|
||||
Ok::<_, AsyncBoxError>(())
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn dict_loop(manifest: Manifest, shards: ShardsArc) -> Result<(), AsyncBoxError> {
|
||||
async fn save_compression_stats_loop(manifest: Arc<Manifest>) -> Result<(), AsyncBoxError> {
|
||||
loop {
|
||||
let mut hint_names = shards.hint_names().await?;
|
||||
manifest.save_compression_stats().await?;
|
||||
select! {
|
||||
_ = tokio::time::sleep(tokio::time::Duration::from_secs(60)) => {}
|
||||
_ = crate::shutdown_signal::shutdown_signal() => {
|
||||
info!("persist_dict_stats: shutdown signal received");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn dict_stats_loop(compressor: CompressorArc) -> Result<(), AsyncBoxError> {
|
||||
fn humanized_bytes_str(number: &usize) -> String {
|
||||
humansize::format_size(*number, humansize::BINARY)
|
||||
}
|
||||
fn compression_ratio_str(row: &DictStatTableRow) -> String {
|
||||
if row.uncompressed_size == 0 {
|
||||
"(n/a)".to_string()
|
||||
} else {
|
||||
format!(
|
||||
"{:.2} x",
|
||||
row.compressed_size as f64 / row.uncompressed_size as f64
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Tabled, Default)]
|
||||
struct DictStatTableRow {
|
||||
name: String,
|
||||
#[tabled(rename = "# entries")]
|
||||
num_entries: usize,
|
||||
#[tabled(
|
||||
rename = "compression ratio",
|
||||
display_with("compression_ratio_str", self)
|
||||
)]
|
||||
_ratio: (),
|
||||
#[tabled(rename = "uncompressed bytes", display_with = "humanized_bytes_str")]
|
||||
uncompressed_size: usize,
|
||||
#[tabled(rename = "compressed bytes", display_with = "humanized_bytes_str")]
|
||||
compressed_size: usize,
|
||||
}
|
||||
|
||||
loop {
|
||||
{
|
||||
let compressor = compressor.read().await;
|
||||
let stats = compressor.stats();
|
||||
|
||||
Table::new(stats.iter().map(|(id, stat)| {
|
||||
let name = match id {
|
||||
sql_types::DictId::None => "none",
|
||||
sql_types::DictId::ZstdGeneric => "zstd_generic",
|
||||
sql_types::DictId::Zstd(id) => compressor
|
||||
.by_id(id)
|
||||
.map(|d| d.name())
|
||||
.unwrap_or("(missing dict)"),
|
||||
}
|
||||
.to_string();
|
||||
|
||||
DictStatTableRow {
|
||||
name,
|
||||
num_entries: stat.num_entries,
|
||||
uncompressed_size: stat.uncompressed_size,
|
||||
compressed_size: stat.compressed_size,
|
||||
..Default::default()
|
||||
}
|
||||
}))
|
||||
.with(Style::rounded())
|
||||
.to_string()
|
||||
.lines()
|
||||
.for_each(|line| debug!("{}", line));
|
||||
}
|
||||
|
||||
select! {
|
||||
_ = tokio::time::sleep(tokio::time::Duration::from_secs(5)) => {}
|
||||
_ = crate::shutdown_signal::shutdown_signal() => {
|
||||
info!("dict_stats: shutdown signal received");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn new_dict_loop(manifest: Arc<Manifest>, shards: ShardsArc) -> Result<(), AsyncBoxError> {
|
||||
loop {
|
||||
let hint_names = shards.hint_names().await?;
|
||||
let mut new_hint_names = hint_names.clone();
|
||||
{
|
||||
// find what hint names don't have a corresponding dictionary
|
||||
let compressor = manifest.compressor();
|
||||
let compressor = compressor.read().await;
|
||||
compressor.names().for_each(|name| {
|
||||
hint_names.remove(name);
|
||||
new_hint_names.remove(name);
|
||||
});
|
||||
}
|
||||
|
||||
for hint_name in hint_names {
|
||||
debug!("creating dictionary for {}", hint_name);
|
||||
let samples = shards.samples_for_hint_name(&hint_name, 10).await?;
|
||||
let samples_bytes = samples
|
||||
.iter()
|
||||
.map(|s| s.data.as_ref())
|
||||
.collect::<Vec<&[u8]>>();
|
||||
manifest
|
||||
.insert_dict_from_samples(hint_name, samples_bytes)
|
||||
for hint_name in new_hint_names {
|
||||
let samples = shards.samples_for_hint_name(&hint_name, 100).await?;
|
||||
let num_samples = samples.len();
|
||||
let bytes_samples = samples.iter().map(|s| s.data.len()).sum::<usize>();
|
||||
let bytes_samples_human = humansize::format_size(bytes_samples, humansize::BINARY);
|
||||
if num_samples < 10 || bytes_samples < 1024 {
|
||||
debug!(
|
||||
"skipping dictionary for {} - not enough samples: ({} samples, {})",
|
||||
hint_name, num_samples, bytes_samples_human
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
debug!("building dictionary for {}", hint_name);
|
||||
let now = chrono::Utc::now();
|
||||
let ref_samples = samples.iter().map(|s| s.data.as_ref()).collect::<Vec<_>>();
|
||||
let dict = manifest
|
||||
.insert_zstd_dict_from_samples(&hint_name, ref_samples)
|
||||
.await?;
|
||||
|
||||
let duration = chrono::Utc::now() - now;
|
||||
let bytes_dict_human =
|
||||
humansize::format_size(dict.dict_bytes().len(), humansize::BINARY);
|
||||
debug!(
|
||||
"built dictionary {} ({}) in {}ms: {} samples / {} sample bytes, {} dict size",
|
||||
dict.id(),
|
||||
hint_name,
|
||||
duration.num_milliseconds(),
|
||||
num_samples,
|
||||
bytes_samples_human,
|
||||
bytes_dict_human,
|
||||
);
|
||||
}
|
||||
|
||||
select! {
|
||||
_ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => {}
|
||||
_ = crate::shutdown_signal::shutdown_signal() => {
|
||||
info!("dict loop: shutdown signal received");
|
||||
info!("new_dict_loop: shutdown signal received");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
@@ -171,7 +309,9 @@ async fn server_loop(
|
||||
.route("/get/:sha256", get(handlers::get_handler::get_handler))
|
||||
.route("/info", get(handlers::info_handler::info_handler))
|
||||
.layer(Extension(shards))
|
||||
.layer(Extension(compressor));
|
||||
.layer(Extension(compressor))
|
||||
// limit the max size of a request body to 100mb
|
||||
.layer(DefaultBodyLimit::max(1024 * 1024 * 100));
|
||||
|
||||
axum::serve(server, app.into_make_service())
|
||||
.with_graceful_shutdown(crate::shutdown_signal::shutdown_signal())
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
mod manifest_key;
|
||||
|
||||
use std::{error::Error, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
|
||||
use rusqlite::params;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_rusqlite::Connection;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::{
|
||||
compression_stats::{CompressionStat, CompressionStats},
|
||||
compressor::{Compressor, CompressorArc},
|
||||
concat_lines,
|
||||
sql_types::ZstdDictId,
|
||||
zstd_dict::{ZstdDict, ZstdDictArc, ZstdEncoder},
|
||||
concat_lines, into_tokio_rusqlite_err,
|
||||
sql_types::{DictId, ZstdDictId},
|
||||
zstd_dict::{ZstdDict, ZstdDictArc, ZstdEncoder, ZSTD_LEVEL},
|
||||
AsyncBoxError,
|
||||
};
|
||||
|
||||
@@ -35,38 +37,46 @@ impl Manifest {
|
||||
self.compressor.clone()
|
||||
}
|
||||
|
||||
pub async fn insert_dict_from_samples<Str: Into<String>>(
|
||||
pub async fn save_compression_stats(&self) -> Result<(), AsyncBoxError> {
|
||||
let compressor = self.compressor();
|
||||
self.conn
|
||||
.call(move |conn| {
|
||||
let compressor = compressor.blocking_read();
|
||||
let compression_stats = compressor.stats();
|
||||
save_compression_stats(conn, compression_stats).map_err(into_tokio_rusqlite_err)?;
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn insert_zstd_dict_from_samples<Str: Into<String>>(
|
||||
&self,
|
||||
name: Str,
|
||||
samples: Vec<&[u8]>,
|
||||
) -> Result<ZstdDictArc, AsyncBoxError> {
|
||||
let name = name.into();
|
||||
let encoder = ZstdEncoder::from_samples(3, samples);
|
||||
let encoder = ZstdEncoder::from_samples(ZSTD_LEVEL, samples);
|
||||
let zstd_dict = self
|
||||
.conn
|
||||
.call(move |conn| {
|
||||
let level = 3;
|
||||
let mut stmt = conn.prepare(concat_lines!(
|
||||
"INSERT INTO dictionaries (name, level, dict)",
|
||||
"VALUES (?, ?, ?)",
|
||||
"RETURNING id"
|
||||
"INSERT INTO zstd_dictionaries (name, encoder_json)",
|
||||
"VALUES (?, ?)",
|
||||
"RETURNING zstd_dict_id"
|
||||
))?;
|
||||
let dict_id =
|
||||
stmt.query_row(params![name, level, encoder.dict_bytes()], |row| row.get(0))?;
|
||||
Ok(ZstdDict::new(dict_id, name, encoder))
|
||||
let zstd_dict_id = stmt.query_row(params![name, encoder], |row| row.get(0))?;
|
||||
Ok(ZstdDict::new(zstd_dict_id, name, encoder))
|
||||
})
|
||||
.await?;
|
||||
let mut compressor = self.compressor.write().await;
|
||||
compressor.add(zstd_dict)
|
||||
compressor.add_stat(DictId::Zstd(zstd_dict.id()), CompressionStat::default());
|
||||
compressor.add_zstd_dict(zstd_dict)
|
||||
}
|
||||
}
|
||||
|
||||
async fn initialize(
|
||||
conn: Connection,
|
||||
num_shards: Option<usize>,
|
||||
) -> Result<Manifest, Box<dyn Error + Send + Sync>> {
|
||||
let stored_num_shards: Option<usize> = conn
|
||||
.call(|conn| {
|
||||
async fn migrate_manifest(conn: &Connection) -> Result<(), tokio_rusqlite::Error> {
|
||||
conn.call(|conn| {
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS manifest (key TEXT PRIMARY KEY, value)",
|
||||
[],
|
||||
@@ -74,21 +84,50 @@ async fn initialize(
|
||||
|
||||
conn.execute(
|
||||
concat_lines!(
|
||||
"CREATE TABLE IF NOT EXISTS dictionaries (",
|
||||
" id INTEGER PRIMARY KEY AUTOINCREMENT,",
|
||||
" level INTEGER NOT NULL,",
|
||||
"CREATE TABLE IF NOT EXISTS zstd_dictionaries (",
|
||||
" zstd_dict_id INTEGER PRIMARY KEY AUTOINCREMENT,",
|
||||
" name TEXT NOT NULL,",
|
||||
" dict BLOB NOT NULL",
|
||||
" encoder_json TEXT NOT NULL",
|
||||
")"
|
||||
),
|
||||
[],
|
||||
)?;
|
||||
|
||||
Ok(get_manifest_key(conn, NumShards)?)
|
||||
conn.execute(
|
||||
concat_lines!(
|
||||
"CREATE TABLE IF NOT EXISTS compression_stats (",
|
||||
" dict_id TEXT PRIMARY KEY,",
|
||||
" num_entries INTEGER NOT NULL DEFAULT 0,",
|
||||
" uncompressed_size INTEGER NOT NULL DEFAULT 0,",
|
||||
" compressed_size INTEGER NOT NULL DEFAULT 0",
|
||||
")"
|
||||
),
|
||||
[],
|
||||
)?;
|
||||
|
||||
// insert the default dictionaries (none, zstd_generic)
|
||||
conn.execute(
|
||||
concat_lines!(
|
||||
"INSERT OR IGNORE INTO compression_stats (dict_id)",
|
||||
"VALUES (?), (?)"
|
||||
),
|
||||
params![DictId::None, DictId::ZstdGeneric],
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn load_and_store_num_shards(
|
||||
conn: &Connection,
|
||||
requested_num_shards: Option<usize>,
|
||||
) -> Result<usize, AsyncBoxError> {
|
||||
let stored_num_shards: Option<usize> = conn
|
||||
.call(|conn| Ok(get_manifest_key(conn, NumShards)?))
|
||||
.await?;
|
||||
|
||||
let num_shards = match (stored_num_shards, num_shards) {
|
||||
Ok(match (stored_num_shards, requested_num_shards) {
|
||||
(Some(stored_num_shards), Some(num_shards)) => {
|
||||
if stored_num_shards == num_shards {
|
||||
num_shards
|
||||
@@ -115,29 +154,113 @@ async fn initialize(
|
||||
// existing database, use loaded num_shards
|
||||
num_shards
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
|
||||
type DictRow = (ZstdDictId, String, i32, Vec<u8>);
|
||||
let rows: Vec<DictRow> = conn
|
||||
async fn load_zstd_dicts(
|
||||
conn: &Connection,
|
||||
compressor: &mut Compressor,
|
||||
) -> Result<(), AsyncBoxError> {
|
||||
type ZstdDictRow = (ZstdDictId, String, ZstdEncoder);
|
||||
let rows: Vec<ZstdDictRow> = conn
|
||||
.call(|conn| {
|
||||
let mut stmt = conn.prepare("SELECT id, name, level, dict FROM dictionaries")?;
|
||||
let mut rows = vec![];
|
||||
for r in stmt.query_map([], |row| DictRow::try_from(row))? {
|
||||
rows.push(r?);
|
||||
}
|
||||
let rows = conn
|
||||
.prepare(concat_lines!(
|
||||
"SELECT",
|
||||
" zstd_dict_id, name, encoder_json",
|
||||
"FROM zstd_dictionaries",
|
||||
))?
|
||||
.query_map([], |row| ZstdDictRow::try_from(row))?
|
||||
.collect::<Result<_, _>>()?;
|
||||
Ok(rows)
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut compressor = Compressor::default();
|
||||
for (dict_id, name, level, dict_bytes) in rows {
|
||||
compressor.add(ZstdDict::new(
|
||||
dict_id,
|
||||
name,
|
||||
ZstdEncoder::from_dict_bytes(level, dict_bytes),
|
||||
))?;
|
||||
debug!("loaded {} zstd dictionaries from manifest", rows.len());
|
||||
for (zstd_dict_id, name, encoder) in rows {
|
||||
compressor.add_zstd_dict(ZstdDict::new(zstd_dict_id, name, encoder))?;
|
||||
}
|
||||
let compressor = compressor.into_arc();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_compression_stats(
|
||||
conn: &Connection,
|
||||
compressor: &mut Compressor,
|
||||
) -> Result<(), AsyncBoxError> {
|
||||
type CompressionStatRow = (DictId, usize, usize, usize);
|
||||
let rows: Vec<CompressionStatRow> = conn
|
||||
.call(|conn| {
|
||||
let rows = conn
|
||||
.prepare(concat_lines!(
|
||||
"SELECT",
|
||||
" dict_id, num_entries, uncompressed_size, compressed_size",
|
||||
"FROM compression_stats",
|
||||
))?
|
||||
.query_map([], |row| CompressionStatRow::try_from(row))?
|
||||
.collect::<Result<_, _>>()?;
|
||||
Ok(rows)
|
||||
})
|
||||
.await?;
|
||||
|
||||
debug!("loaded {} compression stats from manifest", rows.len());
|
||||
for (dict_id, num_entries, uncompressed_size, compressed_size) in rows {
|
||||
compressor.add_stat(
|
||||
dict_id,
|
||||
CompressionStat {
|
||||
num_entries,
|
||||
uncompressed_size,
|
||||
compressed_size,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn save_compression_stats(
|
||||
conn: &rusqlite::Connection,
|
||||
compression_stats: &CompressionStats,
|
||||
) -> Result<(), AsyncBoxError> {
|
||||
let mut num_stats = 0;
|
||||
for (id, stat) in compression_stats.iter() {
|
||||
num_stats += 1;
|
||||
conn.execute(
|
||||
concat_lines!(
|
||||
"INSERT INTO compression_stats",
|
||||
" (dict_id, num_entries, uncompressed_size, compressed_size)",
|
||||
"VALUES (?, ?, ?, ?)",
|
||||
"ON CONFLICT (dict_id) DO UPDATE SET",
|
||||
" num_entries = excluded.num_entries,",
|
||||
" uncompressed_size = excluded.uncompressed_size,",
|
||||
" compressed_size = excluded.compressed_size"
|
||||
),
|
||||
params![
|
||||
id,
|
||||
stat.num_entries,
|
||||
stat.uncompressed_size,
|
||||
stat.compressed_size
|
||||
],
|
||||
)?;
|
||||
}
|
||||
info!("saved {} compressor stats", num_stats);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_compressor(conn: &Connection) -> Result<CompressorArc, AsyncBoxError> {
|
||||
let mut compressor = Compressor::default();
|
||||
load_zstd_dicts(conn, &mut compressor).await?;
|
||||
load_compression_stats(conn, &mut compressor).await?;
|
||||
Ok(compressor.into_arc())
|
||||
}
|
||||
|
||||
async fn initialize(
|
||||
conn: Connection,
|
||||
requested_num_shards: Option<usize>,
|
||||
) -> Result<Manifest, AsyncBoxError> {
|
||||
migrate_manifest(&conn).await?;
|
||||
let num_shards = load_and_store_num_shards(&conn, requested_num_shards).await?;
|
||||
let compressor = load_compressor(&conn).await?;
|
||||
Ok(Manifest {
|
||||
conn,
|
||||
num_shards,
|
||||
@@ -149,14 +272,23 @@ async fn initialize(
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_manifest_compression_stats_loading() {
|
||||
let conn = Connection::open_in_memory().await.unwrap();
|
||||
let manifest = initialize(conn, Some(4)).await.unwrap();
|
||||
let compressor = manifest.compressor();
|
||||
let compressor = compressor.read().await;
|
||||
assert_eq!(compressor.stats().iter().count(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_manifest() {
|
||||
let conn = Connection::open_in_memory().await.unwrap();
|
||||
let manifest = initialize(conn, Some(3)).await.unwrap();
|
||||
let manifest = initialize(conn, Some(4)).await.unwrap();
|
||||
|
||||
let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100];
|
||||
let zstd_dict = manifest
|
||||
.insert_dict_from_samples("test", samples)
|
||||
.insert_zstd_dict_from_samples("test", samples)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
use crate::{
|
||||
compressible_data::CompressibleData,
|
||||
into_tokio_rusqlite_err,
|
||||
sql_types::{CompressionId, UtcDateTime},
|
||||
sql_types::{DictId, UtcDateTime},
|
||||
AsyncBoxError,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
pub struct GetArgs {
|
||||
pub sha256: Sha256,
|
||||
}
|
||||
@@ -23,18 +21,12 @@ impl Shard {
|
||||
pub async fn get(&self, args: GetArgs) -> Result<Option<GetResult>, AsyncBoxError> {
|
||||
let maybe_row = self
|
||||
.conn
|
||||
.call(move |conn| {
|
||||
get_compressed_row(conn, &args.sha256).map_err(into_tokio_rusqlite_err)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("get failed: {}", e);
|
||||
Box::new(e)
|
||||
})?;
|
||||
.call(move |conn| Ok(get_compressed_row(conn, &args.sha256)?))
|
||||
.await?;
|
||||
|
||||
if let Some((content_type, stored_size, created_at, compression_id, data)) = maybe_row {
|
||||
if let Some((content_type, stored_size, created_at, dict_id, data)) = maybe_row {
|
||||
let compressor = self.compressor.read().await;
|
||||
let data = compressor.decompress(compression_id, data)?;
|
||||
let data = compressor.decompress(dict_id, data)?;
|
||||
Ok(Some(GetResult {
|
||||
sha256: args.sha256,
|
||||
content_type,
|
||||
@@ -48,13 +40,13 @@ impl Shard {
|
||||
}
|
||||
}
|
||||
|
||||
type CompressedRowResult = (String, usize, UtcDateTime, CompressionId, Vec<u8>);
|
||||
type CompressedRowResult = (String, usize, UtcDateTime, DictId, Vec<u8>);
|
||||
fn get_compressed_row(
|
||||
conn: &mut rusqlite::Connection,
|
||||
sha256: &Sha256,
|
||||
) -> Result<Option<CompressedRowResult>, rusqlite::Error> {
|
||||
conn.query_row(
|
||||
"SELECT content_type, compressed_size, created_at, compression_id, data
|
||||
"SELECT content_type, compressed_size, created_at, dict_id, data
|
||||
FROM entries
|
||||
WHERE sha256 = ?",
|
||||
params![sha256],
|
||||
|
||||
@@ -11,12 +11,7 @@ impl Shard {
|
||||
ensure_schema_versions_table(conn)?;
|
||||
let schema_rows = load_schema_rows(conn)?;
|
||||
|
||||
if let Some((version, date_time)) = schema_rows.first() {
|
||||
debug!(
|
||||
"shard {}: latest schema version: {} @ {}",
|
||||
shard_id, version, date_time
|
||||
);
|
||||
|
||||
if let Some((version, _)) = schema_rows.first() {
|
||||
if *version == 1 {
|
||||
// no-op
|
||||
} else {
|
||||
@@ -70,7 +65,7 @@ fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Err
|
||||
" id INTEGER PRIMARY KEY AUTOINCREMENT,",
|
||||
" sha256 BLOB NOT NULL,",
|
||||
" content_type TEXT NOT NULL,",
|
||||
" compression_id INTEGER NOT NULL,",
|
||||
" dict_id TEXT NOT NULL,",
|
||||
" uncompressed_size INTEGER NOT NULL,",
|
||||
" compressed_size INTEGER NOT NULL,",
|
||||
" data BLOB NOT NULL,",
|
||||
@@ -101,7 +96,7 @@ fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Err
|
||||
|
||||
conn.execute(
|
||||
concat_lines!(
|
||||
"CREATE INDEX IF NOT EXISTS compression_hints_name_idx",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS compression_hints_name_idx",
|
||||
"ON compression_hints (name, ordering)",
|
||||
),
|
||||
[],
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::{
|
||||
compressible_data::CompressibleData,
|
||||
into_tokio_rusqlite_err,
|
||||
sql_types::{CompressionId, UtcDateTime},
|
||||
concat_lines, into_tokio_rusqlite_err,
|
||||
sql_types::{DictId, EntryId, UtcDateTime},
|
||||
AsyncBoxError,
|
||||
};
|
||||
|
||||
@@ -50,18 +50,25 @@ impl Shard {
|
||||
|
||||
let uncompressed_size = data.len();
|
||||
|
||||
let (compression_id, data) = {
|
||||
let (dict_id, data) = {
|
||||
let compressor = self.compressor.read().await;
|
||||
compressor.compress(compression_hint.as_deref(), &content_type, data)?
|
||||
};
|
||||
|
||||
{
|
||||
let mut compressor = self.compressor.write().await;
|
||||
compressor
|
||||
.stats_mut()
|
||||
.add_entry(dict_id, uncompressed_size, data.len());
|
||||
}
|
||||
|
||||
self.conn
|
||||
.call(move |conn| {
|
||||
insert(
|
||||
conn,
|
||||
&sha256,
|
||||
content_type,
|
||||
compression_id,
|
||||
dict_id,
|
||||
uncompressed_size,
|
||||
data,
|
||||
compression_hint,
|
||||
@@ -78,7 +85,11 @@ fn find_with_sha256(
|
||||
sha256: &Sha256,
|
||||
) -> Result<Option<StoreResult>, rusqlite::Error> {
|
||||
conn.query_row(
|
||||
"SELECT uncompressed_size, compressed_size, created_at FROM entries WHERE sha256 = ?",
|
||||
concat_lines!(
|
||||
"SELECT uncompressed_size, compressed_size, created_at",
|
||||
"FROM entries",
|
||||
"WHERE sha256 = ?"
|
||||
),
|
||||
params![sha256],
|
||||
|row| {
|
||||
Ok(StoreResult::Exists {
|
||||
@@ -95,7 +106,7 @@ fn insert(
|
||||
conn: &mut rusqlite::Connection,
|
||||
sha256: &Sha256,
|
||||
content_type: String,
|
||||
compression_id: CompressionId,
|
||||
dict_id: DictId,
|
||||
uncompressed_size: usize,
|
||||
data: CompressibleData,
|
||||
compression_hint: Option<String>,
|
||||
@@ -103,33 +114,37 @@ fn insert(
|
||||
let created_at = UtcDateTime::now();
|
||||
let compressed_size = data.len();
|
||||
|
||||
let entry_id: i64 = conn.query_row(
|
||||
"INSERT INTO entries
|
||||
(sha256, content_type, compression_id, uncompressed_size, compressed_size, data, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
RETURNING id
|
||||
",
|
||||
let tx = conn.transaction()?;
|
||||
let entry_id: EntryId = tx.query_row(
|
||||
concat_lines!(
|
||||
"INSERT INTO entries",
|
||||
" (sha256, content_type, dict_id, uncompressed_size, compressed_size, data, created_at)",
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
"RETURNING id"
|
||||
),
|
||||
params![
|
||||
sha256,
|
||||
content_type,
|
||||
compression_id,
|
||||
dict_id,
|
||||
uncompressed_size,
|
||||
compressed_size,
|
||||
data.as_ref(),
|
||||
created_at,
|
||||
],
|
||||
|row| row.get(0)
|
||||
|row| row.get(0),
|
||||
)?;
|
||||
|
||||
if let Some(compression_hint) = compression_hint {
|
||||
let rand_ordering = rand::random::<i64>();
|
||||
conn.execute(
|
||||
"INSERT INTO compression_hints
|
||||
(name, ordering, entry_id)
|
||||
VALUES (?, ?, ?)",
|
||||
tx.execute(
|
||||
concat_lines!(
|
||||
"INSERT INTO compression_hints (name, ordering, entry_id)",
|
||||
"VALUES (?, ?, ?)"
|
||||
),
|
||||
params![compression_hint, rand_ordering, entry_id],
|
||||
)?;
|
||||
}
|
||||
tx.commit()?;
|
||||
|
||||
Ok(StoreResult::Created {
|
||||
stored_size: compressed_size,
|
||||
|
||||
@@ -21,4 +21,4 @@ use axum::body::Bytes;
|
||||
use rusqlite::{params, types::FromSql, OptionalExtension};
|
||||
|
||||
use tokio_rusqlite::Connection;
|
||||
use tracing::{debug, error};
|
||||
use tracing::debug;
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
use rusqlite::{
|
||||
types::{FromSql, FromSqlError, FromSqlResult, ToSqlOutput, Value::Integer, ValueRef},
|
||||
Error::ToSqlConversionFailure,
|
||||
ToSql,
|
||||
};
|
||||
|
||||
use crate::AsyncBoxError;
|
||||
|
||||
use super::ZstdDictId;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
|
||||
pub enum CompressionId {
|
||||
None,
|
||||
ZstdGeneric,
|
||||
ZstdDictId(ZstdDictId),
|
||||
}
|
||||
|
||||
impl FromSql for CompressionId {
|
||||
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
|
||||
Ok(match value.as_i64()? {
|
||||
-1 => CompressionId::None,
|
||||
-2 => CompressionId::ZstdGeneric,
|
||||
id => CompressionId::ZstdDictId(ZstdDictId(id)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for CompressionId {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
|
||||
let value = match self {
|
||||
CompressionId::None => -1,
|
||||
CompressionId::ZstdGeneric => -2,
|
||||
CompressionId::ZstdDictId(ZstdDictId(id)) => *id,
|
||||
};
|
||||
Ok(ToSqlOutput::Owned(Integer(value)))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql for ZstdDictId {
|
||||
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
|
||||
match value.as_i64()? {
|
||||
id @ (-1 | -2) => Err(FromSqlError::Other(invalid_zstd_dict_id_err(id))),
|
||||
id => Ok(ZstdDictId(id)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for ZstdDictId {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
|
||||
let value = match self.0 {
|
||||
id @ (-1 | -2) => return Err(ToSqlConversionFailure(invalid_zstd_dict_id_err(id))),
|
||||
id => id,
|
||||
};
|
||||
|
||||
Ok(ToSqlOutput::Owned(Integer(value)))
|
||||
}
|
||||
}
|
||||
|
||||
fn invalid_zstd_dict_id_err(id: i64) -> AsyncBoxError {
|
||||
format!("Invalid ZstdDictId: {}", id).into()
|
||||
}
|
||||
75
src/sql_types/dict_id.rs
Normal file
75
src/sql_types/dict_id.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
use rusqlite::{
|
||||
types::{FromSql, FromSqlError, FromSqlResult, ToSqlOutput, ValueRef},
|
||||
ToSql,
|
||||
};
|
||||
|
||||
use super::ZstdDictId;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, PartialOrd, Ord)]
|
||||
pub enum DictId {
|
||||
None,
|
||||
ZstdGeneric,
|
||||
Zstd(ZstdDictId),
|
||||
}
|
||||
|
||||
impl DictId {
|
||||
fn name_prefix(&self) -> &str {
|
||||
match self {
|
||||
DictId::None => "none",
|
||||
DictId::ZstdGeneric => "zstd_generic",
|
||||
DictId::Zstd(_) => "zstd",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for DictId {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
|
||||
let prefix = self.name_prefix();
|
||||
if let DictId::Zstd(id) = self {
|
||||
Ok(ToSqlOutput::from(format!("{}:{}", prefix, id.0)))
|
||||
} else {
|
||||
prefix.to_sql()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql for DictId {
|
||||
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
|
||||
let s = value.as_str()?;
|
||||
let dict_id = if s == "none" {
|
||||
DictId::None
|
||||
} else if s == "zstd_generic" {
|
||||
DictId::ZstdGeneric
|
||||
} else if let Some(id_str) = s.strip_prefix("zstd:") {
|
||||
let id = id_str
|
||||
.parse()
|
||||
.map_err(|e| FromSqlError::Other(format!("invalid ZstdDictId: {}", e).into()))?;
|
||||
DictId::Zstd(ZstdDictId::column_result(ValueRef::Integer(id))?)
|
||||
} else {
|
||||
return Err(FromSqlError::Other(format!("invalid DictId: {}", s).into()));
|
||||
};
|
||||
Ok(dict_id)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rusqlite::types::Value;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_dict_id() {
|
||||
let id = DictId::Zstd(ZstdDictId(42));
|
||||
let id_str = "zstd:42";
|
||||
|
||||
let sql_to_id = DictId::column_result(ValueRef::Text(id_str.as_bytes()));
|
||||
assert_eq!(id, sql_to_id.unwrap());
|
||||
|
||||
let id_to_sql = match id.to_sql() {
|
||||
Ok(ToSqlOutput::Owned(Value::Text(text))) => text,
|
||||
_ => panic!("unexpected ToSqlOutput: {:?}", id.to_sql()),
|
||||
};
|
||||
assert_eq!(id_str, id_to_sql);
|
||||
}
|
||||
}
|
||||
24
src/sql_types/entry_id.rs
Normal file
24
src/sql_types/entry_id.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
use rusqlite::{
|
||||
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
|
||||
ToSql,
|
||||
};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
|
||||
pub struct EntryId(pub i64);
|
||||
impl From<i64> for EntryId {
|
||||
fn from(id: i64) -> Self {
|
||||
Self(id)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql for EntryId {
|
||||
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
|
||||
Ok(value.as_i64()?.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for EntryId {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
|
||||
self.0.to_sql()
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
mod compression_id;
|
||||
mod dict_id;
|
||||
mod entry_id;
|
||||
mod utc_date_time;
|
||||
mod zstd_dict_id;
|
||||
|
||||
pub use compression_id::CompressionId;
|
||||
pub use dict_id::DictId;
|
||||
pub use entry_id::EntryId;
|
||||
pub use utc_date_time::UtcDateTime;
|
||||
pub use zstd_dict_id::ZstdDictId;
|
||||
|
||||
@@ -1,7 +1,32 @@
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
|
||||
use std::fmt::Display;
|
||||
|
||||
use rusqlite::{
|
||||
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
|
||||
ToSql,
|
||||
};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, PartialOrd, Ord)]
|
||||
pub struct ZstdDictId(pub i64);
|
||||
impl From<i64> for ZstdDictId {
|
||||
fn from(id: i64) -> Self {
|
||||
Self(id)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ZstdDictId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "zstd_dict:{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql for ZstdDictId {
|
||||
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
|
||||
Ok(ZstdDictId(value.as_i64()?))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for ZstdDictId {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
|
||||
self.0.to_sql()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
use ouroboros::self_referencing;
|
||||
use rusqlite::{
|
||||
types::{FromSql, FromSqlError, FromSqlResult, ToSqlOutput, Value, ValueRef},
|
||||
Error::ToSqlConversionFailure,
|
||||
ToSql,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{io, sync::Arc};
|
||||
use zstd::dict::{DecoderDictionary, EncoderDictionary};
|
||||
|
||||
use crate::{sql_types::ZstdDictId, AsyncBoxError};
|
||||
|
||||
const ENCODER_DICT_SIZE: usize = 2 * 1024 * 1024;
|
||||
pub const ZSTD_LEVEL: i32 = 9;
|
||||
|
||||
pub type ZstdDictArc = Arc<ZstdDict>;
|
||||
|
||||
#[self_referencing]
|
||||
@@ -20,9 +29,62 @@ pub struct ZstdEncoder {
|
||||
decoder_dict: DecoderDictionary<'this>,
|
||||
}
|
||||
|
||||
impl Serialize for ZstdEncoder {
|
||||
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
use serde::ser::SerializeTuple;
|
||||
|
||||
// serialize level and dict_bytes as a tuple
|
||||
let mut state = serializer.serialize_tuple(2)?;
|
||||
state.serialize_element(&self.level())?;
|
||||
state.serialize_element(self.dict_bytes())?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for ZstdEncoder {
|
||||
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
|
||||
use serde::de::SeqAccess;
|
||||
|
||||
struct ZstdEncoderVisitor;
|
||||
impl<'de> serde::de::Visitor<'de> for ZstdEncoderVisitor {
|
||||
type Value = ZstdEncoder;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("a tuple of (i32, Vec<u8>)")
|
||||
}
|
||||
|
||||
fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
|
||||
let level = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
|
||||
let dict_bytes = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
|
||||
Ok(ZstdEncoder::from_dict_bytes(level, dict_bytes))
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_tuple(2, ZstdEncoderVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for ZstdEncoder {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
|
||||
let json = serde_json::to_string(self).map_err(|e| ToSqlConversionFailure(Box::new(e)))?;
|
||||
Ok(ToSqlOutput::Owned(Value::Text(json)))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql for ZstdEncoder {
|
||||
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
|
||||
let json = value.as_str()?;
|
||||
serde_json::from_str(json).map_err(|e| FromSqlError::Other(e.into()))
|
||||
}
|
||||
}
|
||||
|
||||
impl ZstdEncoder {
|
||||
pub fn from_samples(level: i32, samples: Vec<&[u8]>) -> Self {
|
||||
let dict_bytes = zstd::dict::from_samples(&samples, 1024 * 1024).unwrap();
|
||||
let dict_bytes = zstd::dict::from_samples(&samples, ENCODER_DICT_SIZE).unwrap();
|
||||
Self::from_dict_bytes(level, dict_bytes)
|
||||
}
|
||||
|
||||
@@ -36,6 +98,9 @@ impl ZstdEncoder {
|
||||
.build()
|
||||
}
|
||||
|
||||
pub fn level(&self) -> i32 {
|
||||
*self.borrow_level()
|
||||
}
|
||||
pub fn dict_bytes(&self) -> &[u8] {
|
||||
self.borrow_dict_bytes()
|
||||
}
|
||||
@@ -128,7 +193,7 @@ pub mod test {
|
||||
id,
|
||||
name.to_owned(),
|
||||
ZstdEncoder::from_samples(
|
||||
3,
|
||||
5,
|
||||
vec![
|
||||
"hello, world",
|
||||
"this is a test",
|
||||
|
||||
Reference in New Issue
Block a user