move manifest to its own db, zstd dictionary compression
This commit is contained in:
72
Cargo.lock
generated
72
Cargo.lock
generated
@@ -38,6 +38,12 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aliasable"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd"
|
||||
|
||||
[[package]]
|
||||
name = "allocator-api2"
|
||||
version = "0.2.18"
|
||||
@@ -276,6 +282,7 @@ dependencies = [
|
||||
"futures",
|
||||
"hex",
|
||||
"kdam",
|
||||
"ouroboros",
|
||||
"rand",
|
||||
"reqwest",
|
||||
"rstest",
|
||||
@@ -483,6 +490,12 @@ dependencies = [
|
||||
"crypto-common",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2"
|
||||
|
||||
[[package]]
|
||||
name = "encoding_rs"
|
||||
version = "0.8.34"
|
||||
@@ -904,6 +917,15 @@ version = "2.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3"
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.11"
|
||||
@@ -1154,6 +1176,31 @@ dependencies = [
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ouroboros"
|
||||
version = "0.18.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "97b7be5a8a3462b752f4be3ff2b2bf2f7f1d00834902e46be2a4d68b87b0573c"
|
||||
dependencies = [
|
||||
"aliasable",
|
||||
"ouroboros_macro",
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ouroboros_macro"
|
||||
version = "0.18.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b645dcde5f119c2c454a92d0dfa271a2a3b205da92e4292a68ead4bdbfde1f33"
|
||||
dependencies = [
|
||||
"heck 0.4.1",
|
||||
"itertools",
|
||||
"proc-macro2",
|
||||
"proc-macro2-diagnostics",
|
||||
"quote",
|
||||
"syn 2.0.60",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "overload"
|
||||
version = "0.1.1"
|
||||
@@ -1266,6 +1313,19 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2-diagnostics"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.60",
|
||||
"version_check",
|
||||
"yansi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.36"
|
||||
@@ -1649,6 +1709,12 @@ version = "0.9.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
|
||||
|
||||
[[package]]
|
||||
name = "static_assertions"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.10.0"
|
||||
@@ -2293,6 +2359,12 @@ dependencies = [
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yansi"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy"
|
||||
version = "0.7.32"
|
||||
|
||||
@@ -30,7 +30,8 @@ tracing = "0.1.40"
|
||||
tracing-subscriber = "0.3.18"
|
||||
reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] }
|
||||
hex = "0.4.3"
|
||||
zstd = "0.13.1"
|
||||
zstd = { version = "0.13.1", features = ["experimental"] }
|
||||
ouroboros = "0.18.3"
|
||||
|
||||
[dev-dependencies]
|
||||
rstest = "0.19.0"
|
||||
|
||||
9
src/manifest/dict_id.rs
Normal file
9
src/manifest/dict_id.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
use rusqlite::types::FromSql;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
|
||||
pub struct DictId(i64);
|
||||
impl FromSql for DictId {
|
||||
fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
|
||||
Ok(DictId(value.as_i64()?))
|
||||
}
|
||||
}
|
||||
38
src/manifest/manifest_key.rs
Normal file
38
src/manifest/manifest_key.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
use rusqlite::{params, types::FromSql, OptionalExtension, ToSql};
|
||||
|
||||
pub trait ManifestKey {
|
||||
type Value: ToSql + FromSql;
|
||||
fn sql_key(&self) -> &'static str;
|
||||
}
|
||||
|
||||
pub struct NumShards;
|
||||
impl ManifestKey for NumShards {
|
||||
type Value = usize;
|
||||
fn sql_key(&self) -> &'static str {
|
||||
"num_shards"
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_manifest_key<Key: ManifestKey>(
|
||||
conn: &rusqlite::Connection,
|
||||
key: Key,
|
||||
) -> Result<Option<Key::Value>, rusqlite::Error> {
|
||||
conn.query_row(
|
||||
"SELECT value FROM manifest WHERE key = ?",
|
||||
params![key.sql_key()],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.optional()
|
||||
}
|
||||
|
||||
pub fn set_manifest_key<Key: ManifestKey>(
|
||||
conn: &rusqlite::Connection,
|
||||
key: Key,
|
||||
value: Key::Value,
|
||||
) -> Result<(), rusqlite::Error> {
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO manifest (key, value) VALUES (?, ?)",
|
||||
params![key.sql_key(), value],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,56 +1,27 @@
|
||||
use std::{collections::HashMap, error::Error, io};
|
||||
mod dict_id;
|
||||
mod manifest_key;
|
||||
mod zstd_dict;
|
||||
|
||||
use rusqlite::{params, types::FromSql, OptionalExtension, ToSql};
|
||||
use std::{collections::HashMap, error::Error, sync::Arc};
|
||||
|
||||
use rusqlite::params;
|
||||
use tokio_rusqlite::Connection;
|
||||
|
||||
use crate::shards::Shards;
|
||||
|
||||
use self::{
|
||||
dict_id::DictId,
|
||||
manifest_key::{get_manifest_key, set_manifest_key, NumShards},
|
||||
zstd_dict::ZstdDict,
|
||||
};
|
||||
|
||||
pub type ZstdDictArc = Arc<ZstdDict>;
|
||||
|
||||
pub struct Manifest {
|
||||
conn: Connection,
|
||||
num_shards: usize,
|
||||
zstd_dicts: HashMap<String, (i64, ZstdDict)>,
|
||||
}
|
||||
|
||||
pub struct ZstdDict {
|
||||
level: i32,
|
||||
encoder_dict: zstd::dict::EncoderDictionary<'static>,
|
||||
decoder_dict: zstd::dict::DecoderDictionary<'static>,
|
||||
}
|
||||
impl ZstdDict {
|
||||
pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
let mut wrapper = io::Cursor::new(data);
|
||||
let mut out_buffer = Vec::with_capacity(data.len());
|
||||
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
|
||||
|
||||
let mut encoder = zstd::stream::Encoder::with_prepared_dictionary(
|
||||
&mut output_wrapper,
|
||||
&self.encoder_dict,
|
||||
)?;
|
||||
io::copy(&mut wrapper, &mut encoder)?;
|
||||
encoder.finish()?;
|
||||
Ok(out_buffer)
|
||||
}
|
||||
|
||||
pub fn decompress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
let mut wrapper = io::Cursor::new(data);
|
||||
let mut out_buffer = Vec::with_capacity(data.len());
|
||||
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
|
||||
|
||||
let mut decoder =
|
||||
zstd::stream::Decoder::with_prepared_dictionary(&mut wrapper, &self.decoder_dict)?;
|
||||
io::copy(&mut decoder, &mut output_wrapper)?;
|
||||
Ok(out_buffer)
|
||||
}
|
||||
}
|
||||
|
||||
impl ZstdDict {
|
||||
fn create(dict: &[u8], level: i32) -> Self {
|
||||
let encoder = zstd::dict::EncoderDictionary::copy(dict, level);
|
||||
let decoder = zstd::dict::DecoderDictionary::copy(dict);
|
||||
Self {
|
||||
level,
|
||||
encoder_dict: encoder,
|
||||
decoder_dict: decoder,
|
||||
}
|
||||
}
|
||||
zstd_dict_by_id: HashMap<DictId, ZstdDictArc>,
|
||||
zstd_dict_by_name: HashMap<String, ZstdDictArc>,
|
||||
}
|
||||
|
||||
impl Manifest {
|
||||
@@ -61,76 +32,62 @@ impl Manifest {
|
||||
self.num_shards
|
||||
}
|
||||
|
||||
pub async fn train_dictionary(
|
||||
async fn train_zstd_dict_with_tag(
|
||||
&mut self,
|
||||
name: String,
|
||||
_name: &str,
|
||||
_shards: Shards,
|
||||
) -> Result<ZstdDictArc, Box<dyn Error>> {
|
||||
// let mut queries = vec![];
|
||||
// for shard in shards.iter() {
|
||||
// queries.push(shard.entries_for_tag(name));
|
||||
// }
|
||||
todo!();
|
||||
}
|
||||
|
||||
async fn create_zstd_dict_from_samples(
|
||||
&mut self,
|
||||
name: &str,
|
||||
samples: Vec<&[u8]>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
if self.zstd_dicts.contains_key(&name) {
|
||||
) -> Result<ZstdDictArc, Box<dyn Error>> {
|
||||
if self.zstd_dict_by_name.contains_key(name) {
|
||||
return Err(format!("dictionary {} already exists", name).into());
|
||||
}
|
||||
|
||||
let level = 3;
|
||||
let dict = zstd::dict::from_samples(
|
||||
let dict_bytes = zstd::dict::from_samples(
|
||||
&samples,
|
||||
1024 * 1024, // 1MB max dictionary size
|
||||
)
|
||||
.unwrap();
|
||||
let zstd_dict = ZstdDict::create(&dict, level);
|
||||
|
||||
let name_copy = name.clone();
|
||||
let dict_id = self
|
||||
let name_copy = name.to_string();
|
||||
let (dict_id, zstd_dict) = self
|
||||
.conn
|
||||
.call(move |conn| {
|
||||
let mut stmt = conn.prepare(
|
||||
"INSERT INTO dictionaries (level, name, dict)
|
||||
"INSERT INTO dictionaries (name, level, dict)
|
||||
VALUES (?, ?, ?)
|
||||
RETURNING id",
|
||||
)?;
|
||||
let dict_id = stmt.query_row(params![level, name_copy, dict], |row| row.get(0))?;
|
||||
Ok(dict_id)
|
||||
let dict_id =
|
||||
stmt.query_row(params![name_copy, level, dict_bytes], |row| row.get(0))?;
|
||||
let zstd_dict = Arc::new(ZstdDict::create(dict_id, name_copy, level, dict_bytes));
|
||||
Ok((dict_id, zstd_dict))
|
||||
})
|
||||
.await?;
|
||||
|
||||
self.zstd_dicts.insert(name, (dict_id, zstd_dict));
|
||||
Ok(())
|
||||
self.zstd_dict_by_id.insert(dict_id, zstd_dict.clone());
|
||||
self.zstd_dict_by_name
|
||||
.insert(name.to_string(), zstd_dict.clone());
|
||||
Ok(zstd_dict)
|
||||
}
|
||||
}
|
||||
|
||||
trait ManifestKey {
|
||||
type Value: ToSql + FromSql;
|
||||
fn sql_key(&self) -> &'static str;
|
||||
}
|
||||
struct NumShards;
|
||||
impl ManifestKey for NumShards {
|
||||
type Value = usize;
|
||||
fn sql_key(&self) -> &'static str {
|
||||
"num_shards"
|
||||
fn get_dictionary_by_id(&self, id: DictId) -> Option<&ZstdDict> {
|
||||
self.zstd_dict_by_id.get(&id).map(|d| &**d)
|
||||
}
|
||||
fn get_dictionary_by_name(&self, name: &str) -> Option<&ZstdDict> {
|
||||
self.zstd_dict_by_name.get(name).map(|d| &**d)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_manifest_key<Key: ManifestKey>(
|
||||
conn: &rusqlite::Connection,
|
||||
key: Key,
|
||||
) -> Result<Option<Key::Value>, rusqlite::Error> {
|
||||
conn.query_row(
|
||||
"SELECT value FROM manifest WHERE key = ?",
|
||||
params![key.sql_key()],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn set_manifest_key<Key: ManifestKey>(
|
||||
conn: &rusqlite::Connection,
|
||||
key: Key,
|
||||
value: Key::Value,
|
||||
) -> Result<(), rusqlite::Error> {
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO manifest (key, value) VALUES (?, ?)",
|
||||
params![key.sql_key(), value],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn initialize(
|
||||
@@ -192,7 +149,7 @@ async fn initialize(
|
||||
let mut stmt = conn.prepare("SELECT id, name, level, dict FROM dictionaries")?;
|
||||
let mut rows = vec![];
|
||||
for r in stmt.query_map([], |row| {
|
||||
let id: i64 = row.get(0)?;
|
||||
let id = row.get(0)?;
|
||||
let name: String = row.get(1)?;
|
||||
let level: i32 = row.get(2)?;
|
||||
let dict: Vec<u8> = row.get(3)?;
|
||||
@@ -204,16 +161,19 @@ async fn initialize(
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut zstd_dicts = HashMap::new();
|
||||
for (id, name, level, dict) in rows {
|
||||
let zstd_dict = ZstdDict::create(&dict, level);
|
||||
zstd_dicts.insert(name, (id, zstd_dict));
|
||||
let mut zstd_dicts_by_id = HashMap::new();
|
||||
let mut zstd_dicts_by_name = HashMap::new();
|
||||
for (id, name, level, dict_bytes) in rows {
|
||||
let zstd_dict = Arc::new(ZstdDict::create(id, name.clone(), level, dict_bytes));
|
||||
zstd_dicts_by_id.insert(id, zstd_dict.clone());
|
||||
zstd_dicts_by_name.insert(name, zstd_dict);
|
||||
}
|
||||
|
||||
Ok(Manifest {
|
||||
conn,
|
||||
num_shards,
|
||||
zstd_dicts,
|
||||
zstd_dict_by_id: zstd_dicts_by_id,
|
||||
zstd_dict_by_name: zstd_dicts_by_name,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -227,12 +187,21 @@ mod tests {
|
||||
let mut manifest = initialize(conn, Some(3)).await.unwrap();
|
||||
|
||||
let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100];
|
||||
manifest
|
||||
.train_dictionary("test".to_string(), samples)
|
||||
let zstd_dict = manifest
|
||||
.create_zstd_dict_from_samples("test", samples)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let zstd_dict = &manifest.zstd_dicts.get("test").unwrap().1;
|
||||
// test that indexes are created correctly
|
||||
assert_eq!(
|
||||
zstd_dict.as_ref(),
|
||||
manifest.get_dictionary_by_id(zstd_dict.id()).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
zstd_dict.as_ref(),
|
||||
manifest.get_dictionary_by_name(zstd_dict.name()).unwrap()
|
||||
);
|
||||
|
||||
let data = b"hello world, this is a test of a sort of long string";
|
||||
let compressed = zstd_dict.compress(data).unwrap();
|
||||
let decompressed = zstd_dict.decompress(&compressed).unwrap();
|
||||
94
src/manifest/zstd_dict.rs
Normal file
94
src/manifest/zstd_dict.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
use super::dict_id::DictId;
|
||||
use ouroboros::self_referencing;
|
||||
use std::{error::Error, io};
|
||||
use zstd::dict::{DecoderDictionary, EncoderDictionary};
|
||||
|
||||
#[self_referencing]
|
||||
pub struct ZstdDict {
|
||||
id: DictId,
|
||||
name: String,
|
||||
level: i32,
|
||||
dict_bytes: Vec<u8>,
|
||||
|
||||
#[borrows(dict_bytes)]
|
||||
#[not_covariant]
|
||||
encoder_dict: EncoderDictionary<'this>,
|
||||
|
||||
#[borrows(dict_bytes)]
|
||||
#[not_covariant]
|
||||
decoder_dict: DecoderDictionary<'this>,
|
||||
}
|
||||
|
||||
impl PartialEq for ZstdDict {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.id() == other.id()
|
||||
&& self.name() == other.name()
|
||||
&& self.level() == other.level()
|
||||
&& self.dict_bytes() == other.dict_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ZstdDict {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ZstdDict")
|
||||
.field("id", &self.id())
|
||||
.field("name", &self.name())
|
||||
.field("level", &self.level())
|
||||
.field("dict_bytes.len", &self.dict_bytes().len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl ZstdDict {
|
||||
pub fn create(id: DictId, name: String, level: i32, dict_bytes: Vec<u8>) -> Self {
|
||||
ZstdDictBuilder {
|
||||
id,
|
||||
name,
|
||||
level,
|
||||
dict_bytes,
|
||||
encoder_dict_builder: |dict_bytes| EncoderDictionary::new(dict_bytes, level),
|
||||
decoder_dict_builder: |dict_bytes| DecoderDictionary::new(dict_bytes),
|
||||
}
|
||||
.build()
|
||||
}
|
||||
|
||||
pub fn id(&self) -> DictId {
|
||||
self.with_id(|id| *id)
|
||||
}
|
||||
pub fn name(&self) -> &str {
|
||||
self.borrow_name()
|
||||
}
|
||||
pub fn level(&self) -> i32 {
|
||||
*self.borrow_level()
|
||||
}
|
||||
pub fn dict_bytes(&self) -> &[u8] {
|
||||
self.borrow_dict_bytes()
|
||||
}
|
||||
|
||||
pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
let mut wrapper = io::Cursor::new(data);
|
||||
let mut out_buffer = Vec::with_capacity(data.len());
|
||||
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
|
||||
|
||||
self.with_encoder_dict(|encoder_dict| {
|
||||
let mut encoder =
|
||||
zstd::stream::Encoder::with_prepared_dictionary(&mut output_wrapper, encoder_dict)?;
|
||||
io::copy(&mut wrapper, &mut encoder)?;
|
||||
encoder.finish()
|
||||
})?;
|
||||
Ok(out_buffer)
|
||||
}
|
||||
|
||||
pub fn decompress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
let mut wrapper = io::Cursor::new(data);
|
||||
let mut out_buffer = Vec::with_capacity(data.len());
|
||||
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
|
||||
|
||||
self.with_decoder_dict(|decoder_dict| {
|
||||
let mut decoder =
|
||||
zstd::stream::Decoder::with_prepared_dictionary(&mut wrapper, decoder_dict)?;
|
||||
io::copy(&mut decoder, &mut output_wrapper)
|
||||
})?;
|
||||
Ok(out_buffer)
|
||||
}
|
||||
}
|
||||
222
src/shard/mod.rs
222
src/shard/mod.rs
@@ -1,10 +1,17 @@
|
||||
mod fn_get;
|
||||
mod fn_migrate;
|
||||
mod fn_store;
|
||||
mod shard;
|
||||
pub mod shard_error;
|
||||
|
||||
pub use fn_get::GetResult;
|
||||
pub use fn_store::{StoreArgs, StoreResult};
|
||||
pub use shard::Shard;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
pub use super::shard::test::*;
|
||||
}
|
||||
|
||||
use crate::{sha256::Sha256, shard::shard_error::ShardError, UseCompression};
|
||||
use axum::body::Bytes;
|
||||
@@ -15,13 +22,6 @@ use tracing::{debug, error};
|
||||
|
||||
pub type UtcDateTime = chrono::DateTime<chrono::Utc>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Shard {
|
||||
id: usize,
|
||||
conn: Connection,
|
||||
use_compression: UseCompression,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
enum Compression {
|
||||
None,
|
||||
@@ -45,62 +45,6 @@ impl FromSql for Compression {
|
||||
}
|
||||
}
|
||||
|
||||
impl Shard {
|
||||
pub async fn open(
|
||||
id: usize,
|
||||
use_compression: UseCompression,
|
||||
conn: Connection,
|
||||
) -> Result<Self, Box<dyn Error>> {
|
||||
let shard = Self {
|
||||
id,
|
||||
use_compression,
|
||||
conn,
|
||||
};
|
||||
shard.migrate().await?;
|
||||
Ok(shard)
|
||||
}
|
||||
|
||||
pub async fn close(self) -> Result<(), Box<dyn Error>> {
|
||||
self.conn.close().await.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
pub fn id(&self) -> usize {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub async fn db_size_bytes(&self) -> Result<usize, Box<dyn Error>> {
|
||||
self.query_single_row(
|
||||
"SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()",
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn query_single_row<T: FromSql + Send + 'static>(
|
||||
&self,
|
||||
query: &'static str,
|
||||
) -> Result<T, Box<dyn Error>> {
|
||||
self.conn
|
||||
.call(move |conn| {
|
||||
let value: T = conn.query_row(query, [], |row| row.get(0))?;
|
||||
Ok(value)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
pub async fn num_entries(&self) -> Result<usize, Box<dyn Error>> {
|
||||
get_num_entries(&self.conn).await.map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_num_entries(conn: &Connection) -> Result<usize, tokio_rusqlite::Error> {
|
||||
conn.call(|conn| {
|
||||
let count: usize = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?;
|
||||
Ok(count)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
fn parse_created_at_str(created_at_str: String) -> Result<UtcDateTime, rusqlite::Error> {
|
||||
let parsed = chrono::DateTime::parse_from_rfc3339(&created_at_str)
|
||||
.map_err(|e| rusqlite::Error::ToSqlConversionFailure(e.into()))?;
|
||||
@@ -112,155 +56,3 @@ fn into_tokio_rusqlite_err<E: Into<Box<dyn Error + Send + Sync + 'static>>>(
|
||||
) -> tokio_rusqlite::Error {
|
||||
tokio_rusqlite::Error::Other(e.into())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
use rstest::rstest;
|
||||
|
||||
use super::{StoreResult, UseCompression};
|
||||
use crate::{sha256::Sha256, shard::StoreArgs};
|
||||
|
||||
pub async fn make_shard_with_compression(use_compression: UseCompression) -> super::Shard {
|
||||
let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap();
|
||||
super::Shard::open(0, use_compression, conn).await.unwrap()
|
||||
}
|
||||
|
||||
pub async fn make_shard() -> super::Shard {
|
||||
make_shard_with_compression(UseCompression::Auto).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_num_entries() {
|
||||
let shard = make_shard().await;
|
||||
let num_entries = shard.num_entries().await.unwrap();
|
||||
assert_eq!(num_entries, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_db_size_bytes() {
|
||||
let shard = make_shard().await;
|
||||
let db_size = shard.db_size_bytes().await.unwrap();
|
||||
assert!(db_size > 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_not_found_get() {
|
||||
let shard = make_shard().await;
|
||||
let sha256 = Sha256::from_bytes("hello, world!".as_bytes());
|
||||
let get_result = shard.get(sha256).await.unwrap();
|
||||
assert!(get_result.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_and_get() {
|
||||
let shard = make_shard().await;
|
||||
let data = "hello, world!".as_bytes();
|
||||
let sha256 = Sha256::from_bytes(data);
|
||||
let store_result = shard
|
||||
.store(StoreArgs {
|
||||
sha256,
|
||||
content_type: "text/plain".to_string(),
|
||||
data: data.into(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
match store_result {
|
||||
StoreResult::Created {
|
||||
stored_size,
|
||||
data_size,
|
||||
created_at,
|
||||
} => {
|
||||
assert_eq!(stored_size, data.len());
|
||||
assert_eq!(data_size, data.len());
|
||||
assert!(created_at > chrono::Utc::now() - chrono::Duration::seconds(1));
|
||||
}
|
||||
_ => panic!("expected StoreResult::Created"),
|
||||
}
|
||||
assert_eq!(shard.num_entries().await.unwrap(), 1);
|
||||
|
||||
let get_result = shard.get(sha256).await.unwrap().unwrap();
|
||||
assert_eq!(get_result.content_type, "text/plain");
|
||||
assert_eq!(get_result.data, data);
|
||||
assert_eq!(get_result.stored_size, data.len());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_duplicate() {
|
||||
let shard = make_shard().await;
|
||||
let data = "hello, world!".as_bytes();
|
||||
let sha256 = Sha256::from_bytes(data);
|
||||
|
||||
let store_result = shard
|
||||
.store(StoreArgs {
|
||||
sha256,
|
||||
content_type: "text/plain".to_string(),
|
||||
data: data.into(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(matches!(store_result, StoreResult::Created { .. }));
|
||||
|
||||
let created_at = match store_result {
|
||||
StoreResult::Created {
|
||||
stored_size,
|
||||
data_size,
|
||||
created_at,
|
||||
} => {
|
||||
assert_eq!(stored_size, data.len());
|
||||
assert_eq!(data_size, data.len());
|
||||
created_at
|
||||
}
|
||||
_ => panic!("expected StoreResult::Created"),
|
||||
};
|
||||
|
||||
let store_result = shard
|
||||
.store(StoreArgs {
|
||||
sha256,
|
||||
content_type: "text/plain".to_string(),
|
||||
data: data.into(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
store_result,
|
||||
StoreResult::Exists {
|
||||
data_size: data.len(),
|
||||
stored_size: data.len(),
|
||||
created_at
|
||||
}
|
||||
);
|
||||
assert_eq!(shard.num_entries().await.unwrap(), 1);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_compression_store_get(
|
||||
#[values(UseCompression::Auto, UseCompression::None, UseCompression::Zstd)]
|
||||
use_compression: UseCompression,
|
||||
#[values(true, false)] incompressible_data: bool,
|
||||
#[values("text/string", "image/jpg", "application/octet-stream")] content_type: String,
|
||||
) {
|
||||
let shard = make_shard_with_compression(use_compression).await;
|
||||
let mut data = vec![b'.'; 1024];
|
||||
if incompressible_data {
|
||||
for byte in data.iter_mut() {
|
||||
*byte = rand::random();
|
||||
}
|
||||
}
|
||||
|
||||
let sha256 = Sha256::from_bytes(&data);
|
||||
let store_result = shard
|
||||
.store(StoreArgs {
|
||||
sha256,
|
||||
content_type: content_type.clone(),
|
||||
data: data.clone().into(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(matches!(store_result, StoreResult::Created { .. }));
|
||||
|
||||
let get_result = shard.get(sha256).await.unwrap().unwrap();
|
||||
assert_eq!(get_result.content_type, content_type);
|
||||
assert_eq!(get_result.data, data);
|
||||
}
|
||||
}
|
||||
|
||||
216
src/shard/shard.rs
Normal file
216
src/shard/shard.rs
Normal file
@@ -0,0 +1,216 @@
|
||||
use super::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Shard {
|
||||
pub(super) id: usize,
|
||||
pub(super) conn: Connection,
|
||||
pub(super) use_compression: UseCompression,
|
||||
}
|
||||
|
||||
impl Shard {
|
||||
pub async fn open(
|
||||
id: usize,
|
||||
use_compression: UseCompression,
|
||||
conn: Connection,
|
||||
) -> Result<Self, Box<dyn Error>> {
|
||||
let shard = Self {
|
||||
id,
|
||||
use_compression,
|
||||
conn,
|
||||
};
|
||||
shard.migrate().await?;
|
||||
Ok(shard)
|
||||
}
|
||||
|
||||
pub async fn close(self) -> Result<(), Box<dyn Error>> {
|
||||
self.conn.close().await.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
pub fn id(&self) -> usize {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub async fn db_size_bytes(&self) -> Result<usize, Box<dyn Error>> {
|
||||
self.query_single_row(
|
||||
"SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()",
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn query_single_row<T: FromSql + Send + 'static>(
|
||||
&self,
|
||||
query: &'static str,
|
||||
) -> Result<T, Box<dyn Error>> {
|
||||
self.conn
|
||||
.call(move |conn| {
|
||||
let value: T = conn.query_row(query, [], |row| row.get(0))?;
|
||||
Ok(value)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
pub async fn num_entries(&self) -> Result<usize, Box<dyn Error>> {
|
||||
get_num_entries(&self.conn).await.map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_num_entries(conn: &Connection) -> Result<usize, tokio_rusqlite::Error> {
|
||||
conn.call(|conn| {
|
||||
let count: usize = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?;
|
||||
Ok(count)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
use rstest::rstest;
|
||||
|
||||
use super::{StoreResult, UseCompression};
|
||||
use crate::{sha256::Sha256, shard::StoreArgs};
|
||||
|
||||
pub async fn make_shard_with_compression(use_compression: UseCompression) -> super::Shard {
|
||||
let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap();
|
||||
super::Shard::open(0, use_compression, conn).await.unwrap()
|
||||
}
|
||||
|
||||
pub async fn make_shard() -> super::Shard {
|
||||
make_shard_with_compression(UseCompression::Auto).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_num_entries() {
|
||||
let shard = make_shard().await;
|
||||
let num_entries = shard.num_entries().await.unwrap();
|
||||
assert_eq!(num_entries, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_db_size_bytes() {
|
||||
let shard = make_shard().await;
|
||||
let db_size = shard.db_size_bytes().await.unwrap();
|
||||
assert!(db_size > 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_not_found_get() {
|
||||
let shard = make_shard().await;
|
||||
let sha256 = Sha256::from_bytes("hello, world!".as_bytes());
|
||||
let get_result = shard.get(sha256).await.unwrap();
|
||||
assert!(get_result.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_and_get() {
|
||||
let shard = make_shard().await;
|
||||
let data = "hello, world!".as_bytes();
|
||||
let sha256 = Sha256::from_bytes(data);
|
||||
let store_result = shard
|
||||
.store(StoreArgs {
|
||||
sha256,
|
||||
content_type: "text/plain".to_string(),
|
||||
data: data.into(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
match store_result {
|
||||
StoreResult::Created {
|
||||
stored_size,
|
||||
data_size,
|
||||
created_at,
|
||||
} => {
|
||||
assert_eq!(stored_size, data.len());
|
||||
assert_eq!(data_size, data.len());
|
||||
assert!(created_at > chrono::Utc::now() - chrono::Duration::seconds(1));
|
||||
}
|
||||
_ => panic!("expected StoreResult::Created"),
|
||||
}
|
||||
assert_eq!(shard.num_entries().await.unwrap(), 1);
|
||||
|
||||
let get_result = shard.get(sha256).await.unwrap().unwrap();
|
||||
assert_eq!(get_result.content_type, "text/plain");
|
||||
assert_eq!(get_result.data, data);
|
||||
assert_eq!(get_result.stored_size, data.len());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_duplicate() {
|
||||
let shard = make_shard().await;
|
||||
let data = "hello, world!".as_bytes();
|
||||
let sha256 = Sha256::from_bytes(data);
|
||||
|
||||
let store_result = shard
|
||||
.store(StoreArgs {
|
||||
sha256,
|
||||
content_type: "text/plain".to_string(),
|
||||
data: data.into(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(matches!(store_result, StoreResult::Created { .. }));
|
||||
|
||||
let created_at = match store_result {
|
||||
StoreResult::Created {
|
||||
stored_size,
|
||||
data_size,
|
||||
created_at,
|
||||
} => {
|
||||
assert_eq!(stored_size, data.len());
|
||||
assert_eq!(data_size, data.len());
|
||||
created_at
|
||||
}
|
||||
_ => panic!("expected StoreResult::Created"),
|
||||
};
|
||||
|
||||
let store_result = shard
|
||||
.store(StoreArgs {
|
||||
sha256,
|
||||
content_type: "text/plain".to_string(),
|
||||
data: data.into(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
store_result,
|
||||
StoreResult::Exists {
|
||||
data_size: data.len(),
|
||||
stored_size: data.len(),
|
||||
created_at
|
||||
}
|
||||
);
|
||||
assert_eq!(shard.num_entries().await.unwrap(), 1);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_compression_store_get(
|
||||
#[values(UseCompression::Auto, UseCompression::None, UseCompression::Zstd)]
|
||||
use_compression: UseCompression,
|
||||
#[values(true, false)] incompressible_data: bool,
|
||||
#[values("text/string", "image/jpg", "application/octet-stream")] content_type: String,
|
||||
) {
|
||||
let shard = make_shard_with_compression(use_compression).await;
|
||||
let mut data = vec![b'.'; 1024];
|
||||
if incompressible_data {
|
||||
for byte in data.iter_mut() {
|
||||
*byte = rand::random();
|
||||
}
|
||||
}
|
||||
|
||||
let sha256 = Sha256::from_bytes(&data);
|
||||
let store_result = shard
|
||||
.store(StoreArgs {
|
||||
sha256,
|
||||
content_type: content_type.clone(),
|
||||
data: data.clone().into(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(matches!(store_result, StoreResult::Created { .. }));
|
||||
|
||||
let get_result = shard.get(sha256).await.unwrap().unwrap();
|
||||
assert_eq!(get_result.content_type, content_type);
|
||||
assert_eq!(get_result.data, data);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user