zstd_dict_id and such
This commit is contained in:
67
src/compressible_data.rs
Normal file
67
src/compressible_data.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
use axum::{body::Bytes, response::IntoResponse};
|
||||
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
pub enum CompressibleData {
|
||||
Bytes(Bytes),
|
||||
Vec(Vec<u8>),
|
||||
}
|
||||
|
||||
impl CompressibleData {
|
||||
pub fn len(&self) -> usize {
|
||||
match self {
|
||||
CompressibleData::Bytes(b) => b.len(),
|
||||
CompressibleData::Vec(v) => v.len(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for CompressibleData {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
match self {
|
||||
CompressibleData::Bytes(b) => b.as_ref(),
|
||||
CompressibleData::Vec(v) => v.as_ref(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Bytes> for CompressibleData {
|
||||
fn from(b: Bytes) -> Self {
|
||||
CompressibleData::Bytes(b)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<u8>> for CompressibleData {
|
||||
fn from(v: Vec<u8>) -> Self {
|
||||
CompressibleData::Vec(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for CompressibleData {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
CompressibleData::Bytes(b) => b.into_response(),
|
||||
CompressibleData::Vec(v) => v.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<&[u8]> for CompressibleData {
|
||||
fn eq(&self, other: &&[u8]) -> bool {
|
||||
let as_ref = self.as_ref();
|
||||
as_ref == *other
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<Vec<u8>> for CompressibleData {
|
||||
fn eq(&self, other: &Vec<u8>) -> bool {
|
||||
let as_ref = self.as_ref();
|
||||
as_ref == other.as_slice()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<Bytes> for CompressibleData {
|
||||
fn eq(&self, other: &Bytes) -> bool {
|
||||
let as_ref = self.as_ref();
|
||||
as_ref == other.as_ref()
|
||||
}
|
||||
}
|
||||
218
src/compressor.rs
Normal file
218
src/compressor.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::{
|
||||
compressible_data::CompressibleData,
|
||||
sql_types::{CompressionId, ZstdDictId},
|
||||
zstd_dict::{ZstdDict, ZstdDictArc},
|
||||
AsyncBoxError, CompressionPolicy,
|
||||
};
|
||||
|
||||
pub type CompressorArc = Arc<RwLock<Compressor>>;
|
||||
|
||||
pub struct Compressor {
|
||||
zstd_dict_by_id: HashMap<ZstdDictId, ZstdDictArc>,
|
||||
zstd_dict_by_name: HashMap<String, ZstdDictArc>,
|
||||
compression_policy: CompressionPolicy,
|
||||
}
|
||||
impl Compressor {
|
||||
pub fn new(compression_policy: CompressionPolicy) -> Self {
|
||||
Self {
|
||||
zstd_dict_by_id: HashMap::new(),
|
||||
zstd_dict_by_name: HashMap::new(),
|
||||
compression_policy,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Compressor {
|
||||
fn default() -> Self {
|
||||
Self::new(CompressionPolicy::Auto)
|
||||
}
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, AsyncBoxError>;
|
||||
|
||||
impl Compressor {
|
||||
pub fn into_arc(self) -> CompressorArc {
|
||||
Arc::new(RwLock::new(self))
|
||||
}
|
||||
|
||||
fn _add_from_samples<Str: Into<String>>(
|
||||
&mut self,
|
||||
id: ZstdDictId,
|
||||
name: Str,
|
||||
samples: Vec<&[u8]>,
|
||||
) -> Result<ZstdDictArc> {
|
||||
let name = name.into();
|
||||
self.check(id, &name)?;
|
||||
let zstd_dict = ZstdDict::from_samples(id, name, 3, samples)?;
|
||||
Ok(self.add(zstd_dict))
|
||||
}
|
||||
|
||||
pub fn add_from_bytes<Str: Into<String>>(
|
||||
&mut self,
|
||||
id: ZstdDictId,
|
||||
name: Str,
|
||||
level: i32,
|
||||
dict_bytes: Vec<u8>,
|
||||
) -> Result<ZstdDictArc> {
|
||||
let name = name.into();
|
||||
self.check(id, &name)?;
|
||||
let zstd_dict = ZstdDict::from_dict_bytes(id, name, level, dict_bytes);
|
||||
Ok(self.add(zstd_dict))
|
||||
}
|
||||
|
||||
pub fn by_id(&self, id: ZstdDictId) -> Option<&ZstdDict> {
|
||||
self.zstd_dict_by_id.get(&id).map(|d| &**d)
|
||||
}
|
||||
pub fn by_name(&self, name: &str) -> Option<&ZstdDict> {
|
||||
self.zstd_dict_by_name.get(name).map(|d| &**d)
|
||||
}
|
||||
|
||||
pub fn compress<Data: Into<CompressibleData>>(
|
||||
&self,
|
||||
name: &str,
|
||||
content_type: &str,
|
||||
data: Data,
|
||||
) -> Result<(CompressionId, CompressibleData)> {
|
||||
let data = data.into();
|
||||
let should_compress = match self.compression_policy {
|
||||
CompressionPolicy::None => false,
|
||||
CompressionPolicy::Auto => auto_compressible_content_type(content_type),
|
||||
CompressionPolicy::ForceZstd => true,
|
||||
};
|
||||
|
||||
if should_compress {
|
||||
let (id, compressed): (_, CompressibleData) = if let Some(dict) = self.by_name(name) {
|
||||
(
|
||||
CompressionId::ZstdDictId(dict.id()),
|
||||
dict.compress(&data)?.into(),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
CompressionId::ZstdGeneric,
|
||||
zstd::stream::encode_all(data.as_ref(), 3)?.into(),
|
||||
)
|
||||
};
|
||||
|
||||
if compressed.len() < data.len() {
|
||||
return Ok((id, compressed));
|
||||
}
|
||||
}
|
||||
|
||||
Ok((CompressionId::None, data))
|
||||
}
|
||||
|
||||
pub fn decompress<Data: Into<CompressibleData>>(
|
||||
&self,
|
||||
compression_id: CompressionId,
|
||||
data: Data,
|
||||
) -> Result<CompressibleData> {
|
||||
let data = data.into();
|
||||
match compression_id {
|
||||
CompressionId::None => Ok(data),
|
||||
CompressionId::ZstdDictId(id) => {
|
||||
if let Some(dict) = self.by_id(id) {
|
||||
Ok(CompressibleData::Vec(dict.decompress(data.as_ref())?))
|
||||
} else {
|
||||
Err(format!("zstd dictionary {:?} not found", id).into())
|
||||
}
|
||||
}
|
||||
CompressionId::ZstdGeneric => Ok(CompressibleData::Vec(zstd::stream::decode_all(
|
||||
data.as_ref(),
|
||||
)?)),
|
||||
}
|
||||
}
|
||||
|
||||
fn check(&self, id: ZstdDictId, name: &str) -> Result<()> {
|
||||
if self.zstd_dict_by_id.contains_key(&id) {
|
||||
return Err(format!("zstd dictionary {:?} already exists", id).into());
|
||||
}
|
||||
if self.zstd_dict_by_name.contains_key(name) {
|
||||
return Err(format!("zstd dictionary {} already exists", name).into());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn add(&mut self, zstd_dict: ZstdDict) -> ZstdDictArc {
|
||||
let zstd_dict = Arc::new(zstd_dict);
|
||||
self.zstd_dict_by_id
|
||||
.insert(zstd_dict.id(), zstd_dict.clone());
|
||||
self.zstd_dict_by_name
|
||||
.insert(zstd_dict.name().to_string(), zstd_dict.clone());
|
||||
zstd_dict
|
||||
}
|
||||
}
|
||||
|
||||
fn auto_compressible_content_type(content_type: &str) -> bool {
|
||||
[
|
||||
"text/",
|
||||
"application/xml",
|
||||
"application/json",
|
||||
"application/javascript",
|
||||
]
|
||||
.iter()
|
||||
.any(|ct| content_type.starts_with(ct))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
use rstest::rstest;
|
||||
|
||||
use super::*;
|
||||
use crate::zstd_dict::test::make_zstd_dict;
|
||||
|
||||
pub fn make_compressor() -> Compressor {
|
||||
make_compressor_with_policy(CompressionPolicy::Auto)
|
||||
}
|
||||
|
||||
pub fn make_compressor_with_policy(compression_policy: CompressionPolicy) -> Compressor {
|
||||
let mut compressor = Compressor::new(compression_policy);
|
||||
let zstd_dict = make_zstd_dict(1.into(), "dict1");
|
||||
compressor.add(zstd_dict);
|
||||
compressor
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auto_compressible_content_type() {
|
||||
assert!(auto_compressible_content_type("text/plain"));
|
||||
assert!(auto_compressible_content_type("application/xml"));
|
||||
assert!(auto_compressible_content_type("application/json"));
|
||||
assert!(auto_compressible_content_type("application/javascript"));
|
||||
assert!(!auto_compressible_content_type("image/png"));
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[test]
|
||||
fn test_compression_policies(
|
||||
#[values(
|
||||
CompressionPolicy::Auto,
|
||||
CompressionPolicy::None,
|
||||
CompressionPolicy::ForceZstd
|
||||
)]
|
||||
compression_policy: CompressionPolicy,
|
||||
#[values("text/plain", "application/json", "image/png")] content_type: &str,
|
||||
) {
|
||||
let compressor = make_compressor_with_policy(compression_policy);
|
||||
let data = b"hello, world!".to_vec();
|
||||
let (compression_id, compressed) = compressor
|
||||
.compress("dict1", content_type, data.clone())
|
||||
.unwrap();
|
||||
|
||||
let data_uncompressed = compressor.decompress(compression_id, compressed).unwrap();
|
||||
assert_eq!(data_uncompressed, data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skip_compressing_small_data() {
|
||||
let compressor = make_compressor();
|
||||
let data = b"hello, world".to_vec();
|
||||
let (compression_id, compressed) = compressor
|
||||
.compress("dict1", "text/plain", data.clone())
|
||||
.unwrap();
|
||||
assert_eq!(compression_id, CompressionId::None);
|
||||
assert_eq!(compressed, data);
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,22 @@
|
||||
use crate::{sha256::Sha256, shard::GetResult, shards::Shards};
|
||||
use crate::{
|
||||
compressor::CompressorArc,
|
||||
sha256::Sha256,
|
||||
shard::{GetArgs, GetResult},
|
||||
shards::Shards,
|
||||
AsyncBoxError,
|
||||
};
|
||||
use axum::{
|
||||
extract::Path,
|
||||
http::{header, HeaderMap, HeaderValue, StatusCode},
|
||||
response::IntoResponse,
|
||||
Extension, Json,
|
||||
};
|
||||
use std::{collections::HashMap, error::Error};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
pub enum GetResponse {
|
||||
MissingSha256,
|
||||
InvalidSha256 { message: String },
|
||||
InternalError { error: Box<dyn Error> },
|
||||
InternalError { error: AsyncBoxError },
|
||||
NotFound,
|
||||
Found { get_result: GetResult },
|
||||
}
|
||||
@@ -69,7 +75,7 @@ fn make_found_response(
|
||||
Err(e) => return GetResponse::from(e).into_response(),
|
||||
};
|
||||
|
||||
let created_at = match HeaderValue::from_str(&created_at.to_rfc3339()) {
|
||||
let created_at = match HeaderValue::from_str(&created_at.to_string()) {
|
||||
Ok(created_at) => created_at,
|
||||
Err(e) => return GetResponse::from(e).into_response(),
|
||||
};
|
||||
@@ -98,7 +104,7 @@ fn make_found_response(
|
||||
(StatusCode::OK, headers, data).into_response()
|
||||
}
|
||||
|
||||
impl<E: Into<Box<dyn Error>>> From<E> for GetResponse {
|
||||
impl<E: Into<AsyncBoxError>> From<E> for GetResponse {
|
||||
fn from(error: E) -> Self {
|
||||
GetResponse::InternalError {
|
||||
error: error.into(),
|
||||
@@ -109,7 +115,8 @@ impl<E: Into<Box<dyn Error>>> From<E> for GetResponse {
|
||||
#[axum::debug_handler]
|
||||
pub async fn get_handler(
|
||||
Path(params): Path<HashMap<String, String>>,
|
||||
Extension(shards): Extension<Shards>,
|
||||
Extension(shards): Extension<Arc<Shards>>,
|
||||
Extension(compressor): Extension<CompressorArc>,
|
||||
) -> GetResponse {
|
||||
let sha256_str = match params.get("sha256") {
|
||||
Some(sha256_str) => sha256_str.clone(),
|
||||
@@ -128,7 +135,7 @@ pub async fn get_handler(
|
||||
};
|
||||
|
||||
let shard = shards.shard_for(&sha256);
|
||||
let get_result = match shard.get(sha256).await {
|
||||
let get_result = match shard.get(GetArgs { sha256, compressor }).await {
|
||||
Ok(get_result) => get_result,
|
||||
Err(e) => return e.into(),
|
||||
};
|
||||
@@ -141,7 +148,10 @@ pub async fn get_handler(
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::{sha256::Sha256, shard::GetResult, shards::test::make_shards};
|
||||
use crate::{
|
||||
compressor::test::make_compressor, sha256::Sha256, shard::GetResult,
|
||||
shards::test::make_shards, sql_types::UtcDateTime,
|
||||
};
|
||||
use axum::{extract::Path, response::IntoResponse, Extension};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@@ -150,13 +160,16 @@ mod test {
|
||||
#[tokio::test]
|
||||
async fn test_get_invalid_sha256() {
|
||||
let shards = Extension(make_shards().await);
|
||||
let response = super::get_handler(Path(HashMap::new()), shards.clone()).await;
|
||||
let compressor = Extension(make_compressor().into_arc());
|
||||
|
||||
let response =
|
||||
super::get_handler(Path(HashMap::new()), shards.clone(), compressor.clone()).await;
|
||||
assert!(matches!(response, super::GetResponse::MissingSha256 { .. }));
|
||||
|
||||
let shards = Extension(make_shards().await);
|
||||
let response = super::get_handler(
|
||||
Path(HashMap::from([(String::from("sha256"), String::from(""))])),
|
||||
shards.clone(),
|
||||
compressor.clone(),
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(response, super::GetResponse::InvalidSha256 { .. }));
|
||||
@@ -167,6 +180,7 @@ mod test {
|
||||
String::from("invalid"),
|
||||
)])),
|
||||
shards.clone(),
|
||||
compressor.clone(),
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(response, super::GetResponse::InvalidSha256 { .. }));
|
||||
@@ -174,8 +188,8 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn test_get_response_found_into_response() {
|
||||
let data = "hello, world!";
|
||||
let sha256 = Sha256::from_bytes(data.as_bytes());
|
||||
let data = "hello, world!".as_bytes().to_owned();
|
||||
let sha256 = Sha256::from_bytes(&data);
|
||||
let sha256_str = sha256.hex_string();
|
||||
let created_at = "2022-03-04T08:12:34+00:00";
|
||||
let response = GetResponse::Found {
|
||||
@@ -183,9 +197,7 @@ mod test {
|
||||
sha256,
|
||||
content_type: "text/plain".to_string(),
|
||||
stored_size: 12345,
|
||||
created_at: chrono::DateTime::parse_from_rfc3339(created_at)
|
||||
.unwrap()
|
||||
.to_utc(),
|
||||
created_at: UtcDateTime::from_string(created_at).unwrap(),
|
||||
data: data.into(),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::shards::Shards;
|
||||
use axum::{http::StatusCode, Extension, Json};
|
||||
|
||||
@@ -19,7 +21,7 @@ pub struct ShardInfo {
|
||||
|
||||
#[axum::debug_handler]
|
||||
pub async fn info_handler(
|
||||
Extension(shards): Extension<Shards>,
|
||||
Extension(shards): Extension<Arc<Shards>>,
|
||||
) -> Result<(StatusCode, Json<InfoResponse>), StatusCode> {
|
||||
let mut shard_infos = vec![];
|
||||
let mut total_db_size_bytes = 0;
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use crate::{
|
||||
compressor::CompressorArc,
|
||||
sha256::Sha256,
|
||||
shard::{StoreArgs, StoreResult},
|
||||
shards::Shards,
|
||||
shards::ShardsArc,
|
||||
};
|
||||
use axum::{body::Bytes, response::IntoResponse, Extension, Json};
|
||||
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||
@@ -59,7 +60,7 @@ impl From<StoreResult> for StoreResponse {
|
||||
} => StoreResponse::Created {
|
||||
stored_size,
|
||||
data_size,
|
||||
created_at: created_at.to_rfc3339(),
|
||||
created_at: created_at.to_string(),
|
||||
},
|
||||
StoreResult::Exists {
|
||||
stored_size,
|
||||
@@ -68,7 +69,7 @@ impl From<StoreResult> for StoreResponse {
|
||||
} => StoreResponse::Exists {
|
||||
stored_size,
|
||||
data_size,
|
||||
created_at: created_at.to_rfc3339(),
|
||||
created_at: created_at.to_string(),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -82,7 +83,8 @@ impl IntoResponse for StoreResponse {
|
||||
|
||||
#[axum::debug_handler]
|
||||
pub async fn store_handler(
|
||||
Extension(shards): Extension<Shards>,
|
||||
Extension(shards): Extension<ShardsArc>,
|
||||
Extension(compressor): Extension<CompressorArc>,
|
||||
TypedMultipart(request): TypedMultipart<StoreRequest>,
|
||||
) -> StoreResponse {
|
||||
let sha256 = Sha256::from_bytes(&request.data.contents);
|
||||
@@ -106,6 +108,7 @@ pub async fn store_handler(
|
||||
sha256,
|
||||
content_type: request.content_type,
|
||||
data: request.data.contents,
|
||||
compressor,
|
||||
})
|
||||
.await
|
||||
{
|
||||
@@ -118,20 +121,23 @@ pub async fn store_handler(
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
use crate::{compressor::Compressor, shards::test::make_shards};
|
||||
|
||||
use super::*;
|
||||
use crate::{shards::test::make_shards_with_compression, UseCompression};
|
||||
use crate::CompressionPolicy;
|
||||
use axum::body::Bytes;
|
||||
use axum_typed_multipart::FieldData;
|
||||
use rstest::rstest;
|
||||
|
||||
async fn send_request<D: Into<Bytes>>(
|
||||
compression_policy: CompressionPolicy,
|
||||
sha256: Option<Sha256>,
|
||||
content_type: &str,
|
||||
use_compression: UseCompression,
|
||||
data: D,
|
||||
) -> StoreResponse {
|
||||
store_handler(
|
||||
Extension(make_shards_with_compression(use_compression).await),
|
||||
Extension(make_shards().await),
|
||||
Extension(Compressor::new(compression_policy).into_arc()),
|
||||
TypedMultipart(StoreRequest {
|
||||
sha256: sha256.map(|s| s.hex_string()),
|
||||
content_type: content_type.to_string(),
|
||||
@@ -146,8 +152,9 @@ pub mod test {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_handler() {
|
||||
let result = send_request(None, "text/plain", UseCompression::Auto, "hello, world!").await;
|
||||
assert_eq!(result.status_code(), StatusCode::CREATED);
|
||||
let result =
|
||||
send_request(CompressionPolicy::Auto, None, "text/plain", "hello, world!").await;
|
||||
assert_eq!(result.status_code(), StatusCode::CREATED, "{:?}", result);
|
||||
assert!(matches!(result, StoreResponse::Created { .. }));
|
||||
}
|
||||
|
||||
@@ -156,9 +163,9 @@ pub mod test {
|
||||
let not_hello_world = Sha256::from_bytes("goodbye, planet!".as_bytes());
|
||||
let hello_world = Sha256::from_bytes("hello, world!".as_bytes());
|
||||
let result = send_request(
|
||||
CompressionPolicy::Auto,
|
||||
Some(not_hello_world),
|
||||
"text/plain",
|
||||
UseCompression::Auto,
|
||||
"hello, world!",
|
||||
)
|
||||
.await;
|
||||
@@ -175,9 +182,9 @@ pub mod test {
|
||||
async fn test_store_handler_matching_sha256() {
|
||||
let hello_world = Sha256::from_bytes("hello, world!".as_bytes());
|
||||
let result = send_request(
|
||||
CompressionPolicy::Auto,
|
||||
Some(hello_world),
|
||||
"text/plain",
|
||||
UseCompression::Auto,
|
||||
"hello, world!",
|
||||
)
|
||||
.await;
|
||||
@@ -194,20 +201,20 @@ pub mod test {
|
||||
|
||||
#[rstest]
|
||||
// textual should be compressed by default
|
||||
#[case("text/plain", UseCompression::Auto, make_assert_lt(1024))]
|
||||
#[case("text/plain", UseCompression::Zstd, make_assert_lt(1024))]
|
||||
#[case("text/plain", UseCompression::None, make_assert_eq(1024))]
|
||||
#[case("text/plain", CompressionPolicy::Auto, make_assert_lt(1024))]
|
||||
#[case("text/plain", CompressionPolicy::ForceZstd, make_assert_lt(1024))]
|
||||
#[case("text/plain", CompressionPolicy::None, make_assert_eq(1024))]
|
||||
// images, etc should not be compressed by default
|
||||
#[case("image/jpg", UseCompression::Auto, make_assert_eq(1024))]
|
||||
#[case("image/jpg", UseCompression::Zstd, make_assert_lt(1024))]
|
||||
#[case("image/jpg", UseCompression::None, make_assert_eq(1024))]
|
||||
#[case("image/jpg", CompressionPolicy::Auto, make_assert_eq(1024))]
|
||||
#[case("image/jpg", CompressionPolicy::ForceZstd, make_assert_lt(1024))]
|
||||
#[case("image/jpg", CompressionPolicy::None, make_assert_eq(1024))]
|
||||
#[tokio::test]
|
||||
async fn test_compressible_data<F: Fn(usize)>(
|
||||
#[case] content_type: &str,
|
||||
#[case] use_compression: UseCompression,
|
||||
#[case] compression_policy: CompressionPolicy,
|
||||
#[case] assert_stored_size: F,
|
||||
) {
|
||||
let result = send_request(None, content_type, use_compression, vec![0; 1024]).await;
|
||||
let result = send_request(compression_policy, None, content_type, vec![0; 1024]).await;
|
||||
assert_eq!(result.status_code(), StatusCode::CREATED);
|
||||
match result {
|
||||
StoreResponse::Created {
|
||||
|
||||
63
src/main.rs
63
src/main.rs
@@ -1,9 +1,13 @@
|
||||
mod compressible_data;
|
||||
mod compressor;
|
||||
mod handlers;
|
||||
mod manifest;
|
||||
mod sha256;
|
||||
mod shard;
|
||||
mod shards;
|
||||
mod shutdown_signal;
|
||||
mod sql_types;
|
||||
mod zstd_dict;
|
||||
|
||||
use crate::{manifest::Manifest, shards::Shards};
|
||||
use axum::{
|
||||
@@ -11,10 +15,12 @@ use axum::{
|
||||
Extension, Router,
|
||||
};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use compressor::CompressorArc;
|
||||
use futures::executor::block_on;
|
||||
use shard::Shard;
|
||||
use std::{error::Error, path::PathBuf};
|
||||
use tokio::net::TcpListener;
|
||||
use shards::ShardsArc;
|
||||
use std::{error::Error, path::PathBuf, sync::Arc};
|
||||
use tokio::{net::TcpListener, select, spawn};
|
||||
use tokio_rusqlite::Connection;
|
||||
use tracing::info;
|
||||
|
||||
@@ -39,18 +45,24 @@ struct Args {
|
||||
|
||||
/// How to compress stored data
|
||||
#[arg(short, long, default_value = "auto")]
|
||||
compression: UseCompression,
|
||||
compression: CompressionPolicy,
|
||||
}
|
||||
|
||||
#[derive(Default, PartialEq, Debug, Copy, Clone, ValueEnum)]
|
||||
pub enum UseCompression {
|
||||
pub enum CompressionPolicy {
|
||||
#[default]
|
||||
Auto,
|
||||
None,
|
||||
Zstd,
|
||||
ForceZstd,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
pub type AsyncBoxError = Box<dyn Error + Send + Sync + 'static>;
|
||||
|
||||
pub fn into_tokio_rusqlite_err<E: Into<AsyncBoxError>>(e: E) -> tokio_rusqlite::Error {
|
||||
tokio_rusqlite::Error::Other(e.into())
|
||||
}
|
||||
|
||||
fn main() -> Result<(), AsyncBoxError> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::DEBUG)
|
||||
.init();
|
||||
@@ -85,7 +97,7 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
for shard_id in 0..manifest.num_shards() {
|
||||
let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id));
|
||||
let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?;
|
||||
let shard = Shard::open(shard_id, UseCompression::Auto, shard_sqlite_conn).await?;
|
||||
let shard = Shard::open(shard_id, shard_sqlite_conn).await?;
|
||||
info!(
|
||||
"shard {} has {} entries",
|
||||
shard.id(),
|
||||
@@ -94,23 +106,46 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
shards_vec.push(shard);
|
||||
}
|
||||
|
||||
let shards = Shards::new(shards_vec).ok_or("num shards must be > 0")?;
|
||||
server_loop(server, shards.clone()).await?;
|
||||
info!("shutting down server...");
|
||||
shards.close_all().await?;
|
||||
let shards = Arc::new(Shards::new(shards_vec).ok_or("num shards must be > 0")?);
|
||||
let compressor = manifest.compressor();
|
||||
let dict_loop_handle = spawn(dict_loop(manifest, shards.clone()));
|
||||
let server_handle = spawn(server_loop(server, shards, compressor));
|
||||
dict_loop_handle.await?;
|
||||
server_handle.await??;
|
||||
info!("server closed sqlite connections. bye!");
|
||||
Ok::<_, Box<dyn Error>>(())
|
||||
Ok::<_, AsyncBoxError>(())
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box<dyn Error>> {
|
||||
async fn dict_loop(manifest: Manifest, shards: ShardsArc) {
|
||||
loop {
|
||||
select! {
|
||||
_ = tokio::time::sleep(tokio::time::Duration::from_secs(1)) => {}
|
||||
_ = crate::shutdown_signal::shutdown_signal() => {
|
||||
info!("dict loop: shutdown signal received");
|
||||
break;
|
||||
}
|
||||
}
|
||||
info!("dict loop: running");
|
||||
let compressor = manifest.compressor();
|
||||
let _compressor = compressor.read().await;
|
||||
for _shard in shards.iter() {}
|
||||
}
|
||||
}
|
||||
|
||||
async fn server_loop(
|
||||
server: TcpListener,
|
||||
shards: ShardsArc,
|
||||
compressor: CompressorArc,
|
||||
) -> Result<(), AsyncBoxError> {
|
||||
let app = Router::new()
|
||||
.route("/store", post(handlers::store_handler::store_handler))
|
||||
.route("/get/:sha256", get(handlers::get_handler::get_handler))
|
||||
.route("/info", get(handlers::info_handler::info_handler))
|
||||
.layer(Extension(shards));
|
||||
.layer(Extension(shards))
|
||||
.layer(Extension(compressor));
|
||||
|
||||
axum::serve(server, app.into_make_service())
|
||||
.with_graceful_shutdown(crate::shutdown_signal::shutdown_signal())
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
use rusqlite::types::FromSql;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
|
||||
pub struct DictId(i64);
|
||||
impl FromSql for DictId {
|
||||
fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
|
||||
Ok(DictId(value.as_i64()?))
|
||||
}
|
||||
}
|
||||
@@ -1,99 +1,81 @@
|
||||
mod dict_id;
|
||||
mod manifest_key;
|
||||
mod zstd_dict;
|
||||
|
||||
use std::{collections::HashMap, error::Error, sync::Arc};
|
||||
use std::{error::Error, sync::Arc};
|
||||
|
||||
use rusqlite::params;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_rusqlite::Connection;
|
||||
|
||||
use crate::shards::Shards;
|
||||
|
||||
use self::{
|
||||
dict_id::DictId,
|
||||
manifest_key::{get_manifest_key, set_manifest_key, NumShards},
|
||||
zstd_dict::ZstdDict,
|
||||
use crate::{
|
||||
compressor::{Compressor, CompressorArc},
|
||||
into_tokio_rusqlite_err,
|
||||
zstd_dict::ZstdDictArc,
|
||||
AsyncBoxError,
|
||||
};
|
||||
|
||||
pub type ZstdDictArc = Arc<ZstdDict>;
|
||||
use self::manifest_key::{get_manifest_key, set_manifest_key, NumShards};
|
||||
|
||||
pub struct Manifest {
|
||||
conn: Connection,
|
||||
num_shards: usize,
|
||||
zstd_dict_by_id: HashMap<DictId, ZstdDictArc>,
|
||||
zstd_dict_by_name: HashMap<String, ZstdDictArc>,
|
||||
compressor: Arc<RwLock<Compressor>>,
|
||||
}
|
||||
|
||||
pub type ManifestArc = Arc<Manifest>;
|
||||
|
||||
impl Manifest {
|
||||
pub async fn open(conn: Connection, num_shards: Option<usize>) -> Result<Self, Box<dyn Error>> {
|
||||
pub async fn open(conn: Connection, num_shards: Option<usize>) -> Result<Self, AsyncBoxError> {
|
||||
initialize(conn, num_shards).await
|
||||
}
|
||||
|
||||
pub fn into_arc(self) -> ManifestArc {
|
||||
Arc::new(self)
|
||||
}
|
||||
|
||||
pub fn num_shards(&self) -> usize {
|
||||
self.num_shards
|
||||
}
|
||||
|
||||
async fn train_zstd_dict_with_tag(
|
||||
&mut self,
|
||||
_name: &str,
|
||||
_shards: Shards,
|
||||
) -> Result<ZstdDictArc, Box<dyn Error>> {
|
||||
// let mut queries = vec![];
|
||||
// for shard in shards.iter() {
|
||||
// queries.push(shard.entries_for_tag(name));
|
||||
// }
|
||||
todo!();
|
||||
pub fn compressor(&self) -> CompressorArc {
|
||||
self.compressor.clone()
|
||||
}
|
||||
|
||||
async fn create_zstd_dict_from_samples(
|
||||
&mut self,
|
||||
name: &str,
|
||||
pub async fn insert_zstd_dict_from_samples<Str: Into<String>>(
|
||||
&self,
|
||||
name: Str,
|
||||
samples: Vec<&[u8]>,
|
||||
) -> Result<ZstdDictArc, Box<dyn Error>> {
|
||||
if self.zstd_dict_by_name.contains_key(name) {
|
||||
return Err(format!("dictionary {} already exists", name).into());
|
||||
}
|
||||
|
||||
let level = 3;
|
||||
let name = name.into();
|
||||
let dict_bytes = zstd::dict::from_samples(
|
||||
&samples,
|
||||
1024 * 1024, // 1MB max dictionary size
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let name_copy = name.to_string();
|
||||
let (dict_id, zstd_dict) = self
|
||||
)?;
|
||||
let compressor = self.compressor.clone();
|
||||
let zstd_dict = self
|
||||
.conn
|
||||
.call(move |conn| {
|
||||
let level = 3;
|
||||
let mut stmt = conn.prepare(
|
||||
"INSERT INTO dictionaries (name, level, dict)
|
||||
VALUES (?, ?, ?)
|
||||
RETURNING id",
|
||||
VALUES (?, ?, ?)
|
||||
RETURNING id",
|
||||
)?;
|
||||
let dict_id =
|
||||
stmt.query_row(params![name_copy, level, dict_bytes], |row| row.get(0))?;
|
||||
let zstd_dict = Arc::new(ZstdDict::create(dict_id, name_copy, level, dict_bytes));
|
||||
Ok((dict_id, zstd_dict))
|
||||
let dict_id = stmt.query_row(params![name, level, dict_bytes], |row| row.get(0))?;
|
||||
let mut compressor = compressor.blocking_write();
|
||||
let zstd_dict = compressor
|
||||
.add_from_bytes(dict_id, name, level, dict_bytes)
|
||||
.map_err(into_tokio_rusqlite_err)?;
|
||||
Ok(zstd_dict)
|
||||
})
|
||||
.await?;
|
||||
|
||||
self.zstd_dict_by_id.insert(dict_id, zstd_dict.clone());
|
||||
self.zstd_dict_by_name
|
||||
.insert(name.to_string(), zstd_dict.clone());
|
||||
Ok(zstd_dict)
|
||||
}
|
||||
|
||||
fn get_dictionary_by_id(&self, id: DictId) -> Option<&ZstdDict> {
|
||||
self.zstd_dict_by_id.get(&id).map(|d| &**d)
|
||||
}
|
||||
fn get_dictionary_by_name(&self, name: &str) -> Option<&ZstdDict> {
|
||||
self.zstd_dict_by_name.get(name).map(|d| &**d)
|
||||
}
|
||||
}
|
||||
|
||||
async fn initialize(
|
||||
conn: Connection,
|
||||
num_shards: Option<usize>,
|
||||
) -> Result<Manifest, Box<dyn Error>> {
|
||||
) -> Result<Manifest, Box<dyn Error + Send + Sync>> {
|
||||
let stored_num_shards: Option<usize> = conn
|
||||
.call(|conn| {
|
||||
conn.execute(
|
||||
@@ -161,19 +143,15 @@ async fn initialize(
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut zstd_dicts_by_id = HashMap::new();
|
||||
let mut zstd_dicts_by_name = HashMap::new();
|
||||
let mut compressor = Compressor::default();
|
||||
for (id, name, level, dict_bytes) in rows {
|
||||
let zstd_dict = Arc::new(ZstdDict::create(id, name.clone(), level, dict_bytes));
|
||||
zstd_dicts_by_id.insert(id, zstd_dict.clone());
|
||||
zstd_dicts_by_name.insert(name, zstd_dict);
|
||||
compressor.add_from_bytes(id, name, level, dict_bytes)?;
|
||||
}
|
||||
|
||||
let compressor = compressor.into_arc();
|
||||
Ok(Manifest {
|
||||
conn,
|
||||
num_shards,
|
||||
zstd_dict_by_id: zstd_dicts_by_id,
|
||||
zstd_dict_by_name: zstd_dicts_by_name,
|
||||
compressor,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -184,22 +162,32 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_manifest() {
|
||||
let conn = Connection::open_in_memory().await.unwrap();
|
||||
let mut manifest = initialize(conn, Some(3)).await.unwrap();
|
||||
let manifest = initialize(conn, Some(3)).await.unwrap();
|
||||
|
||||
let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100];
|
||||
let zstd_dict = manifest
|
||||
.create_zstd_dict_from_samples("test", samples)
|
||||
.insert_zstd_dict_from_samples("test", samples)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// test that indexes are created correctly
|
||||
assert_eq!(
|
||||
zstd_dict.as_ref(),
|
||||
manifest.get_dictionary_by_id(zstd_dict.id()).unwrap()
|
||||
manifest
|
||||
.compressor()
|
||||
.read()
|
||||
.await
|
||||
.by_id(zstd_dict.id())
|
||||
.unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
zstd_dict.as_ref(),
|
||||
manifest.get_dictionary_by_name(zstd_dict.name()).unwrap()
|
||||
manifest
|
||||
.compressor()
|
||||
.read()
|
||||
.await
|
||||
.by_name(zstd_dict.name())
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
let data = b"hello world, this is a test of a sort of long string";
|
||||
|
||||
@@ -1,73 +1,71 @@
|
||||
use std::io::Read;
|
||||
use crate::{
|
||||
compressible_data::CompressibleData,
|
||||
compressor::CompressorArc,
|
||||
into_tokio_rusqlite_err,
|
||||
sql_types::{CompressionId, UtcDateTime},
|
||||
AsyncBoxError,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
pub struct GetArgs {
|
||||
pub sha256: Sha256,
|
||||
pub compressor: CompressorArc,
|
||||
}
|
||||
|
||||
pub struct GetResult {
|
||||
pub sha256: Sha256,
|
||||
pub content_type: String,
|
||||
pub stored_size: usize,
|
||||
pub created_at: UtcDateTime,
|
||||
pub data: Vec<u8>,
|
||||
pub data: CompressibleData,
|
||||
}
|
||||
|
||||
impl Shard {
|
||||
pub async fn get(&self, sha256: Sha256) -> Result<Option<GetResult>, Box<dyn Error>> {
|
||||
self.conn
|
||||
.call(move |conn| get_impl(conn, sha256))
|
||||
pub async fn get(&self, args: GetArgs) -> Result<Option<GetResult>, AsyncBoxError> {
|
||||
let sha256 = args.sha256;
|
||||
let maybe_row = self
|
||||
.conn
|
||||
.call(move |conn| get_compressed_row(conn, sha256).map_err(into_tokio_rusqlite_err))
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("get failed: {}", e);
|
||||
e.into()
|
||||
})
|
||||
Box::new(e)
|
||||
})?;
|
||||
|
||||
if let Some((content_type, stored_size, created_at, compression_id, data)) = maybe_row {
|
||||
let compressor = args.compressor.read().await;
|
||||
let data = compressor.decompress(compression_id, data)?;
|
||||
Ok(Some(GetResult {
|
||||
sha256: args.sha256,
|
||||
content_type,
|
||||
stored_size,
|
||||
created_at,
|
||||
data,
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_impl(
|
||||
fn get_compressed_row(
|
||||
conn: &mut rusqlite::Connection,
|
||||
sha256: Sha256,
|
||||
) -> Result<Option<GetResult>, tokio_rusqlite::Error> {
|
||||
let maybe_row = conn
|
||||
.query_row(
|
||||
"SELECT content_type, compressed_size, created_at, compression, data
|
||||
) -> Result<Option<(String, usize, UtcDateTime, CompressionId, Vec<u8>)>, rusqlite::Error> {
|
||||
conn.query_row(
|
||||
"SELECT content_type, compressed_size, created_at, compression_id, data
|
||||
FROM entries
|
||||
WHERE sha256 = ?",
|
||||
params![sha256.hex_string()],
|
||||
|row| {
|
||||
let content_type = row.get(0)?;
|
||||
let stored_size = row.get(1)?;
|
||||
let created_at = parse_created_at_str(row.get(2)?)?;
|
||||
let compression = row.get(3)?;
|
||||
let data: Vec<u8> = row.get(4)?;
|
||||
Ok((content_type, stored_size, created_at, compression, data))
|
||||
},
|
||||
)
|
||||
.optional()
|
||||
.map_err(into_tokio_rusqlite_err)?;
|
||||
|
||||
let row = match maybe_row {
|
||||
Some(row) => row,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let (content_type, stored_size, created_at, compression, data) = row;
|
||||
let data = match compression {
|
||||
Compression::None => data,
|
||||
Compression::Zstd => {
|
||||
let mut decoder =
|
||||
zstd::Decoder::new(data.as_slice()).map_err(into_tokio_rusqlite_err)?;
|
||||
let mut decompressed = vec![];
|
||||
decoder
|
||||
.read_to_end(&mut decompressed)
|
||||
.map_err(into_tokio_rusqlite_err)?;
|
||||
decompressed
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Some(GetResult {
|
||||
sha256,
|
||||
content_type,
|
||||
stored_size,
|
||||
created_at,
|
||||
data,
|
||||
}))
|
||||
params![sha256.hex_string()],
|
||||
|row| {
|
||||
let content_type = row.get(0)?;
|
||||
let stored_size = row.get(1)?;
|
||||
let created_at = row.get(2)?;
|
||||
let compression_id = row.get(3)?;
|
||||
let data: Vec<u8> = row.get(4)?;
|
||||
Ok((content_type, stored_size, created_at, compression_id, data))
|
||||
},
|
||||
)
|
||||
.optional()
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use crate::AsyncBoxError;
|
||||
|
||||
use super::*;
|
||||
|
||||
impl Shard {
|
||||
pub(super) async fn migrate(&self) -> Result<(), Box<dyn Error>> {
|
||||
pub(super) async fn migrate(&self) -> Result<(), AsyncBoxError> {
|
||||
let shard_id = self.id();
|
||||
// create tables, indexes, etc
|
||||
self.conn
|
||||
@@ -60,7 +62,7 @@ fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Err
|
||||
"CREATE TABLE IF NOT EXISTS entries (
|
||||
sha256 BLOB PRIMARY KEY,
|
||||
content_type TEXT NOT NULL,
|
||||
compression INTEGER NOT NULL,
|
||||
compression_id INTEGER NOT NULL,
|
||||
uncompressed_size INTEGER NOT NULL,
|
||||
compressed_size INTEGER NOT NULL,
|
||||
data BLOB NOT NULL,
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
use crate::{
|
||||
compressor::CompressorArc,
|
||||
into_tokio_rusqlite_err,
|
||||
sql_types::{CompressionId, UtcDateTime},
|
||||
AsyncBoxError,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(PartialEq, Debug)]
|
||||
@@ -19,89 +26,100 @@ pub struct StoreArgs {
|
||||
pub sha256: Sha256,
|
||||
pub content_type: String,
|
||||
pub data: Bytes,
|
||||
pub compressor: CompressorArc,
|
||||
}
|
||||
|
||||
impl Shard {
|
||||
pub async fn store(&self, store_args: StoreArgs) -> Result<StoreResult, Box<dyn Error>> {
|
||||
let use_compression = self.use_compression;
|
||||
self.conn
|
||||
.call(move |conn| store(conn, use_compression, store_args))
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("store failed: {}", e);
|
||||
e.into()
|
||||
pub async fn store(
|
||||
&self,
|
||||
StoreArgs {
|
||||
sha256,
|
||||
data,
|
||||
content_type,
|
||||
compressor,
|
||||
}: StoreArgs,
|
||||
) -> Result<StoreResult, AsyncBoxError> {
|
||||
let sha256 = sha256.hex_string();
|
||||
|
||||
// check for existing entry
|
||||
let sha256_clone = sha256.clone();
|
||||
let maybe_existing_entry = self
|
||||
.conn
|
||||
.call(move |conn| {
|
||||
find_with_sha256(conn, sha256_clone.as_str()).map_err(into_tokio_rusqlite_err)
|
||||
})
|
||||
.await?;
|
||||
|
||||
if let Some(entry) = maybe_existing_entry {
|
||||
return Ok(entry);
|
||||
}
|
||||
|
||||
let uncompressed_size = data.len();
|
||||
|
||||
let compressor = compressor.read().await;
|
||||
let (compression_id, data) = compressor.compress("foobar", &content_type, data)?;
|
||||
|
||||
self.conn
|
||||
.call(move |conn| {
|
||||
insert(
|
||||
conn,
|
||||
sha256,
|
||||
content_type,
|
||||
compression_id,
|
||||
uncompressed_size,
|
||||
data.as_ref(),
|
||||
)
|
||||
.map_err(into_tokio_rusqlite_err)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
|
||||
fn store(
|
||||
fn find_with_sha256(
|
||||
conn: &mut rusqlite::Connection,
|
||||
use_compression: UseCompression,
|
||||
StoreArgs {
|
||||
sha256,
|
||||
content_type,
|
||||
data,
|
||||
}: StoreArgs,
|
||||
) -> Result<StoreResult, tokio_rusqlite::Error> {
|
||||
let sha256 = sha256.hex_string();
|
||||
sha256: &str,
|
||||
) -> Result<Option<StoreResult>, rusqlite::Error> {
|
||||
conn.query_row(
|
||||
"SELECT uncompressed_size, compressed_size, created_at FROM entries WHERE sha256 = ?",
|
||||
params![sha256],
|
||||
|row| {
|
||||
Ok(StoreResult::Exists {
|
||||
stored_size: row.get(0)?,
|
||||
data_size: row.get(1)?,
|
||||
created_at: row.get(2)?,
|
||||
})
|
||||
},
|
||||
)
|
||||
.optional()
|
||||
}
|
||||
|
||||
// check for existing entry
|
||||
let maybe_existing: Option<StoreResult> = conn
|
||||
.query_row(
|
||||
"SELECT uncompressed_size, compressed_size, created_at FROM entries WHERE sha256 = ?",
|
||||
params![sha256],
|
||||
|row| {
|
||||
Ok(StoreResult::Exists {
|
||||
stored_size: row.get(0)?,
|
||||
data_size: row.get(1)?,
|
||||
created_at: parse_created_at_str(row.get(2)?)?,
|
||||
})
|
||||
},
|
||||
)
|
||||
.optional()?;
|
||||
|
||||
if let Some(existing) = maybe_existing {
|
||||
return Ok(existing);
|
||||
}
|
||||
|
||||
let created_at = chrono::Utc::now();
|
||||
let uncompressed_size = data.len();
|
||||
let tmp_data_holder;
|
||||
|
||||
let use_compression = match use_compression {
|
||||
UseCompression::None => false,
|
||||
UseCompression::Auto => auto_compressible_content_type(&content_type),
|
||||
UseCompression::Zstd => true,
|
||||
};
|
||||
|
||||
let (compression, data) = if use_compression {
|
||||
tmp_data_holder = zstd::encode_all(&data[..], 0).map_err(into_tokio_rusqlite_err)?;
|
||||
if tmp_data_holder.len() < data.len() {
|
||||
(Compression::Zstd, &tmp_data_holder[..])
|
||||
} else {
|
||||
(Compression::None, &data[..])
|
||||
}
|
||||
} else {
|
||||
(Compression::None, &data[..])
|
||||
};
|
||||
fn insert(
|
||||
conn: &mut rusqlite::Connection,
|
||||
sha256: String,
|
||||
content_type: String,
|
||||
compression_id: CompressionId,
|
||||
uncompressed_size: usize,
|
||||
data: &[u8],
|
||||
) -> Result<StoreResult, rusqlite::Error> {
|
||||
let created_at = UtcDateTime::now();
|
||||
let compressed_size = data.len();
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO entries
|
||||
(sha256, content_type, compression, uncompressed_size, compressed_size, data, created_at)
|
||||
VALUES
|
||||
(?, ?, ?, ?, ?, ?, ?)
|
||||
",
|
||||
params![
|
||||
sha256,
|
||||
content_type,
|
||||
compression,
|
||||
uncompressed_size,
|
||||
compressed_size,
|
||||
data,
|
||||
created_at.to_rfc3339(),
|
||||
],
|
||||
)?;
|
||||
conn.execute("INSERT INTO entries
|
||||
(sha256, content_type, compression_id, uncompressed_size, compressed_size, data, created_at)
|
||||
VALUES
|
||||
(?, ?, ?, ?, ?, ?, ?)
|
||||
",
|
||||
params![
|
||||
sha256,
|
||||
content_type,
|
||||
compression_id,
|
||||
uncompressed_size,
|
||||
compressed_size,
|
||||
data,
|
||||
created_at,
|
||||
],
|
||||
)?;
|
||||
|
||||
Ok(StoreResult::Created {
|
||||
stored_size: compressed_size,
|
||||
@@ -109,14 +127,3 @@ fn store(
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
fn auto_compressible_content_type(content_type: &str) -> bool {
|
||||
[
|
||||
"text/",
|
||||
"application/xml",
|
||||
"application/json",
|
||||
"application/javascript",
|
||||
]
|
||||
.iter()
|
||||
.any(|ct| content_type.starts_with(ct))
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ mod fn_store;
|
||||
mod shard;
|
||||
pub mod shard_error;
|
||||
|
||||
pub use fn_get::GetResult;
|
||||
pub use fn_get::{GetArgs, GetResult};
|
||||
pub use fn_store::{StoreArgs, StoreResult};
|
||||
pub use shard::Shard;
|
||||
|
||||
@@ -13,46 +13,9 @@ pub mod test {
|
||||
pub use super::shard::test::*;
|
||||
}
|
||||
|
||||
use crate::{sha256::Sha256, shard::shard_error::ShardError, UseCompression};
|
||||
use crate::{sha256::Sha256, shard::shard_error::ShardError};
|
||||
use axum::body::Bytes;
|
||||
use rusqlite::{params, types::FromSql, OptionalExtension, ToSql};
|
||||
use std::error::Error;
|
||||
use rusqlite::{params, types::FromSql, OptionalExtension};
|
||||
|
||||
use tokio_rusqlite::Connection;
|
||||
use tracing::{debug, error};
|
||||
|
||||
pub type UtcDateTime = chrono::DateTime<chrono::Utc>;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
enum Compression {
|
||||
None,
|
||||
Zstd,
|
||||
}
|
||||
impl ToSql for Compression {
|
||||
fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
|
||||
match self {
|
||||
Compression::None => 0.to_sql(),
|
||||
Compression::Zstd => 1.to_sql(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl FromSql for Compression {
|
||||
fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
|
||||
match value.as_i64()? {
|
||||
0 => Ok(Compression::None),
|
||||
1 => Ok(Compression::Zstd),
|
||||
_ => Err(rusqlite::types::FromSqlError::InvalidType),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_created_at_str(created_at_str: String) -> Result<UtcDateTime, rusqlite::Error> {
|
||||
let parsed = chrono::DateTime::parse_from_rfc3339(&created_at_str)
|
||||
.map_err(|e| rusqlite::Error::ToSqlConversionFailure(e.into()))?;
|
||||
Ok(parsed.to_utc())
|
||||
}
|
||||
|
||||
fn into_tokio_rusqlite_err<E: Into<Box<dyn Error + Send + Sync + 'static>>>(
|
||||
e: E,
|
||||
) -> tokio_rusqlite::Error {
|
||||
tokio_rusqlite::Error::Other(e.into())
|
||||
}
|
||||
|
||||
@@ -1,36 +1,25 @@
|
||||
use crate::AsyncBoxError;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Shard {
|
||||
pub(super) id: usize,
|
||||
pub(super) conn: Connection,
|
||||
pub(super) use_compression: UseCompression,
|
||||
}
|
||||
|
||||
impl Shard {
|
||||
pub async fn open(
|
||||
id: usize,
|
||||
use_compression: UseCompression,
|
||||
conn: Connection,
|
||||
) -> Result<Self, Box<dyn Error>> {
|
||||
let shard = Self {
|
||||
id,
|
||||
use_compression,
|
||||
conn,
|
||||
};
|
||||
pub async fn open(id: usize, conn: Connection) -> Result<Self, AsyncBoxError> {
|
||||
let shard = Self { id, conn };
|
||||
shard.migrate().await?;
|
||||
Ok(shard)
|
||||
}
|
||||
|
||||
pub async fn close(self) -> Result<(), Box<dyn Error>> {
|
||||
self.conn.close().await.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
pub fn id(&self) -> usize {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub async fn db_size_bytes(&self) -> Result<usize, Box<dyn Error>> {
|
||||
pub async fn db_size_bytes(&self) -> Result<usize, AsyncBoxError> {
|
||||
self.query_single_row(
|
||||
"SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()",
|
||||
)
|
||||
@@ -40,7 +29,7 @@ impl Shard {
|
||||
async fn query_single_row<T: FromSql + Send + 'static>(
|
||||
&self,
|
||||
query: &'static str,
|
||||
) -> Result<T, Box<dyn Error>> {
|
||||
) -> Result<T, AsyncBoxError> {
|
||||
self.conn
|
||||
.call(move |conn| {
|
||||
let value: T = conn.query_row(query, [], |row| row.get(0))?;
|
||||
@@ -50,7 +39,7 @@ impl Shard {
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
pub async fn num_entries(&self) -> Result<usize, Box<dyn Error>> {
|
||||
pub async fn num_entries(&self) -> Result<usize, AsyncBoxError> {
|
||||
get_num_entries(&self.conn).await.map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
@@ -65,18 +54,17 @@ async fn get_num_entries(conn: &Connection) -> Result<usize, tokio_rusqlite::Err
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
use crate::{
|
||||
compressor::test::make_compressor,
|
||||
sha256::Sha256,
|
||||
shard::{GetArgs, StoreArgs, StoreResult},
|
||||
CompressionPolicy,
|
||||
};
|
||||
use rstest::rstest;
|
||||
|
||||
use super::{StoreResult, UseCompression};
|
||||
use crate::{sha256::Sha256, shard::StoreArgs};
|
||||
|
||||
pub async fn make_shard_with_compression(use_compression: UseCompression) -> super::Shard {
|
||||
let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap();
|
||||
super::Shard::open(0, use_compression, conn).await.unwrap()
|
||||
}
|
||||
|
||||
pub async fn make_shard() -> super::Shard {
|
||||
make_shard_with_compression(UseCompression::Auto).await
|
||||
let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap();
|
||||
super::Shard::open(0, conn).await.unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -96,8 +84,9 @@ pub mod test {
|
||||
#[tokio::test]
|
||||
async fn test_not_found_get() {
|
||||
let shard = make_shard().await;
|
||||
let compressor = make_compressor().into_arc();
|
||||
let sha256 = Sha256::from_bytes("hello, world!".as_bytes());
|
||||
let get_result = shard.get(sha256).await.unwrap();
|
||||
let get_result = shard.get(GetArgs { sha256, compressor }).await.unwrap();
|
||||
assert!(get_result.is_none());
|
||||
}
|
||||
|
||||
@@ -106,11 +95,13 @@ pub mod test {
|
||||
let shard = make_shard().await;
|
||||
let data = "hello, world!".as_bytes();
|
||||
let sha256 = Sha256::from_bytes(data);
|
||||
let compressor = make_compressor().into_arc();
|
||||
let store_result = shard
|
||||
.store(StoreArgs {
|
||||
sha256,
|
||||
content_type: "text/plain".to_string(),
|
||||
data: data.into(),
|
||||
compressor: compressor.clone(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -128,7 +119,11 @@ pub mod test {
|
||||
}
|
||||
assert_eq!(shard.num_entries().await.unwrap(), 1);
|
||||
|
||||
let get_result = shard.get(sha256).await.unwrap().unwrap();
|
||||
let get_result = shard
|
||||
.get(GetArgs { sha256, compressor })
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(get_result.content_type, "text/plain");
|
||||
assert_eq!(get_result.data, data);
|
||||
assert_eq!(get_result.stored_size, data.len());
|
||||
@@ -145,6 +140,7 @@ pub mod test {
|
||||
sha256,
|
||||
content_type: "text/plain".to_string(),
|
||||
data: data.into(),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -168,6 +164,7 @@ pub mod test {
|
||||
sha256,
|
||||
content_type: "text/plain".to_string(),
|
||||
data: data.into(),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -185,12 +182,19 @@ pub mod test {
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_compression_store_get(
|
||||
#[values(UseCompression::Auto, UseCompression::None, UseCompression::Zstd)]
|
||||
use_compression: UseCompression,
|
||||
#[values(
|
||||
CompressionPolicy::Auto,
|
||||
CompressionPolicy::None,
|
||||
CompressionPolicy::ForceZstd
|
||||
)]
|
||||
compression_policy: CompressionPolicy,
|
||||
#[values(true, false)] incompressible_data: bool,
|
||||
#[values("text/string", "image/jpg", "application/octet-stream")] content_type: String,
|
||||
) {
|
||||
let shard = make_shard_with_compression(use_compression).await;
|
||||
use crate::compressor::Compressor;
|
||||
|
||||
let shard = make_shard().await;
|
||||
let compressor = Compressor::new(compression_policy).into_arc();
|
||||
let mut data = vec![b'.'; 1024];
|
||||
if incompressible_data {
|
||||
for byte in data.iter_mut() {
|
||||
@@ -204,12 +208,17 @@ pub mod test {
|
||||
sha256,
|
||||
content_type: content_type.clone(),
|
||||
data: data.clone().into(),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(matches!(store_result, StoreResult::Created { .. }));
|
||||
|
||||
let get_result = shard.get(sha256).await.unwrap().unwrap();
|
||||
let get_result = shard
|
||||
.get(GetArgs { sha256, compressor })
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(get_result.content_type, content_type);
|
||||
assert_eq!(get_result.data, data);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{sha256::Sha256, shard::Shard};
|
||||
use std::error::Error;
|
||||
|
||||
pub type ShardsArc = Arc<Shards>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Shards(Vec<Shard>);
|
||||
@@ -16,13 +19,6 @@ impl Shards {
|
||||
&self.0[shard_id]
|
||||
}
|
||||
|
||||
pub async fn close_all(self) -> Result<(), Box<dyn Error>> {
|
||||
for shard in self.0 {
|
||||
shard.close().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> std::slice::Iter<'_, Shard> {
|
||||
self.0.iter()
|
||||
}
|
||||
@@ -34,15 +30,13 @@ impl Shards {
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
use crate::{shard::test::make_shard_with_compression, UseCompression};
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::Shards;
|
||||
use crate::shard::test::make_shard;
|
||||
|
||||
pub async fn make_shards_with_compression(use_compression: UseCompression) -> Shards {
|
||||
Shards::new(vec![make_shard_with_compression(use_compression).await]).unwrap()
|
||||
}
|
||||
use super::{Shards, ShardsArc};
|
||||
|
||||
pub async fn make_shards() -> Shards {
|
||||
make_shards_with_compression(UseCompression::Auto).await
|
||||
pub async fn make_shards() -> ShardsArc {
|
||||
Arc::new(Shards::new(vec![make_shard().await]).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
61
src/sql_types/compression_id.rs
Normal file
61
src/sql_types/compression_id.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use rusqlite::{
|
||||
types::{FromSql, FromSqlError, FromSqlResult, ToSqlOutput, Value::Integer, ValueRef},
|
||||
Error::ToSqlConversionFailure,
|
||||
ToSql,
|
||||
};
|
||||
|
||||
use crate::AsyncBoxError;
|
||||
|
||||
use super::ZstdDictId;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
|
||||
pub enum CompressionId {
|
||||
None,
|
||||
ZstdGeneric,
|
||||
ZstdDictId(ZstdDictId),
|
||||
}
|
||||
|
||||
impl FromSql for CompressionId {
|
||||
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
|
||||
Ok(match value.as_i64()? {
|
||||
-1 => CompressionId::None,
|
||||
-2 => CompressionId::ZstdGeneric,
|
||||
id => CompressionId::ZstdDictId(ZstdDictId(id)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for CompressionId {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
|
||||
let value = match self {
|
||||
CompressionId::None => -1,
|
||||
CompressionId::ZstdGeneric => -2,
|
||||
CompressionId::ZstdDictId(ZstdDictId(id)) => *id,
|
||||
};
|
||||
Ok(ToSqlOutput::Owned(Integer(value)))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql for ZstdDictId {
|
||||
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
|
||||
match value.as_i64()? {
|
||||
id @ (-1 | -2) => Err(FromSqlError::Other(invalid_zstd_dict_id_err(id))),
|
||||
id => Ok(ZstdDictId(id)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for ZstdDictId {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
|
||||
let value = match self.0 {
|
||||
id @ (-1 | -2) => return Err(ToSqlConversionFailure(invalid_zstd_dict_id_err(id))),
|
||||
id => id,
|
||||
};
|
||||
|
||||
Ok(ToSqlOutput::Owned(Integer(value)))
|
||||
}
|
||||
}
|
||||
|
||||
fn invalid_zstd_dict_id_err(id: i64) -> AsyncBoxError {
|
||||
format!("Invalid ZstdDictId: {}", id).into()
|
||||
}
|
||||
7
src/sql_types/mod.rs
Normal file
7
src/sql_types/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
mod compression_id;
|
||||
mod utc_date_time;
|
||||
mod zstd_dict_id;
|
||||
|
||||
pub use compression_id::CompressionId;
|
||||
pub use utc_date_time::UtcDateTime;
|
||||
pub use zstd_dict_id::ZstdDictId;
|
||||
46
src/sql_types/utc_date_time.rs
Normal file
46
src/sql_types/utc_date_time.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
use chrono::DateTime;
|
||||
use rusqlite::{
|
||||
types::{FromSql, FromSqlError, ToSqlOutput, ValueRef},
|
||||
Result, ToSql,
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Debug, PartialOrd)]
|
||||
pub struct UtcDateTime(DateTime<chrono::Utc>);
|
||||
|
||||
impl UtcDateTime {
|
||||
pub fn now() -> Self {
|
||||
Self(chrono::Utc::now())
|
||||
}
|
||||
pub fn to_string(&self) -> String {
|
||||
self.0.to_rfc3339()
|
||||
}
|
||||
pub fn from_string(s: &str) -> Result<Self, chrono::ParseError> {
|
||||
Ok(Self(DateTime::parse_from_rfc3339(s)?.to_utc()))
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<DateTime<chrono::Utc>> for UtcDateTime {
|
||||
fn eq(&self, other: &DateTime<chrono::Utc>) -> bool {
|
||||
self.0 == *other
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd<DateTime<chrono::Utc>> for UtcDateTime {
|
||||
fn partial_cmp(&self, other: &DateTime<chrono::Utc>) -> Option<std::cmp::Ordering> {
|
||||
self.0.partial_cmp(other)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for UtcDateTime {
|
||||
fn to_sql(&self) -> Result<ToSqlOutput<'_>> {
|
||||
Ok(ToSqlOutput::from(self.0.to_rfc3339()))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql for UtcDateTime {
|
||||
fn column_result(value: ValueRef<'_>) -> Result<Self, FromSqlError> {
|
||||
let parsed = DateTime::parse_from_rfc3339(value.as_str()?)
|
||||
.map_err(|e| FromSqlError::Other(e.into()))?;
|
||||
Ok(UtcDateTime(parsed.to_utc()))
|
||||
}
|
||||
}
|
||||
7
src/sql_types/zstd_dict_id.rs
Normal file
7
src/sql_types/zstd_dict_id.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
|
||||
pub struct ZstdDictId(pub i64);
|
||||
impl From<i64> for ZstdDictId {
|
||||
fn from(id: i64) -> Self {
|
||||
Self(id)
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,14 @@
|
||||
use super::dict_id::DictId;
|
||||
use ouroboros::self_referencing;
|
||||
use std::{error::Error, io};
|
||||
use std::{error::Error, io, sync::Arc};
|
||||
use zstd::dict::{DecoderDictionary, EncoderDictionary};
|
||||
|
||||
use crate::{sql_types::ZstdDictId, AsyncBoxError};
|
||||
|
||||
pub type ZstdDictArc = Arc<ZstdDict>;
|
||||
|
||||
#[self_referencing]
|
||||
pub struct ZstdDict {
|
||||
id: DictId,
|
||||
id: crate::sql_types::ZstdDictId,
|
||||
name: String,
|
||||
level: i32,
|
||||
dict_bytes: Vec<u8>,
|
||||
@@ -31,7 +34,6 @@ impl PartialEq for ZstdDict {
|
||||
impl std::fmt::Debug for ZstdDict {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ZstdDict")
|
||||
.field("id", &self.id())
|
||||
.field("name", &self.name())
|
||||
.field("level", &self.level())
|
||||
.field("dict_bytes.len", &self.dict_bytes().len())
|
||||
@@ -40,7 +42,20 @@ impl std::fmt::Debug for ZstdDict {
|
||||
}
|
||||
|
||||
impl ZstdDict {
|
||||
pub fn create(id: DictId, name: String, level: i32, dict_bytes: Vec<u8>) -> Self {
|
||||
pub fn from_samples(
|
||||
id: ZstdDictId,
|
||||
name: String,
|
||||
level: i32,
|
||||
samples: Vec<&[u8]>,
|
||||
) -> Result<Self, Box<dyn Error + Send + Sync>> {
|
||||
let dict_bytes = zstd::dict::from_samples(
|
||||
&samples,
|
||||
1024 * 1024, // 1MB max dictionary size
|
||||
)?;
|
||||
Ok(Self::from_dict_bytes(id, name, level, dict_bytes))
|
||||
}
|
||||
|
||||
pub fn from_dict_bytes(id: ZstdDictId, name: String, level: i32, dict_bytes: Vec<u8>) -> Self {
|
||||
ZstdDictBuilder {
|
||||
id,
|
||||
name,
|
||||
@@ -52,8 +67,8 @@ impl ZstdDict {
|
||||
.build()
|
||||
}
|
||||
|
||||
pub fn id(&self) -> DictId {
|
||||
self.with_id(|id| *id)
|
||||
pub fn id(&self) -> ZstdDictId {
|
||||
*self.borrow_id()
|
||||
}
|
||||
pub fn name(&self) -> &str {
|
||||
self.borrow_name()
|
||||
@@ -65,9 +80,10 @@ impl ZstdDict {
|
||||
self.borrow_dict_bytes()
|
||||
}
|
||||
|
||||
pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
let mut wrapper = io::Cursor::new(data);
|
||||
let mut out_buffer = Vec::with_capacity(data.len());
|
||||
pub fn compress<DataRef: AsRef<[u8]>>(&self, data: DataRef) -> Result<Vec<u8>, AsyncBoxError> {
|
||||
let as_ref = data.as_ref();
|
||||
let mut wrapper = io::Cursor::new(as_ref);
|
||||
let mut out_buffer = Vec::with_capacity(as_ref.len());
|
||||
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
|
||||
|
||||
self.with_encoder_dict(|encoder_dict| {
|
||||
@@ -79,9 +95,13 @@ impl ZstdDict {
|
||||
Ok(out_buffer)
|
||||
}
|
||||
|
||||
pub fn decompress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
let mut wrapper = io::Cursor::new(data);
|
||||
let mut out_buffer = Vec::with_capacity(data.len());
|
||||
pub fn decompress<DataRef: AsRef<[u8]>>(
|
||||
&self,
|
||||
data: DataRef,
|
||||
) -> Result<Vec<u8>, AsyncBoxError> {
|
||||
let as_ref = data.as_ref();
|
||||
let mut wrapper = io::Cursor::new(as_ref);
|
||||
let mut out_buffer = Vec::with_capacity(as_ref.len());
|
||||
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
|
||||
|
||||
self.with_decoder_dict(|decoder_dict| {
|
||||
@@ -92,3 +112,35 @@ impl ZstdDict {
|
||||
Ok(out_buffer)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
use crate::sql_types::ZstdDictId;
|
||||
|
||||
pub fn make_zstd_dict(id: ZstdDictId, name: &str) -> super::ZstdDict {
|
||||
super::ZstdDict::from_dict_bytes(
|
||||
id,
|
||||
name.to_owned(),
|
||||
3,
|
||||
vec![
|
||||
"hello, world",
|
||||
"this is a test",
|
||||
"of the emergency broadcast system",
|
||||
]
|
||||
.into_iter()
|
||||
.chain(vec!["foo", "bar", "baz"].repeat(100))
|
||||
.map(|s| s.as_bytes().to_owned())
|
||||
.flat_map(|s| s.into_iter())
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zstd_dict() {
|
||||
let dict_bytes = vec![1, 2, 3, 4];
|
||||
let zstd_dict = make_zstd_dict(1.into(), "dict1");
|
||||
let compressed = zstd_dict.compress(b"hello world").unwrap();
|
||||
let decompressed = zstd_dict.decompress(&compressed).unwrap();
|
||||
assert_eq!(decompressed, b"hello world");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user