clippy, concat_lines

This commit is contained in:
Dylan Knutson
2024-05-06 09:05:36 -07:00
parent 83b4dacede
commit 6deb909c43
12 changed files with 243 additions and 169 deletions

View File

@@ -35,3 +35,7 @@ ouroboros = "0.18.3"
[dev-dependencies]
rstest = "0.19.0"
[lints.rust]
unsafe_code = "forbid"
unused_must_use = "forbid"

View File

@@ -39,29 +39,14 @@ impl Compressor {
Arc::new(RwLock::new(self))
}
fn _add_from_samples<Str: Into<String>>(
&mut self,
id: ZstdDictId,
name: Str,
samples: Vec<&[u8]>,
) -> Result<ZstdDictArc> {
let name = name.into();
self.check(id, &name)?;
let zstd_dict = ZstdDict::from_samples(id, name, 3, samples)?;
Ok(self.add(zstd_dict))
}
pub fn add_from_bytes<Str: Into<String>>(
&mut self,
id: ZstdDictId,
name: Str,
level: i32,
dict_bytes: Vec<u8>,
) -> Result<ZstdDictArc> {
let name = name.into();
self.check(id, &name)?;
let zstd_dict = ZstdDict::from_dict_bytes(id, name, level, dict_bytes);
Ok(self.add(zstd_dict))
pub fn add(&mut self, zstd_dict: ZstdDict) -> Result<ZstdDictArc> {
self.check(zstd_dict.id(), zstd_dict.name())?;
let zstd_dict = Arc::new(zstd_dict);
self.zstd_dict_by_id
.insert(zstd_dict.id(), zstd_dict.clone());
self.zstd_dict_by_name
.insert(zstd_dict.name().to_string(), zstd_dict.clone());
Ok(zstd_dict)
}
pub fn by_id(&self, id: ZstdDictId) -> Option<&ZstdDict> {
@@ -70,6 +55,9 @@ impl Compressor {
pub fn by_name(&self, name: &str) -> Option<&ZstdDict> {
self.zstd_dict_by_name.get(name).map(|d| &**d)
}
pub fn names(&self) -> impl Iterator<Item = &String> {
self.zstd_dict_by_name.keys()
}
pub fn compress<Data: Into<CompressibleData>>(
&self,
@@ -143,15 +131,6 @@ impl Compressor {
}
Ok(())
}
fn add(&mut self, zstd_dict: ZstdDict) -> ZstdDictArc {
let zstd_dict = Arc::new(zstd_dict);
self.zstd_dict_by_id
.insert(zstd_dict.id(), zstd_dict.clone());
self.zstd_dict_by_name
.insert(zstd_dict.name().to_string(), zstd_dict.clone());
zstd_dict
}
}
fn auto_compressible_content_type(content_type: &str) -> bool {
@@ -179,7 +158,7 @@ pub mod test {
pub fn make_compressor_with(compression_policy: CompressionPolicy) -> Compressor {
let mut compressor = Compressor::new(compression_policy);
let zstd_dict = make_zstd_dict(1.into(), "dict1");
compressor.add(zstd_dict);
compressor.add(zstd_dict).unwrap();
compressor
}

7
src/concat_lines.rs Normal file
View File

@@ -0,0 +1,7 @@
// like concat, but adds a newline between each expression
#[macro_export]
macro_rules! concat_lines {
($($e:expr),* $(,)?) => {
concat!($($e, "\n"),*)
};
}

View File

@@ -1,5 +1,6 @@
mod compressible_data;
mod compressor;
mod concat_lines;
mod handlers;
mod manifest;
mod sha256;
@@ -22,7 +23,7 @@ use shards::ShardsArc;
use std::{error::Error, path::PathBuf, sync::Arc};
use tokio::{net::TcpListener, select, spawn};
use tokio_rusqlite::Connection;
use tracing::info;
use tracing::{debug, info};
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
@@ -117,7 +118,7 @@ fn main() -> Result<(), AsyncBoxError> {
let compressor = manifest.compressor();
let dict_loop_handle = spawn(dict_loop(manifest, shards.clone()));
let server_handle = spawn(server_loop(server, shards, compressor));
dict_loop_handle.await?;
dict_loop_handle.await??;
server_handle.await??;
info!("server closed sqlite connections. bye!");
Ok::<_, AsyncBoxError>(())
@@ -126,18 +127,35 @@ fn main() -> Result<(), AsyncBoxError> {
Ok(())
}
async fn dict_loop(manifest: Manifest, shards: ShardsArc) {
async fn dict_loop(manifest: Manifest, shards: ShardsArc) -> Result<(), AsyncBoxError> {
loop {
info!("dict loop: running");
let mut hint_names = shards.hint_names().await?;
{
// find what hint names don't have a corresponding dictionary
let compressor = manifest.compressor();
let _compressor = compressor.read().await;
for _shard in shards.iter() {}
let compressor = compressor.read().await;
compressor.names().for_each(|name| {
hint_names.remove(name);
});
}
for hint_name in hint_names {
debug!("creating dictionary for {}", hint_name);
let samples = shards.samples_for_hint_name(&hint_name, 10).await?;
let samples_bytes = samples
.iter()
.map(|s| s.data.as_ref())
.collect::<Vec<&[u8]>>();
manifest
.insert_dict_from_samples(hint_name, samples_bytes)
.await?;
}
select! {
_ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => {}
_ = crate::shutdown_signal::shutdown_signal() => {
info!("dict loop: shutdown signal received");
break;
return Ok(());
}
}
}

View File

@@ -8,8 +8,9 @@ use tokio_rusqlite::Connection;
use crate::{
compressor::{Compressor, CompressorArc},
into_tokio_rusqlite_err,
zstd_dict::ZstdDictArc,
concat_lines,
sql_types::ZstdDictId,
zstd_dict::{ZstdDict, ZstdDictArc, ZstdEncoder},
AsyncBoxError,
};
@@ -21,17 +22,11 @@ pub struct Manifest {
compressor: Arc<RwLock<Compressor>>,
}
pub type ManifestArc = Arc<Manifest>;
impl Manifest {
pub async fn open(conn: Connection, num_shards: Option<usize>) -> Result<Self, AsyncBoxError> {
initialize(conn, num_shards).await
}
pub fn into_arc(self) -> ManifestArc {
Arc::new(self)
}
pub fn num_shards(&self) -> usize {
self.num_shards
}
@@ -40,35 +35,29 @@ impl Manifest {
self.compressor.clone()
}
pub async fn insert_zstd_dict_from_samples<Str: Into<String>>(
pub async fn insert_dict_from_samples<Str: Into<String>>(
&self,
name: Str,
samples: Vec<&[u8]>,
) -> Result<ZstdDictArc, Box<dyn Error>> {
) -> Result<ZstdDictArc, AsyncBoxError> {
let name = name.into();
let dict_bytes = zstd::dict::from_samples(
&samples,
1024 * 1024, // 1MB max dictionary size
)?;
let compressor = self.compressor.clone();
let encoder = ZstdEncoder::from_samples(3, samples);
let zstd_dict = self
.conn
.call(move |conn| {
let level = 3;
let mut stmt = conn.prepare(
"INSERT INTO dictionaries (name, level, dict)
VALUES (?, ?, ?)
RETURNING id",
)?;
let dict_id = stmt.query_row(params![name, level, dict_bytes], |row| row.get(0))?;
let mut compressor = compressor.blocking_write();
let zstd_dict = compressor
.add_from_bytes(dict_id, name, level, dict_bytes)
.map_err(into_tokio_rusqlite_err)?;
Ok(zstd_dict)
let mut stmt = conn.prepare(concat_lines!(
"INSERT INTO dictionaries (name, level, dict)",
"VALUES (?, ?, ?)",
"RETURNING id"
))?;
let dict_id =
stmt.query_row(params![name, level, encoder.dict_bytes()], |row| row.get(0))?;
Ok(ZstdDict::new(dict_id, name, encoder))
})
.await?;
Ok(zstd_dict)
let mut compressor = self.compressor.write().await;
compressor.add(zstd_dict)
}
}
@@ -84,12 +73,14 @@ async fn initialize(
)?;
conn.execute(
"CREATE TABLE IF NOT EXISTS dictionaries (
id INTEGER PRIMARY KEY AUTOINCREMENT,
level INTEGER NOT NULL,
name TEXT NOT NULL,
dict BLOB NOT NULL
)",
concat_lines!(
"CREATE TABLE IF NOT EXISTS dictionaries (",
" id INTEGER PRIMARY KEY AUTOINCREMENT,",
" level INTEGER NOT NULL,",
" name TEXT NOT NULL,",
" dict BLOB NOT NULL",
")"
),
[],
)?;
@@ -126,17 +117,12 @@ async fn initialize(
}
};
let rows = conn
type DictRow = (ZstdDictId, String, i32, Vec<u8>);
let rows: Vec<DictRow> = conn
.call(|conn| {
let mut stmt = conn.prepare("SELECT id, name, level, dict FROM dictionaries")?;
let mut rows = vec![];
for r in stmt.query_map([], |row| {
let id = row.get(0)?;
let name: String = row.get(1)?;
let level: i32 = row.get(2)?;
let dict: Vec<u8> = row.get(3)?;
Ok((id, name, level, dict))
})? {
for r in stmt.query_map([], |row| DictRow::try_from(row))? {
rows.push(r?);
}
Ok(rows)
@@ -144,8 +130,12 @@ async fn initialize(
.await?;
let mut compressor = Compressor::default();
for (id, name, level, dict_bytes) in rows {
compressor.add_from_bytes(id, name, level, dict_bytes)?;
for (dict_id, name, level, dict_bytes) in rows {
compressor.add(ZstdDict::new(
dict_id,
name,
ZstdEncoder::from_dict_bytes(level, dict_bytes),
))?;
}
let compressor = compressor.into_arc();
Ok(Manifest {
@@ -166,7 +156,7 @@ mod tests {
let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100];
let zstd_dict = manifest
.insert_zstd_dict_from_samples("test", samples)
.insert_dict_from_samples("test", samples)
.await
.unwrap();

View File

@@ -0,0 +1,16 @@
use crate::AsyncBoxError;
use super::Shard;
impl Shard {
pub async fn hint_names(&self) -> Result<Vec<String>, AsyncBoxError> {
self.conn
.call(|conn| {
let mut stmt = conn.prepare("SELECT DISTINCT name FROM compression_hints")?;
let rows = stmt.query_map([], |row| row.get(0))?;
Ok(rows.collect::<Result<Vec<_>, _>>()?)
})
.await
.map_err(Into::into)
}
}

View File

@@ -1,4 +1,4 @@
use crate::AsyncBoxError;
use crate::{concat_lines, AsyncBoxError};
use super::*;
@@ -37,17 +37,23 @@ impl Shard {
fn ensure_schema_versions_table(conn: &rusqlite::Connection) -> Result<usize, rusqlite::Error> {
conn.execute(
"CREATE TABLE IF NOT EXISTS schema_version (
version INTEGER PRIMARY KEY,
created_at TEXT NOT NULL
)",
concat_lines!(
"CREATE TABLE IF NOT EXISTS schema_version (",
" version INTEGER PRIMARY KEY,",
" created_at TEXT NOT NULL",
")"
),
[],
)
}
fn load_schema_rows(conn: &rusqlite::Connection) -> Result<Vec<(i64, String)>, rusqlite::Error> {
let mut stmt = conn
.prepare("SELECT version, created_at FROM schema_version ORDER BY version DESC LIMIT 1")?;
let mut stmt = conn.prepare(concat_lines!(
"SELECT version, created_at",
"FROM schema_version",
"ORDER BY version",
"DESC LIMIT 1"
))?;
let rows = stmt.query_map([], |row| {
let version = row.get(0)?;
let created_at = row.get(1)?;
@@ -59,35 +65,45 @@ fn load_schema_rows(conn: &rusqlite::Connection) -> Result<Vec<(i64, String)>, r
fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
debug!("migrating to version 1");
conn.execute(
"CREATE TABLE IF NOT EXISTS entries (
id INTEGER PRIMARY KEY AUTOINCREMENT,
sha256 BLOB NOT NULL,
content_type TEXT NOT NULL,
compression_id INTEGER NOT NULL,
uncompressed_size INTEGER NOT NULL,
compressed_size INTEGER NOT NULL,
data BLOB NOT NULL,
created_at TEXT NOT NULL
)",
concat_lines!(
"CREATE TABLE IF NOT EXISTS entries (",
" id INTEGER PRIMARY KEY AUTOINCREMENT,",
" sha256 BLOB NOT NULL,",
" content_type TEXT NOT NULL,",
" compression_id INTEGER NOT NULL,",
" uncompressed_size INTEGER NOT NULL,",
" compressed_size INTEGER NOT NULL,",
" data BLOB NOT NULL,",
" created_at TEXT NOT NULL",
")"
),
[],
)?;
conn.execute(
"CREATE INDEX IF NOT EXISTS entries_sha256_idx ON entries (sha256)",
concat_lines!(
"CREATE INDEX IF NOT EXISTS entries_sha256_idx",
"ON entries (sha256)"
),
[],
)?;
conn.execute(
"CREATE TABLE IF NOT EXISTS compression_hints (
name TEXT NOT NULL,
ordering INTEGER NOT NULL,
entry_id INTEGER NOT NULL
)",
concat_lines!(
"CREATE TABLE IF NOT EXISTS compression_hints (",
" name TEXT NOT NULL,",
" ordering INTEGER NOT NULL,",
" entry_id INTEGER NOT NULL",
")"
),
[],
)?;
conn.execute(
"CREATE INDEX IF NOT EXISTS compression_hints_name_idx ON compression_hints (name, ordering)",
concat_lines!(
"CREATE INDEX IF NOT EXISTS compression_hints_name_idx",
"ON compression_hints (name, ordering)",
),
[],
)?;

View File

@@ -10,12 +10,12 @@ pub struct SampleForHintResult {
}
impl Shard {
pub async fn samples_for_hint(
pub async fn samples_for_hint_name(
&self,
compression_hint: &str,
hint_name: &str,
limit: usize,
) -> Result<Vec<SampleForHintResult>, AsyncBoxError> {
let compression_hint = compression_hint.to_owned();
let hint_name = hint_name.to_owned();
let result = self
.conn
.call(move |conn| {
@@ -24,7 +24,7 @@ impl Shard {
SELECT entry_id FROM compression_hints WHERE name = ?1 ORDER BY ordering
) LIMIT ?2",
)?;
let rows = stmt.query_map(params![compression_hint, limit], |row| {
let rows = stmt.query_map(params![hint_name, limit], |row| {
let sha256: Sha256 = row.get(0)?;
let data: Vec<u8> = row.get(1)?;
Ok(SampleForHintResult { sha256, data })

View File

@@ -1,4 +1,5 @@
mod fn_get;
mod fn_hint_names;
mod fn_migrate;
mod fn_samples_for_hint;
mod fn_store;
@@ -6,6 +7,7 @@ pub mod shard_error;
mod shard_struct;
pub use fn_get::{GetArgs, GetResult};
pub use fn_samples_for_hint::SampleForHintResult;
pub use fn_store::{StoreArgs, StoreResult};
pub use shard_struct::Shard;

View File

@@ -246,7 +246,7 @@ pub mod test {
.unwrap();
assert!(matches!(store_result, StoreResult::Created { .. }));
let results = shard.samples_for_hint("hint1", 10).await.unwrap();
let results = shard.samples_for_hint_name("hint1", 10).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].sha256, sha256);
assert_eq!(results[0].data, data);
@@ -298,7 +298,7 @@ pub mod test {
assert_eq!(hint2.len(), insert_num);
let hint1samples = shard
.samples_for_hint("hint1", sample_num)
.samples_for_hint_name("hint1", sample_num)
.await
.unwrap()
.into_iter()
@@ -306,7 +306,7 @@ pub mod test {
.collect::<HashMap<_, _>>();
let hint2samples = shard
.samples_for_hint("hint2", sample_num)
.samples_for_hint_name("hint2", sample_num)
.await
.unwrap()
.into_iter()

View File

@@ -1,17 +1,23 @@
use std::sync::Arc;
use std::{collections::HashSet, sync::Arc};
use crate::{sha256::Sha256, shard::Shard};
use tokio::task::JoinSet;
use crate::{
sha256::Sha256,
shard::{SampleForHintResult, Shard},
AsyncBoxError,
};
pub type ShardsArc = Arc<Shards>;
#[derive(Clone)]
pub struct Shards(Vec<Shard>);
pub struct Shards(Vec<Arc<Shard>>);
impl Shards {
pub fn new(shards: Vec<Shard>) -> Option<Self> {
if shards.is_empty() {
return None;
}
Some(Self(shards))
Some(Self(shards.into_iter().map(Arc::new).collect()))
}
pub fn shard_for(&self, sha256: &Sha256) -> &Shard {
@@ -19,13 +25,43 @@ impl Shards {
&self.0[shard_id]
}
pub fn iter(&self) -> std::slice::Iter<'_, Shard> {
self.0.iter()
pub fn iter(&self) -> impl Iterator<Item = &Shard> {
self.0.iter().map(|shard| shard.as_ref())
}
pub fn len(&self) -> usize {
self.0.len()
}
pub async fn hint_names(&self) -> Result<HashSet<String>, AsyncBoxError> {
let mut hint_names = HashSet::new();
for shard in self.iter() {
hint_names.extend(shard.hint_names().await?);
}
Ok(hint_names)
}
pub async fn samples_for_hint_name(
&self,
hint_name: &str,
limit: usize,
) -> Result<Vec<SampleForHintResult>, AsyncBoxError> {
let mut tasks = JoinSet::new();
for shard in self.0.iter() {
let shard = shard.clone();
let hint_name = hint_name.to_owned();
tasks.spawn(async move { shard.samples_for_hint_name(&hint_name, limit).await });
}
let mut hints: Vec<SampleForHintResult> = Vec::new();
while let Some(result) = tasks.join_next().await {
let result = result??;
hints.extend(result);
}
Ok(hints)
}
}
#[cfg(test)]

View File

@@ -1,5 +1,5 @@
use ouroboros::self_referencing;
use std::{error::Error, io, sync::Arc};
use std::{io, sync::Arc};
use zstd::dict::{DecoderDictionary, EncoderDictionary};
use crate::{sql_types::ZstdDictId, AsyncBoxError};
@@ -7,9 +7,7 @@ use crate::{sql_types::ZstdDictId, AsyncBoxError};
pub type ZstdDictArc = Arc<ZstdDict>;
#[self_referencing]
pub struct ZstdDict {
id: crate::sql_types::ZstdDictId,
name: String,
pub struct ZstdEncoder {
level: i32,
dict_bytes: Vec<u8>,
@@ -22,6 +20,33 @@ pub struct ZstdDict {
decoder_dict: DecoderDictionary<'this>,
}
impl ZstdEncoder {
pub fn from_samples(level: i32, samples: Vec<&[u8]>) -> Self {
let dict_bytes = zstd::dict::from_samples(&samples, 1024 * 1024).unwrap();
Self::from_dict_bytes(level, dict_bytes)
}
pub fn from_dict_bytes(level: i32, dict_bytes: Vec<u8>) -> Self {
ZstdEncoderBuilder {
level,
dict_bytes,
encoder_dict_builder: |dict_bytes| EncoderDictionary::new(dict_bytes, level),
decoder_dict_builder: |dict_bytes| DecoderDictionary::new(dict_bytes),
}
.build()
}
pub fn dict_bytes(&self) -> &[u8] {
self.borrow_dict_bytes()
}
}
pub struct ZstdDict {
id: ZstdDictId,
name: String,
encoder: ZstdEncoder,
}
impl PartialEq for ZstdDict {
fn eq(&self, other: &Self) -> bool {
self.id() == other.id()
@@ -42,42 +67,21 @@ impl std::fmt::Debug for ZstdDict {
}
impl ZstdDict {
pub fn from_samples(
id: ZstdDictId,
name: String,
level: i32,
samples: Vec<&[u8]>,
) -> Result<Self, Box<dyn Error + Send + Sync>> {
let dict_bytes = zstd::dict::from_samples(
&samples,
1024 * 1024, // 1MB max dictionary size
)?;
Ok(Self::from_dict_bytes(id, name, level, dict_bytes))
}
pub fn from_dict_bytes(id: ZstdDictId, name: String, level: i32, dict_bytes: Vec<u8>) -> Self {
ZstdDictBuilder {
id,
name,
level,
dict_bytes,
encoder_dict_builder: |dict_bytes| EncoderDictionary::new(dict_bytes, level),
decoder_dict_builder: |dict_bytes| DecoderDictionary::new(dict_bytes),
}
.build()
pub fn new(id: ZstdDictId, name: String, encoder: ZstdEncoder) -> Self {
Self { id, name, encoder }
}
pub fn id(&self) -> ZstdDictId {
*self.borrow_id()
self.id
}
pub fn name(&self) -> &str {
self.borrow_name()
&self.name
}
pub fn level(&self) -> i32 {
*self.borrow_level()
*self.encoder.borrow_level()
}
pub fn dict_bytes(&self) -> &[u8] {
self.borrow_dict_bytes()
self.encoder.borrow_dict_bytes()
}
pub fn compress<DataRef: AsRef<[u8]>>(&self, data: DataRef) -> Result<Vec<u8>, AsyncBoxError> {
@@ -86,7 +90,7 @@ impl ZstdDict {
let mut out_buffer = Vec::with_capacity(as_ref.len());
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
self.with_encoder_dict(|encoder_dict| {
self.encoder.with_encoder_dict(|encoder_dict| {
let mut encoder =
zstd::stream::Encoder::with_prepared_dictionary(&mut output_wrapper, encoder_dict)?;
io::copy(&mut wrapper, &mut encoder)?;
@@ -104,7 +108,7 @@ impl ZstdDict {
let mut out_buffer = Vec::with_capacity(as_ref.len());
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
self.with_decoder_dict(|decoder_dict| {
self.encoder.with_decoder_dict(|decoder_dict| {
let mut decoder =
zstd::stream::Decoder::with_prepared_dictionary(&mut wrapper, decoder_dict)?;
io::copy(&mut decoder, &mut output_wrapper)
@@ -117,10 +121,13 @@ impl ZstdDict {
pub mod test {
use crate::sql_types::ZstdDictId;
use super::ZstdEncoder;
pub fn make_zstd_dict(id: ZstdDictId, name: &str) -> super::ZstdDict {
super::ZstdDict::from_dict_bytes(
super::ZstdDict::new(
id,
name.to_owned(),
ZstdEncoder::from_samples(
3,
vec![
"hello, world",
@@ -129,15 +136,14 @@ pub mod test {
]
.into_iter()
.chain(vec!["foo", "bar", "baz"].repeat(100))
.map(|s| s.as_bytes().to_owned())
.flat_map(|s| s.into_iter())
.map(|s| s.as_bytes())
.collect(),
),
)
}
#[test]
fn test_zstd_dict() {
let dict_bytes = vec![1, 2, 3, 4];
fn test_zstd_dict_basics() {
let zstd_dict = make_zstd_dict(1.into(), "dict1");
let compressed = zstd_dict.compress(b"hello world").unwrap();
let decompressed = zstd_dict.decompress(&compressed).unwrap();