Compare commits

...

10 Commits

Author SHA1 Message Date
Dylan Knutson
0bb8a333e4 add dockerfile 2024-06-12 14:03:25 -07:00
Dylan Knutson
eef6811625 remove zstd dicts & hints 2024-06-12 13:39:49 -07:00
Dylan Knutson
f44d884761 fix policy setting 2024-06-12 13:28:16 -07:00
Dylan Knutson
adbccc97a5 brotli generic compression 2024-06-04 19:53:05 -07:00
Dylan Knutson
3f44344ac0 zstd dict refactorings 2024-05-11 14:36:50 -04:00
Dylan Knutson
6deb909c43 clippy, concat_lines 2024-05-06 09:05:44 -07:00
Dylan Knutson
83b4dacede tests for hints 2024-05-05 23:02:14 -07:00
Dylan Knutson
b0955c9c64 move compression into Shard field 2024-05-05 21:03:31 -07:00
Dylan Knutson
a3b550526e store compression hint name 2024-05-05 20:52:21 -07:00
Dylan Knutson
bd2de7cfac zstd_dict_id and such 2024-05-05 19:05:41 -07:00
37 changed files with 1579 additions and 574 deletions

132
Cargo.lock generated
View File

@@ -44,6 +44,21 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd"
[[package]]
name = "alloc-no-stdlib"
version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3"
[[package]]
name = "alloc-stdlib"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece"
dependencies = [
"alloc-no-stdlib",
]
[[package]]
name = "allocator-api2"
version = "0.2.18"
@@ -277,11 +292,14 @@ version = "0.1.0"
dependencies = [
"axum",
"axum_typed_multipart",
"brotli",
"chrono",
"clap",
"futures",
"hex",
"humansize",
"kdam",
"num_cpus",
"ouroboros",
"rand",
"reqwest",
@@ -290,10 +308,12 @@ dependencies = [
"serde",
"serde_json",
"sha2",
"tabled",
"tokio",
"tokio-rusqlite",
"tracing",
"tracing-subscriber",
"walkdir",
"zstd",
]
@@ -306,12 +326,39 @@ dependencies = [
"generic-array",
]
[[package]]
name = "brotli"
version = "6.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b"
dependencies = [
"alloc-no-stdlib",
"alloc-stdlib",
"brotli-decompressor",
]
[[package]]
name = "brotli-decompressor"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6221fe77a248b9117d431ad93761222e1cf8ff282d9d1d5d9f53d6299a1cf76"
dependencies = [
"alloc-no-stdlib",
"alloc-stdlib",
]
[[package]]
name = "bumpalo"
version = "3.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
[[package]]
name = "bytecount"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce"
[[package]]
name = "bytes"
version = "1.6.0"
@@ -805,6 +852,15 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "humansize"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6cb51c9a029ddc91b07a787f1d86b53ccfa49b0e86688c946ebe8d3555685dd7"
dependencies = [
"libm",
]
[[package]]
name = "hyper"
version = "1.3.1"
@@ -972,6 +1028,12 @@ version = "0.2.153"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd"
[[package]]
name = "libm"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "libsqlite3-sys"
version = "0.28.0"
@@ -1207,6 +1269,17 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "papergrid"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ad43c07024ef767f9160710b3a6773976194758c7919b17e63b863db0bdf7fb"
dependencies = [
"bytecount",
"fnv",
"unicode-width",
]
[[package]]
name = "parking_lot"
version = "0.12.1"
@@ -1552,6 +1625,15 @@ version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1"
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]]
name = "schannel"
version = "0.1.23"
@@ -1734,6 +1816,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
@@ -1781,6 +1864,30 @@ dependencies = [
"libc",
]
[[package]]
name = "tabled"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c998b0c8b921495196a48aabaf1901ff28be0760136e31604f7967b0792050e"
dependencies = [
"papergrid",
"tabled_derive",
"unicode-width",
]
[[package]]
name = "tabled_derive"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c138f99377e5d653a371cdad263615634cfc8467685dfe8e73e2b8e98f44b17"
dependencies = [
"heck 0.4.1",
"proc-macro-error",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "tempfile"
version = "3.10.1"
@@ -2047,6 +2154,12 @@ dependencies = [
"tinyvec",
]
[[package]]
name = "unicode-width"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6"
[[package]]
name = "url"
version = "2.5.0"
@@ -2088,6 +2201,16 @@ version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "walkdir"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
dependencies = [
"same-file",
"winapi-util",
]
[[package]]
name = "want"
version = "0.3.1"
@@ -2195,6 +2318,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b"
dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"

View File

@@ -12,6 +12,10 @@ path = "src/main.rs"
name = "load-test"
path = "load_test/main.rs"
[[bin]]
name = "fixture-inserter"
path = "fixture_inserter/main.rs"
[dependencies]
axum = { version = "0.7.5", features = ["macros"] }
axum_typed_multipart = "0.11.1"
@@ -32,6 +36,15 @@ reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] }
hex = "0.4.3"
zstd = { version = "0.13.1", features = ["experimental"] }
ouroboros = "0.18.3"
humansize = "2.1.3"
walkdir = "2.5.0"
tabled = "0.15.0"
brotli = "6.0.0"
num_cpus = "1.16.0"
[dev-dependencies]
rstest = "0.19.0"
[lints.rust]
unsafe_code = "forbid"
unused_must_use = "forbid"

14
Dockerfile Normal file
View File

@@ -0,0 +1,14 @@
FROM rust:1.77 as builder
RUN mkdir /build
WORKDIR /build
COPY Cargo.toml Cargo.lock ./
COPY src ./src
RUN --mount=type=cache,target=/usr/local/cargo/registry \
--mount=type=cache,target=/build/target \
cargo build --release --bin blob-store-app && \
cp target/release/blob-store-app /blob-store-app
FROM debian:bookworm-slim
RUN apt-get update && apt-get install -qqy libsqlite3-0
COPY --from=builder /blob-store-app /usr/local/bin/blob-store-server
ENTRYPOINT ["blob-store-server", "--db-path", "/data", "--bind", "0.0.0.0"]

107
fixture_inserter/main.rs Normal file
View File

@@ -0,0 +1,107 @@
use clap::{arg, Parser};
use kdam::BarExt;
use kdam::{tqdm, Bar};
use rand::prelude::SliceRandom;
use reqwest::blocking::{multipart, Client};
use std::{
error::Error,
sync::{Arc, Mutex},
};
use walkdir::WalkDir;
#[derive(Parser, Debug, Clone)]
struct Args {
#[arg(long)]
hint_name: String,
#[arg(long)]
fixture_dir: String,
#[arg(long)]
limit: Option<usize>,
#[arg(long)]
num_threads: Option<usize>,
}
fn main() -> Result<(), Box<dyn Error>> {
let args = Args::parse();
let hint_name = args.hint_name;
let num_threads = args.num_threads.unwrap_or(1);
let pb = Arc::new(Mutex::new(tqdm!()));
let mut entries = WalkDir::new(&args.fixture_dir)
.into_iter()
.filter_map(Result::ok)
.filter(|entry| entry.file_type().is_file())
.filter_map(|entry| {
let path = entry.path();
let name = path.to_str()?.to_string();
Some(name)
})
.collect::<Vec<_>>();
// shuffle the entries
entries.shuffle(&mut rand::thread_rng());
// if there's a limit, drop the rest
if let Some(limit) = args.limit {
entries.truncate(limit);
}
let entry_slices = entries.chunks(entries.len() / num_threads);
println!("entry_slices: {:?}", entry_slices.len());
let join_handles = entry_slices.map(|entry_slice| {
let hint_name = hint_name.clone();
let pb = pb.clone();
let entry_slice = entry_slice.to_vec();
std::thread::spawn(move || {
let client = Client::new();
for entry in entry_slice.into_iter() {
store_file(&client, &hint_name, &entry, pb.clone()).unwrap();
}
})
});
for join_handle in join_handles {
join_handle.join().unwrap();
}
Ok(())
}
fn store_file(
client: &Client,
hint_name: &str,
file_name: &str,
pb: Arc<Mutex<Bar>>,
) -> Result<(), Box<dyn Error>> {
let file_bytes = std::fs::read(file_name)?;
let content_type = if file_name.ends_with(".html") {
"text/html"
} else if file_name.ends_with(".json") {
"application/json"
} else if file_name.ends_with(".jpg") || file_name.ends_with(".jpeg") {
"image/jpeg"
} else if file_name.ends_with(".png") {
"image/png"
} else if file_name.ends_with(".txt") {
"text/plain"
} else if file_name.ends_with(".kindle.images") {
"application/octet-stream"
} else {
"text/html"
};
let form = multipart::Form::new()
.text("content_type", content_type)
.text("compression_hint", hint_name.to_string())
.part("data", multipart::Part::bytes(file_bytes));
let _ = client
.post("http://localhost:7692/store")
.multipart(form)
.send();
let mut pb = pb.lock().unwrap();
pb.update(1)?;
Ok(())
}

View File

