add dictionary compression, first pass

This commit is contained in:
Dylan Knutson
2024-04-26 13:28:56 -07:00
parent 34e46ed020
commit 20dcf84c91
4 changed files with 363 additions and 39 deletions

101
Cargo.lock generated
View File

@@ -29,6 +29,15 @@ dependencies = [
"zerocopy", "zerocopy",
] ]
[[package]]
name = "aho-corasick"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
dependencies = [
"memchr",
]
[[package]] [[package]]
name = "allocator-api2" name = "allocator-api2"
version = "0.2.18" version = "0.2.18"
@@ -269,6 +278,7 @@ dependencies = [
"kdam", "kdam",
"rand", "rand",
"reqwest", "reqwest",
"rstest",
"rusqlite", "rusqlite",
"serde", "serde",
"serde_json", "serde_json",
@@ -617,6 +627,12 @@ version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]] [[package]]
name = "futures-util" name = "futures-util"
version = "0.3.30" version = "0.3.30"
@@ -662,6 +678,12 @@ version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
[[package]]
name = "glob"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.4.4" version = "0.4.4"
@@ -1292,6 +1314,41 @@ dependencies = [
"bitflags 1.3.2", "bitflags 1.3.2",
] ]
[[package]]
name = "regex"
version = "1.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata",
"regex-syntax",
]
[[package]]
name = "regex-automata"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56"
[[package]]
name = "relative-path"
version = "1.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e898588f33fdd5b9420719948f9f2a32c922a246964576f71ba7f24f80610fbc"
[[package]] [[package]]
name = "reqwest" name = "reqwest"
version = "0.12.4" version = "0.12.4"
@@ -1336,6 +1393,35 @@ dependencies = [
"winreg", "winreg",
] ]
[[package]]
name = "rstest"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d5316d2a1479eeef1ea21e7f9ddc67c191d497abc8fc3ba2467857abbb68330"
dependencies = [
"futures",
"futures-timer",
"rstest_macros",
"rustc_version",
]
[[package]]
name = "rstest_macros"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04a9df72cc1f67020b0d63ad9bfe4a323e459ea7eb68e03bd9824db49f9a4c25"
dependencies = [
"cfg-if",
"glob",
"proc-macro2",
"quote",
"regex",
"relative-path",
"rustc_version",
"syn 2.0.60",
"unicode-ident",
]
[[package]] [[package]]
name = "rusqlite" name = "rusqlite"
version = "0.31.0" version = "0.31.0"
@@ -1356,6 +1442,15 @@ version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
[[package]]
name = "rustc_version"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366"
dependencies = [
"semver",
]
[[package]] [[package]]
name = "rustix" name = "rustix"
version = "0.38.34" version = "0.38.34"
@@ -1435,6 +1530,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "semver"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca"
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.198" version = "1.0.198"

View File

@@ -18,7 +18,7 @@ axum_typed_multipart = "0.11.1"
chrono = "0.4.38" chrono = "0.4.38"
clap = { version = "4.5.4", features = ["derive"] } clap = { version = "4.5.4", features = ["derive"] }
futures = "0.3.30" futures = "0.3.30"
kdam = "0.5.1" kdam = "0.5.1" # for load-test
rand = "0.8.5" rand = "0.8.5"
rusqlite = { version = "0.31.0", features = ["vtab"] } rusqlite = { version = "0.31.0", features = ["vtab"] }
serde = { version = "1.0.198", features = ["serde_derive"] } serde = { version = "1.0.198", features = ["serde_derive"] }
@@ -31,3 +31,6 @@ tracing-subscriber = "0.3.18"
reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] } reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] }
hex = "0.4.3" hex = "0.4.3"
zstd = "0.13.1" zstd = "0.13.1"
[dev-dependencies]
rstest = "0.19.0"

View File

