move manifest to its own db, zstd dictionary compression

This commit is contained in:
Dylan Knutson
2024-04-27 10:57:09 -07:00
parent 20dcf84c91
commit 182584cbe9
8 changed files with 510 additions and 319 deletions

72
Cargo.lock generated
View File

@@ -38,6 +38,12 @@ dependencies = [
"memchr",
]
[[package]]
name = "aliasable"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd"
[[package]]
name = "allocator-api2"
version = "0.2.18"
@@ -276,6 +282,7 @@ dependencies = [
"futures",
"hex",
"kdam",
"ouroboros",
"rand",
"reqwest",
"rstest",
@@ -483,6 +490,12 @@ dependencies = [
"crypto-common",
]
[[package]]
name = "either"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2"
[[package]]
name = "encoding_rs"
version = "0.8.34"
@@ -904,6 +917,15 @@ version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3"
[[package]]
name = "itertools"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.11"
@@ -1154,6 +1176,31 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "ouroboros"
version = "0.18.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97b7be5a8a3462b752f4be3ff2b2bf2f7f1d00834902e46be2a4d68b87b0573c"
dependencies = [
"aliasable",
"ouroboros_macro",
"static_assertions",
]
[[package]]
name = "ouroboros_macro"
version = "0.18.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b645dcde5f119c2c454a92d0dfa271a2a3b205da92e4292a68ead4bdbfde1f33"
dependencies = [
"heck 0.4.1",
"itertools",
"proc-macro2",
"proc-macro2-diagnostics",
"quote",
"syn 2.0.60",
]
[[package]]
name = "overload"
version = "0.1.1"
@@ -1266,6 +1313,19 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "proc-macro2-diagnostics"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.60",
"version_check",
"yansi",
]
[[package]]
name = "quote"
version = "1.0.36"
@@ -1649,6 +1709,12 @@ version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "strsim"
version = "0.10.0"
@@ -2293,6 +2359,12 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "yansi"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"
[[package]]
name = "zerocopy"
version = "0.7.32"

View File

@@ -30,7 +30,8 @@ tracing = "0.1.40"
tracing-subscriber = "0.3.18"
reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] }
hex = "0.4.3"
zstd = "0.13.1"
zstd = { version = "0.13.1", features = ["experimental"] }
ouroboros = "0.18.3"
[dev-dependencies]
rstest = "0.19.0"

9
src/manifest/dict_id.rs Normal file
View File

@@ -0,0 +1,9 @@
use rusqlite::types::FromSql;
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
pub struct DictId(i64);
impl FromSql for DictId {
fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
Ok(DictId(value.as_i64()?))
}
}

View File

@@ -0,0 +1,38 @@
use rusqlite::{params, types::FromSql, OptionalExtension, ToSql};
pub trait ManifestKey {
type Value: ToSql + FromSql;
fn sql_key(&self) -> &'static str;
}
pub struct NumShards;
impl ManifestKey for NumShards {
type Value = usize;
fn sql_key(&self) -> &'static str {
"num_shards"
}
}
pub fn get_manifest_key<Key: ManifestKey>(
conn: &rusqlite::Connection,
key: Key,
) -> Result<Option<Key::Value>, rusqlite::Error> {
conn.query_row(
"SELECT value FROM manifest WHERE key = ?",
params![key.sql_key()],
|row| row.get(0),
)
.optional()
}
pub fn set_manifest_key<Key: ManifestKey>(
conn: &rusqlite::Connection,
key: Key,
value: Key::Value,
) -> Result<(), rusqlite::Error> {
conn.execute(
"INSERT OR REPLACE INTO manifest (key, value) VALUES (?, ?)",
params![key.sql_key(), value],
)?;
Ok(())
}

View File