@@ -1,12 +1,11 @@
use std::sync::Arc;
use std::sync::Mutex;
use clap::Parser;
use kdam::tqdm;
use kdam::Bar;
use kdam::BarExt;
use rand::Rng;
use reqwest::StatusCode;
use std::sync::Arc;
use std::sync::Mutex;
#[derive(Parser, Debug, Clone)]
#[command(version, about, long_about = None)]
@@ -43,14 +42,23 @@ fn run_loop(pb: Arc<Mutex<Bar>>, args: Args) -> Result<(), Box<dyn std::error::E
let mut rng = rand::thread_rng();
let mut rand_data = vec![0u8; args.file_size];
rng.fill(&mut rand_data[..]);
let hints = vec![
"foo", "bar", "baz", "qux", "quux", "corge", "grault", "garply", "waldo", "fred", "plugh",
"xyzzy", "thud",
];
loop {
// tweak a byte in the data
let idx = rng.gen_range(0..rand_data.len());
rand_data[idx] = rng.gen();
let hint = hints[rng.gen_range(0..hints.len())];
let form = reqwest::blocking::multipart::Form::new()
.text("content_type", "text/plain")
.part(
"compression_hint",
reqwest::blocking::multipart::Part::text(hint),
)
.part(
"data",
reqwest::blocking::multipart::Part::bytes(rand_data.clone()),

18
src/app_time_formatter.rs Normal file
View File

@@ -0,0 +1,18 @@
use std::time::Instant;
use tracing_subscriber::fmt::{format::Writer, time::FormatTime};
pub struct AppTimeFormatter(Instant);
impl AppTimeFormatter {
pub fn new() -> Self {
Self(Instant::now())
}
}
impl FormatTime for AppTimeFormatter {
fn format_time(&self, w: &mut Writer<'_>) -> std::fmt::Result {
let e = self.0.elapsed();
write!(w, "{:5}.{:03}s", e.as_secs(), e.subsec_millis())
}
}

67
src/compressible_data.rs Normal file
View File

@@ -0,0 +1,67 @@
use axum::{body::Bytes, response::IntoResponse};
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum CompressibleData {
Bytes(Bytes),
Vec(Vec<u8>),
}
impl CompressibleData {
pub fn len(&self) -> usize {
match self {
CompressibleData::Bytes(b) => b.len(),
CompressibleData::Vec(v) => v.len(),
}
}
}
impl AsRef<[u8]> for CompressibleData {
fn as_ref(&self) -> &[u8] {
match self {
CompressibleData::Bytes(b) => b.as_ref(),
CompressibleData::Vec(v) => v.as_ref(),
}
}
}
impl From<Bytes> for CompressibleData {
fn from(b: Bytes) -> Self {
CompressibleData::Bytes(b)
}
}
impl From<Vec<u8>> for CompressibleData {
fn from(v: Vec<u8>) -> Self {
CompressibleData::Vec(v)
}
}
impl IntoResponse for CompressibleData {
fn into_response(self) -> axum::response::Response {
match self {
CompressibleData::Bytes(b) => b.into_response(),
CompressibleData::Vec(v) => v.into_response(),
}
}
}
impl PartialEq<&[u8]> for CompressibleData {
fn eq(&self, other: &&[u8]) -> bool {
let as_ref = self.as_ref();
as_ref == *other
}
}
impl PartialEq<Vec<u8>> for CompressibleData {
fn eq(&self, other: &Vec<u8>) -> bool {
let as_ref = self.as_ref();
as_ref == other.as_slice()
}
}
impl PartialEq<Bytes> for CompressibleData {
fn eq(&self, other: &Bytes) -> bool {
let as_ref = self.as_ref();
as_ref == other.as_ref()
}
}

222
src/compression_manager.rs Normal file
View File

@@ -0,0 +1,222 @@
use crate::{
compressible_data::CompressibleData,
compression_stats::{CompressionStat, CompressionStats},
compressor::{BrotliGenericCompressor, Compressor, ZstdGenericCompressor},
sql_types::CompressionId,
AsyncBoxError, CompressionPolicy,
};
use std::sync::Arc;
use tokio::sync::{RwLock, RwLockReadGuard};
pub type CompressionManagerArc = Arc<RwLock<CompressionManager>>;
pub type CompressionStatsArc = Arc<RwLock<CompressionStats>>;
pub struct CompressionManager {
compression_policy: CompressionPolicy,
compression_stats: CompressionStatsArc,
}
impl CompressionManager {
pub fn new(compression_policy: CompressionPolicy) -> Self {
Self {
compression_policy,
compression_stats: Arc::new(RwLock::new(CompressionStats::default())),
}
}
}
impl Default for CompressionManager {
fn default() -> Self {
Self::new(CompressionPolicy::Auto)
}
}
type Result<T> = std::result::Result<T, AsyncBoxError>;
impl CompressionManager {
pub fn into_arc(self) -> CompressionManagerArc {
Arc::new(RwLock::new(self))
}
pub fn set_compression_policy(&mut self, compression_policy: CompressionPolicy) {
self.compression_policy = compression_policy;
}
pub async fn set_stat(&mut self, id: CompressionId, stat: CompressionStat) {
let mut compression_stats = self.compression_stats.write().await;
*compression_stats.for_id_mut(id) = stat;
}
pub async fn compress<Data: Into<CompressibleData>>(
&self,
content_type: &str,
data: Data,
) -> Result<(CompressionId, CompressibleData)> {
let data = data.into();
match self.compression_policy {
CompressionPolicy::Auto => self.compress_auto(content_type, data).await,
CompressionPolicy::None => Ok((CompressionId::None, data)),
CompressionPolicy::ForceZstd => self.compress_with_id(CompressionId::Zstd, &data).await,
CompressionPolicy::ForceBrotli => {
self.compress_with_id(CompressionId::Brotli, &data).await
}
}
}
async fn compress_auto(
&self,
content_type: &str,
data: CompressibleData,
) -> Result<(CompressionId, CompressibleData)> {
if auto_compressible_content_type(content_type) {
let mut compressed_refs = vec![];
for id in [CompressionId::Zstd, CompressionId::Brotli] {
let compressed = self.compress_with_id(id, &data).await?;
compressed_refs.push((id, compressed.1.len(), compressed.1));
}
compressed_refs.push((CompressionId::None, data.len(), data));
// find the smallest compressed data
let (compression_id, _, data) = compressed_refs
.into_iter()
.min_by_key(|(_, len, _)| *len)
.unwrap();
Ok((compression_id, data))
} else {
Ok((CompressionId::None, data))
}
}
async fn compress_with_id(
&self,
compression_id: CompressionId,
data: &CompressibleData,
) -> Result<(CompressionId, CompressibleData)> {
let compressor = self.get_compressor(compression_id)?;
// get current time
let now = std::time::Instant::now();
let compressed_data = compressor.compress(data)?;
let mut stats = self.compression_stats.write().await;
stats.add_entry(
compression_id,
data.len(),
compressed_data.len(),
now.elapsed(),
);
Ok((compression_id, compressed_data))
}
fn get_compressor(&self, compression_id: CompressionId) -> Result<Box<dyn Compressor>> {
Ok(match compression_id {
CompressionId::Zstd => Box::new(ZstdGenericCompressor),
CompressionId::Brotli => Box::new(BrotliGenericCompressor),
CompressionId::None => return Err("no compressor for CompressionId::None".into()),
})
}
pub fn decompress<Data: Into<CompressibleData>>(
&self,
compression_id: CompressionId,
data: Data,
) -> Result<CompressibleData> {
let data = data.into();
match compression_id {
CompressionId::None => Ok(data),
CompressionId::Zstd => Ok(CompressibleData::Vec(zstd::stream::decode_all(
data.as_ref(),
)?)),
CompressionId::Brotli => {
let mut decompressed = vec![];
let mut reader = std::io::Cursor::new(data.as_ref());
brotli::BrotliDecompress(&mut reader, &mut decompressed)?;
Ok(CompressibleData::Vec(decompressed))
}
}
}
pub async fn stats(&self) -> RwLockReadGuard<CompressionStats> {
self.compression_stats.read().await
}
}
fn auto_compressible_content_type(content_type: &str) -> bool {
[
"text/",
"application/xml",
"application/json",
"application/javascript",
]
.iter()
.any(|ct| content_type.starts_with(ct))
}
#[cfg(test)]
pub mod test {
use super::*;
use rstest::rstest;
pub fn make_compressor() -> CompressionManager {
make_compressor_with(CompressionPolicy::Auto)
}
pub fn make_compressor_with(compression_policy: CompressionPolicy) -> CompressionManager {
CompressionManager::new(compression_policy)
}
#[test]
fn test_auto_compressible_content_type() {
assert!(auto_compressible_content_type("text/plain"));
assert!(auto_compressible_content_type("application/xml"));
assert!(auto_compressible_content_type("application/json"));
assert!(auto_compressible_content_type("application/javascript"));
assert!(!auto_compressible_content_type("image/png"));
}
#[rstest]
#[tokio::test]
async fn test_compression_policies(
#[values(
CompressionPolicy::Auto,
CompressionPolicy::None,
CompressionPolicy::ForceZstd
)]
compression_policy: CompressionPolicy,
#[values("text/plain", "application/json", "image/png")] content_type: &str,
) {
let compressor = make_compressor_with(compression_policy);
let data = b"hello, world!".to_vec();
let (compression_id, compressed) = compressor
.compress(content_type, data.clone())
.await
.unwrap();
let data_uncompressed = compressor.decompress(compression_id, compressed).unwrap();
assert_eq!(data_uncompressed, data);
}
#[tokio::test]
async fn test_skip_compressing_small_data() {
let compressor = make_compressor();
let data = b"hi!".to_vec();
let (compression_id, compressed) = compressor
.compress("text/plain", data.clone())
.await
.unwrap();
assert_eq!(compression_id, CompressionId::None);
assert_eq!(compressed, data);
}
#[tokio::test]
async fn test_compresses_longer_data() {
let compressor = make_compressor();
let data = vec![b'.'; 1024];
let (compression_id, compressed) = compressor
.compress("text/plain", data.clone())
.await
.unwrap();
assert_eq!(compression_id, CompressionId::Brotli);
assert_ne!(compressed, data);
assert!(compressed.len() < data.len());
}
}

52
src/compression_stats.rs Normal file
View File

@@ -0,0 +1,52 @@
use crate::sql_types::CompressionId;
use std::collections::HashMap;
#[derive(Default, Copy, Clone)]
pub struct CompressionStat {
pub num_entries: usize,
pub uncompressed_size: usize,
pub compressed_size: usize,
pub elapsed: std::time::Duration,
}
#[derive(Default, Clone)]
pub struct CompressionStats {
order: Vec<CompressionId>,
id_to_stat: HashMap<CompressionId, CompressionStat>,
}
impl CompressionStats {
pub fn for_id_mut(&mut self, id: CompressionId) -> &mut CompressionStat {
self.id_to_stat.entry(id).or_insert_with_key(|_| {
self.order.push(id);
self.order.sort();
Default::default()
})
}
pub fn add_entry(
&mut self,
id: CompressionId,
uncompressed_size: usize,
compressed_size: usize,
elapsed: std::time::Duration,
) -> &CompressionStat {
let stat = self.for_id_mut(id);
stat.num_entries += 1;
stat.uncompressed_size += uncompressed_size;
stat.compressed_size += compressed_size;
stat.elapsed += elapsed;
stat
}
#[cfg(test)]
pub fn len(&self) -> usize {
self.order.len()
}
pub fn iter(&self) -> impl Iterator<Item = (CompressionId, &CompressionStat)> {
self.order
.iter()
.flat_map(|id| self.id_to_stat.get(id).map(|stat| (*id, stat)))
}
}

View File

