Compare commits
10 Commits
182584cbe9
...
0bb8a333e4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0bb8a333e4 | ||
|
|
eef6811625 | ||
|
|
f44d884761 | ||
|
|
adbccc97a5 | ||
|
|
3f44344ac0 | ||
|
|
6deb909c43 | ||
|
|
83b4dacede | ||
|
|
b0955c9c64 | ||
|
|
a3b550526e | ||
|
|
bd2de7cfac |
132
Cargo.lock
generated
132
Cargo.lock
generated
@@ -44,6 +44,21 @@ version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd"
|
||||
|
||||
[[package]]
|
||||
name = "alloc-no-stdlib"
|
||||
version = "2.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3"
|
||||
|
||||
[[package]]
|
||||
name = "alloc-stdlib"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece"
|
||||
dependencies = [
|
||||
"alloc-no-stdlib",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "allocator-api2"
|
||||
version = "0.2.18"
|
||||
@@ -277,11 +292,14 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"axum_typed_multipart",
|
||||
"brotli",
|
||||
"chrono",
|
||||
"clap",
|
||||
"futures",
|
||||
"hex",
|
||||
"humansize",
|
||||
"kdam",
|
||||
"num_cpus",
|
||||
"ouroboros",
|
||||
"rand",
|
||||
"reqwest",
|
||||
@@ -290,10 +308,12 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"tabled",
|
||||
"tokio",
|
||||
"tokio-rusqlite",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"walkdir",
|
||||
"zstd",
|
||||
]
|
||||
|
||||
@@ -306,12 +326,39 @@ dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "brotli"
|
||||
version = "6.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b"
|
||||
dependencies = [
|
||||
"alloc-no-stdlib",
|
||||
"alloc-stdlib",
|
||||
"brotli-decompressor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "brotli-decompressor"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6221fe77a248b9117d431ad93761222e1cf8ff282d9d1d5d9f53d6299a1cf76"
|
||||
dependencies = [
|
||||
"alloc-no-stdlib",
|
||||
"alloc-stdlib",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "3.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
|
||||
|
||||
[[package]]
|
||||
name = "bytecount"
|
||||
version = "0.6.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce"
|
||||
|
||||
[[package]]
|
||||
name = "bytes"
|
||||
version = "1.6.0"
|
||||
@@ -805,6 +852,15 @@ version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
|
||||
|
||||
[[package]]
|
||||
name = "humansize"
|
||||
version = "2.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6cb51c9a029ddc91b07a787f1d86b53ccfa49b0e86688c946ebe8d3555685dd7"
|
||||
dependencies = [
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "1.3.1"
|
||||
@@ -972,6 +1028,12 @@ version = "0.2.153"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd"
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
|
||||
|
||||
[[package]]
|
||||
name = "libsqlite3-sys"
|
||||
version = "0.28.0"
|
||||
@@ -1207,6 +1269,17 @@ version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
|
||||
|
||||
[[package]]
|
||||
name = "papergrid"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ad43c07024ef767f9160710b3a6773976194758c7919b17e63b863db0bdf7fb"
|
||||
dependencies = [
|
||||
"bytecount",
|
||||
"fnv",
|
||||
"unicode-width",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.1"
|
||||
@@ -1552,6 +1625,15 @@ version = "1.0.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1"
|
||||
|
||||
[[package]]
|
||||
name = "same-file"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
version = "0.1.23"
|
||||
@@ -1734,6 +1816,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
@@ -1781,6 +1864,30 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tabled"
|
||||
version = "0.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c998b0c8b921495196a48aabaf1901ff28be0760136e31604f7967b0792050e"
|
||||
dependencies = [
|
||||
"papergrid",
|
||||
"tabled_derive",
|
||||
"unicode-width",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tabled_derive"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c138f99377e5d653a371cdad263615634cfc8467685dfe8e73e2b8e98f44b17"
|
||||
dependencies = [
|
||||
"heck 0.4.1",
|
||||
"proc-macro-error",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tempfile"
|
||||
version = "3.10.1"
|
||||
@@ -2047,6 +2154,12 @@ dependencies = [
|
||||
"tinyvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-width"
|
||||
version = "0.1.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6"
|
||||
|
||||
[[package]]
|
||||
name = "url"
|
||||
version = "2.5.0"
|
||||
@@ -2088,6 +2201,16 @@ version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||
|
||||
[[package]]
|
||||
name = "walkdir"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
|
||||
dependencies = [
|
||||
"same-file",
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "want"
|
||||
version = "0.3.1"
|
||||
@@ -2195,6 +2318,15 @@ version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-util"
|
||||
version = "0.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b"
|
||||
dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
|
||||
13
Cargo.toml
13
Cargo.toml
@@ -12,6 +12,10 @@ path = "src/main.rs"
|
||||
name = "load-test"
|
||||
path = "load_test/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "fixture-inserter"
|
||||
path = "fixture_inserter/main.rs"
|
||||
|
||||
[dependencies]
|
||||
axum = { version = "0.7.5", features = ["macros"] }
|
||||
axum_typed_multipart = "0.11.1"
|
||||
@@ -32,6 +36,15 @@ reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] }
|
||||
hex = "0.4.3"
|
||||
zstd = { version = "0.13.1", features = ["experimental"] }
|
||||
ouroboros = "0.18.3"
|
||||
humansize = "2.1.3"
|
||||
walkdir = "2.5.0"
|
||||
tabled = "0.15.0"
|
||||
brotli = "6.0.0"
|
||||
num_cpus = "1.16.0"
|
||||
|
||||
[dev-dependencies]
|
||||
rstest = "0.19.0"
|
||||
|
||||
[lints.rust]
|
||||
unsafe_code = "forbid"
|
||||
unused_must_use = "forbid"
|
||||
|
||||
14
Dockerfile
Normal file
14
Dockerfile
Normal file
@@ -0,0 +1,14 @@
|
||||
FROM rust:1.77 as builder
|
||||
RUN mkdir /build
|
||||
WORKDIR /build
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
COPY src ./src
|
||||
RUN --mount=type=cache,target=/usr/local/cargo/registry \
|
||||
--mount=type=cache,target=/build/target \
|
||||
cargo build --release --bin blob-store-app && \
|
||||
cp target/release/blob-store-app /blob-store-app
|
||||
|
||||
FROM debian:bookworm-slim
|
||||
RUN apt-get update && apt-get install -qqy libsqlite3-0
|
||||
COPY --from=builder /blob-store-app /usr/local/bin/blob-store-server
|
||||
ENTRYPOINT ["blob-store-server", "--db-path", "/data", "--bind", "0.0.0.0"]
|
||||
107
fixture_inserter/main.rs
Normal file
107
fixture_inserter/main.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
use clap::{arg, Parser};
|
||||
use kdam::BarExt;
|
||||
use kdam::{tqdm, Bar};
|
||||
use rand::prelude::SliceRandom;
|
||||
use reqwest::blocking::{multipart, Client};
|
||||
use std::{
|
||||
error::Error,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
use walkdir::WalkDir;
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
hint_name: String,
|
||||
#[arg(long)]
|
||||
fixture_dir: String,
|
||||
#[arg(long)]
|
||||
limit: Option<usize>,
|
||||
#[arg(long)]
|
||||
num_threads: Option<usize>,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
let args = Args::parse();
|
||||
let hint_name = args.hint_name;
|
||||
let num_threads = args.num_threads.unwrap_or(1);
|
||||
let pb = Arc::new(Mutex::new(tqdm!()));
|
||||
|
||||
let mut entries = WalkDir::new(&args.fixture_dir)
|
||||
.into_iter()
|
||||
.filter_map(Result::ok)
|
||||
.filter(|entry| entry.file_type().is_file())
|
||||
.filter_map(|entry| {
|
||||
let path = entry.path();
|
||||
let name = path.to_str()?.to_string();
|
||||
Some(name)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// shuffle the entries
|
||||
entries.shuffle(&mut rand::thread_rng());
|
||||
|
||||
// if there's a limit, drop the rest
|
||||
if let Some(limit) = args.limit {
|
||||
entries.truncate(limit);
|
||||
}
|
||||
|
||||
let entry_slices = entries.chunks(entries.len() / num_threads);
|
||||
println!("entry_slices: {:?}", entry_slices.len());
|
||||
|
||||
let join_handles = entry_slices.map(|entry_slice| {
|
||||
let hint_name = hint_name.clone();
|
||||
let pb = pb.clone();
|
||||
let entry_slice = entry_slice.to_vec();
|
||||
std::thread::spawn(move || {
|
||||
let client = Client::new();
|
||||
for entry in entry_slice.into_iter() {
|
||||
store_file(&client, &hint_name, &entry, pb.clone()).unwrap();
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
for join_handle in join_handles {
|
||||
join_handle.join().unwrap();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn store_file(
|
||||
client: &Client,
|
||||
hint_name: &str,
|
||||
file_name: &str,
|
||||
pb: Arc<Mutex<Bar>>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
let file_bytes = std::fs::read(file_name)?;
|
||||
|
||||
let content_type = if file_name.ends_with(".html") {
|
||||
"text/html"
|
||||
} else if file_name.ends_with(".json") {
|
||||
"application/json"
|
||||
} else if file_name.ends_with(".jpg") || file_name.ends_with(".jpeg") {
|
||||
"image/jpeg"
|
||||
} else if file_name.ends_with(".png") {
|
||||
"image/png"
|
||||
} else if file_name.ends_with(".txt") {
|
||||
"text/plain"
|
||||
} else if file_name.ends_with(".kindle.images") {
|
||||
"application/octet-stream"
|
||||
} else {
|
||||
"text/html"
|
||||
};
|
||||
|
||||
let form = multipart::Form::new()
|
||||
.text("content_type", content_type)
|
||||
.text("compression_hint", hint_name.to_string())
|
||||
.part("data", multipart::Part::bytes(file_bytes));
|
||||
|
||||
let _ = client
|
||||
.post("http://localhost:7692/store")
|
||||
.multipart(form)
|
||||
.send();
|
||||
|
||||
let mut pb = pb.lock().unwrap();
|
||||
pb.update(1)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,12 +1,11 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use clap::Parser;
|
||||
use kdam::tqdm;
|
||||
use kdam::Bar;
|
||||
use kdam::BarExt;
|
||||
use rand::Rng;
|
||||
use reqwest::StatusCode;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(version, about, long_about = None)]
|
||||
@@ -43,14 +42,23 @@ fn run_loop(pb: Arc<Mutex<Bar>>, args: Args) -> Result<(), Box<dyn std::error::E
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut rand_data = vec![0u8; args.file_size];
|
||||
rng.fill(&mut rand_data[..]);
|
||||
let hints = vec![
|
||||
"foo", "bar", "baz", "qux", "quux", "corge", "grault", "garply", "waldo", "fred", "plugh",
|
||||
"xyzzy", "thud",
|
||||
];
|
||||
|
||||
loop {
|
||||
// tweak a byte in the data
|
||||
let idx = rng.gen_range(0..rand_data.len());
|
||||
rand_data[idx] = rng.gen();
|
||||
let hint = hints[rng.gen_range(0..hints.len())];
|
||||
|
||||
let form = reqwest::blocking::multipart::Form::new()
|
||||
.text("content_type", "text/plain")
|
||||
.part(
|
||||
"compression_hint",
|
||||
reqwest::blocking::multipart::Part::text(hint),
|
||||
)
|
||||
.part(
|
||||
"data",
|
||||
reqwest::blocking::multipart::Part::bytes(rand_data.clone()),
|
||||
|
||||
18
src/app_time_formatter.rs
Normal file
18
src/app_time_formatter.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use tracing_subscriber::fmt::{format::Writer, time::FormatTime};
|
||||
|
||||
pub struct AppTimeFormatter(Instant);
|
||||
|
||||
impl AppTimeFormatter {
|
||||
pub fn new() -> Self {
|
||||
Self(Instant::now())
|
||||
}
|
||||
}
|
||||
|
||||
impl FormatTime for AppTimeFormatter {
|
||||
fn format_time(&self, w: &mut Writer<'_>) -> std::fmt::Result {
|
||||
let e = self.0.elapsed();
|
||||
write!(w, "{:5}.{:03}s", e.as_secs(), e.subsec_millis())
|
||||
}
|
||||
}
|
||||
67
src/compressible_data.rs
Normal file
67
src/compressible_data.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
use axum::{body::Bytes, response::IntoResponse};
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub enum CompressibleData {
|
||||
Bytes(Bytes),
|
||||
Vec(Vec<u8>),
|
||||
}
|
||||
|
||||
impl CompressibleData {
|
||||
pub fn len(&self) -> usize {
|
||||
match self {
|
||||
CompressibleData::Bytes(b) => b.len(),
|
||||
CompressibleData::Vec(v) => v.len(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for CompressibleData {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
match self {
|
||||
CompressibleData::Bytes(b) => b.as_ref(),
|
||||
CompressibleData::Vec(v) => v.as_ref(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Bytes> for CompressibleData {
|
||||
fn from(b: Bytes) -> Self {
|
||||
CompressibleData::Bytes(b)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<u8>> for CompressibleData {
|
||||
fn from(v: Vec<u8>) -> Self {
|
||||
CompressibleData::Vec(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for CompressibleData {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
CompressibleData::Bytes(b) => b.into_response(),
|
||||
CompressibleData::Vec(v) => v.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<&[u8]> for CompressibleData {
|
||||
fn eq(&self, other: &&[u8]) -> bool {
|
||||
let as_ref = self.as_ref();
|
||||
as_ref == *other
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<Vec<u8>> for CompressibleData {
|
||||
fn eq(&self, other: &Vec<u8>) -> bool {
|
||||
let as_ref = self.as_ref();
|
||||
as_ref == other.as_slice()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<Bytes> for CompressibleData {
|
||||
fn eq(&self, other: &Bytes) -> bool {
|
||||
let as_ref = self.as_ref();
|
||||
as_ref == other.as_ref()
|
||||
}
|
||||
}
|
||||
222
src/compression_manager.rs
Normal file
222
src/compression_manager.rs
Normal file
@@ -0,0 +1,222 @@
|
||||
use crate::{
|
||||
compressible_data::CompressibleData,
|
||||
compression_stats::{CompressionStat, CompressionStats},
|
||||
compressor::{BrotliGenericCompressor, Compressor, ZstdGenericCompressor},
|
||||
sql_types::CompressionId,
|
||||
AsyncBoxError, CompressionPolicy,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{RwLock, RwLockReadGuard};
|
||||
|
||||
pub type CompressionManagerArc = Arc<RwLock<CompressionManager>>;
|
||||
pub type CompressionStatsArc = Arc<RwLock<CompressionStats>>;
|
||||
|
||||
pub struct CompressionManager {
|
||||
compression_policy: CompressionPolicy,
|
||||
compression_stats: CompressionStatsArc,
|
||||
}
|
||||
|
||||
impl CompressionManager {
|
||||
pub fn new(compression_policy: CompressionPolicy) -> Self {
|
||||
Self {
|
||||
compression_policy,
|
||||
compression_stats: Arc::new(RwLock::new(CompressionStats::default())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CompressionManager {
|
||||
fn default() -> Self {
|
||||
Self::new(CompressionPolicy::Auto)
|
||||
}
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, AsyncBoxError>;
|
||||
|
||||
impl CompressionManager {
|
||||
pub fn into_arc(self) -> CompressionManagerArc {
|
||||
Arc::new(RwLock::new(self))
|
||||
}
|
||||
|
||||
pub fn set_compression_policy(&mut self, compression_policy: CompressionPolicy) {
|
||||
self.compression_policy = compression_policy;
|
||||
}
|
||||
|
||||
pub async fn set_stat(&mut self, id: CompressionId, stat: CompressionStat) {
|
||||
let mut compression_stats = self.compression_stats.write().await;
|
||||
*compression_stats.for_id_mut(id) = stat;
|
||||
}
|
||||
|
||||
pub async fn compress<Data: Into<CompressibleData>>(
|
||||
&self,
|
||||
content_type: &str,
|
||||
data: Data,
|
||||
) -> Result<(CompressionId, CompressibleData)> {
|
||||
let data = data.into();
|
||||
|
||||
match self.compression_policy {
|
||||
CompressionPolicy::Auto => self.compress_auto(content_type, data).await,
|
||||
CompressionPolicy::None => Ok((CompressionId::None, data)),
|
||||
CompressionPolicy::ForceZstd => self.compress_with_id(CompressionId::Zstd, &data).await,
|
||||
CompressionPolicy::ForceBrotli => {
|
||||
self.compress_with_id(CompressionId::Brotli, &data).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn compress_auto(
|
||||
&self,
|
||||
content_type: &str,
|
||||
data: CompressibleData,
|
||||
) -> Result<(CompressionId, CompressibleData)> {
|
||||
if auto_compressible_content_type(content_type) {
|
||||
let mut compressed_refs = vec![];
|
||||
for id in [CompressionId::Zstd, CompressionId::Brotli] {
|
||||
let compressed = self.compress_with_id(id, &data).await?;
|
||||
compressed_refs.push((id, compressed.1.len(), compressed.1));
|
||||
}
|
||||
compressed_refs.push((CompressionId::None, data.len(), data));
|
||||
|
||||
// find the smallest compressed data
|
||||
let (compression_id, _, data) = compressed_refs
|
||||
.into_iter()
|
||||
.min_by_key(|(_, len, _)| *len)
|
||||
.unwrap();
|
||||
Ok((compression_id, data))
|
||||
} else {
|
||||
Ok((CompressionId::None, data))
|
||||
}
|
||||
}
|
||||
|
||||
async fn compress_with_id(
|
||||
&self,
|
||||
compression_id: CompressionId,
|
||||
data: &CompressibleData,
|
||||
) -> Result<(CompressionId, CompressibleData)> {
|
||||
let compressor = self.get_compressor(compression_id)?;
|
||||
// get current time
|
||||
let now = std::time::Instant::now();
|
||||
let compressed_data = compressor.compress(data)?;
|
||||
let mut stats = self.compression_stats.write().await;
|
||||
stats.add_entry(
|
||||
compression_id,
|
||||
data.len(),
|
||||
compressed_data.len(),
|
||||
now.elapsed(),
|
||||
);
|
||||
Ok((compression_id, compressed_data))
|
||||
}
|
||||
|
||||
fn get_compressor(&self, compression_id: CompressionId) -> Result<Box<dyn Compressor>> {
|
||||
Ok(match compression_id {
|
||||
CompressionId::Zstd => Box::new(ZstdGenericCompressor),
|
||||
CompressionId::Brotli => Box::new(BrotliGenericCompressor),
|
||||
CompressionId::None => return Err("no compressor for CompressionId::None".into()),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn decompress<Data: Into<CompressibleData>>(
|
||||
&self,
|
||||
compression_id: CompressionId,
|
||||
data: Data,
|
||||
) -> Result<CompressibleData> {
|
||||
let data = data.into();
|
||||
match compression_id {
|
||||
CompressionId::None => Ok(data),
|
||||
CompressionId::Zstd => Ok(CompressibleData::Vec(zstd::stream::decode_all(
|
||||
data.as_ref(),
|
||||
)?)),
|
||||
CompressionId::Brotli => {
|
||||
let mut decompressed = vec![];
|
||||
let mut reader = std::io::Cursor::new(data.as_ref());
|
||||
brotli::BrotliDecompress(&mut reader, &mut decompressed)?;
|
||||
Ok(CompressibleData::Vec(decompressed))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stats(&self) -> RwLockReadGuard<CompressionStats> {
|
||||
self.compression_stats.read().await
|
||||
}
|
||||
}
|
||||
|
||||
fn auto_compressible_content_type(content_type: &str) -> bool {
|
||||
[
|
||||
"text/",
|
||||
"application/xml",
|
||||
"application/json",
|
||||
"application/javascript",
|
||||
]
|
||||
.iter()
|
||||
.any(|ct| content_type.starts_with(ct))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
use super::*;
|
||||
use rstest::rstest;
|
||||
|
||||
pub fn make_compressor() -> CompressionManager {
|
||||
make_compressor_with(CompressionPolicy::Auto)
|
||||
}
|
||||
|
||||
pub fn make_compressor_with(compression_policy: CompressionPolicy) -> CompressionManager {
|
||||
CompressionManager::new(compression_policy)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auto_compressible_content_type() {
|
||||
assert!(auto_compressible_content_type("text/plain"));
|
||||
assert!(auto_compressible_content_type("application/xml"));
|
||||
assert!(auto_compressible_content_type("application/json"));
|
||||
assert!(auto_compressible_content_type("application/javascript"));
|
||||
assert!(!auto_compressible_content_type("image/png"));
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_compression_policies(
|
||||
#[values(
|
||||
CompressionPolicy::Auto,
|
||||
CompressionPolicy::None,
|
||||
CompressionPolicy::ForceZstd
|
||||
)]
|
||||
compression_policy: CompressionPolicy,
|
||||
#[values("text/plain", "application/json", "image/png")] content_type: &str,
|
||||
) {
|
||||
let compressor = make_compressor_with(compression_policy);
|
||||
let data = b"hello, world!".to_vec();
|
||||
let (compression_id, compressed) = compressor
|
||||
.compress(content_type, data.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let data_uncompressed = compressor.decompress(compression_id, compressed).unwrap();
|
||||
assert_eq!(data_uncompressed, data);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_skip_compressing_small_data() {
|
||||
let compressor = make_compressor();
|
||||
let data = b"hi!".to_vec();
|
||||
let (compression_id, compressed) = compressor
|
||||
.compress("text/plain", data.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(compression_id, CompressionId::None);
|
||||
assert_eq!(compressed, data);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_compresses_longer_data() {
|
||||
let compressor = make_compressor();
|
||||
let data = vec![b'.'; 1024];
|
||||
let (compression_id, compressed) = compressor
|
||||
.compress("text/plain", data.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(compression_id, CompressionId::Brotli);
|
||||
assert_ne!(compressed, data);
|
||||
assert!(compressed.len() < data.len());
|
||||
}
|
||||
}
|
||||
52
src/compression_stats.rs
Normal file
52
src/compression_stats.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use crate::sql_types::CompressionId;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Default, Copy, Clone)]
|
||||
pub struct CompressionStat {
|
||||
pub num_entries: usize,
|
||||
pub uncompressed_size: usize,
|
||||
pub compressed_size: usize,
|
||||
pub elapsed: std::time::Duration,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct CompressionStats {
|
||||
order: Vec<CompressionId>,
|
||||
id_to_stat: HashMap<CompressionId, CompressionStat>,
|
||||
}
|
||||
|
||||
impl CompressionStats {
|
||||
pub fn for_id_mut(&mut self, id: CompressionId) -> &mut CompressionStat {
|
||||
self.id_to_stat.entry(id).or_insert_with_key(|_| {
|
||||
self.order.push(id);
|
||||
self.order.sort();
|
||||
Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_entry(
|
||||
&mut self,
|
||||
id: CompressionId,
|
||||
uncompressed_size: usize,
|
||||
compressed_size: usize,
|
||||
elapsed: std::time::Duration,
|
||||
) -> &CompressionStat {
|
||||
let stat = self.for_id_mut(id);
|
||||
stat.num_entries += 1;
|
||||
stat.uncompressed_size += uncompressed_size;
|
||||
stat.compressed_size += compressed_size;
|
||||
stat.elapsed += elapsed;
|
||||
stat
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn len(&self) -> usize {
|
||||
self.order.len()
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = (CompressionId, &CompressionStat)> {
|
||||
self.order
|
||||
.iter()
|
||||
.flat_map(|id| self.id_to_stat.get(id).map(|stat| (*id, stat)))
|
||||
}
|
||||
}
|
||||
20
src/compressor/brotli_generic_compressor.rs
Normal file
20
src/compressor/brotli_generic_compressor.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
use super::compressor_trait::Compressor;
|
||||
use crate::{compressible_data::CompressibleData, AsyncBoxError};
|
||||
use std::io::{Read, Write};
|
||||
|
||||
pub struct BrotliGenericCompressor;
|
||||
impl Compressor for BrotliGenericCompressor {
|
||||
fn compress(&self, data: &CompressibleData) -> Result<CompressibleData, AsyncBoxError> {
|
||||
let mut writer = brotli::CompressorWriter::new(Vec::new(), 4096, 7, 0);
|
||||
writer.write_all(data.as_ref())?;
|
||||
writer.flush()?;
|
||||
Ok(CompressibleData::Vec(writer.into_inner()))
|
||||
}
|
||||
|
||||
fn decompress(&self, data: &CompressibleData) -> Result<CompressibleData, AsyncBoxError> {
|
||||
let mut reader = brotli::Decompressor::new(data.as_ref(), 4096);
|
||||
let mut decompressed = Vec::new();
|
||||
reader.read_to_end(&mut decompressed)?;
|
||||
Ok(CompressibleData::Vec(decompressed))
|
||||
}
|
||||
}
|
||||
38
src/compressor/compressor_trait.rs
Normal file
38
src/compressor/compressor_trait.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
use crate::{compressible_data::CompressibleData, AsyncBoxError};
|
||||
|
||||
pub trait Compressor: Send {
|
||||
fn compress(&self, data: &CompressibleData) -> Result<CompressibleData, AsyncBoxError>;
|
||||
fn decompress(&self, data: &CompressibleData) -> Result<CompressibleData, AsyncBoxError>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(self) mod test {
|
||||
use super::*;
|
||||
use crate::compressor::*;
|
||||
use rstest::*;
|
||||
|
||||
fn random_bytes(len: usize) -> Vec<u8> {
|
||||
(0..len).map(|_| rand::random::<u8>()).collect()
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case(BrotliGenericCompressor)]
|
||||
#[case(ZstdGenericCompressor)]
|
||||
fn test_compressor(
|
||||
#[case] compressor: impl Compressor,
|
||||
#[values(
|
||||
vec![],
|
||||
"foobar".as_bytes().to_vec(),
|
||||
vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
||||
vec![0; 1024],
|
||||
random_bytes(10),
|
||||
random_bytes(1024)
|
||||
)]
|
||||
data_vec: Vec<u8>,
|
||||
) {
|
||||
let data = CompressibleData::Vec(data_vec);
|
||||
let compressed = compressor.compress(&data).unwrap();
|
||||
let decompressed = compressor.decompress(&compressed).unwrap();
|
||||
assert_eq!(decompressed, data);
|
||||
}
|
||||
}
|
||||
7
src/compressor/mod.rs
Normal file
7
src/compressor/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
mod brotli_generic_compressor;
|
||||
mod compressor_trait;
|
||||
mod zstd_generic_compressor;
|
||||
|
||||
pub use brotli_generic_compressor::BrotliGenericCompressor;
|
||||
pub use compressor_trait::Compressor;
|
||||
pub use zstd_generic_compressor::ZstdGenericCompressor;
|
||||
20
src/compressor/zstd_generic_compressor.rs
Normal file
20
src/compressor/zstd_generic_compressor.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
use super::compressor_trait::Compressor;
|
||||
use crate::{compressible_data::CompressibleData, AsyncBoxError};
|
||||
use std::io::{Read, Write};
|
||||
|
||||
pub struct ZstdGenericCompressor;
|
||||
impl Compressor for ZstdGenericCompressor {
|
||||
fn compress(&self, data: &CompressibleData) -> Result<CompressibleData, AsyncBoxError> {
|
||||
let mut writer = zstd::stream::write::Encoder::new(Vec::new(), 11)?;
|
||||
// writer.set_parameter(zstd::zstd_safe::CParameter::WindowLog(24))?;
|
||||
writer.write_all(data.as_ref())?;
|
||||
Ok(CompressibleData::Vec(writer.finish()?))
|
||||
}
|
||||
|
||||
fn decompress(&self, data: &CompressibleData) -> Result<CompressibleData, AsyncBoxError> {
|
||||
let mut reader = zstd::stream::read::Decoder::new(data.as_ref())?;
|
||||
let mut decompressed = Vec::new();
|
||||
reader.read_to_end(&mut decompressed)?;
|
||||
Ok(CompressibleData::Vec(decompressed))
|
||||
}
|
||||
}
|
||||
7
src/concat_lines.rs
Normal file
7
src/concat_lines.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
// like concat, but adds a newline between each expression
|
||||
#[macro_export]
|
||||
macro_rules! concat_lines {
|
||||
($($e:expr),* $(,)?) => {
|
||||
concat!($($e, "\n"),*)
|
||||
};
|
||||
}
|
||||
@@ -1,16 +1,21 @@
|
||||
use crate::{sha256::Sha256, shard::GetResult, shards::Shards};
|
||||
use crate::{
|
||||
sha256::Sha256,
|
||||
shard::{GetArgs, GetResult},
|
||||
shards::Shards,
|
||||
AsyncBoxError,
|
||||
};
|
||||
use axum::{
|
||||
extract::Path,
|
||||
http::{header, HeaderMap, HeaderValue, StatusCode},
|
||||
response::IntoResponse,
|
||||
Extension, Json,
|
||||
};
|
||||
use std::{collections::HashMap, error::Error};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
pub enum GetResponse {
|
||||
MissingSha256,
|
||||
InvalidSha256 { message: String },
|
||||
InternalError { error: Box<dyn Error> },
|
||||
InternalError { error: AsyncBoxError },
|
||||
NotFound,
|
||||
Found { get_result: GetResult },
|
||||
}
|
||||
@@ -69,7 +74,7 @@ fn make_found_response(
|
||||
Err(e) => return GetResponse::from(e).into_response(),
|
||||
};
|
||||
|
||||
let created_at = match HeaderValue::from_str(&created_at.to_rfc3339()) {
|
||||
let created_at = match HeaderValue::from_str(&created_at.to_string()) {
|
||||
Ok(created_at) => created_at,
|
||||
Err(e) => return GetResponse::from(e).into_response(),
|
||||
};
|
||||
@@ -98,7 +103,7 @@ fn make_found_response(
|
||||
(StatusCode::OK, headers, data).into_response()
|
||||
}
|
||||
|
||||
impl<E: Into<Box<dyn Error>>> From<E> for GetResponse {
|
||||
impl<E: Into<AsyncBoxError>> From<E> for GetResponse {
|
||||
fn from(error: E) -> Self {
|
||||
GetResponse::InternalError {
|
||||
error: error.into(),
|
||||
@@ -109,7 +114,7 @@ impl<E: Into<Box<dyn Error>>> From<E> for GetResponse {
|
||||
#[axum::debug_handler]
|
||||
pub async fn get_handler(
|
||||
Path(params): Path<HashMap<String, String>>,
|
||||
Extension(shards): Extension<Shards>,
|
||||
Extension(shards): Extension<Arc<Shards>>,
|
||||
) -> GetResponse {
|
||||
let sha256_str = match params.get("sha256") {
|
||||
Some(sha256_str) => sha256_str.clone(),
|
||||
@@ -128,7 +133,7 @@ pub async fn get_handler(
|
||||
};
|
||||
|
||||
let shard = shards.shard_for(&sha256);
|
||||
let get_result = match shard.get(sha256).await {
|
||||
let get_result = match shard.get(GetArgs { sha256 }).await {
|
||||
Ok(get_result) => get_result,
|
||||
Err(e) => return e.into(),
|
||||
};
|
||||
@@ -141,7 +146,9 @@ pub async fn get_handler(
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::{sha256::Sha256, shard::GetResult, shards::test::make_shards};
|
||||
use crate::{
|
||||
sha256::Sha256, shard::GetResult, shards::test::make_shards, sql_types::UtcDateTime,
|
||||
};
|
||||
use axum::{extract::Path, response::IntoResponse, Extension};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@@ -150,10 +157,10 @@ mod test {
|
||||
#[tokio::test]
|
||||
async fn test_get_invalid_sha256() {
|
||||
let shards = Extension(make_shards().await);
|
||||
|
||||
let response = super::get_handler(Path(HashMap::new()), shards.clone()).await;
|
||||
assert!(matches!(response, super::GetResponse::MissingSha256 { .. }));
|
||||
|
||||
let shards = Extension(make_shards().await);
|
||||
let response = super::get_handler(
|
||||
Path(HashMap::from([(String::from("sha256"), String::from(""))])),
|
||||
shards.clone(),
|
||||
@@ -174,8 +181,8 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn test_get_response_found_into_response() {
|
||||
let data = "hello, world!";
|
||||
let sha256 = Sha256::from_bytes(data.as_bytes());
|
||||
let data = "hello, world!".as_bytes().to_owned();
|
||||
let sha256 = Sha256::from_bytes(&data);
|
||||
let sha256_str = sha256.hex_string();
|
||||
let created_at = "2022-03-04T08:12:34+00:00";
|
||||
let response = GetResponse::Found {
|
||||
@@ -183,9 +190,7 @@ mod test {
|
||||
sha256,
|
||||
content_type: "text/plain".to_string(),
|
||||
stored_size: 12345,
|
||||
created_at: chrono::DateTime::parse_from_rfc3339(created_at)
|
||||
.unwrap()
|
||||
.to_utc(),
|
||||
created_at: UtcDateTime::from_string(created_at).unwrap(),
|
||||
data: data.into(),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::shards::Shards;
|
||||
use axum::{http::StatusCode, Extension, Json};
|
||||
|
||||
use std::sync::Arc;
|
||||
use tracing::error;
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
@@ -19,7 +19,7 @@ pub struct ShardInfo {
|
||||
|
||||
#[axum::debug_handler]
|
||||
pub async fn info_handler(
|
||||
Extension(shards): Extension<Shards>,
|
||||
Extension(shards): Extension<Arc<Shards>>,
|
||||
) -> Result<(StatusCode, Json<InfoResponse>), StatusCode> {
|
||||
let mut shard_infos = vec![];
|
||||
let mut total_db_size_bytes = 0;
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
use crate::{
|
||||
sha256::Sha256,
|
||||
shard::{StoreArgs, StoreResult},
|
||||
shards::Shards,
|
||||
shards::ShardsArc,
|
||||
};
|
||||
use axum::http::StatusCode;
|
||||
use axum::{body::Bytes, response::IntoResponse, Extension, Json};
|
||||
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||
|
||||
use axum::http::StatusCode;
|
||||
use serde::Serialize;
|
||||
use tracing::error;
|
||||
|
||||
@@ -59,7 +58,7 @@ impl From<StoreResult> for StoreResponse {
|
||||
} => StoreResponse::Created {
|
||||
stored_size,
|
||||
data_size,
|
||||
created_at: created_at.to_rfc3339(),
|
||||
created_at: created_at.to_string(),
|
||||
},
|
||||
StoreResult::Exists {
|
||||
stored_size,
|
||||
@@ -68,7 +67,7 @@ impl From<StoreResult> for StoreResponse {
|
||||
} => StoreResponse::Exists {
|
||||
stored_size,
|
||||
data_size,
|
||||
created_at: created_at.to_rfc3339(),
|
||||
created_at: created_at.to_string(),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -82,7 +81,7 @@ impl IntoResponse for StoreResponse {
|
||||
|
||||
#[axum::debug_handler]
|
||||
pub async fn store_handler(
|
||||
Extension(shards): Extension<Shards>,
|
||||
Extension(shards): Extension<ShardsArc>,
|
||||
TypedMultipart(request): TypedMultipart<StoreRequest>,
|
||||
) -> StoreResponse {
|
||||
let sha256 = Sha256::from_bytes(&request.data.contents);
|
||||
@@ -118,20 +117,22 @@ pub async fn store_handler(
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
use crate::{compression_manager::test::make_compressor_with, shards::test::make_shards_with};
|
||||
|
||||
use super::*;
|
||||
use crate::{shards::test::make_shards_with_compression, UseCompression};
|
||||
use crate::CompressionPolicy;
|
||||
use axum::body::Bytes;
|
||||
use axum_typed_multipart::FieldData;
|
||||
use rstest::rstest;
|
||||
|
||||
async fn send_request<D: Into<Bytes>>(
|
||||
compression_policy: CompressionPolicy,
|
||||
sha256: Option<Sha256>,
|
||||
content_type: &str,
|
||||
use_compression: UseCompression,
|
||||
data: D,
|
||||
) -> StoreResponse {
|
||||
store_handler(
|
||||
Extension(make_shards_with_compression(use_compression).await),
|
||||
Extension(make_shards_with(make_compressor_with(compression_policy).into_arc()).await),
|
||||
TypedMultipart(StoreRequest {
|
||||
sha256: sha256.map(|s| s.hex_string()),
|
||||
content_type: content_type.to_string(),
|
||||
@@ -146,8 +147,9 @@ pub mod test {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_handler() {
|
||||
let result = send_request(None, "text/plain", UseCompression::Auto, "hello, world!").await;
|
||||
assert_eq!(result.status_code(), StatusCode::CREATED);
|
||||
let result =
|
||||
send_request(CompressionPolicy::Auto, None, "text/plain", "hello, world!").await;
|
||||
assert_eq!(result.status_code(), StatusCode::CREATED, "{:?}", result);
|
||||
assert!(matches!(result, StoreResponse::Created { .. }));
|
||||
}
|
||||
|
||||
@@ -156,9 +158,9 @@ pub mod test {
|
||||
let not_hello_world = Sha256::from_bytes("goodbye, planet!".as_bytes());
|
||||
let hello_world = Sha256::from_bytes("hello, world!".as_bytes());
|
||||
let result = send_request(
|
||||
CompressionPolicy::Auto,
|
||||
Some(not_hello_world),
|
||||
"text/plain",
|
||||
UseCompression::Auto,
|
||||
"hello, world!",
|
||||
)
|
||||
.await;
|
||||
@@ -175,9 +177,9 @@ pub mod test {
|
||||
async fn test_store_handler_matching_sha256() {
|
||||
let hello_world = Sha256::from_bytes("hello, world!".as_bytes());
|
||||
let result = send_request(
|
||||
CompressionPolicy::Auto,
|
||||
Some(hello_world),
|
||||
"text/plain",
|
||||
UseCompression::Auto,
|
||||
"hello, world!",
|
||||
)
|
||||
.await;
|
||||
@@ -193,21 +195,21 @@ pub mod test {
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
// textual should be compressed by default
|
||||
#[case("text/plain", UseCompression::Auto, make_assert_lt(1024))]
|
||||
#[case("text/plain", UseCompression::Zstd, make_assert_lt(1024))]
|
||||
#[case("text/plain", UseCompression::None, make_assert_eq(1024))]
|
||||
// images, etc should not be compressed by default
|
||||
#[case("image/jpg", UseCompression::Auto, make_assert_eq(1024))]
|
||||
#[case("image/jpg", UseCompression::Zstd, make_assert_lt(1024))]
|
||||
#[case("image/jpg", UseCompression::None, make_assert_eq(1024))]
|
||||
// textual should be compressed if 'auto'
|
||||
#[case("text/plain", CompressionPolicy::Auto, make_assert_lt(1024))]
|
||||
#[case("text/plain", CompressionPolicy::ForceZstd, make_assert_lt(1024))]
|
||||
#[case("text/plain", CompressionPolicy::None, make_assert_eq(1024))]
|
||||
// images, etc should not be compressed if 'auto'
|
||||
#[case("image/jpg", CompressionPolicy::Auto, make_assert_eq(1024))]
|
||||
#[case("image/jpg", CompressionPolicy::ForceZstd, make_assert_lt(1024))]
|
||||
#[case("image/jpg", CompressionPolicy::None, make_assert_eq(1024))]
|
||||
#[tokio::test]
|
||||
async fn test_compressible_data<F: Fn(usize)>(
|
||||
#[case] content_type: &str,
|
||||
#[case] use_compression: UseCompression,
|
||||
#[case] compression_policy: CompressionPolicy,
|
||||
#[case] assert_stored_size: F,
|
||||
) {
|
||||
let result = send_request(None, content_type, use_compression, vec![0; 1024]).await;
|
||||
let result = send_request(compression_policy, None, content_type, vec![0; 1024]).await;
|
||||
assert_eq!(result.status_code(), StatusCode::CREATED);
|
||||
match result {
|
||||
StoreResponse::Created {
|
||||
|
||||
11
src/into_arc.rs
Normal file
11
src/into_arc.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
pub trait IntoArc {
|
||||
fn into_arc(self) -> Arc<Self>;
|
||||
}
|
||||
|
||||
impl<T> IntoArc for T {
|
||||
fn into_arc(self) -> Arc<Self> {
|
||||
Arc::new(self)
|
||||
}
|
||||
}
|
||||
30
src/loops/axum_server_loop.rs
Normal file
30
src/loops/axum_server_loop.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
use axum::{
|
||||
extract::DefaultBodyLimit,
|
||||
routing::{get, post},
|
||||
Extension, Router,
|
||||
};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use crate::{
|
||||
compression_manager::CompressionManagerArc, handlers, shards::ShardsArc, AsyncBoxError,
|
||||
};
|
||||
|
||||
pub async fn axum_server_loop(
|
||||
server: TcpListener,
|
||||
shards: ShardsArc,
|
||||
compression_manager: CompressionManagerArc,
|
||||
) -> Result<(), AsyncBoxError> {
|
||||
let app = Router::new()
|
||||
.route("/store", post(handlers::store_handler::store_handler))
|
||||
.route("/get/:sha256", get(handlers::get_handler::get_handler))
|
||||
.route("/info", get(handlers::info_handler::info_handler))
|
||||
.layer(Extension(shards))
|
||||
.layer(Extension(compression_manager))
|
||||
// limit the max size of a request body to 1000mb
|
||||
.layer(DefaultBodyLimit::max(1024 * 1024 * 1000));
|
||||
|
||||
axum::serve(server, app.into_make_service())
|
||||
.with_graceful_shutdown(crate::shutdown_signal::shutdown_signal())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
77
src/loops/dict_stats_printer_loop.rs
Normal file
77
src/loops/dict_stats_printer_loop.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use tabled::{settings::Style, Table, Tabled};
|
||||
use tokio::select;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::{compression_manager::CompressionManagerArc, AsyncBoxError};
|
||||
|
||||
pub async fn dict_stats_printer_loop(
|
||||
compressor: CompressionManagerArc,
|
||||
) -> Result<(), AsyncBoxError> {
|
||||
fn humanized_bytes_str(number: &usize) -> String {
|
||||
humansize::format_size(*number, humansize::BINARY)
|
||||
}
|
||||
fn humanized_bytes_per_sec(number: &usize) -> String {
|
||||
format!("{}/sec", humansize::format_size(*number, humansize::BINARY))
|
||||
}
|
||||
fn compression_ratio_str(row: &DictStatTableRow) -> String {
|
||||
if row.uncompressed_size == 0 {
|
||||
"(n/a)".to_string()
|
||||
} else {
|
||||
format!(
|
||||
"{:.2} x",
|
||||
row.compressed_size as f64 / row.uncompressed_size as f64
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Tabled, Default)]
|
||||
struct DictStatTableRow {
|
||||
name: String,
|
||||
#[tabled(rename = "# entries")]
|
||||
num_entries: usize,
|
||||
#[tabled(
|
||||
rename = "compression ratio",
|
||||
display_with("compression_ratio_str", self)
|
||||
)]
|
||||
_ratio: (),
|
||||
#[tabled(rename = "uncompressed bytes", display_with = "humanized_bytes_str")]
|
||||
uncompressed_size: usize,
|
||||
#[tabled(rename = "compressed bytes", display_with = "humanized_bytes_str")]
|
||||
compressed_size: usize,
|
||||
#[tabled(rename = "bytes/sec", display_with = "humanized_bytes_per_sec")]
|
||||
bytes_per_sec: usize,
|
||||
}
|
||||
|
||||
loop {
|
||||
{
|
||||
let compressor = compressor.read().await;
|
||||
let stats = compressor.stats().await;
|
||||
|
||||
Table::new(stats.iter().map(|(id, stat)| {
|
||||
let name = id.to_string();
|
||||
|
||||
DictStatTableRow {
|
||||
name,
|
||||
num_entries: stat.num_entries,
|
||||
uncompressed_size: stat.uncompressed_size,
|
||||
compressed_size: stat.compressed_size,
|
||||
bytes_per_sec: (stat.uncompressed_size as f64 / stat.elapsed.as_secs_f64())
|
||||
as usize,
|
||||
..Default::default()
|
||||
}
|
||||
}))
|
||||
.with(Style::rounded())
|
||||
.to_string()
|
||||
.lines()
|
||||
.for_each(|line| debug!("{}", line));
|
||||
}
|
||||
|
||||
select! {
|
||||
_ = tokio::time::sleep(tokio::time::Duration::from_secs(5)) => {}
|
||||
_ = crate::shutdown_signal::shutdown_signal() => {
|
||||
info!("dict_stats: shutdown signal received");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
7
src/loops/mod.rs
Normal file
7
src/loops/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
mod axum_server_loop;
|
||||
mod dict_stats_printer_loop;
|
||||
mod save_compression_stats_loop;
|
||||
|
||||
pub use axum_server_loop::axum_server_loop;
|
||||
pub use dict_stats_printer_loop::dict_stats_printer_loop;
|
||||
pub use save_compression_stats_loop::save_compression_stats_loop;
|
||||
19
src/loops/save_compression_stats_loop.rs
Normal file
19
src/loops/save_compression_stats_loop.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::select;
|
||||
use tracing::info;
|
||||
|
||||
use crate::{manifest::Manifest, AsyncBoxError};
|
||||
|
||||
pub async fn save_compression_stats_loop(manifest: Arc<Manifest>) -> Result<(), AsyncBoxError> {
|
||||
loop {
|
||||
manifest.save_compression_stats().await?;
|
||||
select! {
|
||||
_ = tokio::time::sleep(tokio::time::Duration::from_secs(60)) => {}
|
||||
_ = crate::shutdown_signal::shutdown_signal() => {
|
||||
info!("persist_dict_stats: shutdown signal received");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
112
src/main.rs
112
src/main.rs
@@ -1,22 +1,28 @@
|
||||
mod app_time_formatter;
|
||||
mod compressible_data;
|
||||
mod compression_manager;
|
||||
mod compression_stats;
|
||||
mod compressor;
|
||||
mod concat_lines;
|
||||
mod handlers;
|
||||
mod into_arc;
|
||||
mod loops;
|
||||
mod manifest;
|
||||
mod sha256;
|
||||
mod shard;
|
||||
mod shards;
|
||||
mod shutdown_signal;
|
||||
mod sql_types;
|
||||
|
||||
use crate::{manifest::Manifest, shards::Shards};
|
||||
use axum::{
|
||||
routing::{get, post},
|
||||
Extension, Router,
|
||||
};
|
||||
use app_time_formatter::AppTimeFormatter;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use futures::executor::block_on;
|
||||
use shard::Shard;
|
||||
use std::{error::Error, path::PathBuf};
|
||||
use tokio::net::TcpListener;
|
||||
use std::{error::Error, path::PathBuf, sync::Arc};
|
||||
use tokio::{net::TcpListener, spawn};
|
||||
use tokio_rusqlite::Connection;
|
||||
use tracing::info;
|
||||
use tracing::{debug, info};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(version, about, long_about = None)]
|
||||
@@ -39,38 +45,62 @@ struct Args {
|
||||
|
||||
/// How to compress stored data
|
||||
#[arg(short, long, default_value = "auto")]
|
||||
compression: UseCompression,
|
||||
compression: CompressionPolicy,
|
||||
}
|
||||
|
||||
#[derive(Default, PartialEq, Debug, Copy, Clone, ValueEnum)]
|
||||
pub enum UseCompression {
|
||||
pub enum CompressionPolicy {
|
||||
#[default]
|
||||
Auto,
|
||||
None,
|
||||
Zstd,
|
||||
ForceZstd,
|
||||
ForceBrotli,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
pub type AsyncBoxError = Box<dyn Error + Send + Sync + 'static>;
|
||||
|
||||
pub fn into_tokio_rusqlite_err<E: Into<AsyncBoxError>>(e: E) -> tokio_rusqlite::Error {
|
||||
tokio_rusqlite::Error::Other(e.into())
|
||||
}
|
||||
|
||||
fn main() -> Result<(), AsyncBoxError> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::DEBUG)
|
||||
.with_timer(AppTimeFormatter::new())
|
||||
.with_target(false)
|
||||
.init();
|
||||
|
||||
let args = Args::parse();
|
||||
let db_path = PathBuf::from(&args.db_path);
|
||||
let num_shards = args.shards;
|
||||
|
||||
if db_path.is_file() {
|
||||
return Err("db_path must be a directory".into());
|
||||
}
|
||||
if !db_path.is_dir() {
|
||||
std::fs::create_dir_all(&db_path)?;
|
||||
}
|
||||
|
||||
// block on opening the manifest
|
||||
let manifest = block_on(async {
|
||||
Manifest::open(
|
||||
let manifest = Arc::new(block_on(async {
|
||||
let manifest = Manifest::open(
|
||||
Connection::open(db_path.join("manifest.sqlite")).await?,
|
||||
num_shards,
|
||||
)
|
||||
.await
|
||||
})?;
|
||||
.await?;
|
||||
|
||||
manifest
|
||||
.compression_manager()
|
||||
.write()
|
||||
.await
|
||||
.set_compression_policy(args.compression);
|
||||
|
||||
Ok::<_, AsyncBoxError>(manifest)
|
||||
})?);
|
||||
|
||||
// max num_shards threads
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(manifest.num_shards())
|
||||
.worker_threads(16)
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
@@ -82,38 +112,38 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
manifest.num_shards()
|
||||
);
|
||||
let mut shards_vec = vec![];
|
||||
let mut num_entries = vec![];
|
||||
for shard_id in 0..manifest.num_shards() {
|
||||
let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id));
|
||||
let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?;
|
||||
let shard = Shard::open(shard_id, UseCompression::Auto, shard_sqlite_conn).await?;
|
||||
info!(
|
||||
"shard {} has {} entries",
|
||||
shard.id(),
|
||||
shard.num_entries().await?
|
||||
);
|
||||
let shard =
|
||||
Shard::open(shard_id, shard_sqlite_conn, manifest.compression_manager()).await?;
|
||||
num_entries.push(shard.num_entries().await?);
|
||||
shards_vec.push(shard);
|
||||
}
|
||||
debug!(
|
||||
"loaded shards with {} total entries <- {:?}",
|
||||
num_entries.iter().sum::<usize>(),
|
||||
num_entries
|
||||
);
|
||||
|
||||
let shards = Shards::new(shards_vec).ok_or("num shards must be > 0")?;
|
||||
server_loop(server, shards.clone()).await?;
|
||||
info!("shutting down server...");
|
||||
shards.close_all().await?;
|
||||
info!("server closed sqlite connections. bye!");
|
||||
Ok::<_, Box<dyn Error>>(())
|
||||
let compression_manager = manifest.compression_manager();
|
||||
let shards = Arc::new(Shards::new(shards_vec).ok_or("num shards must be > 0")?);
|
||||
let join_handles = vec![
|
||||
spawn(loops::save_compression_stats_loop(manifest.clone())),
|
||||
spawn(loops::dict_stats_printer_loop(
|
||||
manifest.compression_manager(),
|
||||
)),
|
||||
spawn(loops::axum_server_loop(server, shards, compression_manager)),
|
||||
];
|
||||
for handle in join_handles {
|
||||
handle.await??;
|
||||
}
|
||||
info!("saving compressor stats...");
|
||||
manifest.save_compression_stats().await?;
|
||||
info!("done. bye!");
|
||||
Ok::<_, AsyncBoxError>(())
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box<dyn Error>> {
|
||||
let app = Router::new()
|
||||
.route("/store", post(handlers::store_handler::store_handler))
|
||||
.route("/get/:sha256", get(handlers::get_handler::get_handler))
|
||||
.route("/info", get(handlers::info_handler::info_handler))
|
||||
.layer(Extension(shards));
|
||||
|
||||
axum::serve(server, app.into_make_service())
|
||||
.with_graceful_shutdown(crate::shutdown_signal::shutdown_signal())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
use rusqlite::types::FromSql;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
|
||||
pub struct DictId(i64);
|
||||
impl FromSql for DictId {
|
||||
fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
|
||||
Ok(DictId(value.as_i64()?))
|
||||
}
|
||||
}
|
||||
@@ -1,121 +1,113 @@
|
||||
mod dict_id;
|
||||
mod manifest_key;
|
||||
mod zstd_dict;
|
||||
|
||||
use std::{collections::HashMap, error::Error, sync::Arc};
|
||||
|
||||
use rusqlite::params;
|
||||
use tokio_rusqlite::Connection;
|
||||
|
||||
use crate::shards::Shards;
|
||||
|
||||
use self::{
|
||||
dict_id::DictId,
|
||||
manifest_key::{get_manifest_key, set_manifest_key, NumShards},
|
||||
zstd_dict::ZstdDict,
|
||||
use self::manifest_key::{get_manifest_key, set_manifest_key, NumShards};
|
||||
use crate::{
|
||||
compression_manager::{CompressionManager, CompressionManagerArc},
|
||||
compression_stats::{CompressionStat, CompressionStats},
|
||||
concat_lines, into_tokio_rusqlite_err,
|
||||
sql_types::CompressionId,
|
||||
AsyncBoxError,
|
||||
};
|
||||
|
||||
pub type ZstdDictArc = Arc<ZstdDict>;
|
||||
use rusqlite::params;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_rusqlite::Connection;
|
||||
use tracing::{debug, info};
|
||||
|
||||
pub struct Manifest {
|
||||
conn: Connection,
|
||||
num_shards: usize,
|
||||
zstd_dict_by_id: HashMap<DictId, ZstdDictArc>,
|
||||
zstd_dict_by_name: HashMap<String, ZstdDictArc>,
|
||||
compressor: Arc<RwLock<CompressionManager>>,
|
||||
}
|
||||
|
||||
impl Manifest {
|
||||
pub async fn open(conn: Connection, num_shards: Option<usize>) -> Result<Self, Box<dyn Error>> {
|
||||
pub async fn open(conn: Connection, num_shards: Option<usize>) -> Result<Self, AsyncBoxError> {
|
||||
initialize(conn, num_shards).await
|
||||
}
|
||||
|
||||
pub fn num_shards(&self) -> usize {
|
||||
self.num_shards
|
||||
}
|
||||
|
||||
async fn train_zstd_dict_with_tag(
|
||||
&mut self,
|
||||
_name: &str,
|
||||
_shards: Shards,
|
||||
) -> Result<ZstdDictArc, Box<dyn Error>> {
|
||||
// let mut queries = vec![];
|
||||
// for shard in shards.iter() {
|
||||
// queries.push(shard.entries_for_tag(name));
|
||||
// }
|
||||
todo!();
|
||||
pub fn compression_manager(&self) -> CompressionManagerArc {
|
||||
self.compressor.clone()
|
||||
}
|
||||
|
||||
async fn create_zstd_dict_from_samples(
|
||||
&mut self,
|
||||
name: &str,
|
||||
samples: Vec<&[u8]>,
|
||||
) -> Result<ZstdDictArc, Box<dyn Error>> {
|
||||
if self.zstd_dict_by_name.contains_key(name) {
|
||||
return Err(format!("dictionary {} already exists", name).into());
|
||||
}
|
||||
pub async fn save_compression_stats(&self) -> Result<(), AsyncBoxError> {
|
||||
let compressor_arc = self.compression_manager();
|
||||
let compressor = compressor_arc.read().await;
|
||||
let compression_stats = compressor.stats().await.clone();
|
||||
|
||||
let level = 3;
|
||||
let dict_bytes = zstd::dict::from_samples(
|
||||
&samples,
|
||||
1024 * 1024, // 1MB max dictionary size
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let name_copy = name.to_string();
|
||||
let (dict_id, zstd_dict) = self
|
||||
.conn
|
||||
self.conn
|
||||
.call(move |conn| {
|
||||
let mut stmt = conn.prepare(
|
||||
"INSERT INTO dictionaries (name, level, dict)
|
||||
VALUES (?, ?, ?)
|
||||
RETURNING id",
|
||||
)?;
|
||||
let dict_id =
|
||||
stmt.query_row(params![name_copy, level, dict_bytes], |row| row.get(0))?;
|
||||
let zstd_dict = Arc::new(ZstdDict::create(dict_id, name_copy, level, dict_bytes));
|
||||
Ok((dict_id, zstd_dict))
|
||||
save_compression_stats(conn, compression_stats).map_err(into_tokio_rusqlite_err)?;
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
|
||||
self.zstd_dict_by_id.insert(dict_id, zstd_dict.clone());
|
||||
self.zstd_dict_by_name
|
||||
.insert(name.to_string(), zstd_dict.clone());
|
||||
Ok(zstd_dict)
|
||||
}
|
||||
|
||||
fn get_dictionary_by_id(&self, id: DictId) -> Option<&ZstdDict> {
|
||||
self.zstd_dict_by_id.get(&id).map(|d| &**d)
|
||||
}
|
||||
fn get_dictionary_by_name(&self, name: &str) -> Option<&ZstdDict> {
|
||||
self.zstd_dict_by_name.get(name).map(|d| &**d)
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn initialize(
|
||||
conn: Connection,
|
||||
num_shards: Option<usize>,
|
||||
) -> Result<Manifest, Box<dyn Error>> {
|
||||
async fn migrate_manifest(conn: &Connection) -> Result<(), tokio_rusqlite::Error> {
|
||||
conn.call(|conn| {
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS manifest (key TEXT PRIMARY KEY, value)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
conn.execute(
|
||||
concat_lines!(
|
||||
"CREATE TABLE IF NOT EXISTS zstd_dictionaries (",
|
||||
" zstd_dict_id INTEGER PRIMARY KEY AUTOINCREMENT,",
|
||||
" name TEXT NOT NULL,",
|
||||
" encoder_json TEXT NOT NULL",
|
||||
")"
|
||||
),
|
||||
[],
|
||||
)?;
|
||||
|
||||
conn.execute(
|
||||
concat_lines!(
|
||||
"CREATE TABLE IF NOT EXISTS compression_stats (",
|
||||
" dict_id TEXT PRIMARY KEY,",
|
||||
" num_entries INTEGER NOT NULL DEFAULT 0,",
|
||||
" uncompressed_size INTEGER NOT NULL DEFAULT 0,",
|
||||
" compressed_size INTEGER NOT NULL DEFAULT 0,",
|
||||
" elapsed INTEGER NOT NULL DEFAULT 0",
|
||||
")"
|
||||
),
|
||||
[],
|
||||
)?;
|
||||
|
||||
// insert the default compression stats
|
||||
let default_compression_ids = [
|
||||
CompressionId::None,
|
||||
CompressionId::Zstd,
|
||||
CompressionId::Brotli,
|
||||
];
|
||||
for compression_id in default_compression_ids.iter() {
|
||||
conn.execute(
|
||||
concat_lines!(
|
||||
"INSERT OR IGNORE INTO compression_stats (dict_id)",
|
||||
"VALUES (?)"
|
||||
),
|
||||
params![compression_id],
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn load_and_store_num_shards(
|
||||
conn: &Connection,
|
||||
requested_num_shards: Option<usize>,
|
||||
) -> Result<usize, AsyncBoxError> {
|
||||
let stored_num_shards: Option<usize> = conn
|
||||
.call(|conn| {
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS manifest (key TEXT PRIMARY KEY, value)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS dictionaries (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
level INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
dict BLOB NOT NULL
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
Ok(get_manifest_key(conn, NumShards)?)
|
||||
})
|
||||
.call(|conn| Ok(get_manifest_key(conn, NumShards)?))
|
||||
.await?;
|
||||
|
||||
let num_shards = match (stored_num_shards, num_shards) {
|
||||
Ok(match (stored_num_shards, requested_num_shards) {
|
||||
(Some(stored_num_shards), Some(num_shards)) => {
|
||||
if stored_num_shards == num_shards {
|
||||
num_shards
|
||||
@@ -142,38 +134,94 @@ async fn initialize(
|
||||
// existing database, use loaded num_shards
|
||||
num_shards
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
|
||||
let rows = conn
|
||||
async fn load_compression_stats(
|
||||
conn: &Connection,
|
||||
compressor: &mut CompressionManager,
|
||||
) -> Result<(), AsyncBoxError> {
|
||||
type CompressionStatRow = (CompressionId, usize, usize, usize, u64);
|
||||
let rows: Vec<CompressionStatRow> = conn
|
||||
.call(|conn| {
|
||||
let mut stmt = conn.prepare("SELECT id, name, level, dict FROM dictionaries")?;
|
||||
let mut rows = vec![];
|
||||
for r in stmt.query_map([], |row| {
|
||||
let id = row.get(0)?;
|
||||
let name: String = row.get(1)?;
|
||||
let level: i32 = row.get(2)?;
|
||||
let dict: Vec<u8> = row.get(3)?;
|
||||
Ok((id, name, level, dict))
|
||||
})? {
|
||||
rows.push(r?);
|
||||
}
|
||||
let rows = conn
|
||||
.prepare(concat_lines!(
|
||||
"SELECT",
|
||||
" dict_id, num_entries, uncompressed_size, compressed_size, elapsed",
|
||||
"FROM compression_stats",
|
||||
))?
|
||||
.query_map([], |row| CompressionStatRow::try_from(row))?
|
||||
.collect::<Result<_, _>>()?;
|
||||
Ok(rows)
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut zstd_dicts_by_id = HashMap::new();
|
||||
let mut zstd_dicts_by_name = HashMap::new();
|
||||
for (id, name, level, dict_bytes) in rows {
|
||||
let zstd_dict = Arc::new(ZstdDict::create(id, name.clone(), level, dict_bytes));
|
||||
zstd_dicts_by_id.insert(id, zstd_dict.clone());
|
||||
zstd_dicts_by_name.insert(name, zstd_dict);
|
||||
debug!("loaded {} compression stats from manifest", rows.len());
|
||||
for (dict_id, num_entries, uncompressed_size, compressed_size, elapsed) in rows {
|
||||
compressor
|
||||
.set_stat(
|
||||
dict_id,
|
||||
CompressionStat {
|
||||
num_entries,
|
||||
uncompressed_size,
|
||||
compressed_size,
|
||||
elapsed: Duration::from_millis(elapsed),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn save_compression_stats(
|
||||
conn: &rusqlite::Connection,
|
||||
compression_stats: CompressionStats,
|
||||
) -> Result<(), AsyncBoxError> {
|
||||
let mut num_stats = 0;
|
||||
for (id, stat) in compression_stats.iter() {
|
||||
num_stats += 1;
|
||||
conn.execute(
|
||||
concat_lines!(
|
||||
"INSERT INTO compression_stats",
|
||||
" (dict_id, num_entries, uncompressed_size, compressed_size, elapsed)",
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
"ON CONFLICT (dict_id) DO UPDATE SET",
|
||||
" num_entries = excluded.num_entries,",
|
||||
" uncompressed_size = excluded.uncompressed_size,",
|
||||
" compressed_size = excluded.compressed_size,",
|
||||
" elapsed = excluded.elapsed"
|
||||
),
|
||||
params![
|
||||
id,
|
||||
stat.num_entries,
|
||||
stat.uncompressed_size,
|
||||
stat.compressed_size,
|
||||
stat.elapsed.as_millis() as u64,
|
||||
],
|
||||
)?;
|
||||
}
|
||||
info!("saved {} compressor stats", num_stats);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_compressor(conn: &Connection) -> Result<CompressionManagerArc, AsyncBoxError> {
|
||||
let mut compressor = CompressionManager::default();
|
||||
load_compression_stats(conn, &mut compressor).await?;
|
||||
Ok(compressor.into_arc())
|
||||
}
|
||||
|
||||
async fn initialize(
|
||||
conn: Connection,
|
||||
requested_num_shards: Option<usize>,
|
||||
) -> Result<Manifest, AsyncBoxError> {
|
||||
migrate_manifest(&conn).await?;
|
||||
let num_shards = load_and_store_num_shards(&conn, requested_num_shards).await?;
|
||||
let compressor = load_compressor(&conn).await?;
|
||||
Ok(Manifest {
|
||||
conn,
|
||||
num_shards,
|
||||
zstd_dict_by_id: zstd_dicts_by_id,
|
||||
zstd_dict_by_name: zstd_dicts_by_name,
|
||||
compressor,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -182,30 +230,13 @@ mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_manifest() {
|
||||
async fn test_manifest_compression_stats_loading() {
|
||||
let conn = Connection::open_in_memory().await.unwrap();
|
||||
let mut manifest = initialize(conn, Some(3)).await.unwrap();
|
||||
|
||||
let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100];
|
||||
let zstd_dict = manifest
|
||||
.create_zstd_dict_from_samples("test", samples)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// test that indexes are created correctly
|
||||
assert_eq!(
|
||||
zstd_dict.as_ref(),
|
||||
manifest.get_dictionary_by_id(zstd_dict.id()).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
zstd_dict.as_ref(),
|
||||
manifest.get_dictionary_by_name(zstd_dict.name()).unwrap()
|
||||
);
|
||||
|
||||
let data = b"hello world, this is a test of a sort of long string";
|
||||
let compressed = zstd_dict.compress(data).unwrap();
|
||||
let decompressed = zstd_dict.decompress(&compressed).unwrap();
|
||||
assert_eq!(decompressed, data);
|
||||
assert!(data.len() > compressed.len());
|
||||
let manifest = initialize(conn, Some(4)).await.unwrap();
|
||||
let compressor = manifest.compression_manager();
|
||||
let compressor = compressor.read().await;
|
||||
let stats = compressor.stats().await;
|
||||
assert_eq!(stats.len(), 3);
|
||||
assert_eq!(stats.iter().count(), 3);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
use super::dict_id::DictId;
|
||||
use ouroboros::self_referencing;
|
||||
use std::{error::Error, io};
|
||||
use zstd::dict::{DecoderDictionary, EncoderDictionary};
|
||||
|
||||
#[self_referencing]
|
||||
pub struct ZstdDict {
|
||||
id: DictId,
|
||||
name: String,
|
||||
level: i32,
|
||||
dict_bytes: Vec<u8>,
|
||||
|
||||
#[borrows(dict_bytes)]
|
||||
#[not_covariant]
|
||||
encoder_dict: EncoderDictionary<'this>,
|
||||
|
||||
#[borrows(dict_bytes)]
|
||||
#[not_covariant]
|
||||
decoder_dict: DecoderDictionary<'this>,
|
||||
}
|
||||
|
||||
impl PartialEq for ZstdDict {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.id() == other.id()
|
||||
&& self.name() == other.name()
|
||||
&& self.level() == other.level()
|
||||
&& self.dict_bytes() == other.dict_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ZstdDict {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ZstdDict")
|
||||
.field("id", &self.id())
|
||||
.field("name", &self.name())
|
||||
.field("level", &self.level())
|
||||
.field("dict_bytes.len", &self.dict_bytes().len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl ZstdDict {
|
||||
pub fn create(id: DictId, name: String, level: i32, dict_bytes: Vec<u8>) -> Self {
|
||||
ZstdDictBuilder {
|
||||
id,
|
||||
name,
|
||||
level,
|
||||
dict_bytes,
|
||||
encoder_dict_builder: |dict_bytes| EncoderDictionary::new(dict_bytes, level),
|
||||
decoder_dict_builder: |dict_bytes| DecoderDictionary::new(dict_bytes),
|
||||
}
|
||||
.build()
|
||||
}
|
||||
|
||||
pub fn id(&self) -> DictId {
|
||||
self.with_id(|id| *id)
|
||||
}
|
||||
pub fn name(&self) -> &str {
|
||||
self.borrow_name()
|
||||
}
|
||||
pub fn level(&self) -> i32 {
|
||||
*self.borrow_level()
|
||||
}
|
||||
pub fn dict_bytes(&self) -> &[u8] {
|
||||
self.borrow_dict_bytes()
|
||||
}
|
||||
|
||||
pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
let mut wrapper = io::Cursor::new(data);
|
||||
let mut out_buffer = Vec::with_capacity(data.len());
|
||||
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
|
||||
|
||||
self.with_encoder_dict(|encoder_dict| {
|
||||
let mut encoder =
|
||||
zstd::stream::Encoder::with_prepared_dictionary(&mut output_wrapper, encoder_dict)?;
|
||||
io::copy(&mut wrapper, &mut encoder)?;
|
||||
encoder.finish()
|
||||
})?;
|
||||
Ok(out_buffer)
|
||||
}
|
||||
|
||||
pub fn decompress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
let mut wrapper = io::Cursor::new(data);
|
||||
let mut out_buffer = Vec::with_capacity(data.len());
|
||||
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
|
||||
|
||||
self.with_decoder_dict(|decoder_dict| {
|
||||
let mut decoder =
|
||||
zstd::stream::Decoder::with_prepared_dictionary(&mut wrapper, decoder_dict)?;
|
||||
io::copy(&mut decoder, &mut output_wrapper)
|
||||
})?;
|
||||
Ok(out_buffer)
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,13 @@
|
||||
use rusqlite::{
|
||||
types::{FromSql, ToSqlOutput},
|
||||
ToSql,
|
||||
};
|
||||
use sha2::Digest;
|
||||
use std::{
|
||||
error::Error,
|
||||
fmt::{Display, LowerHex},
|
||||
};
|
||||
|
||||
use sha2::Digest;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Sha256Error {
|
||||
message: String,
|
||||
@@ -17,7 +20,7 @@ impl Display for Sha256Error {
|
||||
}
|
||||
impl Error for Sha256Error {}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Default)]
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Default, Hash)]
|
||||
pub struct Sha256([u8; 32]);
|
||||
impl Sha256 {
|
||||
pub fn from_hex_string(hex: &str) -> Result<Self, Box<dyn Error>> {
|
||||
@@ -47,6 +50,20 @@ impl Sha256 {
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for Sha256 {
|
||||
fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput> {
|
||||
Ok(ToSqlOutput::Borrowed(rusqlite::types::ValueRef::Blob(
|
||||
&self.0,
|
||||
)))
|
||||
}
|
||||
}
|
||||
impl FromSql for Sha256 {
|
||||
fn column_result(value: rusqlite::types::ValueRef) -> rusqlite::types::FromSqlResult<Self> {
|
||||
let bytes = <[u8; 32]>::column_result(value)?;
|
||||
Ok(Self(bytes))
|
||||
}
|
||||
}
|
||||
|
||||
impl LowerHex for Sha256 {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
for byte in self.0.iter() {
|
||||
|
||||
@@ -1,73 +1,55 @@
|
||||
use std::io::Read;
|
||||
|
||||
use super::*;
|
||||
use crate::{
|
||||
compressible_data::CompressibleData,
|
||||
sql_types::{CompressionId, UtcDateTime},
|
||||
AsyncBoxError,
|
||||
};
|
||||
pub struct GetArgs {
|
||||
pub sha256: Sha256,
|
||||
}
|
||||
|
||||
pub struct GetResult {
|
||||
pub sha256: Sha256,
|
||||
pub content_type: String,
|
||||
pub stored_size: usize,
|
||||
pub created_at: UtcDateTime,
|
||||
pub data: Vec<u8>,
|
||||
pub data: CompressibleData,
|
||||
}
|
||||
|
||||
impl Shard {
|
||||
pub async fn get(&self, sha256: Sha256) -> Result<Option<GetResult>, Box<dyn Error>> {
|
||||
self.conn
|
||||
.call(move |conn| get_impl(conn, sha256))
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("get failed: {}", e);
|
||||
e.into()
|
||||
})
|
||||
pub async fn get(&self, args: GetArgs) -> Result<Option<GetResult>, AsyncBoxError> {
|
||||
let maybe_row = self
|
||||
.conn
|
||||
.call(move |conn| Ok(get_compressed_row(conn, &args.sha256)?))
|
||||
.await?;
|
||||
|
||||
if let Some((content_type, stored_size, created_at, dict_id, data)) = maybe_row {
|
||||
let compressor = self.compression_manager.read().await;
|
||||
let data = compressor.decompress(dict_id, data)?;
|
||||
Ok(Some(GetResult {
|
||||
sha256: args.sha256,
|
||||
content_type,
|
||||
stored_size,
|
||||
created_at,
|
||||
data,
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_impl(
|
||||
type CompressedRowResult = (String, usize, UtcDateTime, CompressionId, Vec<u8>);
|
||||
fn get_compressed_row(
|
||||
conn: &mut rusqlite::Connection,
|
||||
sha256: Sha256,
|
||||
) -> Result<Option<GetResult>, tokio_rusqlite::Error> {
|
||||
let maybe_row = conn
|
||||
.query_row(
|
||||
"SELECT content_type, compressed_size, created_at, compression, data
|
||||
sha256: &Sha256,
|
||||
) -> Result<Option<CompressedRowResult>, rusqlite::Error> {
|
||||
conn.query_row(
|
||||
"SELECT content_type, compressed_size, created_at, dict_id, data
|
||||
FROM entries
|
||||
WHERE sha256 = ?",
|
||||
params![sha256.hex_string()],
|
||||
|row| {
|
||||
let content_type = row.get(0)?;
|
||||
let stored_size = row.get(1)?;
|
||||
let created_at = parse_created_at_str(row.get(2)?)?;
|
||||
let compression = row.get(3)?;
|
||||
let data: Vec<u8> = row.get(4)?;
|
||||
Ok((content_type, stored_size, created_at, compression, data))
|
||||
},
|
||||
)
|
||||
.optional()
|
||||
.map_err(into_tokio_rusqlite_err)?;
|
||||
|
||||
let row = match maybe_row {
|
||||
Some(row) => row,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let (content_type, stored_size, created_at, compression, data) = row;
|
||||
let data = match compression {
|
||||
Compression::None => data,
|
||||
Compression::Zstd => {
|
||||
let mut decoder =
|
||||
zstd::Decoder::new(data.as_slice()).map_err(into_tokio_rusqlite_err)?;
|
||||
let mut decompressed = vec![];
|
||||
decoder
|
||||
.read_to_end(&mut decompressed)
|
||||
.map_err(into_tokio_rusqlite_err)?;
|
||||
decompressed
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Some(GetResult {
|
||||
sha256,
|
||||
content_type,
|
||||
stored_size,
|
||||
created_at,
|
||||
data,
|
||||
}))
|
||||
params![sha256],
|
||||
|row| CompressedRowResult::try_from(row),
|
||||
)
|
||||
.optional()
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use super::*;
|
||||
use crate::{concat_lines, AsyncBoxError};
|
||||
|
||||
impl Shard {
|
||||
pub(super) async fn migrate(&self) -> Result<(), Box<dyn Error>> {
|
||||
pub(super) async fn migrate(&self) -> Result<(), AsyncBoxError> {
|
||||
let shard_id = self.id();
|
||||
// create tables, indexes, etc
|
||||
self.conn
|
||||
@@ -9,12 +10,7 @@ impl Shard {
|
||||
ensure_schema_versions_table(conn)?;
|
||||
let schema_rows = load_schema_rows(conn)?;
|
||||
|
||||
if let Some((version, date_time)) = schema_rows.first() {
|
||||
debug!(
|
||||
"shard {}: latest schema version: {} @ {}",
|
||||
shard_id, version, date_time
|
||||
);
|
||||
|
||||
if let Some((version, _)) = schema_rows.first() {
|
||||
if *version == 1 {
|
||||
// no-op
|
||||
} else {
|
||||
@@ -35,17 +31,23 @@ impl Shard {
|
||||
|
||||
fn ensure_schema_versions_table(conn: &rusqlite::Connection) -> Result<usize, rusqlite::Error> {
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS schema_version (
|
||||
version INTEGER PRIMARY KEY,
|
||||
created_at TEXT NOT NULL
|
||||
)",
|
||||
concat_lines!(
|
||||
"CREATE TABLE IF NOT EXISTS schema_version (",
|
||||
" version INTEGER PRIMARY KEY,",
|
||||
" created_at TEXT NOT NULL",
|
||||
")"
|
||||
),
|
||||
[],
|
||||
)
|
||||
}
|
||||
|
||||
fn load_schema_rows(conn: &rusqlite::Connection) -> Result<Vec<(i64, String)>, rusqlite::Error> {
|
||||
let mut stmt = conn
|
||||
.prepare("SELECT version, created_at FROM schema_version ORDER BY version DESC LIMIT 1")?;
|
||||
let mut stmt = conn.prepare(concat_lines!(
|
||||
"SELECT version, created_at",
|
||||
"FROM schema_version",
|
||||
"ORDER BY version",
|
||||
"DESC LIMIT 1"
|
||||
))?;
|
||||
let rows = stmt.query_map([], |row| {
|
||||
let version = row.get(0)?;
|
||||
let created_at = row.get(1)?;
|
||||
@@ -57,15 +59,45 @@ fn load_schema_rows(conn: &rusqlite::Connection) -> Result<Vec<(i64, String)>, r
|
||||
fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
|
||||
debug!("migrating to version 1");
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS entries (
|
||||
sha256 BLOB PRIMARY KEY,
|
||||
content_type TEXT NOT NULL,
|
||||
compression INTEGER NOT NULL,
|
||||
uncompressed_size INTEGER NOT NULL,
|
||||
compressed_size INTEGER NOT NULL,
|
||||
data BLOB NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
)",
|
||||
concat_lines!(
|
||||
"CREATE TABLE IF NOT EXISTS entries (",
|
||||
" id INTEGER PRIMARY KEY AUTOINCREMENT,",
|
||||
" sha256 BLOB NOT NULL,",
|
||||
" content_type TEXT NOT NULL,",
|
||||
" dict_id TEXT NOT NULL,",
|
||||
" uncompressed_size INTEGER NOT NULL,",
|
||||
" compressed_size INTEGER NOT NULL,",
|
||||
" data BLOB NOT NULL,",
|
||||
" created_at TEXT NOT NULL",
|
||||
")"
|
||||
),
|
||||
[],
|
||||
)?;
|
||||
|
||||
conn.execute(
|
||||
concat_lines!(
|
||||
"CREATE INDEX IF NOT EXISTS entries_sha256_idx",
|
||||
"ON entries (sha256)"
|
||||
),
|
||||
[],
|
||||
)?;
|
||||
|
||||
conn.execute(
|
||||
concat_lines!(
|
||||
"CREATE TABLE IF NOT EXISTS compression_hints (",
|
||||
" name TEXT NOT NULL,",
|
||||
" ordering INTEGER NOT NULL,",
|
||||
" entry_id INTEGER NOT NULL",
|
||||
")"
|
||||
),
|
||||
[],
|
||||
)?;
|
||||
|
||||
conn.execute(
|
||||
concat_lines!(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS compression_hints_name_idx",
|
||||
"ON compression_hints (name, ordering)",
|
||||
),
|
||||
[],
|
||||
)?;
|
||||
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
compressible_data::CompressibleData,
|
||||
concat_lines, into_tokio_rusqlite_err,
|
||||
sql_types::{CompressionId, UtcDateTime},
|
||||
AsyncBoxError,
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Debug)]
|
||||
pub enum StoreResult {
|
||||
@@ -22,86 +28,96 @@ pub struct StoreArgs {
|
||||
}
|
||||
|
||||
impl Shard {
|
||||
pub async fn store(&self, store_args: StoreArgs) -> Result<StoreResult, Box<dyn Error>> {
|
||||
let use_compression = self.use_compression;
|
||||
pub async fn store(
|
||||
&self,
|
||||
StoreArgs {
|
||||
sha256,
|
||||
data,
|
||||
content_type,
|
||||
}: StoreArgs,
|
||||
) -> Result<StoreResult, AsyncBoxError> {
|
||||
let existing_entry = self
|
||||
.conn
|
||||
.call(move |conn| find_with_sha256(conn, &sha256).map_err(into_tokio_rusqlite_err))
|
||||
.await?;
|
||||
|
||||
if let Some(entry) = existing_entry {
|
||||
return Ok(entry);
|
||||
}
|
||||
|
||||
let uncompressed_size = data.len();
|
||||
|
||||
let (dict_id, data) = {
|
||||
let compression_manager = self.compression_manager.read().await;
|
||||
compression_manager.compress(&content_type, data).await?
|
||||
};
|
||||
|
||||
self.conn
|
||||
.call(move |conn| store(conn, use_compression, store_args))
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("store failed: {}", e);
|
||||
e.into()
|
||||
.call(move |conn| {
|
||||
insert(
|
||||
conn,
|
||||
&sha256,
|
||||
content_type,
|
||||
dict_id,
|
||||
uncompressed_size,
|
||||
data,
|
||||
)
|
||||
.map_err(into_tokio_rusqlite_err)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
|
||||
fn store(
|
||||
fn find_with_sha256(
|
||||
conn: &mut rusqlite::Connection,
|
||||
use_compression: UseCompression,
|
||||
StoreArgs {
|
||||
sha256,
|
||||
content_type,
|
||||
data,
|
||||
}: StoreArgs,
|
||||
) -> Result<StoreResult, tokio_rusqlite::Error> {
|
||||
let sha256 = sha256.hex_string();
|
||||
sha256: &Sha256,
|
||||
) -> Result<Option<StoreResult>, rusqlite::Error> {
|
||||
conn.query_row(
|
||||
concat_lines!(
|
||||
"SELECT uncompressed_size, compressed_size, created_at",
|
||||
"FROM entries",
|
||||
"WHERE sha256 = ?"
|
||||
),
|
||||
params![sha256],
|
||||
|row| {
|
||||
Ok(StoreResult::Exists {
|
||||
stored_size: row.get(0)?,
|
||||
data_size: row.get(1)?,
|
||||
created_at: row.get(2)?,
|
||||
})
|
||||
},
|
||||
)
|
||||
.optional()
|
||||
}
|
||||
|
||||
// check for existing entry
|
||||
let maybe_existing: Option<StoreResult> = conn
|
||||
.query_row(
|
||||
"SELECT uncompressed_size, compressed_size, created_at FROM entries WHERE sha256 = ?",
|
||||
params![sha256],
|
||||
|row| {
|
||||
Ok(StoreResult::Exists {
|
||||
stored_size: row.get(0)?,
|
||||
data_size: row.get(1)?,
|
||||
created_at: parse_created_at_str(row.get(2)?)?,
|
||||
})
|
||||
},
|
||||
)
|
||||
.optional()?;
|
||||
|
||||
if let Some(existing) = maybe_existing {
|
||||
return Ok(existing);
|
||||
}
|
||||
|
||||
let created_at = chrono::Utc::now();
|
||||
let uncompressed_size = data.len();
|
||||
let tmp_data_holder;
|
||||
|
||||
let use_compression = match use_compression {
|
||||
UseCompression::None => false,
|
||||
UseCompression::Auto => auto_compressible_content_type(&content_type),
|
||||
UseCompression::Zstd => true,
|
||||
};
|
||||
|
||||
let (compression, data) = if use_compression {
|
||||
tmp_data_holder = zstd::encode_all(&data[..], 0).map_err(into_tokio_rusqlite_err)?;
|
||||
if tmp_data_holder.len() < data.len() {
|
||||
(Compression::Zstd, &tmp_data_holder[..])
|
||||
} else {
|
||||
(Compression::None, &data[..])
|
||||
}
|
||||
} else {
|
||||
(Compression::None, &data[..])
|
||||
};
|
||||
fn insert(
|
||||
conn: &mut rusqlite::Connection,
|
||||
sha256: &Sha256,
|
||||
content_type: String,
|
||||
dict_id: CompressionId,
|
||||
uncompressed_size: usize,
|
||||
data: CompressibleData,
|
||||
) -> Result<StoreResult, rusqlite::Error> {
|
||||
let created_at = UtcDateTime::now();
|
||||
let compressed_size = data.len();
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO entries
|
||||
(sha256, content_type, compression, uncompressed_size, compressed_size, data, created_at)
|
||||
VALUES
|
||||
(?, ?, ?, ?, ?, ?, ?)
|
||||
",
|
||||
params![
|
||||
sha256,
|
||||
content_type,
|
||||
compression,
|
||||
uncompressed_size,
|
||||
compressed_size,
|
||||
data,
|
||||
created_at.to_rfc3339(),
|
||||
],
|
||||
)?;
|
||||
concat_lines!(
|
||||
"INSERT INTO entries",
|
||||
" (sha256, content_type, dict_id, uncompressed_size, compressed_size, data, created_at)",
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
),
|
||||
params![
|
||||
sha256,
|
||||
content_type,
|
||||
dict_id,
|
||||
uncompressed_size,
|
||||
compressed_size,
|
||||
data.as_ref(),
|
||||
created_at,
|
||||
],
|
||||
)?;
|
||||
|
||||
Ok(StoreResult::Created {
|
||||
stored_size: compressed_size,
|
||||
@@ -109,14 +125,3 @@ fn store(
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
fn auto_compressible_content_type(content_type: &str) -> bool {
|
||||
[
|
||||
"text/",
|
||||
"application/xml",
|
||||
"application/json",
|
||||
"application/javascript",
|
||||
]
|
||||
.iter()
|
||||
.any(|ct| content_type.starts_with(ct))
|
||||
}
|
||||
|
||||
@@ -1,58 +1,19 @@
|
||||
mod fn_get;
|
||||
mod fn_migrate;
|
||||
mod fn_store;
|
||||
mod shard;
|
||||
pub mod shard_error;
|
||||
mod shard_struct;
|
||||
|
||||
pub use fn_get::GetResult;
|
||||
pub use fn_get::{GetArgs, GetResult};
|
||||
pub use fn_store::{StoreArgs, StoreResult};
|
||||
pub use shard::Shard;
|
||||
pub use shard_struct::Shard;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
pub use super::shard::test::*;
|
||||
pub use super::shard_struct::test::*;
|
||||
}
|
||||
|
||||
use crate::{sha256::Sha256, shard::shard_error::ShardError, UseCompression};
|
||||
use crate::{sha256::Sha256, shard::shard_error::ShardError};
|
||||
use axum::body::Bytes;
|
||||
use rusqlite::{params, types::FromSql, OptionalExtension, ToSql};
|
||||
use std::error::Error;
|
||||
use rusqlite::{params, types::FromSql, OptionalExtension};
|
||||
use tokio_rusqlite::Connection;
|
||||
use tracing::{debug, error};
|
||||
|
||||
pub type UtcDateTime = chrono::DateTime<chrono::Utc>;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
enum Compression {
|
||||
None,
|
||||
Zstd,
|
||||
}
|
||||
impl ToSql for Compression {
|
||||
fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
|
||||
match self {
|
||||
Compression::None => 0.to_sql(),
|
||||
Compression::Zstd => 1.to_sql(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl FromSql for Compression {
|
||||
fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
|
||||
match value.as_i64()? {
|
||||
0 => Ok(Compression::None),
|
||||
1 => Ok(Compression::Zstd),
|
||||
_ => Err(rusqlite::types::FromSqlError::InvalidType),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_created_at_str(created_at_str: String) -> Result<UtcDateTime, rusqlite::Error> {
|
||||
let parsed = chrono::DateTime::parse_from_rfc3339(&created_at_str)
|
||||
.map_err(|e| rusqlite::Error::ToSqlConversionFailure(e.into()))?;
|
||||
Ok(parsed.to_utc())
|
||||
}
|
||||
|
||||
fn into_tokio_rusqlite_err<E: Into<Box<dyn Error + Send + Sync + 'static>>>(
|
||||
e: E,
|
||||
) -> tokio_rusqlite::Error {
|
||||
tokio_rusqlite::Error::Other(e.into())
|
||||
}
|
||||
use tracing::debug;
|
||||
|
||||
@@ -1,36 +1,33 @@
|
||||
use super::*;
|
||||
use crate::{compression_manager::CompressionManagerArc, AsyncBoxError};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Shard {
|
||||
pub(super) id: usize,
|
||||
pub(super) conn: Connection,
|
||||
pub(super) use_compression: UseCompression,
|
||||
pub(super) compression_manager: CompressionManagerArc,
|
||||
}
|
||||
|
||||
impl Shard {
|
||||
pub async fn open(
|
||||
id: usize,
|
||||
use_compression: UseCompression,
|
||||
conn: Connection,
|
||||
) -> Result<Self, Box<dyn Error>> {
|
||||
compression_manager: CompressionManagerArc,
|
||||
) -> Result<Self, AsyncBoxError> {
|
||||
let shard = Self {
|
||||
id,
|
||||
use_compression,
|
||||
conn,
|
||||
compression_manager,
|
||||
};
|
||||
shard.migrate().await?;
|
||||
Ok(shard)
|
||||
}
|
||||
|
||||
pub async fn close(self) -> Result<(), Box<dyn Error>> {
|
||||
self.conn.close().await.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
pub fn id(&self) -> usize {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub async fn db_size_bytes(&self) -> Result<usize, Box<dyn Error>> {
|
||||
pub async fn db_size_bytes(&self) -> Result<usize, AsyncBoxError> {
|
||||
self.query_single_row(
|
||||
"SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()",
|
||||
)
|
||||
@@ -40,7 +37,7 @@ impl Shard {
|
||||
async fn query_single_row<T: FromSql + Send + 'static>(
|
||||
&self,
|
||||
query: &'static str,
|
||||
) -> Result<T, Box<dyn Error>> {
|
||||
) -> Result<T, AsyncBoxError> {
|
||||
self.conn
|
||||
.call(move |conn| {
|
||||
let value: T = conn.query_row(query, [], |row| row.get(0))?;
|
||||
@@ -50,7 +47,7 @@ impl Shard {
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
pub async fn num_entries(&self) -> Result<usize, Box<dyn Error>> {
|
||||
pub async fn num_entries(&self) -> Result<usize, AsyncBoxError> {
|
||||
get_num_entries(&self.conn).await.map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
@@ -65,18 +62,23 @@ async fn get_num_entries(conn: &Connection) -> Result<usize, tokio_rusqlite::Err
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
use crate::compression_manager::test::make_compressor_with;
|
||||
use crate::compression_manager::CompressionManagerArc;
|
||||
use crate::{
|
||||
compression_manager::test::make_compressor,
|
||||
sha256::Sha256,
|
||||
shard::{GetArgs, StoreArgs, StoreResult},
|
||||
CompressionPolicy,
|
||||
};
|
||||
use rstest::rstest;
|
||||
|
||||
use super::{StoreResult, UseCompression};
|
||||
use crate::{sha256::Sha256, shard::StoreArgs};
|
||||
|
||||
pub async fn make_shard_with_compression(use_compression: UseCompression) -> super::Shard {
|
||||
pub async fn make_shard_with(compressor: CompressionManagerArc) -> super::Shard {
|
||||
let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap();
|
||||
super::Shard::open(0, use_compression, conn).await.unwrap()
|
||||
super::Shard::open(0, conn, compressor).await.unwrap()
|
||||
}
|
||||
|
||||
pub async fn make_shard() -> super::Shard {
|
||||
make_shard_with_compression(UseCompression::Auto).await
|
||||
make_shard_with(make_compressor().into_arc()).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -97,7 +99,7 @@ pub mod test {
|
||||
async fn test_not_found_get() {
|
||||
let shard = make_shard().await;
|
||||
let sha256 = Sha256::from_bytes("hello, world!".as_bytes());
|
||||
let get_result = shard.get(sha256).await.unwrap();
|
||||
let get_result = shard.get(GetArgs { sha256 }).await.unwrap();
|
||||
assert!(get_result.is_none());
|
||||
}
|
||||
|
||||
@@ -111,16 +113,17 @@ pub mod test {
|
||||
sha256,
|
||||
content_type: "text/plain".to_string(),
|
||||
data: data.into(),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
match store_result {
|
||||
StoreResult::Created {
|
||||
stored_size,
|
||||
stored_size: _,
|
||||
data_size,
|
||||
created_at,
|
||||
} => {
|
||||
assert_eq!(stored_size, data.len());
|
||||
// assert_eq!(stored_size, data.len());
|
||||
assert_eq!(data_size, data.len());
|
||||
assert!(created_at > chrono::Utc::now() - chrono::Duration::seconds(1));
|
||||
}
|
||||
@@ -128,7 +131,7 @@ pub mod test {
|
||||
}
|
||||
assert_eq!(shard.num_entries().await.unwrap(), 1);
|
||||
|
||||
let get_result = shard.get(sha256).await.unwrap().unwrap();
|
||||
let get_result = shard.get(GetArgs { sha256 }).await.unwrap().unwrap();
|
||||
assert_eq!(get_result.content_type, "text/plain");
|
||||
assert_eq!(get_result.data, data);
|
||||
assert_eq!(get_result.stored_size, data.len());
|
||||
@@ -145,6 +148,7 @@ pub mod test {
|
||||
sha256,
|
||||
content_type: "text/plain".to_string(),
|
||||
data: data.into(),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -168,6 +172,7 @@ pub mod test {
|
||||
sha256,
|
||||
content_type: "text/plain".to_string(),
|
||||
data: data.into(),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -185,12 +190,16 @@ pub mod test {
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_compression_store_get(
|
||||
#[values(UseCompression::Auto, UseCompression::None, UseCompression::Zstd)]
|
||||
use_compression: UseCompression,
|
||||
#[values(
|
||||
CompressionPolicy::Auto,
|
||||
CompressionPolicy::None,
|
||||
CompressionPolicy::ForceZstd
|
||||
)]
|
||||
compression_policy: CompressionPolicy,
|
||||
#[values(true, false)] incompressible_data: bool,
|
||||
#[values("text/string", "image/jpg", "application/octet-stream")] content_type: String,
|
||||
) {
|
||||
let shard = make_shard_with_compression(use_compression).await;
|
||||
let shard = make_shard_with(make_compressor_with(compression_policy).into_arc()).await;
|
||||
let mut data = vec![b'.'; 1024];
|
||||
if incompressible_data {
|
||||
for byte in data.iter_mut() {
|
||||
@@ -204,12 +213,13 @@ pub mod test {
|
||||
sha256,
|
||||
content_type: content_type.clone(),
|
||||
data: data.clone().into(),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(matches!(store_result, StoreResult::Created { .. }));
|
||||
|
||||
let get_result = shard.get(sha256).await.unwrap().unwrap();
|
||||
let get_result = shard.get(GetArgs { sha256 }).await.unwrap().unwrap();
|
||||
assert_eq!(get_result.content_type, content_type);
|
||||
assert_eq!(get_result.data, data);
|
||||
}
|
||||
@@ -1,14 +1,16 @@
|
||||
use crate::{sha256::Sha256, shard::Shard};
|
||||
use std::error::Error;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub type ShardsArc = Arc<Shards>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Shards(Vec<Shard>);
|
||||
pub struct Shards(Vec<Arc<Shard>>);
|
||||
impl Shards {
|
||||
pub fn new(shards: Vec<Shard>) -> Option<Self> {
|
||||
if shards.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some(Self(shards))
|
||||
Some(Self(shards.into_iter().map(Arc::new).collect()))
|
||||
}
|
||||
|
||||
pub fn shard_for(&self, sha256: &Sha256) -> &Shard {
|
||||
@@ -16,15 +18,8 @@ impl Shards {
|
||||
&self.0[shard_id]
|
||||
}
|
||||
|
||||
pub async fn close_all(self) -> Result<(), Box<dyn Error>> {
|
||||
for shard in self.0 {
|
||||
shard.close().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> std::slice::Iter<'_, Shard> {
|
||||
self.0.iter()
|
||||
pub fn iter(&self) -> impl Iterator<Item = &Shard> {
|
||||
self.0.iter().map(|shard| shard.as_ref())
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
@@ -34,15 +29,19 @@ impl Shards {
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
use crate::{shard::test::make_shard_with_compression, UseCompression};
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::Shards;
|
||||
use crate::{
|
||||
compression_manager::{test::make_compressor, CompressionManagerArc},
|
||||
shard::test::make_shard_with,
|
||||
};
|
||||
|
||||
pub async fn make_shards_with_compression(use_compression: UseCompression) -> Shards {
|
||||
Shards::new(vec![make_shard_with_compression(use_compression).await]).unwrap()
|
||||
use super::{Shards, ShardsArc};
|
||||
|
||||
pub async fn make_shards() -> ShardsArc {
|
||||
make_shards_with(make_compressor().into_arc()).await
|
||||
}
|
||||
|
||||
pub async fn make_shards() -> Shards {
|
||||
make_shards_with_compression(UseCompression::Auto).await
|
||||
pub async fn make_shards_with(compressor: CompressionManagerArc) -> ShardsArc {
|
||||
Arc::new(Shards::new(vec![make_shard_with(compressor.clone()).await]).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
85
src/sql_types/compression_id.rs
Normal file
85
src/sql_types/compression_id.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
use crate::AsyncBoxError;
|
||||
use rusqlite::{
|
||||
types::{FromSql, FromSqlError, FromSqlResult, ToSqlOutput, ValueRef},
|
||||
ToSql,
|
||||
};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, PartialOrd, Ord)]
|
||||
pub enum CompressionId {
|
||||
None,
|
||||
Zstd,
|
||||
Brotli,
|
||||
}
|
||||
|
||||
const NONE_PREFIX: &str = "none";
|
||||
const ZSTD_PREFIX: &str = "zstd";
|
||||
const BROTLI_PREFIX: &str = "brotli";
|
||||
|
||||
impl CompressionId {
|
||||
pub fn prefix(&self) -> &'static str {
|
||||
match self {
|
||||
CompressionId::None => NONE_PREFIX,
|
||||
CompressionId::Zstd => ZSTD_PREFIX,
|
||||
CompressionId::Brotli => BROTLI_PREFIX,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_str(id_as_str: &str) -> Result<Self, AsyncBoxError> {
|
||||
match id_as_str {
|
||||
NONE_PREFIX => Ok(CompressionId::None),
|
||||
ZSTD_PREFIX => Ok(CompressionId::Zstd),
|
||||
BROTLI_PREFIX => Ok(CompressionId::Brotli),
|
||||
_ => Err(format!("invalid DictId: {}", id_as_str).into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToString for CompressionId {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
CompressionId::None => NONE_PREFIX.to_string(),
|
||||
CompressionId::Zstd => ZSTD_PREFIX.to_string(),
|
||||
CompressionId::Brotli => BROTLI_PREFIX.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for CompressionId {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
|
||||
self.prefix().to_sql()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql for CompressionId {
|
||||
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
|
||||
let s = value.as_str()?;
|
||||
CompressionId::from_str(s).map_err(FromSqlError::Other)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rstest::rstest;
|
||||
|
||||
#[rstest(
|
||||
case(CompressionId::None, "none"),
|
||||
case(CompressionId::Zstd, "zstd"),
|
||||
case(CompressionId::Brotli, "brotli")
|
||||
)]
|
||||
#[test]
|
||||
fn test_dict_id(#[case] id: CompressionId, #[case] id_str: &str) {
|
||||
use rusqlite::types::Value;
|
||||
|
||||
let sql_to_id = CompressionId::column_result(ValueRef::Text(id_str.as_bytes()));
|
||||
assert_eq!(id, sql_to_id.unwrap());
|
||||
|
||||
let id_to_sql = match id.to_sql() {
|
||||
Ok(ToSqlOutput::Borrowed(ValueRef::Text(text))) => text.to_owned(),
|
||||
Ok(ToSqlOutput::Owned(Value::Text(text))) => text.as_bytes().to_owned(),
|
||||
_ => panic!("unexpected ToSqlOutput: {:?}", id.to_sql()),
|
||||
};
|
||||
let id_to_sql = std::str::from_utf8(&id_to_sql).unwrap();
|
||||
assert_eq!(id_str, id_to_sql);
|
||||
}
|
||||
}
|
||||
24
src/sql_types/entry_id.rs
Normal file
24
src/sql_types/entry_id.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
use rusqlite::{
|
||||
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
|
||||
ToSql,
|
||||
};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
|
||||
pub struct EntryId(pub i64);
|
||||
impl From<i64> for EntryId {
|
||||
fn from(id: i64) -> Self {
|
||||
Self(id)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql for EntryId {
|
||||
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
|
||||
Ok(value.as_i64()?.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for EntryId {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
|
||||
self.0.to_sql()
|
||||
}
|
||||
}
|
||||
5
src/sql_types/mod.rs
Normal file
5
src/sql_types/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
mod compression_id;
|
||||
mod utc_date_time;
|
||||
|
||||
pub use compression_id::CompressionId;
|
||||
pub use utc_date_time::UtcDateTime;
|
||||
51
src/sql_types/utc_date_time.rs
Normal file
51
src/sql_types/utc_date_time.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use chrono::DateTime;
|
||||
use rusqlite::{
|
||||
types::{FromSql, FromSqlError, ToSqlOutput, ValueRef},
|
||||
Result, ToSql,
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Debug, PartialOrd)]
|
||||
pub struct UtcDateTime(DateTime<chrono::Utc>);
|
||||
|
||||
impl UtcDateTime {
|
||||
pub fn now() -> Self {
|
||||
Self(chrono::Utc::now())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn from_string(s: &str) -> Result<Self, chrono::ParseError> {
|
||||
Ok(Self(DateTime::parse_from_rfc3339(s)?.to_utc()))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToString for UtcDateTime {
|
||||
fn to_string(&self) -> String {
|
||||
self.0.to_rfc3339()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<DateTime<chrono::Utc>> for UtcDateTime {
|
||||
fn eq(&self, other: &DateTime<chrono::Utc>) -> bool {
|
||||
self.0 == *other
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd<DateTime<chrono::Utc>> for UtcDateTime {
|
||||
fn partial_cmp(&self, other: &DateTime<chrono::Utc>) -> Option<std::cmp::Ordering> {
|
||||
self.0.partial_cmp(other)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for UtcDateTime {
|
||||
fn to_sql(&self) -> Result<ToSqlOutput<'_>> {
|
||||
Ok(ToSqlOutput::from(self.0.to_rfc3339()))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql for UtcDateTime {
|
||||
fn column_result(value: ValueRef<'_>) -> Result<Self, FromSqlError> {
|
||||
let parsed = DateTime::parse_from_rfc3339(value.as_str()?)
|
||||
.map_err(|e| FromSqlError::Other(e.into()))?;
|
||||
Ok(UtcDateTime(parsed.to_utc()))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user