@@ -1,56 +1,27 @@
use std::{collections::HashMap, error::Error, io};
mod dict_id;
mod manifest_key;
mod zstd_dict;
use rusqlite::{params, types::FromSql, OptionalExtension, ToSql};
use std::{collections::HashMap, error::Error, sync::Arc};
use rusqlite::params;
use tokio_rusqlite::Connection;
use crate::shards::Shards;
use self::{
dict_id::DictId,
manifest_key::{get_manifest_key, set_manifest_key, NumShards},
zstd_dict::ZstdDict,
};
pub type ZstdDictArc = Arc<ZstdDict>;
pub struct Manifest {
conn: Connection,
num_shards: usize,
zstd_dicts: HashMap<String, (i64, ZstdDict)>,
}
pub struct ZstdDict {
level: i32,
encoder_dict: zstd::dict::EncoderDictionary<'static>,
decoder_dict: zstd::dict::DecoderDictionary<'static>,
}
impl ZstdDict {
pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
let mut wrapper = io::Cursor::new(data);
let mut out_buffer = Vec::with_capacity(data.len());
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
let mut encoder = zstd::stream::Encoder::with_prepared_dictionary(
&mut output_wrapper,
&self.encoder_dict,
)?;
io::copy(&mut wrapper, &mut encoder)?;
encoder.finish()?;
Ok(out_buffer)
}
pub fn decompress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
let mut wrapper = io::Cursor::new(data);
let mut out_buffer = Vec::with_capacity(data.len());
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
let mut decoder =
zstd::stream::Decoder::with_prepared_dictionary(&mut wrapper, &self.decoder_dict)?;
io::copy(&mut decoder, &mut output_wrapper)?;
Ok(out_buffer)
}
}
impl ZstdDict {
fn create(dict: &[u8], level: i32) -> Self {
let encoder = zstd::dict::EncoderDictionary::copy(dict, level);
let decoder = zstd::dict::DecoderDictionary::copy(dict);
Self {
level,
encoder_dict: encoder,
decoder_dict: decoder,
}
}
zstd_dict_by_id: HashMap<DictId, ZstdDictArc>,
zstd_dict_by_name: HashMap<String, ZstdDictArc>,
}
impl Manifest {
@@ -61,76 +32,62 @@ impl Manifest {
self.num_shards
}
pub async fn train_dictionary(
async fn train_zstd_dict_with_tag(
&mut self,
name: String,
_name: &str,
_shards: Shards,
) -> Result<ZstdDictArc, Box<dyn Error>> {
// let mut queries = vec![];
// for shard in shards.iter() {
// queries.push(shard.entries_for_tag(name));
// }
todo!();
}
async fn create_zstd_dict_from_samples(
&mut self,
name: &str,
samples: Vec<&[u8]>,
) -> Result<(), Box<dyn Error>> {
if self.zstd_dicts.contains_key(&name) {
) -> Result<ZstdDictArc, Box<dyn Error>> {
if self.zstd_dict_by_name.contains_key(name) {
return Err(format!("dictionary {} already exists", name).into());
}
let level = 3;
let dict = zstd::dict::from_samples(
let dict_bytes = zstd::dict::from_samples(
&samples,
1024 * 1024, // 1MB max dictionary size
)
.unwrap();
let zstd_dict = ZstdDict::create(&dict, level);
let name_copy = name.clone();
let dict_id = self
let name_copy = name.to_string();
let (dict_id, zstd_dict) = self
.conn
.call(move |conn| {
let mut stmt = conn.prepare(
"INSERT INTO dictionaries (level, name, dict)
"INSERT INTO dictionaries (name, level, dict)
VALUES (?, ?, ?)
RETURNING id",
)?;
let dict_id = stmt.query_row(params![level, name_copy, dict], |row| row.get(0))?;
Ok(dict_id)
let dict_id =
stmt.query_row(params![name_copy, level, dict_bytes], |row| row.get(0))?;
let zstd_dict = Arc::new(ZstdDict::create(dict_id, name_copy, level, dict_bytes));
Ok((dict_id, zstd_dict))
})
.await?;
self.zstd_dicts.insert(name, (dict_id, zstd_dict));
Ok(())
self.zstd_dict_by_id.insert(dict_id, zstd_dict.clone());
self.zstd_dict_by_name
.insert(name.to_string(), zstd_dict.clone());
Ok(zstd_dict)
}
}
trait ManifestKey {
type Value: ToSql + FromSql;
fn sql_key(&self) -> &'static str;
}
struct NumShards;
impl ManifestKey for NumShards {
type Value = usize;
fn sql_key(&self) -> &'static str {
"num_shards"
fn get_dictionary_by_id(&self, id: DictId) -> Option<&ZstdDict> {
self.zstd_dict_by_id.get(&id).map(|d| &**d)
}
fn get_dictionary_by_name(&self, name: &str) -> Option<&ZstdDict> {
self.zstd_dict_by_name.get(name).map(|d| &**d)
}
}
fn get_manifest_key<Key: ManifestKey>(
conn: &rusqlite::Connection,
key: Key,
) -> Result<Option<Key::Value>, rusqlite::Error> {
conn.query_row(
"SELECT value FROM manifest WHERE key = ?",
params![key.sql_key()],
|row| row.get(0),
)
.optional()
}
fn set_manifest_key<Key: ManifestKey>(
conn: &rusqlite::Connection,
key: Key,
value: Key::Value,
) -> Result<(), rusqlite::Error> {
conn.execute(
"INSERT OR REPLACE INTO manifest (key, value) VALUES (?, ?)",
params![key.sql_key(), value],
)?;
Ok(())
}
async fn initialize(
@@ -192,7 +149,7 @@ async fn initialize(
let mut stmt = conn.prepare("SELECT id, name, level, dict FROM dictionaries")?;
let mut rows = vec![];
for r in stmt.query_map([], |row| {
let id: i64 = row.get(0)?;
let id = row.get(0)?;
let name: String = row.get(1)?;
let level: i32 = row.get(2)?;
let dict: Vec<u8> = row.get(3)?;
@@ -204,16 +161,19 @@ async fn initialize(
})
.await?;
let mut zstd_dicts = HashMap::new();
for (id, name, level, dict) in rows {
let zstd_dict = ZstdDict::create(&dict, level);
zstd_dicts.insert(name, (id, zstd_dict));
let mut zstd_dicts_by_id = HashMap::new();
let mut zstd_dicts_by_name = HashMap::new();
for (id, name, level, dict_bytes) in rows {
let zstd_dict = Arc::new(ZstdDict::create(id, name.clone(), level, dict_bytes));
zstd_dicts_by_id.insert(id, zstd_dict.clone());
zstd_dicts_by_name.insert(name, zstd_dict);
}
Ok(Manifest {
conn,
num_shards,
zstd_dicts,
zstd_dict_by_id: zstd_dicts_by_id,
zstd_dict_by_name: zstd_dicts_by_name,
})
}
@@ -227,12 +187,21 @@ mod tests {
let mut manifest = initialize(conn, Some(3)).await.unwrap();
let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100];
manifest
.train_dictionary("test".to_string(), samples)
let zstd_dict = manifest
.create_zstd_dict_from_samples("test", samples)
.await
.unwrap();
let zstd_dict = &manifest.zstd_dicts.get("test").unwrap().1;
// test that indexes are created correctly
assert_eq!(
zstd_dict.as_ref(),
manifest.get_dictionary_by_id(zstd_dict.id()).unwrap()
);
assert_eq!(
zstd_dict.as_ref(),
manifest.get_dictionary_by_name(zstd_dict.name()).unwrap()
);
let data = b"hello world, this is a test of a sort of long string";
let compressed = zstd_dict.compress(data).unwrap();
let decompressed = zstd_dict.decompress(&compressed).unwrap();

94
src/manifest/zstd_dict.rs Normal file
View File

@@ -0,0 +1,94 @@
use super::dict_id::DictId;
use ouroboros::self_referencing;
use std::{error::Error, io};
use zstd::dict::{DecoderDictionary, EncoderDictionary};
#[self_referencing]
pub struct ZstdDict {
id: DictId,
name: String,
level: i32,
dict_bytes: Vec<u8>,
#[borrows(dict_bytes)]
#[not_covariant]
encoder_dict: EncoderDictionary<'this>,
#[borrows(dict_bytes)]
#[not_covariant]
decoder_dict: DecoderDictionary<'this>,
}
impl PartialEq for ZstdDict {
fn eq(&self, other: &Self) -> bool {
self.id() == other.id()
&& self.name() == other.name()
&& self.level() == other.level()
&& self.dict_bytes() == other.dict_bytes()
}
}
impl std::fmt::Debug for ZstdDict {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ZstdDict")
.field("id", &self.id())
.field("name", &self.name())
.field("level", &self.level())
.field("dict_bytes.len", &self.dict_bytes().len())
.finish()
}
}
impl ZstdDict {
pub fn create(id: DictId, name: String, level: i32, dict_bytes: Vec<u8>) -> Self {
ZstdDictBuilder {
id,
name,
level,
dict_bytes,
encoder_dict_builder: |dict_bytes| EncoderDictionary::new(dict_bytes, level),
decoder_dict_builder: |dict_bytes| DecoderDictionary::new(dict_bytes),
}
.build()
}
pub fn id(&self) -> DictId {
self.with_id(|id| *id)
}
pub fn name(&self) -> &str {
self.borrow_name()
}
pub fn level(&self) -> i32 {
*self.borrow_level()
}
pub fn dict_bytes(&self) -> &[u8] {
self.borrow_dict_bytes()
}
pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
let mut wrapper = io::Cursor::new(data);
let mut out_buffer = Vec::with_capacity(data.len());
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
self.with_encoder_dict(|encoder_dict| {
let mut encoder =
zstd::stream::Encoder::with_prepared_dictionary(&mut output_wrapper, encoder_dict)?;
io::copy(&mut wrapper, &mut encoder)?;
encoder.finish()
})?;
Ok(out_buffer)
}
pub fn decompress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
let mut wrapper = io::Cursor::new(data);
let mut out_buffer = Vec::with_capacity(data.len());
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
self.with_decoder_dict(|decoder_dict| {
let mut decoder =
zstd::stream::Decoder::with_prepared_dictionary(&mut wrapper, decoder_dict)?;
io::copy(&mut decoder, &mut output_wrapper)
})?;
Ok(out_buffer)
}
}

View File

@@ -1,10 +1,17 @@
mod fn_get;
mod fn_migrate;
mod fn_store;
mod shard;
pub mod shard_error;
pub use fn_get::GetResult;
pub use fn_store::{StoreArgs, StoreResult};
pub use shard::Shard;
#[cfg(test)]
pub mod test {
pub use super::shard::test::*;
}
use crate::{sha256::Sha256, shard::shard_error::ShardError, UseCompression};
use axum::body::Bytes;
@@ -15,13 +22,6 @@ use tracing::{debug, error};
pub type UtcDateTime = chrono::DateTime<chrono::Utc>;
#[derive(Clone)]
pub struct Shard {
id: usize,
conn: Connection,
use_compression: UseCompression,
}
#[derive(Debug, PartialEq, Eq)]
enum Compression {
None,
@@ -45,62 +45,6 @@ impl FromSql for Compression {
}
}
impl Shard {
pub async fn open(
id: usize,
use_compression: UseCompression,
conn: Connection,
) -> Result<Self, Box<dyn Error>> {
let shard = Self {
id,
use_compression,
conn,
};
shard.migrate().await?;
Ok(shard)
}
pub async fn close(self) -> Result<(), Box<dyn Error>> {
self.conn.close().await.map_err(|e| e.into())
}
pub fn id(&self) -> usize {
self.id
}
pub async fn db_size_bytes(&self) -> Result<usize, Box<dyn Error>> {
self.query_single_row(
"SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()",
)
.await
}
async fn query_single_row<T: FromSql + Send + 'static>(
&self,
query: &'static str,
) -> Result<T, Box<dyn Error>> {
self.conn
.call(move |conn| {
let value: T = conn.query_row(query, [], |row| row.get(0))?;
Ok(value)
})
.await
.map_err(|e| e.into())
}
pub async fn num_entries(&self) -> Result<usize, Box<dyn Error>> {
get_num_entries(&self.conn).await.map_err(|e| e.into())
}
}
async fn get_num_entries(conn: &Connection) -> Result<usize, tokio_rusqlite::Error> {
conn.call(|conn| {
let count: usize = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?;
Ok(count)
})
.await
}
fn parse_created_at_str(created_at_str: String) -> Result<UtcDateTime, rusqlite::Error> {
let parsed = chrono::DateTime::parse_from_rfc3339(&created_at_str)
.map_err(|e| rusqlite::Error::ToSqlConversionFailure(e.into()))?;
@@ -112,155 +56,3 @@ fn into_tokio_rusqlite_err<E: Into<Box<dyn Error + Send + Sync + 'static>>>(
) -> tokio_rusqlite::Error {
tokio_rusqlite::Error::Other(e.into())
}
#[cfg(test)]
pub mod test {
use rstest::rstest;
use super::{StoreResult, UseCompression};
use crate::{sha256::Sha256, shard::StoreArgs};
pub async fn make_shard_with_compression(use_compression: UseCompression) -> super::Shard {
let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap();
super::Shard::open(0, use_compression, conn).await.unwrap()
}
pub async fn make_shard() -> super::Shard {
make_shard_with_compression(UseCompression::Auto).await
}
#[tokio::test]
async fn test_num_entries() {
let shard = make_shard().await;
let num_entries = shard.num_entries().await.unwrap();
assert_eq!(num_entries, 0);
}
#[tokio::test]
async fn test_db_size_bytes() {
let shard = make_shard().await;
let db_size = shard.db_size_bytes().await.unwrap();
assert!(db_size > 0);
}
#[tokio::test]
async fn test_not_found_get() {
let shard = make_shard().await;
let sha256 = Sha256::from_bytes("hello, world!".as_bytes());
let get_result = shard.get(sha256).await.unwrap();
assert!(get_result.is_none());
}
#[tokio::test]
async fn test_store_and_get() {
let shard = make_shard().await;
let data = "hello, world!".as_bytes();
let sha256 = Sha256::from_bytes(data);
let store_result = shard
.store(StoreArgs {
sha256,
content_type: "text/plain".to_string(),
data: data.into(),
})
.await
.unwrap();
match store_result {
StoreResult::Created {
stored_size,
data_size,
created_at,
} => {
assert_eq!(stored_size, data.len());
assert_eq!(data_size, data.len());
assert!(created_at > chrono::Utc::now() - chrono::Duration::seconds(1));
}
_ => panic!("expected StoreResult::Created"),
}
assert_eq!(shard.num_entries().await.unwrap(), 1);
let get_result = shard.get(sha256).await.unwrap().unwrap();
assert_eq!(get_result.content_type, "text/plain");
assert_eq!(get_result.data, data);
assert_eq!(get_result.stored_size, data.len());
}
#[tokio::test]
async fn test_store_duplicate() {
let shard = make_shard().await;
let data = "hello, world!".as_bytes();
let sha256 = Sha256::from_bytes(data);
let store_result = shard
.store(StoreArgs {
sha256,
content_type: "text/plain".to_string(),
data: data.into(),
})
.await
.unwrap();
assert!(matches!(store_result, StoreResult::Created { .. }));
let created_at = match store_result {
StoreResult::Created {
stored_size,
data_size,
created_at,
} => {
assert_eq!(stored_size, data.len());
assert_eq!(data_size, data.len());
created_at
}
_ => panic!("expected StoreResult::Created"),
};
let store_result = shard
.store(StoreArgs {
sha256,
content_type: "text/plain".to_string(),
data: data.into(),
})
.await
.unwrap();
assert_eq!(
store_result,
StoreResult::Exists {
data_size: data.len(),
stored_size: data.len(),
created_at
}
);
assert_eq!(shard.num_entries().await.unwrap(), 1);
}
#[rstest]
#[tokio::test]
async fn test_compression_store_get(
#[values(UseCompression::Auto, UseCompression::None, UseCompression::Zstd)]
use_compression: UseCompression,
#[values(true, false)] incompressible_data: bool,
#[values("text/string", "image/jpg", "application/octet-stream")] content_type: String,
) {
let shard = make_shard_with_compression(use_compression).await;
let mut data = vec![b'.'; 1024];
if incompressible_data {
for byte in data.iter_mut() {
*byte = rand::random();
}
}
let sha256 = Sha256::from_bytes(&data);
let store_result = shard
.store(StoreArgs {
sha256,
content_type: content_type.clone(),
data: data.clone().into(),
})
.await
.unwrap();
assert!(matches!(store_result, StoreResult::Created { .. }));
let get_result = shard.get(sha256).await.unwrap().unwrap();
assert_eq!(get_result.content_type, content_type);
assert_eq!(get_result.data, data);
}
}

216
src/shard/shard.rs Normal file
View File

@@ -0,0 +1,216 @@
use super::*;
#[derive(Clone)]
pub struct Shard {
pub(super) id: usize,
pub(super) conn: Connection,
pub(super) use_compression: UseCompression,
}
impl Shard {
pub async fn open(
id: usize,
use_compression: UseCompression,
conn: Connection,
) -> Result<Self, Box<dyn Error>> {
let shard = Self {
id,
use_compression,
conn,
};
shard.migrate().await?;
Ok(shard)
}
pub async fn close(self) -> Result<(), Box<dyn Error>> {
self.conn.close().await.map_err(|e| e.into())
}
pub fn id(&self) -> usize {
self.id
}
pub async fn db_size_bytes(&self) -> Result<usize, Box<dyn Error>> {
self.query_single_row(
"SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()",
)
.await
}
async fn query_single_row<T: FromSql + Send + 'static>(
&self,
query: &'static str,
) -> Result<T, Box<dyn Error>> {
self.conn
.call(move |conn| {
let value: T = conn.query_row(query, [], |row| row.get(0))?;
Ok(value)
})
.await
.map_err(|e| e.into())
}
pub async fn num_entries(&self) -> Result<usize, Box<dyn Error>> {
get_num_entries(&self.conn).await.map_err(|e| e.into())
}
}
async fn get_num_entries(conn: &Connection) -> Result<usize, tokio_rusqlite::Error> {
conn.call(|conn| {
let count: usize = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?;
Ok(count)
})
.await
}
#[cfg(test)]
pub mod test {
use rstest::rstest;
use super::{StoreResult, UseCompression};
use crate::{sha256::Sha256, shard::StoreArgs};
pub async fn make_shard_with_compression(use_compression: UseCompression) -> super::Shard {
let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap();
super::Shard::open(0, use_compression, conn).await.unwrap()
}
pub async fn make_shard() -> super::Shard {
make_shard_with_compression(UseCompression::Auto).await
}
#[tokio::test]
async fn test_num_entries() {
let shard = make_shard().await;
let num_entries = shard.num_entries().await.unwrap();
assert_eq!(num_entries, 0);
}
#[tokio::test]
async fn test_db_size_bytes() {
let shard = make_shard().await;
let db_size = shard.db_size_bytes().await.unwrap();
assert!(db_size > 0);
}
#[tokio::test]
async fn test_not_found_get() {
let shard = make_shard().await;
let sha256 = Sha256::from_bytes("hello, world!".as_bytes());
let get_result = shard.get(sha256).await.unwrap();
assert!(get_result.is_none());
}
#[tokio::test]
async fn test_store_and_get() {
let shard = make_shard().await;
let data = "hello, world!".as_bytes();
let sha256 = Sha256::from_bytes(data);
let store_result = shard
.store(StoreArgs {
sha256,
content_type: "text/plain".to_string(),
data: data.into(),
})
.await
.unwrap();
match store_result {
StoreResult::Created {
stored_size,
data_size,
created_at,
} => {
assert_eq!(stored_size, data.len());
assert_eq!(data_size, data.len());
assert!(created_at > chrono::Utc::now() - chrono::Duration::seconds(1));
}
_ => panic!("expected StoreResult::Created"),
}
assert_eq!(shard.num_entries().await.unwrap(), 1);
let get_result = shard.get(sha256).await.unwrap().unwrap();
assert_eq!(get_result.content_type, "text/plain");
assert_eq!(get_result.data, data);
assert_eq!(get_result.stored_size, data.len());
}
#[tokio::test]
async fn test_store_duplicate() {
let shard = make_shard().await;
let data = "hello, world!".as_bytes();
let sha256 = Sha256::from_bytes(data);
let store_result = shard
.store(StoreArgs {
sha256,
content_type: "text/plain".to_string(),
data: data.into(),
})
.await
.unwrap();
assert!(matches!(store_result, StoreResult::Created { .. }));
let created_at = match store_result {
StoreResult::Created {
stored_size,
data_size,
created_at,
} => {
assert_eq!(stored_size, data.len());
assert_eq!(data_size, data.len());
created_at
}
_ => panic!("expected StoreResult::Created"),
};
let store_result = shard
.store(StoreArgs {
sha256,
content_type: "text/plain".to_string(),
data: data.into(),
})
.await
.unwrap();
assert_eq!(
store_result,
StoreResult::Exists {
data_size: data.len(),
stored_size: data.len(),
created_at
}
);
assert_eq!(shard.num_entries().await.unwrap(), 1);
}
#[rstest]
#[tokio::test]
async fn test_compression_store_get(
#[values(UseCompression::Auto, UseCompression::None, UseCompression::Zstd)]
use_compression: UseCompression,
#[values(true, false)] incompressible_data: bool,
#[values("text/string", "image/jpg", "application/octet-stream")] content_type: String,
) {
let shard = make_shard_with_compression(use_compression).await;
let mut data = vec![b'.'; 1024];
if incompressible_data {
for byte in data.iter_mut() {
*byte = rand::random();
}
}
let sha256 = Sha256::from_bytes(&data);
let store_result = shard
.store(StoreArgs {
sha256,
content_type: content_type.clone(),
data: data.clone().into(),
})
.await
.unwrap();
assert!(matches!(store_result, StoreResult::Created { .. }));
let get_result = shard.get(sha256).await.unwrap().unwrap();
assert_eq!(get_result.content_type, content_type);
assert_eq!(get_result.data, data);
}
}