@@ -0,0 +1,20 @@
use super::compressor_trait::Compressor;
use crate::{compressible_data::CompressibleData, AsyncBoxError};
use std::io::{Read, Write};
pub struct BrotliGenericCompressor;
impl Compressor for BrotliGenericCompressor {
fn compress(&self, data: &CompressibleData) -> Result<CompressibleData, AsyncBoxError> {
let mut writer = brotli::CompressorWriter::new(Vec::new(), 4096, 7, 0);
writer.write_all(data.as_ref())?;
writer.flush()?;
Ok(CompressibleData::Vec(writer.into_inner()))
}
fn decompress(&self, data: &CompressibleData) -> Result<CompressibleData, AsyncBoxError> {
let mut reader = brotli::Decompressor::new(data.as_ref(), 4096);
let mut decompressed = Vec::new();
reader.read_to_end(&mut decompressed)?;
Ok(CompressibleData::Vec(decompressed))
}
}

View File

@@ -0,0 +1,38 @@
use crate::{compressible_data::CompressibleData, AsyncBoxError};
pub trait Compressor: Send {
fn compress(&self, data: &CompressibleData) -> Result<CompressibleData, AsyncBoxError>;
fn decompress(&self, data: &CompressibleData) -> Result<CompressibleData, AsyncBoxError>;
}
#[cfg(test)]
pub(self) mod test {
use super::*;
use crate::compressor::*;
use rstest::*;
fn random_bytes(len: usize) -> Vec<u8> {
(0..len).map(|_| rand::random::<u8>()).collect()
}
#[rstest]
#[case(BrotliGenericCompressor)]
#[case(ZstdGenericCompressor)]
fn test_compressor(
#[case] compressor: impl Compressor,
#[values(
vec![],
"foobar".as_bytes().to_vec(),
vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
vec![0; 1024],
random_bytes(10),
random_bytes(1024)
)]
data_vec: Vec<u8>,
) {
let data = CompressibleData::Vec(data_vec);
let compressed = compressor.compress(&data).unwrap();
let decompressed = compressor.decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
}

7
src/compressor/mod.rs Normal file
View File

@@ -0,0 +1,7 @@
mod brotli_generic_compressor;
mod compressor_trait;
mod zstd_generic_compressor;
pub use brotli_generic_compressor::BrotliGenericCompressor;
pub use compressor_trait::Compressor;
pub use zstd_generic_compressor::ZstdGenericCompressor;

View File

@@ -0,0 +1,20 @@
use super::compressor_trait::Compressor;
use crate::{compressible_data::CompressibleData, AsyncBoxError};
use std::io::{Read, Write};
pub struct ZstdGenericCompressor;
impl Compressor for ZstdGenericCompressor {
fn compress(&self, data: &CompressibleData) -> Result<CompressibleData, AsyncBoxError> {
let mut writer = zstd::stream::write::Encoder::new(Vec::new(), 11)?;
// writer.set_parameter(zstd::zstd_safe::CParameter::WindowLog(24))?;
writer.write_all(data.as_ref())?;
Ok(CompressibleData::Vec(writer.finish()?))
}
fn decompress(&self, data: &CompressibleData) -> Result<CompressibleData, AsyncBoxError> {
let mut reader = zstd::stream::read::Decoder::new(data.as_ref())?;
let mut decompressed = Vec::new();
reader.read_to_end(&mut decompressed)?;
Ok(CompressibleData::Vec(decompressed))
}
}

7
src/concat_lines.rs Normal file
View File

@@ -0,0 +1,7 @@
// like concat, but adds a newline between each expression
#[macro_export]
macro_rules! concat_lines {
($($e:expr),* $(,)?) => {
concat!($($e, "\n"),*)
};
}

View File

