add dictionary compression, first pass
This commit is contained in:
101
Cargo.lock
generated
101
Cargo.lock
generated
@@ -29,6 +29,15 @@ dependencies = [
|
||||
"zerocopy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "allocator-api2"
|
||||
version = "0.2.18"
|
||||
@@ -269,6 +278,7 @@ dependencies = [
|
||||
"kdam",
|
||||
"rand",
|
||||
"reqwest",
|
||||
"rstest",
|
||||
"rusqlite",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -617,6 +627,12 @@ version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
|
||||
|
||||
[[package]]
|
||||
name = "futures-timer"
|
||||
version = "3.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
version = "0.3.30"
|
||||
@@ -662,6 +678,12 @@ version = "0.28.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.4.4"
|
||||
@@ -1292,6 +1314,41 @@ dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.10.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-automata",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.4.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56"
|
||||
|
||||
[[package]]
|
||||
name = "relative-path"
|
||||
version = "1.9.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e898588f33fdd5b9420719948f9f2a32c922a246964576f71ba7f24f80610fbc"
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.12.4"
|
||||
@@ -1336,6 +1393,35 @@ dependencies = [
|
||||
"winreg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rstest"
|
||||
version = "0.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d5316d2a1479eeef1ea21e7f9ddc67c191d497abc8fc3ba2467857abbb68330"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"futures-timer",
|
||||
"rstest_macros",
|
||||
"rustc_version",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rstest_macros"
|
||||
version = "0.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "04a9df72cc1f67020b0d63ad9bfe4a323e459ea7eb68e03bd9824db49f9a4c25"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"glob",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"relative-path",
|
||||
"rustc_version",
|
||||
"syn 2.0.60",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rusqlite"
|
||||
version = "0.31.0"
|
||||
@@ -1356,6 +1442,15 @@ version = "0.1.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
|
||||
|
||||
[[package]]
|
||||
name = "rustc_version"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366"
|
||||
dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.34"
|
||||
@@ -1435,6 +1530,12 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "1.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.198"
|
||||
|
||||
@@ -18,7 +18,7 @@ axum_typed_multipart = "0.11.1"
|
||||
chrono = "0.4.38"
|
||||
clap = { version = "4.5.4", features = ["derive"] }
|
||||
futures = "0.3.30"
|
||||
kdam = "0.5.1"
|
||||
kdam = "0.5.1" # for load-test
|
||||
rand = "0.8.5"
|
||||
rusqlite = { version = "0.31.0", features = ["vtab"] }
|
||||
serde = { version = "1.0.198", features = ["serde_derive"] }
|
||||
@@ -31,3 +31,6 @@ tracing-subscriber = "0.3.18"
|
||||
reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] }
|
||||
hex = "0.4.3"
|
||||
zstd = "0.13.1"
|
||||
|
||||
[dev-dependencies]
|
||||
rstest = "0.19.0"
|
||||
|
||||
54
src/main.rs
54
src/main.rs
@@ -1,15 +1,17 @@
|
||||
mod handlers;
|
||||
mod manifest;
|
||||
mod sha256;
|
||||
mod shard;
|
||||
mod shards;
|
||||
mod shutdown_signal;
|
||||
|
||||
use crate::shards::Shards;
|
||||
use crate::{manifest::Manifest, shards::Shards};
|
||||
use axum::{
|
||||
routing::{get, post},
|
||||
Extension, Router,
|
||||
};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use futures::executor::block_on;
|
||||
use shard::Shard;
|
||||
use std::{error::Error, path::PathBuf};
|
||||
use tokio::net::TcpListener;
|
||||
@@ -48,11 +50,6 @@ pub enum UseCompression {
|
||||
Zstd,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, serde::Serialize)]
|
||||
struct ManifestData {
|
||||
shards: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::DEBUG)
|
||||
@@ -60,11 +57,20 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
|
||||
let args = Args::parse();
|
||||
let db_path = PathBuf::from(&args.db_path);
|
||||
let num_shards = validate_manifest(&args)?;
|
||||
let num_shards = args.shards;
|
||||
|
||||
// block on opening the manifest
|
||||
let manifest = block_on(async {
|
||||
Manifest::open(
|
||||
Connection::open(db_path.join("manifest.sqlite")).await?,
|
||||
num_shards,
|
||||
)
|
||||
.await
|
||||
})?;
|
||||
|
||||
// max num_shards threads
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(num_shards as usize)
|
||||
.worker_threads(manifest.num_shards())
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
@@ -73,10 +79,10 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
info!(
|
||||
"listening on {} with {} shards",
|
||||
server.local_addr()?,
|
||||
num_shards
|
||||
manifest.num_shards()
|
||||
);
|
||||
let mut shards_vec = vec![];
|
||||
for shard_id in 0..num_shards {
|
||||
for shard_id in 0..manifest.num_shards() {
|
||||
let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id));
|
||||
let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?;
|
||||
let shard = Shard::open(shard_id, UseCompression::Auto, shard_sqlite_conn).await?;
|
||||
@@ -111,31 +117,3 @@ async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box<dyn
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_manifest(args: &Args) -> Result<usize, Box<dyn Error>> {
|
||||
let manifest_path = PathBuf::from(&args.db_path).join("manifest.json");
|
||||
if manifest_path.exists() {
|
||||
let file_content = std::fs::read_to_string(manifest_path)?;
|
||||
let manifest: ManifestData = serde_json::from_str(&file_content)?;
|
||||
info!("loading existing database with {} shards", manifest.shards);
|
||||
if let Some(shards) = args.shards {
|
||||
if shards != manifest.shards {
|
||||
return Err(format!(
|
||||
"manifest indicates {} shards, expected {}",
|
||||
manifest.shards, shards
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
Ok(manifest.shards)
|
||||
} else if let Some(shards) = args.shards {
|
||||
info!("creating new database with {} shards", shards);
|
||||
std::fs::create_dir_all(&args.db_path)?;
|
||||
let manifest = ManifestData { shards };
|
||||
let manifest_json = serde_json::to_string(&manifest)?;
|
||||
std::fs::write(manifest_path, manifest_json)?;
|
||||
Ok(shards)
|
||||
} else {
|
||||
Err("new database needs --shards argument".into())
|
||||
}
|
||||
}
|
||||
|
||||
242
src/manifest.rs
Normal file
242
src/manifest.rs
Normal file
@@ -0,0 +1,242 @@
|
||||
use std::{collections::HashMap, error::Error, io};
|
||||
|
||||
use rusqlite::{params, types::FromSql, OptionalExtension, ToSql};
|
||||
use tokio_rusqlite::Connection;
|
||||
|
||||
pub struct Manifest {
|
||||
conn: Connection,
|
||||
num_shards: usize,
|
||||
zstd_dicts: HashMap<String, (i64, ZstdDict)>,
|
||||
}
|
||||
|
||||
pub struct ZstdDict {
|
||||
level: i32,
|
||||
encoder_dict: zstd::dict::EncoderDictionary<'static>,
|
||||
decoder_dict: zstd::dict::DecoderDictionary<'static>,
|
||||
}
|
||||
impl ZstdDict {
|
||||
pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
let mut wrapper = io::Cursor::new(data);
|
||||
let mut out_buffer = Vec::with_capacity(data.len());
|
||||
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
|
||||
|
||||
let mut encoder = zstd::stream::Encoder::with_prepared_dictionary(
|
||||
&mut output_wrapper,
|
||||
&self.encoder_dict,
|
||||
)?;
|
||||
io::copy(&mut wrapper, &mut encoder)?;
|
||||
encoder.finish()?;
|
||||
Ok(out_buffer)
|
||||
}
|
||||
|
||||
pub fn decompress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
let mut wrapper = io::Cursor::new(data);
|
||||
let mut out_buffer = Vec::with_capacity(data.len());
|
||||
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
|
||||
|
||||
let mut decoder =
|
||||
zstd::stream::Decoder::with_prepared_dictionary(&mut wrapper, &self.decoder_dict)?;
|
||||
io::copy(&mut decoder, &mut output_wrapper)?;
|
||||
Ok(out_buffer)
|
||||
}
|
||||
}
|
||||
|
||||
impl ZstdDict {
|
||||
fn create(dict: &[u8], level: i32) -> Self {
|
||||
let encoder = zstd::dict::EncoderDictionary::copy(dict, level);
|
||||
let decoder = zstd::dict::DecoderDictionary::copy(dict);
|
||||
Self {
|
||||
level,
|
||||
encoder_dict: encoder,
|
||||
decoder_dict: decoder,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Manifest {
|
||||
pub async fn open(conn: Connection, num_shards: Option<usize>) -> Result<Self, Box<dyn Error>> {
|
||||
initialize(conn, num_shards).await
|
||||
}
|
||||
pub fn num_shards(&self) -> usize {
|
||||
self.num_shards
|
||||
}
|
||||
|
||||
pub async fn train_dictionary(
|
||||
&mut self,
|
||||
name: String,
|
||||
samples: Vec<&[u8]>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
if self.zstd_dicts.contains_key(&name) {
|
||||
return Err(format!("dictionary {} already exists", name).into());
|
||||
}
|
||||
|
||||
let level = 3;
|
||||
let dict = zstd::dict::from_samples(
|
||||
&samples,
|
||||
1024 * 1024, // 1MB max dictionary size
|
||||
)
|
||||
.unwrap();
|
||||
let zstd_dict = ZstdDict::create(&dict, level);
|
||||
|
||||
let name_copy = name.clone();
|
||||
let dict_id = self
|
||||
.conn
|
||||
.call(move |conn| {
|
||||
let mut stmt = conn.prepare(
|
||||
"INSERT INTO dictionaries (level, name, dict)
|
||||
VALUES (?, ?, ?)
|
||||
RETURNING id",
|
||||
)?;
|
||||
let dict_id = stmt.query_row(params![level, name_copy, dict], |row| row.get(0))?;
|
||||
Ok(dict_id)
|
||||
})
|
||||
.await?;
|
||||
|
||||
self.zstd_dicts.insert(name, (dict_id, zstd_dict));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
trait ManifestKey {
|
||||
type Value: ToSql + FromSql;
|
||||
fn sql_key(&self) -> &'static str;
|
||||
}
|
||||
struct NumShards;
|
||||
impl ManifestKey for NumShards {
|
||||
type Value = usize;
|
||||
fn sql_key(&self) -> &'static str {
|
||||
"num_shards"
|
||||
}
|
||||
}
|
||||
|
||||
fn get_manifest_key<Key: ManifestKey>(
|
||||
conn: &rusqlite::Connection,
|
||||
key: Key,
|
||||
) -> Result<Option<Key::Value>, rusqlite::Error> {
|
||||
conn.query_row(
|
||||
"SELECT value FROM manifest WHERE key = ?",
|
||||
params![key.sql_key()],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn set_manifest_key<Key: ManifestKey>(
|
||||
conn: &rusqlite::Connection,
|
||||
key: Key,
|
||||
value: Key::Value,
|
||||
) -> Result<(), rusqlite::Error> {
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO manifest (key, value) VALUES (?, ?)",
|
||||
params![key.sql_key(), value],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn initialize(
|
||||
conn: Connection,
|
||||
num_shards: Option<usize>,
|
||||
) -> Result<Manifest, Box<dyn Error>> {
|
||||
let stored_num_shards: Option<usize> = conn
|
||||
.call(|conn| {
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS manifest (key TEXT PRIMARY KEY, value)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS dictionaries (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
level INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
dict BLOB NOT NULL
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
Ok(get_manifest_key(conn, NumShards)?)
|
||||
})
|
||||
.await?;
|
||||
|
||||
let num_shards = match (stored_num_shards, num_shards) {
|
||||
(Some(stored_num_shards), Some(num_shards)) => {
|
||||
if stored_num_shards == num_shards {
|
||||
num_shards
|
||||
} else {
|
||||
return Err(format!(
|
||||
"manifest indicates {} shards, but expected {} shards",
|
||||
stored_num_shards, num_shards
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
(None, None) => return Err("Must supply --shards [number] to new database".into()),
|
||||
(None, Some(num_shards)) => {
|
||||
// new database with num_shards provided
|
||||
conn.call(move |conn| {
|
||||
// insert num_shards into manifest
|
||||
set_manifest_key(conn, NumShards, num_shards)?;
|
||||
Ok(num_shards)
|
||||
})
|
||||
.await?;
|
||||
num_shards
|
||||
}
|
||||
(Some(num_shards), None) => {
|
||||
// existing database, use loaded num_shards
|
||||
num_shards
|
||||
}
|
||||
};
|
||||
|
||||
let rows = conn
|
||||
.call(|conn| {
|
||||
let mut stmt = conn.prepare("SELECT id, name, level, dict FROM dictionaries")?;
|
||||
let mut rows = vec![];
|
||||
for r in stmt.query_map([], |row| {
|
||||
let id: i64 = row.get(0)?;
|
||||
let name: String = row.get(1)?;
|
||||
let level: i32 = row.get(2)?;
|
||||
let dict: Vec<u8> = row.get(3)?;
|
||||
Ok((id, name, level, dict))
|
||||
})? {
|
||||
rows.push(r?);
|
||||
}
|
||||
Ok(rows)
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut zstd_dicts = HashMap::new();
|
||||
for (id, name, level, dict) in rows {
|
||||
let zstd_dict = ZstdDict::create(&dict, level);
|
||||
zstd_dicts.insert(name, (id, zstd_dict));
|
||||
}
|
||||
|
||||
Ok(Manifest {
|
||||
conn,
|
||||
num_shards,
|
||||
zstd_dicts,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_manifest() {
|
||||
let conn = Connection::open_in_memory().await.unwrap();
|
||||
let mut manifest = initialize(conn, Some(3)).await.unwrap();
|
||||
|
||||
let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100];
|
||||
manifest
|
||||
.train_dictionary("test".to_string(), samples)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let zstd_dict = &manifest.zstd_dicts.get("test").unwrap().1;
|
||||
let data = b"hello world, this is a test of a sort of long string";
|
||||
let compressed = zstd_dict.compress(data).unwrap();
|
||||
let decompressed = zstd_dict.decompress(&compressed).unwrap();
|
||||
assert_eq!(decompressed, data);
|
||||
assert!(data.len() > compressed.len());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user