diff --git a/Cargo.lock b/Cargo.lock index 9ca36bf..b273425 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -281,6 +281,7 @@ dependencies = [ "clap", "futures", "hex", + "humansize", "kdam", "ouroboros", "rand", @@ -290,10 +291,12 @@ dependencies = [ "serde", "serde_json", "sha2", + "tabled", "tokio", "tokio-rusqlite", "tracing", "tracing-subscriber", + "walkdir", "zstd", ] @@ -312,6 +315,12 @@ version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +[[package]] +name = "bytecount" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce" + [[package]] name = "bytes" version = "1.6.0" @@ -805,6 +814,15 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "humansize" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6cb51c9a029ddc91b07a787f1d86b53ccfa49b0e86688c946ebe8d3555685dd7" +dependencies = [ + "libm", +] + [[package]] name = "hyper" version = "1.3.1" @@ -972,6 +990,12 @@ version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + [[package]] name = "libsqlite3-sys" version = "0.28.0" @@ -1207,6 +1231,17 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "papergrid" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ad43c07024ef767f9160710b3a6773976194758c7919b17e63b863db0bdf7fb" +dependencies = [ + "bytecount", + "fnv", + "unicode-width", +] + [[package]] name = "parking_lot" version = "0.12.1" @@ -1552,6 +1587,15 @@ version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.23" @@ -1734,6 +1778,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ "proc-macro2", + "quote", "unicode-ident", ] @@ -1781,6 +1826,30 @@ dependencies = [ "libc", ] +[[package]] +name = "tabled" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c998b0c8b921495196a48aabaf1901ff28be0760136e31604f7967b0792050e" +dependencies = [ + "papergrid", + "tabled_derive", + "unicode-width", +] + +[[package]] +name = "tabled_derive" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c138f99377e5d653a371cdad263615634cfc8467685dfe8e73e2b8e98f44b17" +dependencies = [ + "heck 0.4.1", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "tempfile" version = "3.10.1" @@ -2047,6 +2116,12 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-width" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6" + [[package]] name = "url" version = "2.5.0" @@ -2088,6 +2163,16 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -2195,6 +2280,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" diff --git a/Cargo.toml b/Cargo.toml index 4aaa1c8..a844892 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,10 @@ path = "src/main.rs" name = "load-test" path = "load_test/main.rs" +[[bin]] +name = "fixture-inserter" +path = "fixture_inserter/main.rs" + [dependencies] axum = { version = "0.7.5", features = ["macros"] } axum_typed_multipart = "0.11.1" @@ -32,6 +36,9 @@ reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] } hex = "0.4.3" zstd = { version = "0.13.1", features = ["experimental"] } ouroboros = "0.18.3" +humansize = "2.1.3" +walkdir = "2.5.0" +tabled = "0.15.0" [dev-dependencies] rstest = "0.19.0" diff --git a/fixture_inserter/main.rs b/fixture_inserter/main.rs new file mode 100644 index 0000000..64d6052 --- /dev/null +++ b/fixture_inserter/main.rs @@ -0,0 +1,107 @@ +use std::{ + error::Error, + sync::{Arc, Mutex}, +}; + +use clap::{arg, Parser}; +use kdam::BarExt; +use kdam::{tqdm, Bar}; +use rand::prelude::SliceRandom; +use reqwest::blocking::{multipart, Client}; +use walkdir::WalkDir; + +#[derive(Parser, Debug, Clone)] +struct Args { + #[arg(long)] + hint_name: String, + #[arg(long)] + fixture_dir: String, + #[arg(long)] + limit: Option, + #[arg(long)] + num_threads: Option, +} + +fn main() -> Result<(), Box> { + let args = Args::parse(); + let hint_name = args.hint_name; + let num_threads = args.num_threads.unwrap_or(1); + let pb = Arc::new(Mutex::new(tqdm!())); + + let mut entries = WalkDir::new(&args.fixture_dir) + .into_iter() + .filter_map(Result::ok) + .filter(|entry| entry.file_type().is_file()) + .filter_map(|entry| { + let path = entry.path(); + let name = path.to_str()?.to_string(); + Some(name) + }) + .collect::>(); + + // shuffle the entries + entries.shuffle(&mut rand::thread_rng()); + + // if there's a limit, drop the rest + if let Some(limit) = args.limit { + entries.truncate(limit); + } + + let entry_slices = entries.chunks(entries.len() / num_threads); + + let join_handles = entry_slices.map(|entry_slice| { + let hint_name = hint_name.clone(); + let pb = pb.clone(); + let entry_slice = entry_slice.to_vec(); + std::thread::spawn(move || { + let client = Client::new(); + for entry in entry_slice.into_iter() { + store_file(&client, &hint_name, &entry, pb.clone()).unwrap(); + } + }) + }); + + for join_handle in join_handles { + join_handle.join().unwrap(); + } + Ok(()) +} + +fn store_file( + client: &Client, + hint_name: &str, + file_name: &str, + pb: Arc>, +) -> Result<(), Box> { + let file_bytes = std::fs::read(file_name)?; + + let content_type = if file_name.ends_with(".html") { + "text/html" + } else if file_name.ends_with(".json") { + "application/json" + } else if file_name.ends_with(".jpg") || file_name.ends_with(".jpeg") { + "image/jpeg" + } else if file_name.ends_with(".png") { + "image/png" + } else if file_name.ends_with(".txt") { + "text/plain" + } else if file_name.ends_with(".kindle.images") { + "application/octet-stream" + } else { + "text/html" + }; + + let form = multipart::Form::new() + .text("content_type", content_type) + .text("compression_hint", hint_name.to_string()) + .part("data", multipart::Part::bytes(file_bytes)); + + let _ = client + .post("http://localhost:7692/store") + .multipart(form) + .send(); + + let mut pb = pb.lock().unwrap(); + pb.update(1)?; + Ok(()) +} diff --git a/load_test/main.rs b/load_test/main.rs index 66b76b5..33c983a 100644 --- a/load_test/main.rs +++ b/load_test/main.rs @@ -43,14 +43,23 @@ fn run_loop(pb: Arc>, args: Args) -> Result<(), Box, + id_ordering: Vec, +} + +impl CompressionStats { + pub fn for_id_mut(&mut self, dict_id: DictId) -> &mut CompressionStat { + self.id_to_stat_map.entry(dict_id).or_insert_with_key(|_| { + self.id_ordering.push(dict_id); + self.id_ordering.sort(); + Default::default() + }) + } + + pub fn add_entry( + &mut self, + dict_id: DictId, + uncompressed_size: usize, + compressed_size: usize, + ) -> &CompressionStat { + let stat = self.for_id_mut(dict_id); + stat.num_entries += 1; + stat.uncompressed_size += uncompressed_size; + stat.compressed_size += compressed_size; + stat + } + + pub fn iter(&self) -> impl Iterator { + self.id_ordering + .iter() + .flat_map(|id| self.id_to_stat_map.get(id).map(|stat| (*id, stat))) + } +} diff --git a/src/compressor.rs b/src/compressor.rs index e086f39..34a89ae 100644 --- a/src/compressor.rs +++ b/src/compressor.rs @@ -1,11 +1,13 @@ use std::{collections::HashMap, sync::Arc}; +use rand::Rng; use tokio::sync::RwLock; use crate::{ compressible_data::CompressibleData, - sql_types::{CompressionId, ZstdDictId}, - zstd_dict::{ZstdDict, ZstdDictArc}, + compression_stats::{CompressionStat, CompressionStats}, + sql_types::{DictId, ZstdDictId}, + zstd_dict::{ZstdDict, ZstdDictArc, ZSTD_LEVEL}, AsyncBoxError, CompressionPolicy, }; @@ -13,15 +15,18 @@ pub type CompressorArc = Arc>; pub struct Compressor { zstd_dict_by_id: HashMap, - zstd_dict_by_name: HashMap, + zstd_dict_by_name: HashMap>, compression_policy: CompressionPolicy, + compression_stats: CompressionStats, } + impl Compressor { pub fn new(compression_policy: CompressionPolicy) -> Self { Self { zstd_dict_by_id: HashMap::new(), zstd_dict_by_name: HashMap::new(), compression_policy, + compression_stats: CompressionStats::default(), } } } @@ -39,22 +44,40 @@ impl Compressor { Arc::new(RwLock::new(self)) } - pub fn add(&mut self, zstd_dict: ZstdDict) -> Result { - self.check(zstd_dict.id(), zstd_dict.name())?; + pub fn add_stat(&mut self, id: DictId, stat: CompressionStat) { + *self.compression_stats.for_id_mut(id) = stat; + } + + pub fn add_zstd_dict(&mut self, zstd_dict: ZstdDict) -> Result { + self.check(zstd_dict.id())?; let zstd_dict = Arc::new(zstd_dict); + self.zstd_dict_by_id .insert(zstd_dict.id(), zstd_dict.clone()); - self.zstd_dict_by_name - .insert(zstd_dict.name().to_string(), zstd_dict.clone()); + + let name = zstd_dict.name(); + if let Some(by_name_list) = self.zstd_dict_by_name.get_mut(name) { + by_name_list.push(zstd_dict.clone()); + } else { + self.zstd_dict_by_name + .insert(name.to_string(), vec![zstd_dict.clone()]); + } Ok(zstd_dict) } pub fn by_id(&self, id: ZstdDictId) -> Option<&ZstdDict> { self.zstd_dict_by_id.get(&id).map(|d| &**d) } + pub fn by_name(&self, name: &str) -> Option<&ZstdDict> { - self.zstd_dict_by_name.get(name).map(|d| &**d) + if let Some(by_name_list) = self.zstd_dict_by_name.get(name) { + // take a random element + Some(&by_name_list[rand::thread_rng().gen_range(0..by_name_list.len())]) + } else { + None + } } + pub fn names(&self) -> impl Iterator { self.zstd_dict_by_name.keys() } @@ -64,7 +87,7 @@ impl Compressor { hint: Option<&str>, content_type: &str, data: Data, - ) -> Result<(CompressionId, CompressibleData)> { + ) -> Result<(DictId, CompressibleData)> { let data = data.into(); let should_compress = match self.compression_policy { CompressionPolicy::None => false, @@ -74,18 +97,15 @@ impl Compressor { let generic_compress = || -> Result<_> { Ok(( - CompressionId::ZstdGeneric, - zstd::stream::encode_all(data.as_ref(), 3)?.into(), + DictId::ZstdGeneric, + zstd::stream::encode_all(data.as_ref(), ZSTD_LEVEL)?.into(), )) }; if should_compress { let (id, compressed): (_, CompressibleData) = if let Some(hint) = hint { if let Some(dict) = self.by_name(hint) { - ( - CompressionId::ZstdDictId(dict.id()), - dict.compress(&data)?.into(), - ) + (DictId::Zstd(dict.id()), dict.compress(&data)?.into()) } else { generic_compress()? } @@ -98,37 +118,42 @@ impl Compressor { } } - Ok((CompressionId::None, data)) + Ok((DictId::None, data)) } pub fn decompress>( &self, - compression_id: CompressionId, + dict_id: DictId, data: Data, ) -> Result { let data = data.into(); - match compression_id { - CompressionId::None => Ok(data), - CompressionId::ZstdDictId(id) => { + match dict_id { + DictId::None => Ok(data), + DictId::Zstd(id) => { if let Some(dict) = self.by_id(id) { Ok(CompressibleData::Vec(dict.decompress(data.as_ref())?)) } else { Err(format!("zstd dictionary {:?} not found", id).into()) } } - CompressionId::ZstdGeneric => Ok(CompressibleData::Vec(zstd::stream::decode_all( + DictId::ZstdGeneric => Ok(CompressibleData::Vec(zstd::stream::decode_all( data.as_ref(), )?)), } } - fn check(&self, id: ZstdDictId, name: &str) -> Result<()> { + pub fn stats_mut(&mut self) -> &mut CompressionStats { + &mut self.compression_stats + } + + pub fn stats(&self) -> &CompressionStats { + &self.compression_stats + } + + fn check(&self, id: ZstdDictId) -> Result<()> { if self.zstd_dict_by_id.contains_key(&id) { return Err(format!("zstd dictionary id {} already exists", id.0).into()); } - if self.zstd_dict_by_name.contains_key(name) { - return Err(format!("zstd dictionary name {} already exists", name).into()); - } Ok(()) } } @@ -146,6 +171,8 @@ fn auto_compressible_content_type(content_type: &str) -> bool { #[cfg(test)] pub mod test { + use std::collections::HashSet; + use rstest::rstest; use super::*; @@ -158,7 +185,7 @@ pub mod test { pub fn make_compressor_with(compression_policy: CompressionPolicy) -> Compressor { let mut compressor = Compressor::new(compression_policy); let zstd_dict = make_zstd_dict(1.into(), "dict1"); - compressor.add(zstd_dict).unwrap(); + compressor.add_zstd_dict(zstd_dict).unwrap(); compressor } @@ -184,22 +211,22 @@ pub mod test { ) { let compressor = make_compressor_with(compression_policy); let data = b"hello, world!".to_vec(); - let (compression_id, compressed) = compressor + let (dict_id, compressed) = compressor .compress(Some("dict1"), content_type, data.clone()) .unwrap(); - let data_uncompressed = compressor.decompress(compression_id, compressed).unwrap(); + let data_uncompressed = compressor.decompress(dict_id, compressed).unwrap(); assert_eq!(data_uncompressed, data); } #[test] fn test_skip_compressing_small_data() { let compressor = make_compressor(); - let data = b"hello, world".to_vec(); - let (compression_id, compressed) = compressor + let data = b"hi!".to_vec(); + let (dict_id, compressed) = compressor .compress(Some("dict1"), "text/plain", data.clone()) .unwrap(); - assert_eq!(compression_id, CompressionId::None); + assert_eq!(dict_id, DictId::None); assert_eq!(compressed, data); } @@ -207,11 +234,25 @@ pub mod test { fn test_compresses_longer_data() { let compressor = make_compressor(); let data = vec![b'.'; 1024]; - let (compression_id, compressed) = compressor + let (dict_id, compressed) = compressor .compress(Some("dict1"), "text/plain", data.clone()) .unwrap(); - assert_eq!(compression_id, CompressionId::ZstdDictId(1.into())); + assert_eq!(dict_id, DictId::Zstd(1.into())); assert_ne!(compressed, data); assert!(compressed.len() < data.len()); } + + #[test] + fn test_multiple_dicts_same_name() { + let mut compressor = make_compressor(); + let zstd_dict = make_zstd_dict(2.into(), "dict1"); + compressor.add_zstd_dict(zstd_dict).unwrap(); + + let mut seen_ids = HashSet::new(); + for _ in 0..1000 { + let zstd_dict = compressor.by_name("dict1").unwrap(); + seen_ids.insert(zstd_dict.id()); + } + assert_eq!(seen_ids, [1, 2].iter().copied().map(ZstdDictId).collect()); + } } diff --git a/src/main.rs b/src/main.rs index 63f3dd2..d010be7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ mod compressible_data; +mod compression_stats; mod compressor; mod concat_lines; mod handlers; @@ -12,18 +13,22 @@ mod zstd_dict; use crate::{manifest::Manifest, shards::Shards}; use axum::{ + extract::DefaultBodyLimit, routing::{get, post}, Extension, Router, }; use clap::{Parser, ValueEnum}; use compressor::CompressorArc; +use core::fmt; use futures::executor::block_on; use shard::Shard; use shards::ShardsArc; -use std::{error::Error, path::PathBuf, sync::Arc}; +use std::{error::Error, path::PathBuf, sync::Arc, time::Instant}; +use tabled::{settings::Style, Table, Tabled}; use tokio::{net::TcpListener, select, spawn}; use tokio_rusqlite::Connection; use tracing::{debug, info}; +use tracing_subscriber::fmt::{format::Writer, time::FormatTime}; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] @@ -63,9 +68,18 @@ pub fn into_tokio_rusqlite_err>(e: E) -> tokio_rusqlite:: tokio_rusqlite::Error::Other(e.into()) } +struct AppTimeFormatter(Instant); +impl FormatTime for AppTimeFormatter { + fn format_time(&self, w: &mut Writer<'_>) -> fmt::Result { + let e = self.0.elapsed(); + write!(w, "{:5}.{:03}s", e.as_secs(), e.subsec_millis()) + } +} + fn main() -> Result<(), AsyncBoxError> { tracing_subscriber::fmt() .with_max_level(tracing::Level::DEBUG) + .with_timer(AppTimeFormatter(Instant::now())) .init(); let args = Args::parse(); @@ -80,13 +94,13 @@ fn main() -> Result<(), AsyncBoxError> { } // block on opening the manifest - let manifest = block_on(async { + let manifest = Arc::new(block_on(async { Manifest::open( Connection::open(db_path.join("manifest.sqlite")).await?, num_shards, ) .await - })?; + })?); // max num_shards threads let runtime = tokio::runtime::Builder::new_multi_thread() @@ -102,59 +116,183 @@ fn main() -> Result<(), AsyncBoxError> { manifest.num_shards() ); let mut shards_vec = vec![]; + let mut num_entries = vec![]; for shard_id in 0..manifest.num_shards() { let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id)); let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?; let shard = Shard::open(shard_id, shard_sqlite_conn, manifest.compressor()).await?; - info!( - "shard {} has {} entries", - shard.id(), - shard.num_entries().await? - ); + num_entries.push(shard.num_entries().await?); shards_vec.push(shard); } + debug!( + "loaded shards with {} total entries <- {:?}", + num_entries.iter().sum::(), + num_entries + ); + + let compressor = manifest.compressor(); + { + let compressor = compressor.read().await; + debug!( + "loaded compression dictionaries: {:?}", + compressor.names().collect::>() + ); + } let shards = Arc::new(Shards::new(shards_vec).ok_or("num shards must be > 0")?); - let compressor = manifest.compressor(); - let dict_loop_handle = spawn(dict_loop(manifest, shards.clone())); - let server_handle = spawn(server_loop(server, shards, compressor)); - dict_loop_handle.await??; - server_handle.await??; - info!("server closed sqlite connections. bye!"); + let join_handles = vec![ + spawn(save_compression_stats_loop(manifest.clone())), + spawn(dict_stats_loop(manifest.compressor())), + spawn(new_dict_loop(manifest.clone(), shards.clone())), + spawn(server_loop(server, shards, compressor)), + ]; + for handle in join_handles { + handle.await??; + } + info!("saving compressor stats..."); + manifest.save_compression_stats().await?; + info!("done. bye!"); Ok::<_, AsyncBoxError>(()) })?; Ok(()) } -async fn dict_loop(manifest: Manifest, shards: ShardsArc) -> Result<(), AsyncBoxError> { +async fn save_compression_stats_loop(manifest: Arc) -> Result<(), AsyncBoxError> { loop { - let mut hint_names = shards.hint_names().await?; + manifest.save_compression_stats().await?; + select! { + _ = tokio::time::sleep(tokio::time::Duration::from_secs(60)) => {} + _ = crate::shutdown_signal::shutdown_signal() => { + info!("persist_dict_stats: shutdown signal received"); + return Ok(()); + } + } + } +} + +async fn dict_stats_loop(compressor: CompressorArc) -> Result<(), AsyncBoxError> { + fn humanized_bytes_str(number: &usize) -> String { + humansize::format_size(*number, humansize::BINARY) + } + fn compression_ratio_str(row: &DictStatTableRow) -> String { + if row.uncompressed_size == 0 { + "(n/a)".to_string() + } else { + format!( + "{:.2} x", + row.compressed_size as f64 / row.uncompressed_size as f64 + ) + } + } + + #[derive(Tabled, Default)] + struct DictStatTableRow { + name: String, + #[tabled(rename = "# entries")] + num_entries: usize, + #[tabled( + rename = "compression ratio", + display_with("compression_ratio_str", self) + )] + _ratio: (), + #[tabled(rename = "uncompressed bytes", display_with = "humanized_bytes_str")] + uncompressed_size: usize, + #[tabled(rename = "compressed bytes", display_with = "humanized_bytes_str")] + compressed_size: usize, + } + + loop { + { + let compressor = compressor.read().await; + let stats = compressor.stats(); + + Table::new(stats.iter().map(|(id, stat)| { + let name = match id { + sql_types::DictId::None => "none", + sql_types::DictId::ZstdGeneric => "zstd_generic", + sql_types::DictId::Zstd(id) => compressor + .by_id(id) + .map(|d| d.name()) + .unwrap_or("(missing dict)"), + } + .to_string(); + + DictStatTableRow { + name, + num_entries: stat.num_entries, + uncompressed_size: stat.uncompressed_size, + compressed_size: stat.compressed_size, + ..Default::default() + } + })) + .with(Style::rounded()) + .to_string() + .lines() + .for_each(|line| debug!("{}", line)); + } + + select! { + _ = tokio::time::sleep(tokio::time::Duration::from_secs(5)) => {} + _ = crate::shutdown_signal::shutdown_signal() => { + info!("dict_stats: shutdown signal received"); + return Ok(()); + } + } + } +} + +async fn new_dict_loop(manifest: Arc, shards: ShardsArc) -> Result<(), AsyncBoxError> { + loop { + let hint_names = shards.hint_names().await?; + let mut new_hint_names = hint_names.clone(); { // find what hint names don't have a corresponding dictionary let compressor = manifest.compressor(); let compressor = compressor.read().await; compressor.names().for_each(|name| { - hint_names.remove(name); + new_hint_names.remove(name); }); } - for hint_name in hint_names { - debug!("creating dictionary for {}", hint_name); - let samples = shards.samples_for_hint_name(&hint_name, 10).await?; - let samples_bytes = samples - .iter() - .map(|s| s.data.as_ref()) - .collect::>(); - manifest - .insert_dict_from_samples(hint_name, samples_bytes) + for hint_name in new_hint_names { + let samples = shards.samples_for_hint_name(&hint_name, 100).await?; + let num_samples = samples.len(); + let bytes_samples = samples.iter().map(|s| s.data.len()).sum::(); + let bytes_samples_human = humansize::format_size(bytes_samples, humansize::BINARY); + if num_samples < 10 || bytes_samples < 1024 { + debug!( + "skipping dictionary for {} - not enough samples: ({} samples, {})", + hint_name, num_samples, bytes_samples_human + ); + continue; + } + + debug!("building dictionary for {}", hint_name); + let now = chrono::Utc::now(); + let ref_samples = samples.iter().map(|s| s.data.as_ref()).collect::>(); + let dict = manifest + .insert_zstd_dict_from_samples(&hint_name, ref_samples) .await?; + + let duration = chrono::Utc::now() - now; + let bytes_dict_human = + humansize::format_size(dict.dict_bytes().len(), humansize::BINARY); + debug!( + "built dictionary {} ({}) in {}ms: {} samples / {} sample bytes, {} dict size", + dict.id(), + hint_name, + duration.num_milliseconds(), + num_samples, + bytes_samples_human, + bytes_dict_human, + ); } select! { _ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => {} _ = crate::shutdown_signal::shutdown_signal() => { - info!("dict loop: shutdown signal received"); + info!("new_dict_loop: shutdown signal received"); return Ok(()); } } @@ -171,7 +309,9 @@ async fn server_loop( .route("/get/:sha256", get(handlers::get_handler::get_handler)) .route("/info", get(handlers::info_handler::info_handler)) .layer(Extension(shards)) - .layer(Extension(compressor)); + .layer(Extension(compressor)) + // limit the max size of a request body to 100mb + .layer(DefaultBodyLimit::max(1024 * 1024 * 100)); axum::serve(server, app.into_make_service()) .with_graceful_shutdown(crate::shutdown_signal::shutdown_signal()) diff --git a/src/manifest/mod.rs b/src/manifest/mod.rs index 0c10be3..eff377f 100644 --- a/src/manifest/mod.rs +++ b/src/manifest/mod.rs @@ -1,16 +1,18 @@ mod manifest_key; -use std::{error::Error, sync::Arc}; +use std::sync::Arc; use rusqlite::params; use tokio::sync::RwLock; use tokio_rusqlite::Connection; +use tracing::{debug, info}; use crate::{ + compression_stats::{CompressionStat, CompressionStats}, compressor::{Compressor, CompressorArc}, - concat_lines, - sql_types::ZstdDictId, - zstd_dict::{ZstdDict, ZstdDictArc, ZstdEncoder}, + concat_lines, into_tokio_rusqlite_err, + sql_types::{DictId, ZstdDictId}, + zstd_dict::{ZstdDict, ZstdDictArc, ZstdEncoder, ZSTD_LEVEL}, AsyncBoxError, }; @@ -35,60 +37,97 @@ impl Manifest { self.compressor.clone() } - pub async fn insert_dict_from_samples>( + pub async fn save_compression_stats(&self) -> Result<(), AsyncBoxError> { + let compressor = self.compressor(); + self.conn + .call(move |conn| { + let compressor = compressor.blocking_read(); + let compression_stats = compressor.stats(); + save_compression_stats(conn, compression_stats).map_err(into_tokio_rusqlite_err)?; + Ok(()) + }) + .await?; + Ok(()) + } + + pub async fn insert_zstd_dict_from_samples>( &self, name: Str, samples: Vec<&[u8]>, ) -> Result { let name = name.into(); - let encoder = ZstdEncoder::from_samples(3, samples); + let encoder = ZstdEncoder::from_samples(ZSTD_LEVEL, samples); let zstd_dict = self .conn .call(move |conn| { - let level = 3; let mut stmt = conn.prepare(concat_lines!( - "INSERT INTO dictionaries (name, level, dict)", - "VALUES (?, ?, ?)", - "RETURNING id" + "INSERT INTO zstd_dictionaries (name, encoder_json)", + "VALUES (?, ?)", + "RETURNING zstd_dict_id" ))?; - let dict_id = - stmt.query_row(params![name, level, encoder.dict_bytes()], |row| row.get(0))?; - Ok(ZstdDict::new(dict_id, name, encoder)) + let zstd_dict_id = stmt.query_row(params![name, encoder], |row| row.get(0))?; + Ok(ZstdDict::new(zstd_dict_id, name, encoder)) }) .await?; let mut compressor = self.compressor.write().await; - compressor.add(zstd_dict) + compressor.add_stat(DictId::Zstd(zstd_dict.id()), CompressionStat::default()); + compressor.add_zstd_dict(zstd_dict) } } -async fn initialize( - conn: Connection, - num_shards: Option, -) -> Result> { +async fn migrate_manifest(conn: &Connection) -> Result<(), tokio_rusqlite::Error> { + conn.call(|conn| { + conn.execute( + "CREATE TABLE IF NOT EXISTS manifest (key TEXT PRIMARY KEY, value)", + [], + )?; + + conn.execute( + concat_lines!( + "CREATE TABLE IF NOT EXISTS zstd_dictionaries (", + " zstd_dict_id INTEGER PRIMARY KEY AUTOINCREMENT,", + " name TEXT NOT NULL,", + " encoder_json TEXT NOT NULL", + ")" + ), + [], + )?; + + conn.execute( + concat_lines!( + "CREATE TABLE IF NOT EXISTS compression_stats (", + " dict_id TEXT PRIMARY KEY,", + " num_entries INTEGER NOT NULL DEFAULT 0,", + " uncompressed_size INTEGER NOT NULL DEFAULT 0,", + " compressed_size INTEGER NOT NULL DEFAULT 0", + ")" + ), + [], + )?; + + // insert the default dictionaries (none, zstd_generic) + conn.execute( + concat_lines!( + "INSERT OR IGNORE INTO compression_stats (dict_id)", + "VALUES (?), (?)" + ), + params![DictId::None, DictId::ZstdGeneric], + )?; + + Ok(()) + }) + .await +} + +async fn load_and_store_num_shards( + conn: &Connection, + requested_num_shards: Option, +) -> Result { let stored_num_shards: Option = conn - .call(|conn| { - conn.execute( - "CREATE TABLE IF NOT EXISTS manifest (key TEXT PRIMARY KEY, value)", - [], - )?; - - conn.execute( - concat_lines!( - "CREATE TABLE IF NOT EXISTS dictionaries (", - " id INTEGER PRIMARY KEY AUTOINCREMENT,", - " level INTEGER NOT NULL,", - " name TEXT NOT NULL,", - " dict BLOB NOT NULL", - ")" - ), - [], - )?; - - Ok(get_manifest_key(conn, NumShards)?) - }) + .call(|conn| Ok(get_manifest_key(conn, NumShards)?)) .await?; - let num_shards = match (stored_num_shards, num_shards) { + Ok(match (stored_num_shards, requested_num_shards) { (Some(stored_num_shards), Some(num_shards)) => { if stored_num_shards == num_shards { num_shards @@ -115,29 +154,113 @@ async fn initialize( // existing database, use loaded num_shards num_shards } - }; + }) +} - type DictRow = (ZstdDictId, String, i32, Vec); - let rows: Vec = conn +async fn load_zstd_dicts( + conn: &Connection, + compressor: &mut Compressor, +) -> Result<(), AsyncBoxError> { + type ZstdDictRow = (ZstdDictId, String, ZstdEncoder); + let rows: Vec = conn .call(|conn| { - let mut stmt = conn.prepare("SELECT id, name, level, dict FROM dictionaries")?; - let mut rows = vec![]; - for r in stmt.query_map([], |row| DictRow::try_from(row))? { - rows.push(r?); - } + let rows = conn + .prepare(concat_lines!( + "SELECT", + " zstd_dict_id, name, encoder_json", + "FROM zstd_dictionaries", + ))? + .query_map([], |row| ZstdDictRow::try_from(row))? + .collect::>()?; Ok(rows) }) .await?; - let mut compressor = Compressor::default(); - for (dict_id, name, level, dict_bytes) in rows { - compressor.add(ZstdDict::new( - dict_id, - name, - ZstdEncoder::from_dict_bytes(level, dict_bytes), - ))?; + debug!("loaded {} zstd dictionaries from manifest", rows.len()); + for (zstd_dict_id, name, encoder) in rows { + compressor.add_zstd_dict(ZstdDict::new(zstd_dict_id, name, encoder))?; } - let compressor = compressor.into_arc(); + + Ok(()) +} + +async fn load_compression_stats( + conn: &Connection, + compressor: &mut Compressor, +) -> Result<(), AsyncBoxError> { + type CompressionStatRow = (DictId, usize, usize, usize); + let rows: Vec = conn + .call(|conn| { + let rows = conn + .prepare(concat_lines!( + "SELECT", + " dict_id, num_entries, uncompressed_size, compressed_size", + "FROM compression_stats", + ))? + .query_map([], |row| CompressionStatRow::try_from(row))? + .collect::>()?; + Ok(rows) + }) + .await?; + + debug!("loaded {} compression stats from manifest", rows.len()); + for (dict_id, num_entries, uncompressed_size, compressed_size) in rows { + compressor.add_stat( + dict_id, + CompressionStat { + num_entries, + uncompressed_size, + compressed_size, + }, + ); + } + + Ok(()) +} + +fn save_compression_stats( + conn: &rusqlite::Connection, + compression_stats: &CompressionStats, +) -> Result<(), AsyncBoxError> { + let mut num_stats = 0; + for (id, stat) in compression_stats.iter() { + num_stats += 1; + conn.execute( + concat_lines!( + "INSERT INTO compression_stats", + " (dict_id, num_entries, uncompressed_size, compressed_size)", + "VALUES (?, ?, ?, ?)", + "ON CONFLICT (dict_id) DO UPDATE SET", + " num_entries = excluded.num_entries,", + " uncompressed_size = excluded.uncompressed_size,", + " compressed_size = excluded.compressed_size" + ), + params![ + id, + stat.num_entries, + stat.uncompressed_size, + stat.compressed_size + ], + )?; + } + info!("saved {} compressor stats", num_stats); + Ok(()) +} + +async fn load_compressor(conn: &Connection) -> Result { + let mut compressor = Compressor::default(); + load_zstd_dicts(conn, &mut compressor).await?; + load_compression_stats(conn, &mut compressor).await?; + Ok(compressor.into_arc()) +} + +async fn initialize( + conn: Connection, + requested_num_shards: Option, +) -> Result { + migrate_manifest(&conn).await?; + let num_shards = load_and_store_num_shards(&conn, requested_num_shards).await?; + let compressor = load_compressor(&conn).await?; Ok(Manifest { conn, num_shards, @@ -149,14 +272,23 @@ async fn initialize( mod tests { use super::*; + #[tokio::test] + async fn test_manifest_compression_stats_loading() { + let conn = Connection::open_in_memory().await.unwrap(); + let manifest = initialize(conn, Some(4)).await.unwrap(); + let compressor = manifest.compressor(); + let compressor = compressor.read().await; + assert_eq!(compressor.stats().iter().count(), 2); + } + #[tokio::test] async fn test_manifest() { let conn = Connection::open_in_memory().await.unwrap(); - let manifest = initialize(conn, Some(3)).await.unwrap(); + let manifest = initialize(conn, Some(4)).await.unwrap(); let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100]; let zstd_dict = manifest - .insert_dict_from_samples("test", samples) + .insert_zstd_dict_from_samples("test", samples) .await .unwrap(); diff --git a/src/shard/fn_get.rs b/src/shard/fn_get.rs index 1899b48..8c53f35 100644 --- a/src/shard/fn_get.rs +++ b/src/shard/fn_get.rs @@ -1,12 +1,10 @@ use crate::{ compressible_data::CompressibleData, - into_tokio_rusqlite_err, - sql_types::{CompressionId, UtcDateTime}, + sql_types::{DictId, UtcDateTime}, AsyncBoxError, }; use super::*; - pub struct GetArgs { pub sha256: Sha256, } @@ -23,18 +21,12 @@ impl Shard { pub async fn get(&self, args: GetArgs) -> Result, AsyncBoxError> { let maybe_row = self .conn - .call(move |conn| { - get_compressed_row(conn, &args.sha256).map_err(into_tokio_rusqlite_err) - }) - .await - .map_err(|e| { - error!("get failed: {}", e); - Box::new(e) - })?; + .call(move |conn| Ok(get_compressed_row(conn, &args.sha256)?)) + .await?; - if let Some((content_type, stored_size, created_at, compression_id, data)) = maybe_row { + if let Some((content_type, stored_size, created_at, dict_id, data)) = maybe_row { let compressor = self.compressor.read().await; - let data = compressor.decompress(compression_id, data)?; + let data = compressor.decompress(dict_id, data)?; Ok(Some(GetResult { sha256: args.sha256, content_type, @@ -48,13 +40,13 @@ impl Shard { } } -type CompressedRowResult = (String, usize, UtcDateTime, CompressionId, Vec); +type CompressedRowResult = (String, usize, UtcDateTime, DictId, Vec); fn get_compressed_row( conn: &mut rusqlite::Connection, sha256: &Sha256, ) -> Result, rusqlite::Error> { conn.query_row( - "SELECT content_type, compressed_size, created_at, compression_id, data + "SELECT content_type, compressed_size, created_at, dict_id, data FROM entries WHERE sha256 = ?", params![sha256], diff --git a/src/shard/fn_migrate.rs b/src/shard/fn_migrate.rs index 1513bd9..8587855 100644 --- a/src/shard/fn_migrate.rs +++ b/src/shard/fn_migrate.rs @@ -11,12 +11,7 @@ impl Shard { ensure_schema_versions_table(conn)?; let schema_rows = load_schema_rows(conn)?; - if let Some((version, date_time)) = schema_rows.first() { - debug!( - "shard {}: latest schema version: {} @ {}", - shard_id, version, date_time - ); - + if let Some((version, _)) = schema_rows.first() { if *version == 1 { // no-op } else { @@ -70,7 +65,7 @@ fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Err " id INTEGER PRIMARY KEY AUTOINCREMENT,", " sha256 BLOB NOT NULL,", " content_type TEXT NOT NULL,", - " compression_id INTEGER NOT NULL,", + " dict_id TEXT NOT NULL,", " uncompressed_size INTEGER NOT NULL,", " compressed_size INTEGER NOT NULL,", " data BLOB NOT NULL,", @@ -101,7 +96,7 @@ fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Err conn.execute( concat_lines!( - "CREATE INDEX IF NOT EXISTS compression_hints_name_idx", + "CREATE UNIQUE INDEX IF NOT EXISTS compression_hints_name_idx", "ON compression_hints (name, ordering)", ), [], diff --git a/src/shard/fn_store.rs b/src/shard/fn_store.rs index 1349c63..f1ad5c1 100644 --- a/src/shard/fn_store.rs +++ b/src/shard/fn_store.rs @@ -1,7 +1,7 @@ use crate::{ compressible_data::CompressibleData, - into_tokio_rusqlite_err, - sql_types::{CompressionId, UtcDateTime}, + concat_lines, into_tokio_rusqlite_err, + sql_types::{DictId, EntryId, UtcDateTime}, AsyncBoxError, }; @@ -50,18 +50,25 @@ impl Shard { let uncompressed_size = data.len(); - let (compression_id, data) = { + let (dict_id, data) = { let compressor = self.compressor.read().await; compressor.compress(compression_hint.as_deref(), &content_type, data)? }; + { + let mut compressor = self.compressor.write().await; + compressor + .stats_mut() + .add_entry(dict_id, uncompressed_size, data.len()); + } + self.conn .call(move |conn| { insert( conn, &sha256, content_type, - compression_id, + dict_id, uncompressed_size, data, compression_hint, @@ -78,7 +85,11 @@ fn find_with_sha256( sha256: &Sha256, ) -> Result, rusqlite::Error> { conn.query_row( - "SELECT uncompressed_size, compressed_size, created_at FROM entries WHERE sha256 = ?", + concat_lines!( + "SELECT uncompressed_size, compressed_size, created_at", + "FROM entries", + "WHERE sha256 = ?" + ), params![sha256], |row| { Ok(StoreResult::Exists { @@ -95,7 +106,7 @@ fn insert( conn: &mut rusqlite::Connection, sha256: &Sha256, content_type: String, - compression_id: CompressionId, + dict_id: DictId, uncompressed_size: usize, data: CompressibleData, compression_hint: Option, @@ -103,33 +114,37 @@ fn insert( let created_at = UtcDateTime::now(); let compressed_size = data.len(); - let entry_id: i64 = conn.query_row( - "INSERT INTO entries - (sha256, content_type, compression_id, uncompressed_size, compressed_size, data, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?) - RETURNING id - ", + let tx = conn.transaction()?; + let entry_id: EntryId = tx.query_row( + concat_lines!( + "INSERT INTO entries", + " (sha256, content_type, dict_id, uncompressed_size, compressed_size, data, created_at)", + "VALUES (?, ?, ?, ?, ?, ?, ?)", + "RETURNING id" + ), params![ sha256, content_type, - compression_id, + dict_id, uncompressed_size, compressed_size, data.as_ref(), created_at, ], - |row| row.get(0) + |row| row.get(0), )?; if let Some(compression_hint) = compression_hint { let rand_ordering = rand::random::(); - conn.execute( - "INSERT INTO compression_hints - (name, ordering, entry_id) - VALUES (?, ?, ?)", + tx.execute( + concat_lines!( + "INSERT INTO compression_hints (name, ordering, entry_id)", + "VALUES (?, ?, ?)" + ), params![compression_hint, rand_ordering, entry_id], )?; } + tx.commit()?; Ok(StoreResult::Created { stored_size: compressed_size, diff --git a/src/shard/mod.rs b/src/shard/mod.rs index 62d8851..b89650d 100644 --- a/src/shard/mod.rs +++ b/src/shard/mod.rs @@ -21,4 +21,4 @@ use axum::body::Bytes; use rusqlite::{params, types::FromSql, OptionalExtension}; use tokio_rusqlite::Connection; -use tracing::{debug, error}; +use tracing::debug; diff --git a/src/sql_types/compression_id.rs b/src/sql_types/compression_id.rs deleted file mode 100644 index 0d7a341..0000000 --- a/src/sql_types/compression_id.rs +++ /dev/null @@ -1,61 +0,0 @@ -use rusqlite::{ - types::{FromSql, FromSqlError, FromSqlResult, ToSqlOutput, Value::Integer, ValueRef}, - Error::ToSqlConversionFailure, - ToSql, -}; - -use crate::AsyncBoxError; - -use super::ZstdDictId; - -#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)] -pub enum CompressionId { - None, - ZstdGeneric, - ZstdDictId(ZstdDictId), -} - -impl FromSql for CompressionId { - fn column_result(value: ValueRef<'_>) -> FromSqlResult { - Ok(match value.as_i64()? { - -1 => CompressionId::None, - -2 => CompressionId::ZstdGeneric, - id => CompressionId::ZstdDictId(ZstdDictId(id)), - }) - } -} - -impl ToSql for CompressionId { - fn to_sql(&self) -> rusqlite::Result> { - let value = match self { - CompressionId::None => -1, - CompressionId::ZstdGeneric => -2, - CompressionId::ZstdDictId(ZstdDictId(id)) => *id, - }; - Ok(ToSqlOutput::Owned(Integer(value))) - } -} - -impl FromSql for ZstdDictId { - fn column_result(value: ValueRef<'_>) -> FromSqlResult { - match value.as_i64()? { - id @ (-1 | -2) => Err(FromSqlError::Other(invalid_zstd_dict_id_err(id))), - id => Ok(ZstdDictId(id)), - } - } -} - -impl ToSql for ZstdDictId { - fn to_sql(&self) -> rusqlite::Result> { - let value = match self.0 { - id @ (-1 | -2) => return Err(ToSqlConversionFailure(invalid_zstd_dict_id_err(id))), - id => id, - }; - - Ok(ToSqlOutput::Owned(Integer(value))) - } -} - -fn invalid_zstd_dict_id_err(id: i64) -> AsyncBoxError { - format!("Invalid ZstdDictId: {}", id).into() -} diff --git a/src/sql_types/dict_id.rs b/src/sql_types/dict_id.rs new file mode 100644 index 0000000..565b38d --- /dev/null +++ b/src/sql_types/dict_id.rs @@ -0,0 +1,75 @@ +use rusqlite::{ + types::{FromSql, FromSqlError, FromSqlResult, ToSqlOutput, ValueRef}, + ToSql, +}; + +use super::ZstdDictId; + +#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, PartialOrd, Ord)] +pub enum DictId { + None, + ZstdGeneric, + Zstd(ZstdDictId), +} + +impl DictId { + fn name_prefix(&self) -> &str { + match self { + DictId::None => "none", + DictId::ZstdGeneric => "zstd_generic", + DictId::Zstd(_) => "zstd", + } + } +} + +impl ToSql for DictId { + fn to_sql(&self) -> rusqlite::Result> { + let prefix = self.name_prefix(); + if let DictId::Zstd(id) = self { + Ok(ToSqlOutput::from(format!("{}:{}", prefix, id.0))) + } else { + prefix.to_sql() + } + } +} + +impl FromSql for DictId { + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + let s = value.as_str()?; + let dict_id = if s == "none" { + DictId::None + } else if s == "zstd_generic" { + DictId::ZstdGeneric + } else if let Some(id_str) = s.strip_prefix("zstd:") { + let id = id_str + .parse() + .map_err(|e| FromSqlError::Other(format!("invalid ZstdDictId: {}", e).into()))?; + DictId::Zstd(ZstdDictId::column_result(ValueRef::Integer(id))?) + } else { + return Err(FromSqlError::Other(format!("invalid DictId: {}", s).into())); + }; + Ok(dict_id) + } +} + +#[cfg(test)] +mod tests { + use rusqlite::types::Value; + + use super::*; + + #[test] + fn test_dict_id() { + let id = DictId::Zstd(ZstdDictId(42)); + let id_str = "zstd:42"; + + let sql_to_id = DictId::column_result(ValueRef::Text(id_str.as_bytes())); + assert_eq!(id, sql_to_id.unwrap()); + + let id_to_sql = match id.to_sql() { + Ok(ToSqlOutput::Owned(Value::Text(text))) => text, + _ => panic!("unexpected ToSqlOutput: {:?}", id.to_sql()), + }; + assert_eq!(id_str, id_to_sql); + } +} diff --git a/src/sql_types/entry_id.rs b/src/sql_types/entry_id.rs new file mode 100644 index 0000000..8f41d78 --- /dev/null +++ b/src/sql_types/entry_id.rs @@ -0,0 +1,24 @@ +use rusqlite::{ + types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}, + ToSql, +}; + +#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)] +pub struct EntryId(pub i64); +impl From for EntryId { + fn from(id: i64) -> Self { + Self(id) + } +} + +impl FromSql for EntryId { + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + Ok(value.as_i64()?.into()) + } +} + +impl ToSql for EntryId { + fn to_sql(&self) -> rusqlite::Result> { + self.0.to_sql() + } +} diff --git a/src/sql_types/mod.rs b/src/sql_types/mod.rs index c53ffa2..6020296 100644 --- a/src/sql_types/mod.rs +++ b/src/sql_types/mod.rs @@ -1,7 +1,9 @@ -mod compression_id; +mod dict_id; +mod entry_id; mod utc_date_time; mod zstd_dict_id; -pub use compression_id::CompressionId; +pub use dict_id::DictId; +pub use entry_id::EntryId; pub use utc_date_time::UtcDateTime; pub use zstd_dict_id::ZstdDictId; diff --git a/src/sql_types/zstd_dict_id.rs b/src/sql_types/zstd_dict_id.rs index 8f1306a..48ecbd2 100644 --- a/src/sql_types/zstd_dict_id.rs +++ b/src/sql_types/zstd_dict_id.rs @@ -1,7 +1,32 @@ -#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)] +use std::fmt::Display; + +use rusqlite::{ + types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}, + ToSql, +}; + +#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, PartialOrd, Ord)] pub struct ZstdDictId(pub i64); impl From for ZstdDictId { fn from(id: i64) -> Self { Self(id) } } + +impl Display for ZstdDictId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "zstd_dict:{}", self.0) + } +} + +impl FromSql for ZstdDictId { + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + Ok(ZstdDictId(value.as_i64()?)) + } +} + +impl ToSql for ZstdDictId { + fn to_sql(&self) -> rusqlite::Result> { + self.0.to_sql() + } +} diff --git a/src/zstd_dict.rs b/src/zstd_dict.rs index 505eed1..dfc3f7c 100644 --- a/src/zstd_dict.rs +++ b/src/zstd_dict.rs @@ -1,9 +1,18 @@ use ouroboros::self_referencing; +use rusqlite::{ + types::{FromSql, FromSqlError, FromSqlResult, ToSqlOutput, Value, ValueRef}, + Error::ToSqlConversionFailure, + ToSql, +}; +use serde::{Deserialize, Serialize}; use std::{io, sync::Arc}; use zstd::dict::{DecoderDictionary, EncoderDictionary}; use crate::{sql_types::ZstdDictId, AsyncBoxError}; +const ENCODER_DICT_SIZE: usize = 2 * 1024 * 1024; +pub const ZSTD_LEVEL: i32 = 9; + pub type ZstdDictArc = Arc; #[self_referencing] @@ -20,9 +29,62 @@ pub struct ZstdEncoder { decoder_dict: DecoderDictionary<'this>, } +impl Serialize for ZstdEncoder { + fn serialize(&self, serializer: S) -> Result { + use serde::ser::SerializeTuple; + + // serialize level and dict_bytes as a tuple + let mut state = serializer.serialize_tuple(2)?; + state.serialize_element(&self.level())?; + state.serialize_element(self.dict_bytes())?; + state.end() + } +} + +impl<'de> Deserialize<'de> for ZstdEncoder { + fn deserialize>(deserializer: D) -> Result { + use serde::de::SeqAccess; + + struct ZstdEncoderVisitor; + impl<'de> serde::de::Visitor<'de> for ZstdEncoderVisitor { + type Value = ZstdEncoder; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a tuple of (i32, Vec)") + } + + fn visit_seq>(self, mut seq: A) -> Result { + let level = seq + .next_element()? + .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?; + let dict_bytes = seq + .next_element()? + .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?; + Ok(ZstdEncoder::from_dict_bytes(level, dict_bytes)) + } + } + + deserializer.deserialize_tuple(2, ZstdEncoderVisitor) + } +} + +impl ToSql for ZstdEncoder { + fn to_sql(&self) -> rusqlite::Result> { + let json = serde_json::to_string(self).map_err(|e| ToSqlConversionFailure(Box::new(e)))?; + Ok(ToSqlOutput::Owned(Value::Text(json))) + } +} + +impl FromSql for ZstdEncoder { + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + let json = value.as_str()?; + serde_json::from_str(json).map_err(|e| FromSqlError::Other(e.into())) + } +} + impl ZstdEncoder { pub fn from_samples(level: i32, samples: Vec<&[u8]>) -> Self { - let dict_bytes = zstd::dict::from_samples(&samples, 1024 * 1024).unwrap(); + let dict_bytes = zstd::dict::from_samples(&samples, ENCODER_DICT_SIZE).unwrap(); Self::from_dict_bytes(level, dict_bytes) } @@ -36,6 +98,9 @@ impl ZstdEncoder { .build() } + pub fn level(&self) -> i32 { + *self.borrow_level() + } pub fn dict_bytes(&self) -> &[u8] { self.borrow_dict_bytes() } @@ -128,7 +193,7 @@ pub mod test { id, name.to_owned(), ZstdEncoder::from_samples( - 3, + 5, vec![ "hello, world", "this is a test",