@@ -1,16 +1,21 @@
use crate::{sha256::Sha256, shard::GetResult, shards::Shards};
use crate::{
sha256::Sha256,
shard::{GetArgs, GetResult},
shards::Shards,
AsyncBoxError,
};
use axum::{
extract::Path,
http::{header, HeaderMap, HeaderValue, StatusCode},
response::IntoResponse,
Extension, Json,
};
use std::{collections::HashMap, error::Error};
use std::{collections::HashMap, sync::Arc};
pub enum GetResponse {
MissingSha256,
InvalidSha256 { message: String },
InternalError { error: Box<dyn Error> },
InternalError { error: AsyncBoxError },
NotFound,
Found { get_result: GetResult },
}
@@ -69,7 +74,7 @@ fn make_found_response(
Err(e) => return GetResponse::from(e).into_response(),
};
let created_at = match HeaderValue::from_str(&created_at.to_rfc3339()) {
let created_at = match HeaderValue::from_str(&created_at.to_string()) {
Ok(created_at) => created_at,
Err(e) => return GetResponse::from(e).into_response(),
};
@@ -98,7 +103,7 @@ fn make_found_response(
(StatusCode::OK, headers, data).into_response()
}
impl<E: Into<Box<dyn Error>>> From<E> for GetResponse {
impl<E: Into<AsyncBoxError>> From<E> for GetResponse {
fn from(error: E) -> Self {
GetResponse::InternalError {
error: error.into(),
@@ -109,7 +114,7 @@ impl<E: Into<Box<dyn Error>>> From<E> for GetResponse {
#[axum::debug_handler]
pub async fn get_handler(
Path(params): Path<HashMap<String, String>>,
Extension(shards): Extension<Shards>,
Extension(shards): Extension<Arc<Shards>>,
) -> GetResponse {
let sha256_str = match params.get("sha256") {
Some(sha256_str) => sha256_str.clone(),
@@ -128,7 +133,7 @@ pub async fn get_handler(
};
let shard = shards.shard_for(&sha256);
let get_result = match shard.get(sha256).await {
let get_result = match shard.get(GetArgs { sha256 }).await {
Ok(get_result) => get_result,
Err(e) => return e.into(),
};
@@ -141,7 +146,9 @@ pub async fn get_handler(
#[cfg(test)]
mod test {
use crate::{sha256::Sha256, shard::GetResult, shards::test::make_shards};
use crate::{
sha256::Sha256, shard::GetResult, shards::test::make_shards, sql_types::UtcDateTime,
};
use axum::{extract::Path, response::IntoResponse, Extension};
use std::collections::HashMap;
@@ -150,10 +157,10 @@ mod test {
#[tokio::test]
async fn test_get_invalid_sha256() {
let shards = Extension(make_shards().await);
let response = super::get_handler(Path(HashMap::new()), shards.clone()).await;
assert!(matches!(response, super::GetResponse::MissingSha256 { .. }));
let shards = Extension(make_shards().await);
let response = super::get_handler(
Path(HashMap::from([(String::from("sha256"), String::from(""))])),
shards.clone(),
@@ -174,8 +181,8 @@ mod test {
#[test]
fn test_get_response_found_into_response() {
let data = "hello, world!";
let sha256 = Sha256::from_bytes(data.as_bytes());
let data = "hello, world!".as_bytes().to_owned();
let sha256 = Sha256::from_bytes(&data);
let sha256_str = sha256.hex_string();
let created_at = "2022-03-04T08:12:34+00:00";
let response = GetResponse::Found {
@@ -183,9 +190,7 @@ mod test {
sha256,
content_type: "text/plain".to_string(),
stored_size: 12345,
created_at: chrono::DateTime::parse_from_rfc3339(created_at)
.unwrap()
.to_utc(),
created_at: UtcDateTime::from_string(created_at).unwrap(),
data: data.into(),
},
}

View File

@@ -1,6 +1,6 @@
use crate::shards::Shards;
use axum::{http::StatusCode, Extension, Json};
use std::sync::Arc;
use tracing::error;
#[derive(serde::Serialize)]
@@ -19,7 +19,7 @@ pub struct ShardInfo {
#[axum::debug_handler]
pub async fn info_handler(
Extension(shards): Extension<Shards>,
Extension(shards): Extension<Arc<Shards>>,
) -> Result<(StatusCode, Json<InfoResponse>), StatusCode> {
let mut shard_infos = vec![];
let mut total_db_size_bytes = 0;

View File

@@ -1,12 +1,11 @@
use crate::{
sha256::Sha256,
shard::{StoreArgs, StoreResult},
shards::Shards,
shards::ShardsArc,
};
use axum::http::StatusCode;
use axum::{body::Bytes, response::IntoResponse, Extension, Json};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use axum::http::StatusCode;
use serde::Serialize;
use tracing::error;
@@ -59,7 +58,7 @@ impl From<StoreResult> for StoreResponse {
} => StoreResponse::Created {
stored_size,
data_size,
created_at: created_at.to_rfc3339(),
created_at: created_at.to_string(),
},
StoreResult::Exists {
stored_size,
@@ -68,7 +67,7 @@ impl From<StoreResult> for StoreResponse {
} => StoreResponse::Exists {
stored_size,
data_size,
created_at: created_at.to_rfc3339(),
created_at: created_at.to_string(),
},
}
}
@@ -82,7 +81,7 @@ impl IntoResponse for StoreResponse {
#[axum::debug_handler]
pub async fn store_handler(
Extension(shards): Extension<Shards>,
Extension(shards): Extension<ShardsArc>,
TypedMultipart(request): TypedMultipart<StoreRequest>,
) -> StoreResponse {
let sha256 = Sha256::from_bytes(&request.data.contents);
@@ -118,20 +117,22 @@ pub async fn store_handler(
#[cfg(test)]
pub mod test {
use crate::{compression_manager::test::make_compressor_with, shards::test::make_shards_with};
use super::*;
use crate::{shards::test::make_shards_with_compression, UseCompression};
use crate::CompressionPolicy;
use axum::body::Bytes;
use axum_typed_multipart::FieldData;
use rstest::rstest;
async fn send_request<D: Into<Bytes>>(
compression_policy: CompressionPolicy,
sha256: Option<Sha256>,
content_type: &str,
use_compression: UseCompression,
data: D,
) -> StoreResponse {
store_handler(
Extension(make_shards_with_compression(use_compression).await),
Extension(make_shards_with(make_compressor_with(compression_policy).into_arc()).await),
TypedMultipart(StoreRequest {
sha256: sha256.map(|s| s.hex_string()),
content_type: content_type.to_string(),
@@ -146,8 +147,9 @@ pub mod test {
#[tokio::test]
async fn test_store_handler() {
let result = send_request(None, "text/plain", UseCompression::Auto, "hello, world!").await;
assert_eq!(result.status_code(), StatusCode::CREATED);
let result =
send_request(CompressionPolicy::Auto, None, "text/plain", "hello, world!").await;
assert_eq!(result.status_code(), StatusCode::CREATED, "{:?}", result);
assert!(matches!(result, StoreResponse::Created { .. }));
}
@@ -156,9 +158,9 @@ pub mod test {
let not_hello_world = Sha256::from_bytes("goodbye, planet!".as_bytes());
let hello_world = Sha256::from_bytes("hello, world!".as_bytes());
let result = send_request(
CompressionPolicy::Auto,
Some(not_hello_world),
"text/plain",
UseCompression::Auto,
"hello, world!",
)
.await;
@@ -175,9 +177,9 @@ pub mod test {
async fn test_store_handler_matching_sha256() {
let hello_world = Sha256::from_bytes("hello, world!".as_bytes());
let result = send_request(
CompressionPolicy::Auto,
Some(hello_world),
"text/plain",
UseCompression::Auto,
"hello, world!",
)
.await;
@@ -193,21 +195,21 @@ pub mod test {
}
#[rstest]
// textual should be compressed by default
#[case("text/plain", UseCompression::Auto, make_assert_lt(1024))]
#[case("text/plain", UseCompression::Zstd, make_assert_lt(1024))]
#[case("text/plain", UseCompression::None, make_assert_eq(1024))]
// images, etc should not be compressed by default
#[case("image/jpg", UseCompression::Auto, make_assert_eq(1024))]
#[case("image/jpg", UseCompression::Zstd, make_assert_lt(1024))]
#[case("image/jpg", UseCompression::None, make_assert_eq(1024))]
// textual should be compressed if 'auto'
#[case("text/plain", CompressionPolicy::Auto, make_assert_lt(1024))]
#[case("text/plain", CompressionPolicy::ForceZstd, make_assert_lt(1024))]
#[case("text/plain", CompressionPolicy::None, make_assert_eq(1024))]
// images, etc should not be compressed if 'auto'
#[case("image/jpg", CompressionPolicy::Auto, make_assert_eq(1024))]
#[case("image/jpg", CompressionPolicy::ForceZstd, make_assert_lt(1024))]
#[case("image/jpg", CompressionPolicy::None, make_assert_eq(1024))]
#[tokio::test]
async fn test_compressible_data<F: Fn(usize)>(
#[case] content_type: &str,
#[case] use_compression: UseCompression,
#[case] compression_policy: CompressionPolicy,
#[case] assert_stored_size: F,
) {
let result = send_request(None, content_type, use_compression, vec![0; 1024]).await;
let result = send_request(compression_policy, None, content_type, vec![0; 1024]).await;
assert_eq!(result.status_code(), StatusCode::CREATED);
match result {
StoreResponse::Created {

11
src/into_arc.rs Normal file
View File

@@ -0,0 +1,11 @@
use std::sync::Arc;
pub trait IntoArc {
fn into_arc(self) -> Arc<Self>;
}
impl<T> IntoArc for T {
fn into_arc(self) -> Arc<Self> {
Arc::new(self)
}
}

View File

@@ -0,0 +1,30 @@
use axum::{
extract::DefaultBodyLimit,
routing::{get, post},
Extension, Router,
};
use tokio::net::TcpListener;
use crate::{
compression_manager::CompressionManagerArc, handlers, shards::ShardsArc, AsyncBoxError,
};
pub async fn axum_server_loop(
server: TcpListener,
shards: ShardsArc,
compression_manager: CompressionManagerArc,
) -> Result<(), AsyncBoxError> {
let app = Router::new()
.route("/store", post(handlers::store_handler::store_handler))
.route("/get/:sha256", get(handlers::get_handler::get_handler))
.route("/info", get(handlers::info_handler::info_handler))
.layer(Extension(shards))
.layer(Extension(compression_manager))
// limit the max size of a request body to 1000mb
.layer(DefaultBodyLimit::max(1024 * 1024 * 1000));
axum::serve(server, app.into_make_service())
.with_graceful_shutdown(crate::shutdown_signal::shutdown_signal())
.await?;
Ok(())
}

View File

@@ -0,0 +1,77 @@
use tabled::{settings::Style, Table, Tabled};
use tokio::select;
use tracing::{debug, info};
use crate::{compression_manager::CompressionManagerArc, AsyncBoxError};
pub async fn dict_stats_printer_loop(
compressor: CompressionManagerArc,
) -> Result<(), AsyncBoxError> {
fn humanized_bytes_str(number: &usize) -> String {
humansize::format_size(*number, humansize::BINARY)
}
fn humanized_bytes_per_sec(number: &usize) -> String {
format!("{}/sec", humansize::format_size(*number, humansize::BINARY))
}
fn compression_ratio_str(row: &DictStatTableRow) -> String {
if row.uncompressed_size == 0 {
"(n/a)".to_string()
} else {
format!(
"{:.2} x",
row.compressed_size as f64 / row.uncompressed_size as f64
)
}
}
#[derive(Tabled, Default)]
struct DictStatTableRow {
name: String,
#[tabled(rename = "# entries")]
num_entries: usize,
#[tabled(
rename = "compression ratio",
display_with("compression_ratio_str", self)
)]
_ratio: (),
#[tabled(rename = "uncompressed bytes", display_with = "humanized_bytes_str")]
uncompressed_size: usize,
#[tabled(rename = "compressed bytes", display_with = "humanized_bytes_str")]
compressed_size: usize,
#[tabled(rename = "bytes/sec", display_with = "humanized_bytes_per_sec")]
bytes_per_sec: usize,
}
loop {
{
let compressor = compressor.read().await;
let stats = compressor.stats().await;
Table::new(stats.iter().map(|(id, stat)| {
let name = id.to_string();
DictStatTableRow {
name,
num_entries: stat.num_entries,
uncompressed_size: stat.uncompressed_size,
compressed_size: stat.compressed_size,
bytes_per_sec: (stat.uncompressed_size as f64 / stat.elapsed.as_secs_f64())
as usize,
..Default::default()
}
}))
.with(Style::rounded())
.to_string()
.lines()
.for_each(|line| debug!("{}", line));
}
select! {
_ = tokio::time::sleep(tokio::time::Duration::from_secs(5)) => {}
_ = crate::shutdown_signal::shutdown_signal() => {
info!("dict_stats: shutdown signal received");
return Ok(());
}
}
}
}

7
src/loops/mod.rs Normal file
View File

@@ -0,0 +1,7 @@
mod axum_server_loop;
mod dict_stats_printer_loop;
mod save_compression_stats_loop;
pub use axum_server_loop::axum_server_loop;
pub use dict_stats_printer_loop::dict_stats_printer_loop;
pub use save_compression_stats_loop::save_compression_stats_loop;

View File

@@ -0,0 +1,19 @@
use std::sync::Arc;
use tokio::select;
use tracing::info;
use crate::{manifest::Manifest, AsyncBoxError};
pub async fn save_compression_stats_loop(manifest: Arc<Manifest>) -> Result<(), AsyncBoxError> {
loop {
manifest.save_compression_stats().await?;
select! {
_ = tokio::time::sleep(tokio::time::Duration::from_secs(60)) => {}
_ = crate::shutdown_signal::shutdown_signal() => {
info!("persist_dict_stats: shutdown signal received");
return Ok(());
}
}
}
}

View File

@@ -1,22 +1,28 @@
mod app_time_formatter;
mod compressible_data;
mod compression_manager;
mod compression_stats;
mod compressor;
mod concat_lines;
mod handlers;
mod into_arc;
mod loops;
mod manifest;
mod sha256;
mod shard;
mod shards;
mod shutdown_signal;
mod sql_types;
use crate::{manifest::Manifest, shards::Shards};
use axum::{
routing::{get, post},
Extension, Router,
};
use app_time_formatter::AppTimeFormatter;
use clap::{Parser, ValueEnum};
use futures::executor::block_on;
use shard::Shard;
use std::{error::Error, path::PathBuf};
use tokio::net::TcpListener;
use std::{error::Error, path::PathBuf, sync::Arc};
use tokio::{net::TcpListener, spawn};
use tokio_rusqlite::Connection;
use tracing::info;
use tracing::{debug, info};
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
@@ -39,38 +45,62 @@ struct Args {
/// How to compress stored data
#[arg(short, long, default_value = "auto")]
compression: UseCompression,
compression: CompressionPolicy,
}
#[derive(Default, PartialEq, Debug, Copy, Clone, ValueEnum)]
pub enum UseCompression {
pub enum CompressionPolicy {
#[default]
Auto,
None,
Zstd,
ForceZstd,
ForceBrotli,
}
fn main() -> Result<(), Box<dyn Error>> {
pub type AsyncBoxError = Box<dyn Error + Send + Sync + 'static>;
pub fn into_tokio_rusqlite_err<E: Into<AsyncBoxError>>(e: E) -> tokio_rusqlite::Error {
tokio_rusqlite::Error::Other(e.into())
}
fn main() -> Result<(), AsyncBoxError> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.with_timer(AppTimeFormatter::new())
.with_target(false)
.init();
let args = Args::parse();
let db_path = PathBuf::from(&args.db_path);
let num_shards = args.shards;
if db_path.is_file() {
return Err("db_path must be a directory".into());
}
if !db_path.is_dir() {
std::fs::create_dir_all(&db_path)?;
}
// block on opening the manifest
let manifest = block_on(async {
Manifest::open(
let manifest = Arc::new(block_on(async {
let manifest = Manifest::open(
Connection::open(db_path.join("manifest.sqlite")).await?,
num_shards,
)
.await
})?;
.await?;
manifest
.compression_manager()
.write()
.await
.set_compression_policy(args.compression);
Ok::<_, AsyncBoxError>(manifest)
})?);
// max num_shards threads
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(manifest.num_shards())
.worker_threads(16)
.enable_all()
.build()?;
@@ -82,38 +112,38 @@ fn main() -> Result<(), Box<dyn Error>> {
manifest.num_shards()
);
let mut shards_vec = vec![];
let mut num_entries = vec![];
for shard_id in 0..manifest.num_shards() {
let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id));
let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?;
let shard = Shard::open(shard_id, UseCompression::Auto, shard_sqlite_conn).await?;
info!(
"shard {} has {} entries",
shard.id(),
shard.num_entries().await?
);
let shard =
Shard::open(shard_id, shard_sqlite_conn, manifest.compression_manager()).await?;
num_entries.push(shard.num_entries().await?);
shards_vec.push(shard);
}
debug!(
"loaded shards with {} total entries <- {:?}",
num_entries.iter().sum::<usize>(),
num_entries
);
let shards = Shards::new(shards_vec).ok_or("num shards must be > 0")?;
server_loop(server, shards.clone()).await?;
info!("shutting down server...");
shards.close_all().await?;
info!("server closed sqlite connections. bye!");
Ok::<_, Box<dyn Error>>(())
let compression_manager = manifest.compression_manager();
let shards = Arc::new(Shards::new(shards_vec).ok_or("num shards must be > 0")?);
let join_handles = vec![
spawn(loops::save_compression_stats_loop(manifest.clone())),
spawn(loops::dict_stats_printer_loop(
manifest.compression_manager(),
)),
spawn(loops::axum_server_loop(server, shards, compression_manager)),
];
for handle in join_handles {
handle.await??;
}
info!("saving compressor stats...");
manifest.save_compression_stats().await?;
info!("done. bye!");
Ok::<_, AsyncBoxError>(())
})?;
Ok(())
}
async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box<dyn Error>> {
let app = Router::new()
.route("/store", post(handlers::store_handler::store_handler))
.route("/get/:sha256", get(handlers::get_handler::get_handler))
.route("/info", get(handlers::info_handler::info_handler))
.layer(Extension(shards));
axum::serve(server, app.into_make_service())
.with_graceful_shutdown(crate::shutdown_signal::shutdown_signal())
.await?;
Ok(())
}

View File

@@ -1,9 +0,0 @@
use rusqlite::types::FromSql;
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
pub struct DictId(i64);
impl FromSql for DictId {
fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
Ok(DictId(value.as_i64()?))
}
}

View File

@@ -1,121 +1,113 @@
mod dict_id;
mod manifest_key;
mod zstd_dict;
use std::{collections::HashMap, error::Error, sync::Arc};
use rusqlite::params;
use tokio_rusqlite::Connection;
use crate::shards::Shards;
use self::{
dict_id::DictId,
manifest_key::{get_manifest_key, set_manifest_key, NumShards},
zstd_dict::ZstdDict,
use self::manifest_key::{get_manifest_key, set_manifest_key, NumShards};
use crate::{
compression_manager::{CompressionManager, CompressionManagerArc},
compression_stats::{CompressionStat, CompressionStats},
concat_lines, into_tokio_rusqlite_err,
sql_types::CompressionId,
AsyncBoxError,
};
pub type ZstdDictArc = Arc<ZstdDict>;
use rusqlite::params;
use std::{sync::Arc, time::Duration};
use tokio::sync::RwLock;
use tokio_rusqlite::Connection;
use tracing::{debug, info};
pub struct Manifest {
conn: Connection,
num_shards: usize,
zstd_dict_by_id: HashMap<DictId, ZstdDictArc>,
zstd_dict_by_name: HashMap<String, ZstdDictArc>,
compressor: Arc<RwLock<CompressionManager>>,
}
impl Manifest {
pub async fn open(conn: Connection, num_shards: Option<usize>) -> Result<Self, Box<dyn Error>> {
pub async fn open(conn: Connection, num_shards: Option<usize>) -> Result<Self, AsyncBoxError> {
initialize(conn, num_shards).await
}
pub fn num_shards(&self) -> usize {
self.num_shards
}
async fn train_zstd_dict_with_tag(
&mut self,
_name: &str,
_shards: Shards,
) -> Result<ZstdDictArc, Box<dyn Error>> {
// let mut queries = vec![];
// for shard in shards.iter() {
// queries.push(shard.entries_for_tag(name));
// }
todo!();
pub fn compression_manager(&self) -> CompressionManagerArc {
self.compressor.clone()
}
async fn create_zstd_dict_from_samples(
&mut self,
name: &str,
samples: Vec<&[u8]>,
) -> Result<ZstdDictArc, Box<dyn Error>> {
if self.zstd_dict_by_name.contains_key(name) {
return Err(format!("dictionary {} already exists", name).into());
}
pub async fn save_compression_stats(&self) -> Result<(), AsyncBoxError> {
let compressor_arc = self.compression_manager();
let compressor = compressor_arc.read().await;
let compression_stats = compressor.stats().await.clone();
let level = 3;
let dict_bytes = zstd::dict::from_samples(
&samples,
1024 * 1024, // 1MB max dictionary size
)
.unwrap();
let name_copy = name.to_string();
let (dict_id, zstd_dict) = self
.conn
self.conn
.call(move |conn| {
let mut stmt = conn.prepare(
"INSERT INTO dictionaries (name, level, dict)
VALUES (?, ?, ?)
RETURNING id",
)?;
let dict_id =
stmt.query_row(params![name_copy, level, dict_bytes], |row| row.get(0))?;
let zstd_dict = Arc::new(ZstdDict::create(dict_id, name_copy, level, dict_bytes));
Ok((dict_id, zstd_dict))
save_compression_stats(conn, compression_stats).map_err(into_tokio_rusqlite_err)?;
Ok(())
})
.await?;
self.zstd_dict_by_id.insert(dict_id, zstd_dict.clone());
self.zstd_dict_by_name
.insert(name.to_string(), zstd_dict.clone());
Ok(zstd_dict)
}
fn get_dictionary_by_id(&self, id: DictId) -> Option<&ZstdDict> {
self.zstd_dict_by_id.get(&id).map(|d| &**d)
}
fn get_dictionary_by_name(&self, name: &str) -> Option<&ZstdDict> {
self.zstd_dict_by_name.get(name).map(|d| &**d)
Ok(())
}
}
async fn initialize(
conn: Connection,
num_shards: Option<usize>,
) -> Result<Manifest, Box<dyn Error>> {
async fn migrate_manifest(conn: &Connection) -> Result<(), tokio_rusqlite::Error> {
conn.call(|conn| {
conn.execute(
"CREATE TABLE IF NOT EXISTS manifest (key TEXT PRIMARY KEY, value)",
[],
)?;
conn.execute(
concat_lines!(
"CREATE TABLE IF NOT EXISTS zstd_dictionaries (",
" zstd_dict_id INTEGER PRIMARY KEY AUTOINCREMENT,",
" name TEXT NOT NULL,",
" encoder_json TEXT NOT NULL",
")"
),
[],
)?;
conn.execute(
concat_lines!(
"CREATE TABLE IF NOT EXISTS compression_stats (",
" dict_id TEXT PRIMARY KEY,",
" num_entries INTEGER NOT NULL DEFAULT 0,",
" uncompressed_size INTEGER NOT NULL DEFAULT 0,",
" compressed_size INTEGER NOT NULL DEFAULT 0,",
" elapsed INTEGER NOT NULL DEFAULT 0",
")"
),
[],
)?;
// insert the default compression stats
let default_compression_ids = [
CompressionId::None,
CompressionId::Zstd,
CompressionId::Brotli,
];
for compression_id in default_compression_ids.iter() {
conn.execute(
concat_lines!(
"INSERT OR IGNORE INTO compression_stats (dict_id)",
"VALUES (?)"
),
params![compression_id],
)?;
}
Ok(())
})
.await
}
async fn load_and_store_num_shards(
conn: &Connection,
requested_num_shards: Option<usize>,
) -> Result<usize, AsyncBoxError> {
let stored_num_shards: Option<usize> = conn
.call(|conn| {
conn.execute(
"CREATE TABLE IF NOT EXISTS manifest (key TEXT PRIMARY KEY, value)",
[],
)?;
conn.execute(
"CREATE TABLE IF NOT EXISTS dictionaries (
id INTEGER PRIMARY KEY AUTOINCREMENT,
level INTEGER NOT NULL,
name TEXT NOT NULL,
dict BLOB NOT NULL
)",
[],
)?;
Ok(get_manifest_key(conn, NumShards)?)
})
.call(|conn| Ok(get_manifest_key(conn, NumShards)?))
.await?;
let num_shards = match (stored_num_shards, num_shards) {
Ok(match (stored_num_shards, requested_num_shards) {
(Some(stored_num_shards), Some(num_shards)) => {
if stored_num_shards == num_shards {
num_shards
@@ -142,38 +134,94 @@ async fn initialize(
// existing database, use loaded num_shards
num_shards
}
};
})
}
let rows = conn
async fn load_compression_stats(
conn: &Connection,
compressor: &mut CompressionManager,
) -> Result<(), AsyncBoxError> {
type CompressionStatRow = (CompressionId, usize, usize, usize, u64);
let rows: Vec<CompressionStatRow> = conn
.call(|conn| {
let mut stmt = conn.prepare("SELECT id, name, level, dict FROM dictionaries")?;
let mut rows = vec![];
for r in stmt.query_map([], |row| {
let id = row.get(0)?;
let name: String = row.get(1)?;
let level: i32 = row.get(2)?;
let dict: Vec<u8> = row.get(3)?;
Ok((id, name, level, dict))
})? {
rows.push(r?);
}
let rows = conn
.prepare(concat_lines!(
"SELECT",
" dict_id, num_entries, uncompressed_size, compressed_size, elapsed",
"FROM compression_stats",
))?
.query_map([], |row| CompressionStatRow::try_from(row))?
.collect::<Result<_, _>>()?;
Ok(rows)
})
.await?;
let mut zstd_dicts_by_id = HashMap::new();
let mut zstd_dicts_by_name = HashMap::new();
for (id, name, level, dict_bytes) in rows {
let zstd_dict = Arc::new(ZstdDict::create(id, name.clone(), level, dict_bytes));
zstd_dicts_by_id.insert(id, zstd_dict.clone());
zstd_dicts_by_name.insert(name, zstd_dict);
debug!("loaded {} compression stats from manifest", rows.len());
for (dict_id, num_entries, uncompressed_size, compressed_size, elapsed) in rows {
compressor
.set_stat(
dict_id,
CompressionStat {
num_entries,
uncompressed_size,
compressed_size,
elapsed: Duration::from_millis(elapsed),
},
)
.await;
}
Ok(())
}
fn save_compression_stats(
conn: &rusqlite::Connection,
compression_stats: CompressionStats,
) -> Result<(), AsyncBoxError> {
let mut num_stats = 0;
for (id, stat) in compression_stats.iter() {
num_stats += 1;
conn.execute(
concat_lines!(
"INSERT INTO compression_stats",
" (dict_id, num_entries, uncompressed_size, compressed_size, elapsed)",
"VALUES (?, ?, ?, ?, ?)",
"ON CONFLICT (dict_id) DO UPDATE SET",
" num_entries = excluded.num_entries,",
" uncompressed_size = excluded.uncompressed_size,",
" compressed_size = excluded.compressed_size,",
" elapsed = excluded.elapsed"
),
params![
id,
stat.num_entries,
stat.uncompressed_size,
stat.compressed_size,
stat.elapsed.as_millis() as u64,
],
)?;
}
info!("saved {} compressor stats", num_stats);
Ok(())
}
async fn load_compressor(conn: &Connection) -> Result<CompressionManagerArc, AsyncBoxError> {
let mut compressor = CompressionManager::default();
load_compression_stats(conn, &mut compressor).await?;
Ok(compressor.into_arc())
}
async fn initialize(
conn: Connection,
requested_num_shards: Option<usize>,
) -> Result<Manifest, AsyncBoxError> {
migrate_manifest(&conn).await?;
let num_shards = load_and_store_num_shards(&conn, requested_num_shards).await?;
let compressor = load_compressor(&conn).await?;
Ok(Manifest {
conn,
num_shards,
zstd_dict_by_id: zstd_dicts_by_id,
zstd_dict_by_name: zstd_dicts_by_name,
compressor,
})
}
@@ -182,30 +230,13 @@ mod tests {
use super::*;
#[tokio::test]
async fn test_manifest() {
async fn test_manifest_compression_stats_loading() {
let conn = Connection::open_in_memory().await.unwrap();
let mut manifest = initialize(conn, Some(3)).await.unwrap();
let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100];
let zstd_dict = manifest
.create_zstd_dict_from_samples("test", samples)
.await
.unwrap();
// test that indexes are created correctly
assert_eq!(
zstd_dict.as_ref(),
manifest.get_dictionary_by_id(zstd_dict.id()).unwrap()
);
assert_eq!(
zstd_dict.as_ref(),
manifest.get_dictionary_by_name(zstd_dict.name()).unwrap()
);
let data = b"hello world, this is a test of a sort of long string";
let compressed = zstd_dict.compress(data).unwrap();
let decompressed = zstd_dict.decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
assert!(data.len() > compressed.len());
let manifest = initialize(conn, Some(4)).await.unwrap();
let compressor = manifest.compression_manager();
let compressor = compressor.read().await;
let stats = compressor.stats().await;
assert_eq!(stats.len(), 3);
assert_eq!(stats.iter().count(), 3);
}
}

View File

@@ -1,94 +0,0 @@
use super::dict_id::DictId;
use ouroboros::self_referencing;
use std::{error::Error, io};
use zstd::dict::{DecoderDictionary, EncoderDictionary};
#[self_referencing]
pub struct ZstdDict {
id: DictId,
name: String,
level: i32,
dict_bytes: Vec<u8>,
#[borrows(dict_bytes)]
#[not_covariant]
encoder_dict: EncoderDictionary<'this>,
#[borrows(dict_bytes)]
#[not_covariant]
decoder_dict: DecoderDictionary<'this>,
}
impl PartialEq for ZstdDict {
fn eq(&self, other: &Self) -> bool {
self.id() == other.id()
&& self.name() == other.name()
&& self.level() == other.level()
&& self.dict_bytes() == other.dict_bytes()
}
}
impl std::fmt::Debug for ZstdDict {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ZstdDict")
.field("id", &self.id())
.field("name", &self.name())
.field("level", &self.level())
.field("dict_bytes.len", &self.dict_bytes().len())
.finish()
}
}
impl ZstdDict {
pub fn create(id: DictId, name: String, level: i32, dict_bytes: Vec<u8>) -> Self {
ZstdDictBuilder {
id,
name,
level,
dict_bytes,
encoder_dict_builder: |dict_bytes| EncoderDictionary::new(dict_bytes, level),
decoder_dict_builder: |dict_bytes| DecoderDictionary::new(dict_bytes),
}
.build()
}
pub fn id(&self) -> DictId {
self.with_id(|id| *id)
}
pub fn name(&self) -> &str {
self.borrow_name()
}
pub fn level(&self) -> i32 {
*self.borrow_level()
}
pub fn dict_bytes(&self) -> &[u8] {
self.borrow_dict_bytes()
}
pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
let mut wrapper = io::Cursor::new(data);
let mut out_buffer = Vec::with_capacity(data.len());
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
self.with_encoder_dict(|encoder_dict| {
let mut encoder =
zstd::stream::Encoder::with_prepared_dictionary(&mut output_wrapper, encoder_dict)?;
io::copy(&mut wrapper, &mut encoder)?;
encoder.finish()
})?;
Ok(out_buffer)
}
pub fn decompress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
let mut wrapper = io::Cursor::new(data);
let mut out_buffer = Vec::with_capacity(data.len());
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
self.with_decoder_dict(|decoder_dict| {
let mut decoder =
zstd::stream::Decoder::with_prepared_dictionary(&mut wrapper, decoder_dict)?;
io::copy(&mut decoder, &mut output_wrapper)
})?;
Ok(out_buffer)
}
}

View File

@@ -1,10 +1,13 @@
use rusqlite::{
types::{FromSql, ToSqlOutput},
ToSql,
};
use sha2::Digest;
use std::{
error::Error,
fmt::{Display, LowerHex},
};
use sha2::Digest;
#[derive(Debug)]
struct Sha256Error {
message: String,
@@ -17,7 +20,7 @@ impl Display for Sha256Error {
}
impl Error for Sha256Error {}
#[derive(Clone, Copy, PartialEq, Eq, Default)]
#[derive(Clone, Copy, PartialEq, Eq, Default, Hash)]
pub struct Sha256([u8; 32]);
impl Sha256 {
pub fn from_hex_string(hex: &str) -> Result<Self, Box<dyn Error>> {
@@ -47,6 +50,20 @@ impl Sha256 {
}
}
impl ToSql for Sha256 {
fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput> {
Ok(ToSqlOutput::Borrowed(rusqlite::types::ValueRef::Blob(
&self.0,
)))
}
}
impl FromSql for Sha256 {
fn column_result(value: rusqlite::types::ValueRef) -> rusqlite::types::FromSqlResult<Self> {
let bytes = <[u8; 32]>::column_result(value)?;
Ok(Self(bytes))
}
}
impl LowerHex for Sha256 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for byte in self.0.iter() {

View File

@@ -1,73 +1,55 @@
use std::io::Read;
use super::*;
use crate::{
compressible_data::CompressibleData,
sql_types::{CompressionId, UtcDateTime},
AsyncBoxError,
};
pub struct GetArgs {
pub sha256: Sha256,
}
pub struct GetResult {
pub sha256: Sha256,
pub content_type: String,
pub stored_size: usize,
pub created_at: UtcDateTime,
pub data: Vec<u8>,
pub data: CompressibleData,
}
impl Shard {
pub async fn get(&self, sha256: Sha256) -> Result<Option<GetResult>, Box<dyn Error>> {
self.conn
.call(move |conn| get_impl(conn, sha256))
.await
.map_err(|e| {
error!("get failed: {}", e);
e.into()
})
pub async fn get(&self, args: GetArgs) -> Result<Option<GetResult>, AsyncBoxError> {
let maybe_row = self
.conn
.call(move |conn| Ok(get_compressed_row(conn, &args.sha256)?))
.await?;
if let Some((content_type, stored_size, created_at, dict_id, data)) = maybe_row {
let compressor = self.compression_manager.read().await;
let data = compressor.decompress(dict_id, data)?;
Ok(Some(GetResult {
sha256: args.sha256,
content_type,
stored_size,
created_at,
data,
}))
} else {
Ok(None)
}
}
}
fn get_impl(
type CompressedRowResult = (String, usize, UtcDateTime, CompressionId, Vec<u8>);
fn get_compressed_row(
conn: &mut rusqlite::Connection,
sha256: Sha256,
) -> Result<Option<GetResult>, tokio_rusqlite::Error> {
let maybe_row = conn
.query_row(
"SELECT content_type, compressed_size, created_at, compression, data
sha256: &Sha256,
) -> Result<Option<CompressedRowResult>, rusqlite::Error> {
conn.query_row(
"SELECT content_type, compressed_size, created_at, dict_id, data
FROM entries
WHERE sha256 = ?",
params![sha256.hex_string()],
|row| {
let content_type = row.get(0)?;
let stored_size = row.get(1)?;
let created_at = parse_created_at_str(row.get(2)?)?;
let compression = row.get(3)?;
let data: Vec<u8> = row.get(4)?;
Ok((content_type, stored_size, created_at, compression, data))
},
)
.optional()
.map_err(into_tokio_rusqlite_err)?;
let row = match maybe_row {
Some(row) => row,
None => return Ok(None),
};
let (content_type, stored_size, created_at, compression, data) = row;
let data = match compression {
Compression::None => data,
Compression::Zstd => {
let mut decoder =
zstd::Decoder::new(data.as_slice()).map_err(into_tokio_rusqlite_err)?;
let mut decompressed = vec![];
decoder
.read_to_end(&mut decompressed)
.map_err(into_tokio_rusqlite_err)?;
decompressed
}
};
Ok(Some(GetResult {
sha256,
content_type,
stored_size,
created_at,
data,
}))
params![sha256],
|row| CompressedRowResult::try_from(row),
)
.optional()
}

View File

@@ -1,7 +1,8 @@
use super::*;
use crate::{concat_lines, AsyncBoxError};
impl Shard {
pub(super) async fn migrate(&self) -> Result<(), Box<dyn Error>> {
pub(super) async fn migrate(&self) -> Result<(), AsyncBoxError> {
let shard_id = self.id();
// create tables, indexes, etc
self.conn
@@ -9,12 +10,7 @@ impl Shard {
ensure_schema_versions_table(conn)?;
let schema_rows = load_schema_rows(conn)?;
if let Some((version, date_time)) = schema_rows.first() {
debug!(
"shard {}: latest schema version: {} @ {}",
shard_id, version, date_time
);
if let Some((version, _)) = schema_rows.first() {
if *version == 1 {
// no-op
} else {
@@ -35,17 +31,23 @@ impl Shard {
fn ensure_schema_versions_table(conn: &rusqlite::Connection) -> Result<usize, rusqlite::Error> {
conn.execute(
"CREATE TABLE IF NOT EXISTS schema_version (
version INTEGER PRIMARY KEY,
created_at TEXT NOT NULL
)",
concat_lines!(
"CREATE TABLE IF NOT EXISTS schema_version (",
" version INTEGER PRIMARY KEY,",
" created_at TEXT NOT NULL",
")"
),
[],
)
}
fn load_schema_rows(conn: &rusqlite::Connection) -> Result<Vec<(i64, String)>, rusqlite::Error> {
let mut stmt = conn
.prepare("SELECT version, created_at FROM schema_version ORDER BY version DESC LIMIT 1")?;
let mut stmt = conn.prepare(concat_lines!(
"SELECT version, created_at",
"FROM schema_version",
"ORDER BY version",
"DESC LIMIT 1"
))?;
let rows = stmt.query_map([], |row| {
let version = row.get(0)?;
let created_at = row.get(1)?;
@@ -57,15 +59,45 @@ fn load_schema_rows(conn: &rusqlite::Connection) -> Result<Vec<(i64, String)>, r
fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
debug!("migrating to version 1");
conn.execute(
"CREATE TABLE IF NOT EXISTS entries (
sha256 BLOB PRIMARY KEY,
content_type TEXT NOT NULL,
compression INTEGER NOT NULL,
uncompressed_size INTEGER NOT NULL,
compressed_size INTEGER NOT NULL,
data BLOB NOT NULL,
created_at TEXT NOT NULL
)",
concat_lines!(
"CREATE TABLE IF NOT EXISTS entries (",
" id INTEGER PRIMARY KEY AUTOINCREMENT,",
" sha256 BLOB NOT NULL,",
" content_type TEXT NOT NULL,",
" dict_id TEXT NOT NULL,",
" uncompressed_size INTEGER NOT NULL,",
" compressed_size INTEGER NOT NULL,",
" data BLOB NOT NULL,",
" created_at TEXT NOT NULL",
")"
),
[],
)?;
conn.execute(
concat_lines!(
"CREATE INDEX IF NOT EXISTS entries_sha256_idx",
"ON entries (sha256)"
),
[],
)?;
conn.execute(
concat_lines!(
"CREATE TABLE IF NOT EXISTS compression_hints (",
" name TEXT NOT NULL,",
" ordering INTEGER NOT NULL,",
" entry_id INTEGER NOT NULL",
")"
),
[],
)?;
conn.execute(
concat_lines!(
"CREATE UNIQUE INDEX IF NOT EXISTS compression_hints_name_idx",
"ON compression_hints (name, ordering)",
),
[],
)?;

View File

@@ -1,4 +1,10 @@
use super::*;
use crate::{
compressible_data::CompressibleData,
concat_lines, into_tokio_rusqlite_err,
sql_types::{CompressionId, UtcDateTime},
AsyncBoxError,
};
#[derive(PartialEq, Debug)]
pub enum StoreResult {
@@ -22,86 +28,96 @@ pub struct StoreArgs {
}
impl Shard {
pub async fn store(&self, store_args: StoreArgs) -> Result<StoreResult, Box<dyn Error>> {
let use_compression = self.use_compression;
pub async fn store(
&self,
StoreArgs {
sha256,
data,
content_type,
}: StoreArgs,
) -> Result<StoreResult, AsyncBoxError> {
let existing_entry = self
.conn
.call(move |conn| find_with_sha256(conn, &sha256).map_err(into_tokio_rusqlite_err))
.await?;
if let Some(entry) = existing_entry {
return Ok(entry);
}
let uncompressed_size = data.len();
let (dict_id, data) = {
let compression_manager = self.compression_manager.read().await;
compression_manager.compress(&content_type, data).await?
};
self.conn
.call(move |conn| store(conn, use_compression, store_args))
.await
.map_err(|e| {
error!("store failed: {}", e);
e.into()
.call(move |conn| {
insert(
conn,
&sha256,
content_type,
dict_id,
uncompressed_size,
data,
)
.map_err(into_tokio_rusqlite_err)
})
.await
.map_err(|e| e.into())
}
}
fn store(
fn find_with_sha256(
conn: &mut rusqlite::Connection,
use_compression: UseCompression,
StoreArgs {
sha256,
content_type,
data,
}: StoreArgs,
) -> Result<StoreResult, tokio_rusqlite::Error> {
let sha256 = sha256.hex_string();
sha256: &Sha256,
) -> Result<Option<StoreResult>, rusqlite::Error> {
conn.query_row(
concat_lines!(
"SELECT uncompressed_size, compressed_size, created_at",
"FROM entries",
"WHERE sha256 = ?"
),
params![sha256],
|row| {
Ok(StoreResult::Exists {
stored_size: row.get(0)?,
data_size: row.get(1)?,
created_at: row.get(2)?,
})
},
)
.optional()
}
// check for existing entry
let maybe_existing: Option<StoreResult> = conn
.query_row(
"SELECT uncompressed_size, compressed_size, created_at FROM entries WHERE sha256 = ?",
params![sha256],
|row| {
Ok(StoreResult::Exists {
stored_size: row.get(0)?,
data_size: row.get(1)?,
created_at: parse_created_at_str(row.get(2)?)?,
})
},
)
.optional()?;
if let Some(existing) = maybe_existing {
return Ok(existing);
}
let created_at = chrono::Utc::now();
let uncompressed_size = data.len();
let tmp_data_holder;
let use_compression = match use_compression {
UseCompression::None => false,
UseCompression::Auto => auto_compressible_content_type(&content_type),
UseCompression::Zstd => true,
};
let (compression, data) = if use_compression {
tmp_data_holder = zstd::encode_all(&data[..], 0).map_err(into_tokio_rusqlite_err)?;
if tmp_data_holder.len() < data.len() {
(Compression::Zstd, &tmp_data_holder[..])
} else {
(Compression::None, &data[..])
}
} else {
(Compression::None, &data[..])
};
fn insert(
conn: &mut rusqlite::Connection,
sha256: &Sha256,
content_type: String,
dict_id: CompressionId,
uncompressed_size: usize,
data: CompressibleData,
) -> Result<StoreResult, rusqlite::Error> {
let created_at = UtcDateTime::now();
let compressed_size = data.len();
conn.execute(
"INSERT INTO entries
(sha256, content_type, compression, uncompressed_size, compressed_size, data, created_at)
VALUES
(?, ?, ?, ?, ?, ?, ?)
",
params![
sha256,
content_type,
compression,
uncompressed_size,
compressed_size,
data,
created_at.to_rfc3339(),
],
)?;
concat_lines!(
"INSERT INTO entries",
" (sha256, content_type, dict_id, uncompressed_size, compressed_size, data, created_at)",
"VALUES (?, ?, ?, ?, ?, ?, ?)",
),
params![
sha256,
content_type,
dict_id,
uncompressed_size,
compressed_size,
data.as_ref(),
created_at,
],
)?;
Ok(StoreResult::Created {
stored_size: compressed_size,
@@ -109,14 +125,3 @@ fn store(
created_at,
})
}
fn auto_compressible_content_type(content_type: &str) -> bool {
[
"text/",
"application/xml",
"application/json",
"application/javascript",
]
.iter()
.any(|ct| content_type.starts_with(ct))
}

View File

@@ -1,58 +1,19 @@
mod fn_get;
mod fn_migrate;
mod fn_store;
mod shard;
pub mod shard_error;
mod shard_struct;
pub use fn_get::GetResult;
pub use fn_get::{GetArgs, GetResult};
pub use fn_store::{StoreArgs, StoreResult};
pub use shard::Shard;
pub use shard_struct::Shard;
#[cfg(test)]
pub mod test {
pub use super::shard::test::*;
pub use super::shard_struct::test::*;
}
use crate::{sha256::Sha256, shard::shard_error::ShardError, UseCompression};
use crate::{sha256::Sha256, shard::shard_error::ShardError};
use axum::body::Bytes;
use rusqlite::{params, types::FromSql, OptionalExtension, ToSql};
use std::error::Error;
use rusqlite::{params, types::FromSql, OptionalExtension};
use tokio_rusqlite::Connection;
use tracing::{debug, error};
pub type UtcDateTime = chrono::DateTime<chrono::Utc>;
#[derive(Debug, PartialEq, Eq)]
enum Compression {
None,
Zstd,
}
impl ToSql for Compression {
fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
match self {
Compression::None => 0.to_sql(),
Compression::Zstd => 1.to_sql(),
}
}
}
impl FromSql for Compression {
fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
match value.as_i64()? {
0 => Ok(Compression::None),
1 => Ok(Compression::Zstd),
_ => Err(rusqlite::types::FromSqlError::InvalidType),
}
}
}
fn parse_created_at_str(created_at_str: String) -> Result<UtcDateTime, rusqlite::Error> {
let parsed = chrono::DateTime::parse_from_rfc3339(&created_at_str)
.map_err(|e| rusqlite::Error::ToSqlConversionFailure(e.into()))?;
Ok(parsed.to_utc())
}
fn into_tokio_rusqlite_err<E: Into<Box<dyn Error + Send + Sync + 'static>>>(
e: E,
) -> tokio_rusqlite::Error {
tokio_rusqlite::Error::Other(e.into())
}
use tracing::debug;

View File

@@ -1,36 +1,33 @@
use super::*;
use crate::{compression_manager::CompressionManagerArc, AsyncBoxError};
#[derive(Clone)]
pub struct Shard {
pub(super) id: usize,
pub(super) conn: Connection,
pub(super) use_compression: UseCompression,
pub(super) compression_manager: CompressionManagerArc,
}
impl Shard {
pub async fn open(
id: usize,
use_compression: UseCompression,
conn: Connection,
) -> Result<Self, Box<dyn Error>> {
compression_manager: CompressionManagerArc,
) -> Result<Self, AsyncBoxError> {
let shard = Self {
id,
use_compression,
conn,
compression_manager,
};
shard.migrate().await?;
Ok(shard)
}
pub async fn close(self) -> Result<(), Box<dyn Error>> {
self.conn.close().await.map_err(|e| e.into())
}
pub fn id(&self) -> usize {
self.id
}
pub async fn db_size_bytes(&self) -> Result<usize, Box<dyn Error>> {
pub async fn db_size_bytes(&self) -> Result<usize, AsyncBoxError> {
self.query_single_row(
"SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()",
)
@@ -40,7 +37,7 @@ impl Shard {
async fn query_single_row<T: FromSql + Send + 'static>(
&self,
query: &'static str,
) -> Result<T, Box<dyn Error>> {
) -> Result<T, AsyncBoxError> {
self.conn
.call(move |conn| {
let value: T = conn.query_row(query, [], |row| row.get(0))?;
@@ -50,7 +47,7 @@ impl Shard {
.map_err(|e| e.into())
}
pub async fn num_entries(&self) -> Result<usize, Box<dyn Error>> {
pub async fn num_entries(&self) -> Result<usize, AsyncBoxError> {
get_num_entries(&self.conn).await.map_err(|e| e.into())
}
}
@@ -65,18 +62,23 @@ async fn get_num_entries(conn: &Connection) -> Result<usize, tokio_rusqlite::Err
#[cfg(test)]
pub mod test {
use crate::compression_manager::test::make_compressor_with;
use crate::compression_manager::CompressionManagerArc;
use crate::{
compression_manager::test::make_compressor,
sha256::Sha256,
shard::{GetArgs, StoreArgs, StoreResult},
CompressionPolicy,
};
use rstest::rstest;
use super::{StoreResult, UseCompression};
use crate::{sha256::Sha256, shard::StoreArgs};
pub async fn make_shard_with_compression(use_compression: UseCompression) -> super::Shard {
pub async fn make_shard_with(compressor: CompressionManagerArc) -> super::Shard {
let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap();
super::Shard::open(0, use_compression, conn).await.unwrap()
super::Shard::open(0, conn, compressor).await.unwrap()
}
pub async fn make_shard() -> super::Shard {
make_shard_with_compression(UseCompression::Auto).await
make_shard_with(make_compressor().into_arc()).await
}
#[tokio::test]
@@ -97,7 +99,7 @@ pub mod test {
async fn test_not_found_get() {
let shard = make_shard().await;
let sha256 = Sha256::from_bytes("hello, world!".as_bytes());
let get_result = shard.get(sha256).await.unwrap();
let get_result = shard.get(GetArgs { sha256 }).await.unwrap();
assert!(get_result.is_none());
}
@@ -111,16 +113,17 @@ pub mod test {
sha256,
content_type: "text/plain".to_string(),
data: data.into(),
..Default::default()
})
.await
.unwrap();
match store_result {
StoreResult::Created {
stored_size,
stored_size: _,
data_size,
created_at,
} => {
assert_eq!(stored_size, data.len());
// assert_eq!(stored_size, data.len());
assert_eq!(data_size, data.len());
assert!(created_at > chrono::Utc::now() - chrono::Duration::seconds(1));
}
@@ -128,7 +131,7 @@ pub mod test {
}
assert_eq!(shard.num_entries().await.unwrap(), 1);
let get_result = shard.get(sha256).await.unwrap().unwrap();
let get_result = shard.get(GetArgs { sha256 }).await.unwrap().unwrap();
assert_eq!(get_result.content_type, "text/plain");
assert_eq!(get_result.data, data);
assert_eq!(get_result.stored_size, data.len());
@@ -145,6 +148,7 @@ pub mod test {
sha256,
content_type: "text/plain".to_string(),
data: data.into(),
..Default::default()
})
.await
.unwrap();
@@ -168,6 +172,7 @@ pub mod test {
sha256,
content_type: "text/plain".to_string(),
data: data.into(),
..Default::default()
})
.await
.unwrap();
@@ -185,12 +190,16 @@ pub mod test {
#[rstest]
#[tokio::test]
async fn test_compression_store_get(
#[values(UseCompression::Auto, UseCompression::None, UseCompression::Zstd)]
use_compression: UseCompression,
#[values(
CompressionPolicy::Auto,
CompressionPolicy::None,
CompressionPolicy::ForceZstd
)]
compression_policy: CompressionPolicy,
#[values(true, false)] incompressible_data: bool,
#[values("text/string", "image/jpg", "application/octet-stream")] content_type: String,
) {
let shard = make_shard_with_compression(use_compression).await;
let shard = make_shard_with(make_compressor_with(compression_policy).into_arc()).await;
let mut data = vec![b'.'; 1024];
if incompressible_data {
for byte in data.iter_mut() {
@@ -204,12 +213,13 @@ pub mod test {
sha256,
content_type: content_type.clone(),
data: data.clone().into(),
..Default::default()
})
.await
.unwrap();
assert!(matches!(store_result, StoreResult::Created { .. }));
let get_result = shard.get(sha256).await.unwrap().unwrap();
let get_result = shard.get(GetArgs { sha256 }).await.unwrap().unwrap();
assert_eq!(get_result.content_type, content_type);
assert_eq!(get_result.data, data);
}

View File

@@ -1,14 +1,16 @@
use crate::{sha256::Sha256, shard::Shard};
use std::error::Error;
use std::sync::Arc;
pub type ShardsArc = Arc<Shards>;
#[derive(Clone)]
pub struct Shards(Vec<Shard>);
pub struct Shards(Vec<Arc<Shard>>);
impl Shards {
pub fn new(shards: Vec<Shard>) -> Option<Self> {
if shards.is_empty() {
return None;
}
Some(Self(shards))
Some(Self(shards.into_iter().map(Arc::new).collect()))
}
pub fn shard_for(&self, sha256: &Sha256) -> &Shard {
@@ -16,15 +18,8 @@ impl Shards {
&self.0[shard_id]
}
pub async fn close_all(self) -> Result<(), Box<dyn Error>> {
for shard in self.0 {
shard.close().await?;
}
Ok(())
}
pub fn iter(&self) -> std::slice::Iter<'_, Shard> {
self.0.iter()
pub fn iter(&self) -> impl Iterator<Item = &Shard> {
self.0.iter().map(|shard| shard.as_ref())
}
pub fn len(&self) -> usize {
@@ -34,15 +29,19 @@ impl Shards {
#[cfg(test)]
pub mod test {
use crate::{shard::test::make_shard_with_compression, UseCompression};
use std::sync::Arc;
use super::Shards;
use crate::{
compression_manager::{test::make_compressor, CompressionManagerArc},
shard::test::make_shard_with,
};
pub async fn make_shards_with_compression(use_compression: UseCompression) -> Shards {
Shards::new(vec![make_shard_with_compression(use_compression).await]).unwrap()
use super::{Shards, ShardsArc};
pub async fn make_shards() -> ShardsArc {
make_shards_with(make_compressor().into_arc()).await
}
pub async fn make_shards() -> Shards {
make_shards_with_compression(UseCompression::Auto).await
pub async fn make_shards_with(compressor: CompressionManagerArc) -> ShardsArc {
Arc::new(Shards::new(vec![make_shard_with(compressor.clone()).await]).unwrap())
}
}

View File

@@ -0,0 +1,85 @@
use crate::AsyncBoxError;
use rusqlite::{
types::{FromSql, FromSqlError, FromSqlResult, ToSqlOutput, ValueRef},
ToSql,
};
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, PartialOrd, Ord)]
pub enum CompressionId {
None,
Zstd,
Brotli,
}
const NONE_PREFIX: &str = "none";
const ZSTD_PREFIX: &str = "zstd";
const BROTLI_PREFIX: &str = "brotli";
impl CompressionId {
pub fn prefix(&self) -> &'static str {
match self {
CompressionId::None => NONE_PREFIX,
CompressionId::Zstd => ZSTD_PREFIX,
CompressionId::Brotli => BROTLI_PREFIX,
}
}
fn from_str(id_as_str: &str) -> Result<Self, AsyncBoxError> {
match id_as_str {
NONE_PREFIX => Ok(CompressionId::None),
ZSTD_PREFIX => Ok(CompressionId::Zstd),
BROTLI_PREFIX => Ok(CompressionId::Brotli),
_ => Err(format!("invalid DictId: {}", id_as_str).into()),
}
}
}
impl ToString for CompressionId {
fn to_string(&self) -> String {
match self {
CompressionId::None => NONE_PREFIX.to_string(),
CompressionId::Zstd => ZSTD_PREFIX.to_string(),
CompressionId::Brotli => BROTLI_PREFIX.to_string(),
}
}
}
impl ToSql for CompressionId {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
self.prefix().to_sql()
}
}
impl FromSql for CompressionId {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
let s = value.as_str()?;
CompressionId::from_str(s).map_err(FromSqlError::Other)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest(
case(CompressionId::None, "none"),
case(CompressionId::Zstd, "zstd"),
case(CompressionId::Brotli, "brotli")
)]
#[test]
fn test_dict_id(#[case] id: CompressionId, #[case] id_str: &str) {
use rusqlite::types::Value;
let sql_to_id = CompressionId::column_result(ValueRef::Text(id_str.as_bytes()));
assert_eq!(id, sql_to_id.unwrap());
let id_to_sql = match id.to_sql() {
Ok(ToSqlOutput::Borrowed(ValueRef::Text(text))) => text.to_owned(),
Ok(ToSqlOutput::Owned(Value::Text(text))) => text.as_bytes().to_owned(),
_ => panic!("unexpected ToSqlOutput: {:?}", id.to_sql()),
};
let id_to_sql = std::str::from_utf8(&id_to_sql).unwrap();
assert_eq!(id_str, id_to_sql);
}
}

24
src/sql_types/entry_id.rs Normal file
View File

@@ -0,0 +1,24 @@
use rusqlite::{
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
ToSql,
};
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
pub struct EntryId(pub i64);
impl From<i64> for EntryId {
fn from(id: i64) -> Self {
Self(id)
}
}
impl FromSql for EntryId {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
Ok(value.as_i64()?.into())
}
}
impl ToSql for EntryId {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
self.0.to_sql()
}
}

5
src/sql_types/mod.rs Normal file
View File

@@ -0,0 +1,5 @@
mod compression_id;
mod utc_date_time;
pub use compression_id::CompressionId;
pub use utc_date_time::UtcDateTime;

View File

@@ -0,0 +1,51 @@
use chrono::DateTime;
use rusqlite::{
types::{FromSql, FromSqlError, ToSqlOutput, ValueRef},
Result, ToSql,
};
#[derive(PartialEq, Debug, PartialOrd)]
pub struct UtcDateTime(DateTime<chrono::Utc>);
impl UtcDateTime {
pub fn now() -> Self {
Self(chrono::Utc::now())
}
#[cfg(test)]
pub fn from_string(s: &str) -> Result<Self, chrono::ParseError> {
Ok(Self(DateTime::parse_from_rfc3339(s)?.to_utc()))
}
}
impl ToString for UtcDateTime {
fn to_string(&self) -> String {
self.0.to_rfc3339()
}
}
impl PartialEq<DateTime<chrono::Utc>> for UtcDateTime {
fn eq(&self, other: &DateTime<chrono::Utc>) -> bool {
self.0 == *other
}
}
impl PartialOrd<DateTime<chrono::Utc>> for UtcDateTime {
fn partial_cmp(&self, other: &DateTime<chrono::Utc>) -> Option<std::cmp::Ordering> {
self.0.partial_cmp(other)
}
}
impl ToSql for UtcDateTime {
fn to_sql(&self) -> Result<ToSqlOutput<'_>> {
Ok(ToSqlOutput::from(self.0.to_rfc3339()))
}
}
impl FromSql for UtcDateTime {
fn column_result(value: ValueRef<'_>) -> Result<Self, FromSqlError> {
let parsed = DateTime::parse_from_rfc3339(value.as_str()?)
.map_err(|e| FromSqlError::Other(e.into()))?;
Ok(UtcDateTime(parsed.to_utc()))
}
}