@@ -1,15 +1,17 @@
mod handlers; mod handlers;
mod manifest;
mod sha256; mod sha256;
mod shard; mod shard;
mod shards; mod shards;
mod shutdown_signal; mod shutdown_signal;
use crate::shards::Shards; use crate::{manifest::Manifest, shards::Shards};
use axum::{ use axum::{
routing::{get, post}, routing::{get, post},
Extension, Router, Extension, Router,
}; };
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use futures::executor::block_on;
use shard::Shard; use shard::Shard;
use std::{error::Error, path::PathBuf}; use std::{error::Error, path::PathBuf};
use tokio::net::TcpListener; use tokio::net::TcpListener;
@@ -48,11 +50,6 @@ pub enum UseCompression {
Zstd, Zstd,
} }
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ManifestData {
shards: usize,
}
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG) .with_max_level(tracing::Level::DEBUG)
@@ -60,11 +57,20 @@ fn main() -> Result<(), Box<dyn Error>> {
let args = Args::parse(); let args = Args::parse();
let db_path = PathBuf::from(&args.db_path); let db_path = PathBuf::from(&args.db_path);
let num_shards = validate_manifest(&args)?; let num_shards = args.shards;
// block on opening the manifest
let manifest = block_on(async {
Manifest::open(
Connection::open(db_path.join("manifest.sqlite")).await?,
num_shards,
)
.await
})?;
// max num_shards threads // max num_shards threads
let runtime = tokio::runtime::Builder::new_multi_thread() let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(num_shards as usize) .worker_threads(manifest.num_shards())
.enable_all() .enable_all()
.build()?; .build()?;
@@ -73,10 +79,10 @@ fn main() -> Result<(), Box<dyn Error>> {
info!( info!(
"listening on {} with {} shards", "listening on {} with {} shards",
server.local_addr()?, server.local_addr()?,
num_shards manifest.num_shards()
); );
let mut shards_vec = vec![]; let mut shards_vec = vec![];
for shard_id in 0..num_shards { for shard_id in 0..manifest.num_shards() {
let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id)); let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id));
let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?; let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?;
let shard = Shard::open(shard_id, UseCompression::Auto, shard_sqlite_conn).await?; let shard = Shard::open(shard_id, UseCompression::Auto, shard_sqlite_conn).await?;
@@ -111,31 +117,3 @@ async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box<dyn
.await?; .await?;
Ok(()) Ok(())
} }
fn validate_manifest(args: &Args) -> Result<usize, Box<dyn Error>> {
let manifest_path = PathBuf::from(&args.db_path).join("manifest.json");
if manifest_path.exists() {
let file_content = std::fs::read_to_string(manifest_path)?;
let manifest: ManifestData = serde_json::from_str(&file_content)?;
info!("loading existing database with {} shards", manifest.shards);
if let Some(shards) = args.shards {
if shards != manifest.shards {
return Err(format!(
"manifest indicates {} shards, expected {}",
manifest.shards, shards
)
.into());
}
}
Ok(manifest.shards)
} else if let Some(shards) = args.shards {
info!("creating new database with {} shards", shards);
std::fs::create_dir_all(&args.db_path)?;
let manifest = ManifestData { shards };
let manifest_json = serde_json::to_string(&manifest)?;
std::fs::write(manifest_path, manifest_json)?;
Ok(shards)
} else {
Err("new database needs --shards argument".into())
}
}

242
src/manifest.rs Normal file
View File

