From 83b4dacede36125b33ede43c7909ab880461ac23 Mon Sep 17 00:00:00 2001 From: Dylan Knutson Date: Sun, 5 May 2024 23:02:14 -0700 Subject: [PATCH] tests for hints --- src/sha256.rs | 2 +- src/shard/fn_get.rs | 12 +--- src/shard/fn_migrate.rs | 10 ++- src/shard/fn_samples_for_hint.rs | 6 +- src/shard/fn_store.rs | 42 ++++++------ src/shard/mod.rs | 6 +- src/shard/{shard.rs => shard_struct.rs} | 89 ++++++++++++++++++++++++- src/sql_types/utc_date_time.rs | 11 ++- 8 files changed, 136 insertions(+), 42 deletions(-) rename src/shard/{shard.rs => shard_struct.rs} (72%) diff --git a/src/sha256.rs b/src/sha256.rs index 9ae8330..261d667 100644 --- a/src/sha256.rs +++ b/src/sha256.rs @@ -21,7 +21,7 @@ impl Display for Sha256Error { } impl Error for Sha256Error {} -#[derive(Clone, Copy, PartialEq, Eq, Default)] +#[derive(Clone, Copy, PartialEq, Eq, Default, Hash)] pub struct Sha256([u8; 32]); impl Sha256 { pub fn from_hex_string(hex: &str) -> Result> { diff --git a/src/shard/fn_get.rs b/src/shard/fn_get.rs index 4c5dd43..1899b48 100644 --- a/src/shard/fn_get.rs +++ b/src/shard/fn_get.rs @@ -48,23 +48,17 @@ impl Shard { } } +type CompressedRowResult = (String, usize, UtcDateTime, CompressionId, Vec); fn get_compressed_row( conn: &mut rusqlite::Connection, sha256: &Sha256, -) -> Result)>, rusqlite::Error> { +) -> Result, rusqlite::Error> { conn.query_row( "SELECT content_type, compressed_size, created_at, compression_id, data FROM entries WHERE sha256 = ?", params![sha256], - |row| { - let content_type = row.get(0)?; - let stored_size = row.get(1)?; - let created_at = row.get(2)?; - let compression_id = row.get(3)?; - let data: Vec = row.get(4)?; - Ok((content_type, stored_size, created_at, compression_id, data)) - }, + |row| CompressedRowResult::try_from(row), ) .optional() } diff --git a/src/shard/fn_migrate.rs b/src/shard/fn_migrate.rs index 7bb2396..59d8c2d 100644 --- a/src/shard/fn_migrate.rs +++ b/src/shard/fn_migrate.rs @@ -60,7 +60,8 @@ fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Err debug!("migrating to version 1"); conn.execute( "CREATE TABLE IF NOT EXISTS entries ( - sha256 BLOB PRIMARY KEY, + id INTEGER PRIMARY KEY AUTOINCREMENT, + sha256 BLOB NOT NULL, content_type TEXT NOT NULL, compression_id INTEGER NOT NULL, uncompressed_size INTEGER NOT NULL, @@ -71,11 +72,16 @@ fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Err [], )?; + conn.execute( + "CREATE INDEX IF NOT EXISTS entries_sha256_idx ON entries (sha256)", + [], + )?; + conn.execute( "CREATE TABLE IF NOT EXISTS compression_hints ( name TEXT NOT NULL, ordering INTEGER NOT NULL, - sha256 BLOB NOT NULL + entry_id INTEGER NOT NULL )", [], )?; diff --git a/src/shard/fn_samples_for_hint.rs b/src/shard/fn_samples_for_hint.rs index 3f43a3b..2af02a5 100644 --- a/src/shard/fn_samples_for_hint.rs +++ b/src/shard/fn_samples_for_hint.rs @@ -20,9 +20,9 @@ impl Shard { .conn .call(move |conn| { let mut stmt = conn.prepare( - "SELECT sha256, data FROM entries WHERE sha256 IN ( - SELECT sha256 FROM compression_hints WHERE name = ? ORDER BY ordering - ) LIMIT ?", + "SELECT sha256, data FROM entries WHERE id IN ( + SELECT entry_id FROM compression_hints WHERE name = ?1 ORDER BY ordering + ) LIMIT ?2", )?; let rows = stmt.query_map(params![compression_hint, limit], |row| { let sha256: Sha256 = row.get(0)?; diff --git a/src/shard/fn_store.rs b/src/shard/fn_store.rs index 37481cf..1349c63 100644 --- a/src/shard/fn_store.rs +++ b/src/shard/fn_store.rs @@ -50,9 +50,10 @@ impl Shard { let uncompressed_size = data.len(); - let compressor = self.compressor.read().await; - let (compression_id, data) = - compressor.compress(compression_hint.as_deref(), &content_type, data)?; + let (compression_id, data) = { + let compressor = self.compressor.read().await; + compressor.compress(compression_hint.as_deref(), &content_type, data)? + }; self.conn .call(move |conn| { @@ -102,28 +103,31 @@ fn insert( let created_at = UtcDateTime::now(); let compressed_size = data.len(); - conn.execute("INSERT INTO entries - (sha256, content_type, compression_id, uncompressed_size, compressed_size, data, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?) - ", - params![ - sha256, - content_type, - compression_id, - uncompressed_size, - compressed_size, - data.as_ref(), - created_at, - ], - )?; + let entry_id: i64 = conn.query_row( + "INSERT INTO entries + (sha256, content_type, compression_id, uncompressed_size, compressed_size, data, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + RETURNING id + ", + params![ + sha256, + content_type, + compression_id, + uncompressed_size, + compressed_size, + data.as_ref(), + created_at, + ], + |row| row.get(0) + )?; if let Some(compression_hint) = compression_hint { let rand_ordering = rand::random::(); conn.execute( "INSERT INTO compression_hints - (name, ordering, sha256) + (name, ordering, entry_id) VALUES (?, ?, ?)", - params![compression_hint, rand_ordering, sha256], + params![compression_hint, rand_ordering, entry_id], )?; } diff --git a/src/shard/mod.rs b/src/shard/mod.rs index 4a63960..3645290 100644 --- a/src/shard/mod.rs +++ b/src/shard/mod.rs @@ -2,16 +2,16 @@ mod fn_get; mod fn_migrate; mod fn_samples_for_hint; mod fn_store; -mod shard; pub mod shard_error; +mod shard_struct; pub use fn_get::{GetArgs, GetResult}; pub use fn_store::{StoreArgs, StoreResult}; -pub use shard::Shard; +pub use shard_struct::Shard; #[cfg(test)] pub mod test { - pub use super::shard::test::*; + pub use super::shard_struct::test::*; } use crate::{sha256::Sha256, shard::shard_error::ShardError}; diff --git a/src/shard/shard.rs b/src/shard/shard_struct.rs similarity index 72% rename from src/shard/shard.rs rename to src/shard/shard_struct.rs index d71a29a..714c7ed 100644 --- a/src/shard/shard.rs +++ b/src/shard/shard_struct.rs @@ -63,8 +63,10 @@ async fn get_num_entries(conn: &Connection) -> Result super::Shard { let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap(); super::Shard::open(0, conn, compressor).await.unwrap() @@ -229,7 +233,6 @@ pub mod test { #[tokio::test] async fn test_compression_hint() { let shard = make_shard().await; - let compressor = make_compressor().into_arc(); let data = "hello, world!".as_bytes(); let sha256 = Sha256::from_bytes(data); let store_result = shard @@ -248,4 +251,86 @@ pub mod test { assert_eq!(results[0].sha256, sha256); assert_eq!(results[0].data, data); } + + async fn store_random(shard: &Shard, hint: &str) -> (Sha256, Vec) { + let data = (0..1024).map(|_| rand::random::()).collect::>(); + let sha256 = Sha256::from_bytes(&data); + let store_result = shard + .store(StoreArgs { + sha256, + content_type: "text/plain".to_string(), + data: data.clone().into(), + compression_hint: Some(hint.to_string()), + }) + .await + .unwrap(); + assert!(matches!(store_result, StoreResult::Created { .. })); + (sha256, data) + } + + #[tokio::test] + async fn test_compression_hint_limits() { + let get_keys_set = |hash_map: &HashMap>| { + hash_map + .keys() + .into_iter() + .map(|k| *k) + .collect::>() + }; + + let insert_num = 500; + let sample_num = 100; + let shard = make_shard().await; + let mut hint1 = HashMap::new(); + let mut hint2 = HashMap::new(); + + for _ in 0..insert_num { + let (a, b) = store_random(&shard, "hint1").await; + hint1.insert(a, b); + } + + for _ in 0..insert_num { + let (a, b) = store_random(&shard, "hint2").await; + hint2.insert(a, b); + } + + assert_eq!(hint1.len(), insert_num); + assert_eq!(hint2.len(), insert_num); + + let hint1samples = shard + .samples_for_hint("hint1", sample_num) + .await + .unwrap() + .into_iter() + .map(|r| (r.sha256, r.data)) + .collect::>(); + + let hint2samples = shard + .samples_for_hint("hint2", sample_num) + .await + .unwrap() + .into_iter() + .map(|r| (r.sha256, r.data)) + .collect::>(); + + let hint1_keys = get_keys_set(&hint1); + let hint2_keys = get_keys_set(&hint2); + let hint1samples_keys = get_keys_set(&hint1samples); + let hint2samples_keys = get_keys_set(&hint2samples); + + assert_eq!(hint1_keys.len(), insert_num); + assert_eq!(hint2_keys.len(), insert_num); + assert!(hint1_keys.is_disjoint(&hint2_keys)); + assert!( + hint1samples_keys.is_disjoint(&hint2samples_keys), + "hint1: {:?}, hint2: {:?}", + hint1samples_keys, + hint2samples_keys + ); + assert_eq!(hint1samples.len(), sample_num); + assert_eq!(hint2samples.len(), sample_num); + assert_eq!(hint2samples_keys.len(), sample_num); + assert!(hint1_keys.is_superset(&hint1samples_keys)); + assert!(hint2_keys.is_superset(&hint2samples_keys)); + } } diff --git a/src/sql_types/utc_date_time.rs b/src/sql_types/utc_date_time.rs index 650712c..f1f55b1 100644 --- a/src/sql_types/utc_date_time.rs +++ b/src/sql_types/utc_date_time.rs @@ -11,14 +11,19 @@ impl UtcDateTime { pub fn now() -> Self { Self(chrono::Utc::now()) } - pub fn to_string(&self) -> String { - self.0.to_rfc3339() - } + + #[cfg(test)] pub fn from_string(s: &str) -> Result { Ok(Self(DateTime::parse_from_rfc3339(s)?.to_utc())) } } +impl ToString for UtcDateTime { + fn to_string(&self) -> String { + self.0.to_rfc3339() + } +} + impl PartialEq> for UtcDateTime { fn eq(&self, other: &DateTime) -> bool { self.0 == *other