add dictionary compression, first pass
This commit is contained in:
101
Cargo.lock
generated
101
Cargo.lock
generated
@@ -29,6 +29,15 @@ dependencies = [
|
|||||||
"zerocopy",
|
"zerocopy",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "aho-corasick"
|
||||||
|
version = "1.1.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "allocator-api2"
|
name = "allocator-api2"
|
||||||
version = "0.2.18"
|
version = "0.2.18"
|
||||||
@@ -269,6 +278,7 @@ dependencies = [
|
|||||||
"kdam",
|
"kdam",
|
||||||
"rand",
|
"rand",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
|
"rstest",
|
||||||
"rusqlite",
|
"rusqlite",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
@@ -617,6 +627,12 @@ version = "0.3.30"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
|
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-timer"
|
||||||
|
version = "3.0.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-util"
|
name = "futures-util"
|
||||||
version = "0.3.30"
|
version = "0.3.30"
|
||||||
@@ -662,6 +678,12 @@ version = "0.28.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
|
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "glob"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "h2"
|
name = "h2"
|
||||||
version = "0.4.4"
|
version = "0.4.4"
|
||||||
@@ -1292,6 +1314,41 @@ dependencies = [
|
|||||||
"bitflags 1.3.2",
|
"bitflags 1.3.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex"
|
||||||
|
version = "1.10.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c"
|
||||||
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
|
"memchr",
|
||||||
|
"regex-automata",
|
||||||
|
"regex-syntax",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex-automata"
|
||||||
|
version = "0.4.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea"
|
||||||
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
|
"memchr",
|
||||||
|
"regex-syntax",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex-syntax"
|
||||||
|
version = "0.8.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "relative-path"
|
||||||
|
version = "1.9.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e898588f33fdd5b9420719948f9f2a32c922a246964576f71ba7f24f80610fbc"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "reqwest"
|
name = "reqwest"
|
||||||
version = "0.12.4"
|
version = "0.12.4"
|
||||||
@@ -1336,6 +1393,35 @@ dependencies = [
|
|||||||
"winreg",
|
"winreg",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rstest"
|
||||||
|
version = "0.19.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9d5316d2a1479eeef1ea21e7f9ddc67c191d497abc8fc3ba2467857abbb68330"
|
||||||
|
dependencies = [
|
||||||
|
"futures",
|
||||||
|
"futures-timer",
|
||||||
|
"rstest_macros",
|
||||||
|
"rustc_version",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rstest_macros"
|
||||||
|
version = "0.19.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "04a9df72cc1f67020b0d63ad9bfe4a323e459ea7eb68e03bd9824db49f9a4c25"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"glob",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"regex",
|
||||||
|
"relative-path",
|
||||||
|
"rustc_version",
|
||||||
|
"syn 2.0.60",
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rusqlite"
|
name = "rusqlite"
|
||||||
version = "0.31.0"
|
version = "0.31.0"
|
||||||
@@ -1356,6 +1442,15 @@ version = "0.1.23"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
|
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustc_version"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366"
|
||||||
|
dependencies = [
|
||||||
|
"semver",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustix"
|
name = "rustix"
|
||||||
version = "0.38.34"
|
version = "0.38.34"
|
||||||
@@ -1435,6 +1530,12 @@ dependencies = [
|
|||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "semver"
|
||||||
|
version = "1.0.22"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde"
|
name = "serde"
|
||||||
version = "1.0.198"
|
version = "1.0.198"
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ axum_typed_multipart = "0.11.1"
|
|||||||
chrono = "0.4.38"
|
chrono = "0.4.38"
|
||||||
clap = { version = "4.5.4", features = ["derive"] }
|
clap = { version = "4.5.4", features = ["derive"] }
|
||||||
futures = "0.3.30"
|
futures = "0.3.30"
|
||||||
kdam = "0.5.1"
|
kdam = "0.5.1" # for load-test
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
rusqlite = { version = "0.31.0", features = ["vtab"] }
|
rusqlite = { version = "0.31.0", features = ["vtab"] }
|
||||||
serde = { version = "1.0.198", features = ["serde_derive"] }
|
serde = { version = "1.0.198", features = ["serde_derive"] }
|
||||||
@@ -31,3 +31,6 @@ tracing-subscriber = "0.3.18"
|
|||||||
reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] }
|
reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] }
|
||||||
hex = "0.4.3"
|
hex = "0.4.3"
|
||||||
zstd = "0.13.1"
|
zstd = "0.13.1"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
rstest = "0.19.0"
|
||||||
|
|||||||
54
src/main.rs
54
src/main.rs
@@ -1,15 +1,17 @@
|
|||||||
mod handlers;
|
mod handlers;
|
||||||
|
mod manifest;
|
||||||
mod sha256;
|
mod sha256;
|
||||||
mod shard;
|
mod shard;
|
||||||
mod shards;
|
mod shards;
|
||||||
mod shutdown_signal;
|
mod shutdown_signal;
|
||||||
|
|
||||||
use crate::shards::Shards;
|
use crate::{manifest::Manifest, shards::Shards};
|
||||||
use axum::{
|
use axum::{
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
Extension, Router,
|
Extension, Router,
|
||||||
};
|
};
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
|
use futures::executor::block_on;
|
||||||
use shard::Shard;
|
use shard::Shard;
|
||||||
use std::{error::Error, path::PathBuf};
|
use std::{error::Error, path::PathBuf};
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
@@ -48,11 +50,6 @@ pub enum UseCompression {
|
|||||||
Zstd,
|
Zstd,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, serde::Deserialize, serde::Serialize)]
|
|
||||||
struct ManifestData {
|
|
||||||
shards: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<(), Box<dyn Error>> {
|
fn main() -> Result<(), Box<dyn Error>> {
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
.with_max_level(tracing::Level::DEBUG)
|
.with_max_level(tracing::Level::DEBUG)
|
||||||
@@ -60,11 +57,20 @@ fn main() -> Result<(), Box<dyn Error>> {
|
|||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let db_path = PathBuf::from(&args.db_path);
|
let db_path = PathBuf::from(&args.db_path);
|
||||||
let num_shards = validate_manifest(&args)?;
|
let num_shards = args.shards;
|
||||||
|
|
||||||
|
// block on opening the manifest
|
||||||
|
let manifest = block_on(async {
|
||||||
|
Manifest::open(
|
||||||
|
Connection::open(db_path.join("manifest.sqlite")).await?,
|
||||||
|
num_shards,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
})?;
|
||||||
|
|
||||||
// max num_shards threads
|
// max num_shards threads
|
||||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||||
.worker_threads(num_shards as usize)
|
.worker_threads(manifest.num_shards())
|
||||||
.enable_all()
|
.enable_all()
|
||||||
.build()?;
|
.build()?;
|
||||||
|
|
||||||
@@ -73,10 +79,10 @@ fn main() -> Result<(), Box<dyn Error>> {
|
|||||||
info!(
|
info!(
|
||||||
"listening on {} with {} shards",
|
"listening on {} with {} shards",
|
||||||
server.local_addr()?,
|
server.local_addr()?,
|
||||||
num_shards
|
manifest.num_shards()
|
||||||
);
|
);
|
||||||
let mut shards_vec = vec![];
|
let mut shards_vec = vec![];
|
||||||
for shard_id in 0..num_shards {
|
for shard_id in 0..manifest.num_shards() {
|
||||||
let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id));
|
let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id));
|
||||||
let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?;
|
let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?;
|
||||||
let shard = Shard::open(shard_id, UseCompression::Auto, shard_sqlite_conn).await?;
|
let shard = Shard::open(shard_id, UseCompression::Auto, shard_sqlite_conn).await?;
|
||||||
@@ -111,31 +117,3 @@ async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box<dyn
|
|||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn validate_manifest(args: &Args) -> Result<usize, Box<dyn Error>> {
|
|
||||||
let manifest_path = PathBuf::from(&args.db_path).join("manifest.json");
|
|
||||||
if manifest_path.exists() {
|
|
||||||
let file_content = std::fs::read_to_string(manifest_path)?;
|
|
||||||
let manifest: ManifestData = serde_json::from_str(&file_content)?;
|
|
||||||
info!("loading existing database with {} shards", manifest.shards);
|
|
||||||
if let Some(shards) = args.shards {
|
|
||||||
if shards != manifest.shards {
|
|
||||||
return Err(format!(
|
|
||||||
"manifest indicates {} shards, expected {}",
|
|
||||||
manifest.shards, shards
|
|
||||||
)
|
|
||||||
.into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(manifest.shards)
|
|
||||||
} else if let Some(shards) = args.shards {
|
|
||||||
info!("creating new database with {} shards", shards);
|
|
||||||
std::fs::create_dir_all(&args.db_path)?;
|
|
||||||
let manifest = ManifestData { shards };
|
|
||||||
let manifest_json = serde_json::to_string(&manifest)?;
|
|
||||||
std::fs::write(manifest_path, manifest_json)?;
|
|
||||||
Ok(shards)
|
|
||||||
} else {
|
|
||||||
Err("new database needs --shards argument".into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
242
src/manifest.rs
Normal file
242
src/manifest.rs
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
use std::{collections::HashMap, error::Error, io};
|
||||||
|
|
||||||
|
use rusqlite::{params, types::FromSql, OptionalExtension, ToSql};
|
||||||
|
use tokio_rusqlite::Connection;
|
||||||
|
|
||||||
|
pub struct Manifest {
|
||||||
|
conn: Connection,
|
||||||
|
num_shards: usize,
|
||||||
|
zstd_dicts: HashMap<String, (i64, ZstdDict)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ZstdDict {
|
||||||
|
level: i32,
|
||||||
|
encoder_dict: zstd::dict::EncoderDictionary<'static>,
|
||||||
|
decoder_dict: zstd::dict::DecoderDictionary<'static>,
|
||||||
|
}
|
||||||
|
impl ZstdDict {
|
||||||
|
pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||||
|
let mut wrapper = io::Cursor::new(data);
|
||||||
|
let mut out_buffer = Vec::with_capacity(data.len());
|
||||||
|
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
|
||||||
|
|
||||||
|
let mut encoder = zstd::stream::Encoder::with_prepared_dictionary(
|
||||||
|
&mut output_wrapper,
|
||||||
|
&self.encoder_dict,
|
||||||
|
)?;
|
||||||
|
io::copy(&mut wrapper, &mut encoder)?;
|
||||||
|
encoder.finish()?;
|
||||||
|
Ok(out_buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decompress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||||
|
let mut wrapper = io::Cursor::new(data);
|
||||||
|
let mut out_buffer = Vec::with_capacity(data.len());
|
||||||
|
let mut output_wrapper = io::Cursor::new(&mut out_buffer);
|
||||||
|
|
||||||
|
let mut decoder =
|
||||||
|
zstd::stream::Decoder::with_prepared_dictionary(&mut wrapper, &self.decoder_dict)?;
|
||||||
|
io::copy(&mut decoder, &mut output_wrapper)?;
|
||||||
|
Ok(out_buffer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ZstdDict {
|
||||||
|
fn create(dict: &[u8], level: i32) -> Self {
|
||||||
|
let encoder = zstd::dict::EncoderDictionary::copy(dict, level);
|
||||||
|
let decoder = zstd::dict::DecoderDictionary::copy(dict);
|
||||||
|
Self {
|
||||||
|
level,
|
||||||
|
encoder_dict: encoder,
|
||||||
|
decoder_dict: decoder,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Manifest {
|
||||||
|
pub async fn open(conn: Connection, num_shards: Option<usize>) -> Result<Self, Box<dyn Error>> {
|
||||||
|
initialize(conn, num_shards).await
|
||||||
|
}
|
||||||
|
pub fn num_shards(&self) -> usize {
|
||||||
|
self.num_shards
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn train_dictionary(
|
||||||
|
&mut self,
|
||||||
|
name: String,
|
||||||
|
samples: Vec<&[u8]>,
|
||||||
|
) -> Result<(), Box<dyn Error>> {
|
||||||
|
if self.zstd_dicts.contains_key(&name) {
|
||||||
|
return Err(format!("dictionary {} already exists", name).into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let level = 3;
|
||||||
|
let dict = zstd::dict::from_samples(
|
||||||
|
&samples,
|
||||||
|
1024 * 1024, // 1MB max dictionary size
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let zstd_dict = ZstdDict::create(&dict, level);
|
||||||
|
|
||||||
|
let name_copy = name.clone();
|
||||||
|
let dict_id = self
|
||||||
|
.conn
|
||||||
|
.call(move |conn| {
|
||||||
|
let mut stmt = conn.prepare(
|
||||||
|
"INSERT INTO dictionaries (level, name, dict)
|
||||||
|
VALUES (?, ?, ?)
|
||||||
|
RETURNING id",
|
||||||
|
)?;
|
||||||
|
let dict_id = stmt.query_row(params![level, name_copy, dict], |row| row.get(0))?;
|
||||||
|
Ok(dict_id)
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
self.zstd_dicts.insert(name, (dict_id, zstd_dict));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
trait ManifestKey {
|
||||||
|
type Value: ToSql + FromSql;
|
||||||
|
fn sql_key(&self) -> &'static str;
|
||||||
|
}
|
||||||
|
struct NumShards;
|
||||||
|
impl ManifestKey for NumShards {
|
||||||
|
type Value = usize;
|
||||||
|
fn sql_key(&self) -> &'static str {
|
||||||
|
"num_shards"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_manifest_key<Key: ManifestKey>(
|
||||||
|
conn: &rusqlite::Connection,
|
||||||
|
key: Key,
|
||||||
|
) -> Result<Option<Key::Value>, rusqlite::Error> {
|
||||||
|
conn.query_row(
|
||||||
|
"SELECT value FROM manifest WHERE key = ?",
|
||||||
|
params![key.sql_key()],
|
||||||
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.optional()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_manifest_key<Key: ManifestKey>(
|
||||||
|
conn: &rusqlite::Connection,
|
||||||
|
key: Key,
|
||||||
|
value: Key::Value,
|
||||||
|
) -> Result<(), rusqlite::Error> {
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR REPLACE INTO manifest (key, value) VALUES (?, ?)",
|
||||||
|
params![key.sql_key(), value],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn initialize(
|
||||||
|
conn: Connection,
|
||||||
|
num_shards: Option<usize>,
|
||||||
|
) -> Result<Manifest, Box<dyn Error>> {
|
||||||
|
let stored_num_shards: Option<usize> = conn
|
||||||
|
.call(|conn| {
|
||||||
|
conn.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS manifest (key TEXT PRIMARY KEY, value)",
|
||||||
|
[],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS dictionaries (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
level INTEGER NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
dict BLOB NOT NULL
|
||||||
|
)",
|
||||||
|
[],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(get_manifest_key(conn, NumShards)?)
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let num_shards = match (stored_num_shards, num_shards) {
|
||||||
|
(Some(stored_num_shards), Some(num_shards)) => {
|
||||||
|
if stored_num_shards == num_shards {
|
||||||
|
num_shards
|
||||||
|
} else {
|
||||||
|
return Err(format!(
|
||||||
|
"manifest indicates {} shards, but expected {} shards",
|
||||||
|
stored_num_shards, num_shards
|
||||||
|
)
|
||||||
|
.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(None, None) => return Err("Must supply --shards [number] to new database".into()),
|
||||||
|
(None, Some(num_shards)) => {
|
||||||
|
// new database with num_shards provided
|
||||||
|
conn.call(move |conn| {
|
||||||
|
// insert num_shards into manifest
|
||||||
|
set_manifest_key(conn, NumShards, num_shards)?;
|
||||||
|
Ok(num_shards)
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
num_shards
|
||||||
|
}
|
||||||
|
(Some(num_shards), None) => {
|
||||||
|
// existing database, use loaded num_shards
|
||||||
|
num_shards
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let rows = conn
|
||||||
|
.call(|conn| {
|
||||||
|
let mut stmt = conn.prepare("SELECT id, name, level, dict FROM dictionaries")?;
|
||||||
|
let mut rows = vec![];
|
||||||
|
for r in stmt.query_map([], |row| {
|
||||||
|
let id: i64 = row.get(0)?;
|
||||||
|
let name: String = row.get(1)?;
|
||||||
|
let level: i32 = row.get(2)?;
|
||||||
|
let dict: Vec<u8> = row.get(3)?;
|
||||||
|
Ok((id, name, level, dict))
|
||||||
|
})? {
|
||||||
|
rows.push(r?);
|
||||||
|
}
|
||||||
|
Ok(rows)
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut zstd_dicts = HashMap::new();
|
||||||
|
for (id, name, level, dict) in rows {
|
||||||
|
let zstd_dict = ZstdDict::create(&dict, level);
|
||||||
|
zstd_dicts.insert(name, (id, zstd_dict));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Manifest {
|
||||||
|
conn,
|
||||||
|
num_shards,
|
||||||
|
zstd_dicts,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_manifest() {
|
||||||
|
let conn = Connection::open_in_memory().await.unwrap();
|
||||||
|
let mut manifest = initialize(conn, Some(3)).await.unwrap();
|
||||||
|
|
||||||
|
let samples: Vec<&[u8]> = vec![b"hello world test of long string"; 100];
|
||||||
|
manifest
|
||||||
|
.train_dictionary("test".to_string(), samples)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let zstd_dict = &manifest.zstd_dicts.get("test").unwrap().1;
|
||||||
|
let data = b"hello world, this is a test of a sort of long string";
|
||||||
|
let compressed = zstd_dict.compress(data).unwrap();
|
||||||
|
let decompressed = zstd_dict.decompress(&compressed).unwrap();
|
||||||
|
assert_eq!(decompressed, data);
|
||||||
|
assert!(data.len() > compressed.len());
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user