store compression hint name

This commit is contained in:
Dylan Knutson
2024-05-05 20:52:21 -07:00
parent bd2de7cfac
commit a3b550526e
10 changed files with 183 additions and 45 deletions

View File

@@ -73,7 +73,7 @@ impl Compressor {
pub fn compress<Data: Into<CompressibleData>>( pub fn compress<Data: Into<CompressibleData>>(
&self, &self,
name: &str, hint: Option<&str>,
content_type: &str, content_type: &str,
data: Data, data: Data,
) -> Result<(CompressionId, CompressibleData)> { ) -> Result<(CompressionId, CompressibleData)> {
@@ -84,17 +84,25 @@ impl Compressor {
CompressionPolicy::ForceZstd => true, CompressionPolicy::ForceZstd => true,
}; };
let generic_compress = || -> Result<_> {
Ok((
CompressionId::ZstdGeneric,
zstd::stream::encode_all(data.as_ref(), 3)?.into(),
))
};
if should_compress { if should_compress {
let (id, compressed): (_, CompressibleData) = if let Some(dict) = self.by_name(name) { let (id, compressed): (_, CompressibleData) = if let Some(hint) = hint {
( if let Some(dict) = self.by_name(hint) {
CompressionId::ZstdDictId(dict.id()), (
dict.compress(&data)?.into(), CompressionId::ZstdDictId(dict.id()),
) dict.compress(&data)?.into(),
)
} else {
generic_compress()?
}
} else { } else {
( generic_compress()?
CompressionId::ZstdGeneric,
zstd::stream::encode_all(data.as_ref(), 3)?.into(),
)
}; };
if compressed.len() < data.len() { if compressed.len() < data.len() {
@@ -128,10 +136,10 @@ impl Compressor {
fn check(&self, id: ZstdDictId, name: &str) -> Result<()> { fn check(&self, id: ZstdDictId, name: &str) -> Result<()> {
if self.zstd_dict_by_id.contains_key(&id) { if self.zstd_dict_by_id.contains_key(&id) {
return Err(format!("zstd dictionary {:?} already exists", id).into()); return Err(format!("zstd dictionary id {} already exists", id.0).into());
} }
if self.zstd_dict_by_name.contains_key(name) { if self.zstd_dict_by_name.contains_key(name) {
return Err(format!("zstd dictionary {} already exists", name).into()); return Err(format!("zstd dictionary name {} already exists", name).into());
} }
Ok(()) Ok(())
} }
@@ -198,7 +206,7 @@ pub mod test {
let compressor = make_compressor_with_policy(compression_policy); let compressor = make_compressor_with_policy(compression_policy);
let data = b"hello, world!".to_vec(); let data = b"hello, world!".to_vec();
let (compression_id, compressed) = compressor let (compression_id, compressed) = compressor
.compress("dict1", content_type, data.clone()) .compress(Some("dict1"), content_type, data.clone())
.unwrap(); .unwrap();
let data_uncompressed = compressor.decompress(compression_id, compressed).unwrap(); let data_uncompressed = compressor.decompress(compression_id, compressed).unwrap();
@@ -210,9 +218,21 @@ pub mod test {
let compressor = make_compressor(); let compressor = make_compressor();
let data = b"hello, world".to_vec(); let data = b"hello, world".to_vec();
let (compression_id, compressed) = compressor let (compression_id, compressed) = compressor
.compress("dict1", "text/plain", data.clone()) .compress(Some("dict1"), "text/plain", data.clone())
.unwrap(); .unwrap();
assert_eq!(compression_id, CompressionId::None); assert_eq!(compression_id, CompressionId::None);
assert_eq!(compressed, data); assert_eq!(compressed, data);
} }
#[test]
fn test_compresses_longer_data() {
let compressor = make_compressor();
let data = vec![b'.'; 1024];
let (compression_id, compressed) = compressor
.compress(Some("dict1"), "text/plain", data.clone())
.unwrap();
assert_eq!(compression_id, CompressionId::ZstdDictId(1.into()));
assert_ne!(compressed, data);
assert!(compressed.len() < data.len());
}
} }

View File

@@ -15,6 +15,7 @@ use tracing::error;
pub struct StoreRequest { pub struct StoreRequest {
pub sha256: Option<String>, pub sha256: Option<String>,
pub content_type: String, pub content_type: String,
pub compression_hint: Option<String>,
pub data: FieldData<Bytes>, pub data: FieldData<Bytes>,
} }
@@ -108,6 +109,7 @@ pub async fn store_handler(
sha256, sha256,
content_type: request.content_type, content_type: request.content_type,
data: request.data.contents, data: request.data.contents,
compression_hint: request.compression_hint,
compressor, compressor,
}) })
.await .await
@@ -141,6 +143,7 @@ pub mod test {
TypedMultipart(StoreRequest { TypedMultipart(StoreRequest {
sha256: sha256.map(|s| s.hex_string()), sha256: sha256.map(|s| s.hex_string()),
content_type: content_type.to_string(), content_type: content_type.to_string(),
compression_hint: None,
data: FieldData { data: FieldData {
metadata: Default::default(), metadata: Default::default(),
contents: data.into(), contents: data.into(),
@@ -200,11 +203,11 @@ pub mod test {
} }
#[rstest] #[rstest]
// textual should be compressed by default // textual should be compressed if 'auto'
#[case("text/plain", CompressionPolicy::Auto, make_assert_lt(1024))] #[case("text/plain", CompressionPolicy::Auto, make_assert_lt(1024))]
#[case("text/plain", CompressionPolicy::ForceZstd, make_assert_lt(1024))] #[case("text/plain", CompressionPolicy::ForceZstd, make_assert_lt(1024))]
#[case("text/plain", CompressionPolicy::None, make_assert_eq(1024))] #[case("text/plain", CompressionPolicy::None, make_assert_eq(1024))]
// images, etc should not be compressed by default // images, etc should not be compressed if 'auto'
#[case("image/jpg", CompressionPolicy::Auto, make_assert_eq(1024))] #[case("image/jpg", CompressionPolicy::Auto, make_assert_eq(1024))]
#[case("image/jpg", CompressionPolicy::ForceZstd, make_assert_lt(1024))] #[case("image/jpg", CompressionPolicy::ForceZstd, make_assert_lt(1024))]
#[case("image/jpg", CompressionPolicy::None, make_assert_eq(1024))] #[case("image/jpg", CompressionPolicy::None, make_assert_eq(1024))]

View File

@@ -71,6 +71,13 @@ fn main() -> Result<(), AsyncBoxError> {
let db_path = PathBuf::from(&args.db_path); let db_path = PathBuf::from(&args.db_path);
let num_shards = args.shards; let num_shards = args.shards;
if db_path.is_file() {
return Err("db_path must be a directory".into());
}
if !db_path.is_dir() {
std::fs::create_dir_all(&db_path)?;
}
// block on opening the manifest // block on opening the manifest
let manifest = block_on(async { let manifest = block_on(async {
Manifest::open( Manifest::open(
@@ -121,17 +128,18 @@ fn main() -> Result<(), AsyncBoxError> {
async fn dict_loop(manifest: Manifest, shards: ShardsArc) { async fn dict_loop(manifest: Manifest, shards: ShardsArc) {
loop { loop {
info!("dict loop: running");
let compressor = manifest.compressor();
let _compressor = compressor.read().await;
for _shard in shards.iter() {}
select! { select! {
_ = tokio::time::sleep(tokio::time::Duration::from_secs(1)) => {} _ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => {}
_ = crate::shutdown_signal::shutdown_signal() => { _ = crate::shutdown_signal::shutdown_signal() => {
info!("dict loop: shutdown signal received"); info!("dict loop: shutdown signal received");
break; break;
} }
} }
info!("dict loop: running");
let compressor = manifest.compressor();
let _compressor = compressor.read().await;
for _shard in shards.iter() {}
} }
} }

View File

@@ -3,6 +3,10 @@ use std::{
fmt::{Display, LowerHex}, fmt::{Display, LowerHex},
}; };
use rusqlite::{
types::{FromSql, ToSqlOutput},
ToSql,
};
use sha2::Digest; use sha2::Digest;
#[derive(Debug)] #[derive(Debug)]
@@ -47,6 +51,20 @@ impl Sha256 {
} }
} }
impl ToSql for Sha256 {
fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput> {
Ok(ToSqlOutput::Borrowed(rusqlite::types::ValueRef::Blob(
&self.0,
)))
}
}
impl FromSql for Sha256 {
fn column_result(value: rusqlite::types::ValueRef) -> rusqlite::types::FromSqlResult<Self> {
let bytes = <[u8; 32]>::column_result(value)?;
Ok(Self(bytes))
}
}
impl LowerHex for Sha256 { impl LowerHex for Sha256 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for byte in self.0.iter() { for byte in self.0.iter() {

View File

@@ -23,10 +23,11 @@ pub struct GetResult {
impl Shard { impl Shard {
pub async fn get(&self, args: GetArgs) -> Result<Option<GetResult>, AsyncBoxError> { pub async fn get(&self, args: GetArgs) -> Result<Option<GetResult>, AsyncBoxError> {
let sha256 = args.sha256;
let maybe_row = self let maybe_row = self
.conn .conn
.call(move |conn| get_compressed_row(conn, sha256).map_err(into_tokio_rusqlite_err)) .call(move |conn| {
get_compressed_row(conn, &args.sha256).map_err(into_tokio_rusqlite_err)
})
.await .await
.map_err(|e| { .map_err(|e| {
error!("get failed: {}", e); error!("get failed: {}", e);
@@ -51,13 +52,13 @@ impl Shard {
fn get_compressed_row( fn get_compressed_row(
conn: &mut rusqlite::Connection, conn: &mut rusqlite::Connection,
sha256: Sha256, sha256: &Sha256,
) -> Result<Option<(String, usize, UtcDateTime, CompressionId, Vec<u8>)>, rusqlite::Error> { ) -> Result<Option<(String, usize, UtcDateTime, CompressionId, Vec<u8>)>, rusqlite::Error> {
conn.query_row( conn.query_row(
"SELECT content_type, compressed_size, created_at, compression_id, data "SELECT content_type, compressed_size, created_at, compression_id, data
FROM entries FROM entries
WHERE sha256 = ?", WHERE sha256 = ?",
params![sha256.hex_string()], params![sha256],
|row| { |row| {
let content_type = row.get(0)?; let content_type = row.get(0)?;
let stored_size = row.get(1)?; let stored_size = row.get(1)?;

View File

@@ -71,6 +71,20 @@ fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Err
[], [],
)?; )?;
conn.execute(
"CREATE TABLE IF NOT EXISTS compression_hints (
name TEXT NOT NULL,
ordering INTEGER NOT NULL,
sha256 BLOB NOT NULL
)",
[],
)?;
conn.execute(
"CREATE INDEX IF NOT EXISTS compression_hints_name_idx ON compression_hints (name, ordering)",
[],
)?;
conn.execute( conn.execute(
"INSERT INTO schema_version (version, created_at) VALUES (1, ?)", "INSERT INTO schema_version (version, created_at) VALUES (1, ?)",
[chrono::Utc::now().to_rfc3339()], [chrono::Utc::now().to_rfc3339()],

View File

@@ -0,0 +1,39 @@
use rusqlite::params;
use crate::{into_tokio_rusqlite_err, sha256::Sha256, AsyncBoxError};
use super::Shard;
pub struct SampleForHintResult {
pub sha256: Sha256,
pub data: Vec<u8>,
}
impl Shard {
pub async fn samples_for_hint(
&self,
compression_hint: &str,
limit: usize,
) -> Result<Vec<SampleForHintResult>, AsyncBoxError> {
let compression_hint = compression_hint.to_owned();
let result = self
.conn
.call(move |conn| {
let mut stmt = conn.prepare(
"SELECT sha256, data FROM entries WHERE sha256 IN (
SELECT sha256 FROM compression_hints WHERE name = ? ORDER BY ordering
) LIMIT ?",
)?;
let rows = stmt.query_map(params![compression_hint, limit], |row| {
let sha256: Sha256 = row.get(0)?;
let data: Vec<u8> = row.get(1)?;
Ok(SampleForHintResult { sha256, data })
})?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(into_tokio_rusqlite_err)
})
.await?;
Ok(result)
}
}

View File

@@ -1,4 +1,5 @@
use crate::{ use crate::{
compressible_data::CompressibleData,
compressor::CompressorArc, compressor::CompressorArc,
into_tokio_rusqlite_err, into_tokio_rusqlite_err,
sql_types::{CompressionId, UtcDateTime}, sql_types::{CompressionId, UtcDateTime},
@@ -27,6 +28,7 @@ pub struct StoreArgs {
pub content_type: String, pub content_type: String,
pub data: Bytes, pub data: Bytes,
pub compressor: CompressorArc, pub compressor: CompressorArc,
pub compression_hint: Option<String>,
} }
impl Shard { impl Shard {
@@ -37,37 +39,34 @@ impl Shard {
data, data,
content_type, content_type,
compressor, compressor,
compression_hint,
}: StoreArgs, }: StoreArgs,
) -> Result<StoreResult, AsyncBoxError> { ) -> Result<StoreResult, AsyncBoxError> {
let sha256 = sha256.hex_string(); let existing_entry = self
// check for existing entry
let sha256_clone = sha256.clone();
let maybe_existing_entry = self
.conn .conn
.call(move |conn| { .call(move |conn| find_with_sha256(conn, &sha256).map_err(into_tokio_rusqlite_err))
find_with_sha256(conn, sha256_clone.as_str()).map_err(into_tokio_rusqlite_err)
})
.await?; .await?;
if let Some(entry) = maybe_existing_entry { if let Some(entry) = existing_entry {
return Ok(entry); return Ok(entry);
} }
let uncompressed_size = data.len(); let uncompressed_size = data.len();
let compressor = compressor.read().await; let compressor = compressor.read().await;
let (compression_id, data) = compressor.compress("foobar", &content_type, data)?; let (compression_id, data) =
compressor.compress(compression_hint.as_deref(), &content_type, data)?;
self.conn self.conn
.call(move |conn| { .call(move |conn| {
insert( insert(
conn, conn,
sha256, &sha256,
content_type, content_type,
compression_id, compression_id,
uncompressed_size, uncompressed_size,
data.as_ref(), data,
compression_hint,
) )
.map_err(into_tokio_rusqlite_err) .map_err(into_tokio_rusqlite_err)
}) })
@@ -78,7 +77,7 @@ impl Shard {
fn find_with_sha256( fn find_with_sha256(
conn: &mut rusqlite::Connection, conn: &mut rusqlite::Connection,
sha256: &str, sha256: &Sha256,
) -> Result<Option<StoreResult>, rusqlite::Error> { ) -> Result<Option<StoreResult>, rusqlite::Error> {
conn.query_row( conn.query_row(
"SELECT uncompressed_size, compressed_size, created_at FROM entries WHERE sha256 = ?", "SELECT uncompressed_size, compressed_size, created_at FROM entries WHERE sha256 = ?",
@@ -96,19 +95,19 @@ fn find_with_sha256(
fn insert( fn insert(
conn: &mut rusqlite::Connection, conn: &mut rusqlite::Connection,
sha256: String, sha256: &Sha256,
content_type: String, content_type: String,
compression_id: CompressionId, compression_id: CompressionId,
uncompressed_size: usize, uncompressed_size: usize,
data: &[u8], data: CompressibleData,
compression_hint: Option<String>,
) -> Result<StoreResult, rusqlite::Error> { ) -> Result<StoreResult, rusqlite::Error> {
let created_at = UtcDateTime::now(); let created_at = UtcDateTime::now();
let compressed_size = data.len(); let compressed_size = data.len();
conn.execute("INSERT INTO entries conn.execute("INSERT INTO entries
(sha256, content_type, compression_id, uncompressed_size, compressed_size, data, created_at) (sha256, content_type, compression_id, uncompressed_size, compressed_size, data, created_at)
VALUES VALUES (?, ?, ?, ?, ?, ?, ?)
(?, ?, ?, ?, ?, ?, ?)
", ",
params![ params![
sha256, sha256,
@@ -116,11 +115,21 @@ fn insert(
compression_id, compression_id,
uncompressed_size, uncompressed_size,
compressed_size, compressed_size,
data, data.as_ref(),
created_at, created_at,
], ],
)?; )?;
if let Some(compression_hint) = compression_hint {
let rand_ordering = rand::random::<i64>();
conn.execute(
"INSERT INTO compression_hints
(name, ordering, sha256)
VALUES (?, ?, ?)",
params![compression_hint, rand_ordering, sha256],
)?;
}
Ok(StoreResult::Created { Ok(StoreResult::Created {
stored_size: compressed_size, stored_size: compressed_size,
data_size: uncompressed_size, data_size: uncompressed_size,

View File

@@ -1,5 +1,6 @@
mod fn_get; mod fn_get;
mod fn_migrate; mod fn_migrate;
mod fn_samples_for_hint;
mod fn_store; mod fn_store;
mod shard; mod shard;
pub mod shard_error; pub mod shard_error;

View File

@@ -54,12 +54,14 @@ async fn get_num_entries(conn: &Connection) -> Result<usize, tokio_rusqlite::Err
#[cfg(test)] #[cfg(test)]
pub mod test { pub mod test {
use crate::compressor::Compressor;
use crate::{ use crate::{
compressor::test::make_compressor, compressor::test::make_compressor,
sha256::Sha256, sha256::Sha256,
shard::{GetArgs, StoreArgs, StoreResult}, shard::{GetArgs, StoreArgs, StoreResult},
CompressionPolicy, CompressionPolicy,
}; };
use rstest::rstest; use rstest::rstest;
pub async fn make_shard() -> super::Shard { pub async fn make_shard() -> super::Shard {
@@ -102,6 +104,7 @@ pub mod test {
content_type: "text/plain".to_string(), content_type: "text/plain".to_string(),
data: data.into(), data: data.into(),
compressor: compressor.clone(), compressor: compressor.clone(),
..Default::default()
}) })
.await .await
.unwrap(); .unwrap();
@@ -191,8 +194,6 @@ pub mod test {
#[values(true, false)] incompressible_data: bool, #[values(true, false)] incompressible_data: bool,
#[values("text/string", "image/jpg", "application/octet-stream")] content_type: String, #[values("text/string", "image/jpg", "application/octet-stream")] content_type: String,
) { ) {
use crate::compressor::Compressor;
let shard = make_shard().await; let shard = make_shard().await;
let compressor = Compressor::new(compression_policy).into_arc(); let compressor = Compressor::new(compression_policy).into_arc();
let mut data = vec![b'.'; 1024]; let mut data = vec![b'.'; 1024];
@@ -222,4 +223,28 @@ pub mod test {
assert_eq!(get_result.content_type, content_type); assert_eq!(get_result.content_type, content_type);
assert_eq!(get_result.data, data); assert_eq!(get_result.data, data);
} }
#[tokio::test]
async fn test_compression_hint() {
let shard = make_shard().await;
let compressor = make_compressor().into_arc();
let data = "hello, world!".as_bytes();
let sha256 = Sha256::from_bytes(data);
let store_result = shard
.store(StoreArgs {
sha256,
content_type: "text/plain".to_string(),
data: data.into(),
compressor: compressor.clone(),
compression_hint: Some("hint1".to_string()),
})
.await
.unwrap();
assert!(matches!(store_result, StoreResult::Created { .. }));
let results = shard.samples_for_hint("hint1", 10).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].sha256, sha256);
assert_eq!(results[0].data, data);
}
} }