diff --git a/Cargo.lock b/Cargo.lock index b5d7018..b1571d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -29,6 +29,15 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + [[package]] name = "allocator-api2" version = "0.2.18" @@ -269,6 +278,7 @@ dependencies = [ "kdam", "rand", "reqwest", + "rstest", "rusqlite", "serde", "serde_json", @@ -617,6 +627,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.30" @@ -662,6 +678,12 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "h2" version = "0.4.4" @@ -1292,6 +1314,41 @@ dependencies = [ "bitflags 1.3.2", ] +[[package]] +name = "regex" +version = "1.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" + +[[package]] +name = "relative-path" +version = "1.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e898588f33fdd5b9420719948f9f2a32c922a246964576f71ba7f24f80610fbc" + [[package]] name = "reqwest" version = "0.12.4" @@ -1336,6 +1393,35 @@ dependencies = [ "winreg", ] +[[package]] +name = "rstest" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d5316d2a1479eeef1ea21e7f9ddc67c191d497abc8fc3ba2467857abbb68330" +dependencies = [ + "futures", + "futures-timer", + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04a9df72cc1f67020b0d63ad9bfe4a323e459ea7eb68e03bd9824db49f9a4c25" +dependencies = [ + "cfg-if", + "glob", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn 2.0.60", + "unicode-ident", +] + [[package]] name = "rusqlite" version = "0.31.0" @@ -1356,6 +1442,15 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "0.38.34" @@ -1435,6 +1530,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" + [[package]] name = "serde" version = "1.0.198" diff --git a/Cargo.toml b/Cargo.toml index 6938c79..8dc3f73 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ axum_typed_multipart = "0.11.1" chrono = "0.4.38" clap = { version = "4.5.4", features = ["derive"] } futures = "0.3.30" -kdam = "0.5.1" +kdam = "0.5.1" # for load-test rand = "0.8.5" rusqlite = { version = "0.31.0", features = ["vtab"] } serde = { version = "1.0.198", features = ["serde_derive"] } @@ -31,3 +31,6 @@ tracing-subscriber = "0.3.18" reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] } hex = "0.4.3" zstd = "0.13.1" + +[dev-dependencies] +rstest = "0.19.0" diff --git a/src/main.rs b/src/main.rs index b1a9871..ce7ae0a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,15 +1,17 @@ mod handlers; +mod manifest; mod sha256; mod shard; mod shards; mod shutdown_signal; -use crate::shards::Shards; +use crate::{manifest::Manifest, shards::Shards}; use axum::{ routing::{get, post}, Extension, Router, }; use clap::{Parser, ValueEnum}; +use futures::executor::block_on; use shard::Shard; use std::{error::Error, path::PathBuf}; use tokio::net::TcpListener; @@ -48,11 +50,6 @@ pub enum UseCompression { Zstd, } -#[derive(Debug, serde::Deserialize, serde::Serialize)] -struct ManifestData { - shards: usize, -} - fn main() -> Result<(), Box> { tracing_subscriber::fmt() .with_max_level(tracing::Level::DEBUG) @@ -60,11 +57,20 @@ fn main() -> Result<(), Box> { let args = Args::parse(); let db_path = PathBuf::from(&args.db_path); - let num_shards = validate_manifest(&args)?; + let num_shards = args.shards; + + // block on opening the manifest + let manifest = block_on(async { + Manifest::open( + Connection::open(db_path.join("manifest.sqlite")).await?, + num_shards, + ) + .await + })?; // max num_shards threads let runtime = tokio::runtime::Builder::new_multi_thread() - .worker_threads(num_shards as usize) + .worker_threads(manifest.num_shards()) .enable_all() .build()?; @@ -73,10 +79,10 @@ fn main() -> Result<(), Box> { info!( "listening on {} with {} shards", server.local_addr()?, - num_shards + manifest.num_shards() ); let mut shards_vec = vec![]; - for shard_id in 0..num_shards { + for shard_id in 0..manifest.num_shards() { let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id)); let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?; let shard = Shard::open(shard_id, UseCompression::Auto, shard_sqlite_conn).await?; @@ -111,31 +117,3 @@ async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box Result> { - let manifest_path = PathBuf::from(&args.db_path).join("manifest.json"); - if manifest_path.exists() { - let file_content = std::fs::read_to_string(manifest_path)?; - let manifest: ManifestData = serde_json::from_str(&file_content)?; - info!("loading existing database with {} shards", manifest.shards); - if let Some(shards) = args.shards { - if shards != manifest.shards { - return Err(format!( - "manifest indicates {} shards, expected {}", - manifest.shards, shards - ) - .into()); - } - } - Ok(manifest.shards) - } else if let Some(shards) = args.shards { - info!("creating new database with {} shards", shards); - std::fs::create_dir_all(&args.db_path)?; - let manifest = ManifestData { shards }; - let manifest_json = serde_json::to_string(&manifest)?; - std::fs::write(manifest_path, manifest_json)?; - Ok(shards) - } else { - Err("new database needs --shards argument".into()) - } -} diff --git a/src/manifest.rs b/src/manifest.rs new file mode 100644 index 0000000..4a1846e --- /dev/null +++ b/src/manifest.rs @@ -0,0 +1,242 @@ +use std::{collections::HashMap, error::Error, io}; + +use rusqlite::{params, types::FromSql, OptionalExtension, ToSql}; +use tokio_rusqlite::Connection; + +pub struct Manifest { + conn: Connection, + num_shards: usize, + zstd_dicts: HashMap, +} + +pub struct ZstdDict { + level: i32, + encoder_dict: zstd::dict::EncoderDictionary<'static>, + decoder_dict: zstd::dict::DecoderDictionary<'static>, +} +impl ZstdDict { + pub fn compress(&self, data: &[u8]) -> Result, Box> { + let mut wrapper = io::Cursor::new(data); + let mut out_buffer = Vec::with_capacity(data.len()); + let mut output_wrapper = io::Cursor::new(&mut out_buffer); + + let mut encoder = zstd::stream::Encoder::with_prepared_dictionary( + &mut output_wrapper, + &self.encoder_dict, + )?; + io::copy(&mut wrapper, &mut encoder)?; + encoder.finish()?; + Ok(out_buffer) + } + + pub fn decompress(&self, data: &[u8]) -> Result, Box> { + let mut wrapper = io::Cursor::new(data); + let mut out_buffer = Vec::with_capacity(data.len()); + let mut output_wrapper = io::Cursor::new(&mut out_buffer); + + let mut decoder = + zstd::stream::Decoder::with_prepared_dictionary(&mut wrapper, &self.decoder_dict)?; + io::copy(&mut decoder, &mut output_wrapper)?; + Ok(out_buffer) + } +} + +impl ZstdDict { + fn create(dict: &[u8], level: i32) -> Self { + let encoder = zstd::dict::EncoderDictionary::copy(dict, level); + let decoder = zstd::dict::DecoderDictionary::copy(dict); + Self { + level, + encoder_dict: encoder, + decoder_dict: decoder, + } + } +} + +impl Manifest { + pub async fn open(conn: Connection, num_shards: Option) -> Result> { + initialize(conn, num_shards).await + } + pub fn num_shards(&self) -> usize { + self.num_shards + } + + pub async fn train_dictionary( + &mut self, + name: String, + samples: Vec<&[u8]>, + ) -> Result<(), Box> { + if self.zstd_dicts.contains_key(&name) { + return Err(format!("dictionary {} already exists", name).into()); + } + + let level = 3; + let dict = zstd::dict::from_samples( + &samples, + 1024 * 1024, // 1MB max dictionary size + ) + .unwrap(); + let zstd_dict = ZstdDict::create(&dict, level); + + let name_copy = name.clone(); + let dict_id = self + .conn + .call(move |conn| { + let mut stmt = conn.prepare( + "INSERT INTO dictionaries (level, name, dict) + VALUES (?, ?, ?) + RETURNING id", + )?; + let dict_id = stmt.query_row(params![level, name_copy, dict], |row| row.get(0))?; + Ok(dict_id) + }) + .await?; + + self.zstd_dicts.insert(name, (dict_id, zstd_dict)); + Ok(()) + } +} + +trait ManifestKey { + type Value: ToSql + FromSql; + fn sql_key(&self) -> &'static str; +} +struct NumShards; +impl ManifestKey for NumShards { + type Value = usize; + fn sql_key(&self) -> &'static str { + "num_shards" + } +} + +fn get_manifest_key( + conn: &rusqlite::Connection, + key: Key, +) -> Result, rusqlite::Error> { + conn.query_row( + "SELECT value FROM manifest WHERE key = ?", + params![key.sql_key()], + |row| row.get(0), + ) + .optional() +} + +fn set_manifest_key( + conn: &rusqlite::Connection, + key: Key, + value: Key::Value, +) -> Result<(), rusqlite::Error> { + conn.execute( + "INSERT OR REPLACE INTO manifest (key, value) VALUES (?, ?)", + params![key.sql_key(), value], + )?; + Ok(()) +} + +async fn initialize( + conn: Connection, + num_shards: Option, +) -> Result> { + let stored_num_shards: Option = conn + .call(|conn| { + conn.execute( + "CREATE TABLE IF NOT EXISTS manifest (key TEXT PRIMARY KEY, value)", + [], + )?; + + conn.execute( + "CREATE TABLE IF NOT EXISTS dictionaries ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + level INTEGER NOT NULL, + name TEXT NOT NULL, + dict BLOB NOT NULL + )", + [], + )?; + + Ok(get_manifest_key(conn, NumShards)?) + }) + .await?; + + let num_shards = match (stored_num_shards, num_shards) { + (Some(stored_num_shards), Some(num_shards)) => { + if stored_num_shards == num_shards { + num_shards + } else { + return Err(format!( + "manifest indicates {} shards, but expected {} shards", + stored_num_shards, num_shards + ) + .into()); + } + } + (None, None) => return Err("Must supply --shards [number] to new database".into()), + (None, Some(num_shards)) => { + // new database with num_shards provided + conn.call(move |conn| { + // insert num_shards into manifest + set_manifest_key(conn, NumShards, num_shards)?; + Ok(num_shards) + }) + .await?; + num_shards + } + (Some(num_shards), None) => { + // existing database, use loaded num_shards + num_shards + } + }; + + let rows = conn + .call(|conn| { + let mut stmt = conn.prepare("SELECT id, name, level, dict FROM dictionaries")?; + let mut rows = vec![]; + for r in stmt.query_map([], |row| { + let id: i64 = row.get(0)?; + let name: String = row.get(1)?; + let level: i32 = row.get(2)?; + let dict: Vec = row.get(3)?; + Ok((id, name, level, dict)) + })? { + rows.push(r?); + } + Ok(rows) + }) + .await?; + + let mut zstd_dicts = HashMap::new(); + for (id, name, level, dict) in rows { + let zstd_dict = ZstdDict::create(&dict, level); + zstd_dicts.insert(name, (id, zstd_dict)); + } + + Ok(Manifest { + conn, + num_shards, + zstd_dicts, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_manifest() { + let conn = Connection::open_in_memory().await.unwrap(); + let mut manifest = initialize(conn, Some(3)).await.unwrap(); + + let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100]; + manifest + .train_dictionary("test".to_string(), samples) + .await + .unwrap(); + + let zstd_dict = &manifest.zstd_dicts.get("test").unwrap().1; + let data = b"hello world, this is a test of a sort of long string"; + let compressed = zstd_dict.compress(data).unwrap(); + let decompressed = zstd_dict.decompress(&compressed).unwrap(); + assert_eq!(decompressed, data); + assert!(data.len() > compressed.len()); + } +}