@@ -0,0 +1,242 @@
use std::{collections::HashMap, error::Error, io};
use rusqlite::{params, types::FromSql, OptionalExtension, ToSql};
use tokio_rusqlite::Connection;
pub struct Manifest {
conn: Connection,
num_shards: usize,
zstd_dicts: HashMap<String, (i64, ZstdDict)>,
}
pub struct ZstdDict {
level: i32,
encoder_dict: zstd::dict::EncoderDictionary<'static>,
decoder_dict: zstd::dict::DecoderDictionary<'static>,
}
impl ZstdDict {
pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
let mut wrapper = io::Cursor::new(data);
let mut out_buffer = Vec::with_capacity(data.len());
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
let mut encoder = zstd::stream::Encoder::with_prepared_dictionary(
&mut output_wrapper,
&self.encoder_dict,
)?;
io::copy(&mut wrapper, &mut encoder)?;
encoder.finish()?;
Ok(out_buffer)
}
pub fn decompress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
let mut wrapper = io::Cursor::new(data);
let mut out_buffer = Vec::with_capacity(data.len());
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
let mut decoder =
zstd::stream::Decoder::with_prepared_dictionary(&mut wrapper, &self.decoder_dict)?;
io::copy(&mut decoder, &mut output_wrapper)?;
Ok(out_buffer)
}
}
impl ZstdDict {
fn create(dict: &[u8], level: i32) -> Self {
let encoder = zstd::dict::EncoderDictionary::copy(dict, level);
let decoder = zstd::dict::DecoderDictionary::copy(dict);
Self {
level,
encoder_dict: encoder,
decoder_dict: decoder,
}
}
}
impl Manifest {
pub async fn open(conn: Connection, num_shards: Option<usize>) -> Result<Self, Box<dyn Error>> {
initialize(conn, num_shards).await
}
pub fn num_shards(&self) -> usize {
self.num_shards
}
pub async fn train_dictionary(
&mut self,
name: String,
samples: Vec<&[u8]>,
) -> Result<(), Box<dyn Error>> {
if self.zstd_dicts.contains_key(&name) {
return Err(format!("dictionary {} already exists", name).into());
}
let level = 3;
let dict = zstd::dict::from_samples(
&samples,
1024 * 1024, // 1MB max dictionary size
)
.unwrap();
let zstd_dict = ZstdDict::create(&dict, level);
let name_copy = name.clone();
let dict_id = self
.conn
.call(move |conn| {
let mut stmt = conn.prepare(
"INSERT INTO dictionaries (level, name, dict)
VALUES (?, ?, ?)
RETURNING id",
)?;
let dict_id = stmt.query_row(params![level, name_copy, dict], |row| row.get(0))?;
Ok(dict_id)
})
.await?;
self.zstd_dicts.insert(name, (dict_id, zstd_dict));
Ok(())
}
}
trait ManifestKey {
type Value: ToSql + FromSql;
fn sql_key(&self) -> &'static str;
}
struct NumShards;
impl ManifestKey for NumShards {
type Value = usize;
fn sql_key(&self) -> &'static str {
"num_shards"
}
}
fn get_manifest_key<Key: ManifestKey>(
conn: &rusqlite::Connection,
key: Key,
) -> Result<Option<Key::Value>, rusqlite::Error> {
conn.query_row(
"SELECT value FROM manifest WHERE key = ?",
params![key.sql_key()],
|row| row.get(0),
)
.optional()
}
fn set_manifest_key<Key: ManifestKey>(
conn: &rusqlite::Connection,
key: Key,
value: Key::Value,
) -> Result<(), rusqlite::Error> {
conn.execute(
"INSERT OR REPLACE INTO manifest (key, value) VALUES (?, ?)",
params![key.sql_key(), value],
)?;
Ok(())
}
async fn initialize(
conn: Connection,
num_shards: Option<usize>,
) -> Result<Manifest, Box<dyn Error>> {
let stored_num_shards: Option<usize> = conn
.call(|conn| {
conn.execute(
"CREATE TABLE IF NOT EXISTS manifest (key TEXT PRIMARY KEY, value)",
[],
)?;
conn.execute(
"CREATE TABLE IF NOT EXISTS dictionaries (
id INTEGER PRIMARY KEY AUTOINCREMENT,
level INTEGER NOT NULL,
name TEXT NOT NULL,
dict BLOB NOT NULL
)",
[],
)?;
Ok(get_manifest_key(conn, NumShards)?)
})
.await?;
let num_shards = match (stored_num_shards, num_shards) {
(Some(stored_num_shards), Some(num_shards)) => {
if stored_num_shards == num_shards {
num_shards
} else {
return Err(format!(
"manifest indicates {} shards, but expected {} shards",
stored_num_shards, num_shards
)
.into());
}
}
(None, None) => return Err("Must supply --shards [number] to new database".into()),
(None, Some(num_shards)) => {
// new database with num_shards provided
conn.call(move |conn| {
// insert num_shards into manifest
set_manifest_key(conn, NumShards, num_shards)?;
Ok(num_shards)
})
.await?;
num_shards
}
(Some(num_shards), None) => {
// existing database, use loaded num_shards
num_shards
}
};
let rows = conn
.call(|conn| {
let mut stmt = conn.prepare("SELECT id, name, level, dict FROM dictionaries")?;
let mut rows = vec![];
for r in stmt.query_map([], |row| {
let id: i64 = row.get(0)?;
let name: String = row.get(1)?;
let level: i32 = row.get(2)?;
let dict: Vec<u8> = row.get(3)?;
Ok((id, name, level, dict))
})? {
rows.push(r?);
}
Ok(rows)
})
.await?;
let mut zstd_dicts = HashMap::new();
for (id, name, level, dict) in rows {
let zstd_dict = ZstdDict::create(&dict, level);
zstd_dicts.insert(name, (id, zstd_dict));
}
Ok(Manifest {
conn,
num_shards,
zstd_dicts,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_manifest() {
let conn = Connection::open_in_memory().await.unwrap();
let mut manifest = initialize(conn, Some(3)).await.unwrap();
let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100];
manifest
.train_dictionary("test".to_string(), samples)
.await
.unwrap();
let zstd_dict = &manifest.zstd_dicts.get("test").unwrap().1;
let data = b"hello world, this is a test of a sort of long string";
let compressed = zstd_dict.compress(data).unwrap();
let decompressed = zstd_dict.decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
assert!(data.len() > compressed.len());
}
}