refactors, add unit tests to Shard

This commit is contained in:
Dylan Knutson
2024-04-25 00:07:30 -07:00
parent 3f4615e42a
commit d93f9bd9df
10 changed files with 323 additions and 235 deletions

View File

@@ -1,3 +0,0 @@
use axum::{http::StatusCode, Json};
pub type AxumJsonResultOf<T> = Result<(StatusCode, Json<T>), (StatusCode, Json<T>)>;

View File

@@ -6,7 +6,7 @@ use axum::{
Extension, Json,
};
use crate::shard::Shards;
use crate::shards::Shards;
#[derive(Debug, serde::Serialize)]
pub struct GetError {

View File

@@ -1,4 +1,4 @@
use crate::shard::Shards;
use crate::shards::Shards;
use axum::{http::StatusCode, Extension, Json};
use tracing::error;
@@ -7,13 +7,14 @@ use tracing::error;
pub struct InfoResponse {
num_shards: usize,
shards: Vec<ShardInfo>,
db_size_bytes: usize,
}
#[derive(serde::Serialize)]
pub struct ShardInfo {
id: usize,
num_entries: usize,
size_bytes: u64,
db_size_bytes: usize,
}
#[axum::debug_handler]
@@ -21,19 +22,22 @@ pub async fn info_handler(
Extension(shards): Extension<Shards>,
) -> Result<(StatusCode, Json<InfoResponse>), StatusCode> {
let mut shard_infos = vec![];
let mut total_db_size_bytes = 0;
for shard in shards.iter() {
let num_entries = shard.num_entries().await.map_err(|e| {
error!("error getting num entries: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let size_bytes = shard.size_bytes().await.map_err(|e| {
let db_size_bytes = shard.db_size_bytes().await.map_err(|e| {
error!("error getting size bytes: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
total_db_size_bytes += db_size_bytes;
shard_infos.push(ShardInfo {
id: shard.id(),
num_entries,
size_bytes,
db_size_bytes,
});
}
Ok((
@@ -41,6 +45,7 @@ pub async fn info_handler(
Json(InfoResponse {
num_shards: shards.len(),
shards: shard_infos,
db_size_bytes: total_db_size_bytes,
}),
))
}

View File

@@ -1,47 +1,32 @@
use crate::{
axum_result_type::AxumJsonResultOf,
request_response::store_request::{StoreRequest, StoreRequestWithSha256, StoreResponse},
request_response::store_request::{StoreRequest, StoreResult},
sha256::Sha256,
shard::Shards,
shards::Shards,
};
use axum::{http::StatusCode, Extension, Json};
use axum::Extension;
use axum_typed_multipart::TypedMultipart;
use std::collections::HashMap;
use tracing::error;
#[axum::debug_handler]
pub async fn store_handler(
Extension(shards): Extension<Shards>,
TypedMultipart(request): TypedMultipart<StoreRequest>,
) -> AxumJsonResultOf<StoreResponse> {
) -> StoreResult {
let sha256 = Sha256::from_bytes(&request.data.contents);
let sha256_str = sha256.hex_string();
let shard = shards.shard_for(&sha256);
if let Some(req_sha256) = request.sha256 {
if req_sha256 != sha256_str {
if sha256 != req_sha256 {
let sha256_str = sha256.hex_string();
error!(
"client sent mismatched sha256: (client) {} != (computed) {}",
req_sha256, sha256_str
);
let mut response = HashMap::new();
response.insert("status", "error".to_owned());
response.insert("message", "sha256 mismatch".to_owned());
return Err((
StatusCode::BAD_REQUEST,
Json(StoreResponse::Error {
sha256: Some(req_sha256),
message: "sha256 mismatch".to_owned(),
}),
));
return StoreResult::Sha256Mismatch { sha256: sha256_str };
}
}
let request_parsed = StoreRequestWithSha256 {
sha256: sha256_str,
content_type: request.content_type,
data: request.data.contents,
};
shard.store(request_parsed).await
shard
.store(sha256, request.content_type, request.data.contents)
.await
}

View File

@@ -1,10 +1,9 @@
mod axum_result_type;
mod handlers;
mod request_response;
mod sha256;
mod shard;
mod shards;
mod shutdown_signal;
use crate::shard::Shards;
use axum::{
routing::{get, post},
Extension, Router,
@@ -13,8 +12,11 @@ use clap::Parser;
use shard::Shard;
use std::{error::Error, path::PathBuf};
use tokio::net::TcpListener;
use tokio_rusqlite::Connection;
use tracing::info;
use crate::shards::Shards;
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
@@ -64,8 +66,9 @@ fn main() -> Result<(), Box<dyn Error>> {
);
let mut shards_vec = vec![];
for shard_id in 0..num_shards {
let shard_path = db_path.join(format!("shard{}.sqlite", shard_id));
let shard = Shard::open(shard_id, &shard_path).await?;
let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id));
let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?;
let shard = Shard::open(shard_id, shard_sqlite_conn).await?;
info!(
"shard {} has {} entries",
shard.id(),

View File

@@ -1,4 +1,4 @@
use axum::body::Bytes;
use axum::{body::Bytes, http::StatusCode, response::IntoResponse, Json};
use axum_typed_multipart::{FieldData, TryFromMultipart};
use serde::Serialize;
@@ -9,25 +9,33 @@ pub struct StoreRequest {
pub data: FieldData<Bytes>,
}
pub struct StoreRequestWithSha256 {
pub sha256: String,
pub content_type: String,
pub data: Bytes,
}
// serializes to:
// {"status": "ok", "sha256": ..., "message": ...}
// {"status": "error", ["sha256": ...], "message": ...}
#[derive(Serialize)]
#[derive(Serialize, PartialEq, Debug)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum StoreResponse {
Ok {
sha256: String,
message: &'static str,
pub enum StoreResult {
Created {
stored_size: usize,
data_size: usize,
},
Error {
sha256: Option<String>,
Exists {
stored_size: usize,
data_size: usize,
},
Sha256Mismatch {
sha256: String,
},
InternalError {
message: String,
},
}
impl IntoResponse for StoreResult {
fn into_response(self) -> axum::response::Response {
let status_code = match &self {
StoreResult::Created { .. } => StatusCode::CREATED,
StoreResult::Exists { .. } => StatusCode::OK,
StoreResult::Sha256Mismatch { .. } => StatusCode::BAD_REQUEST,
StoreResult::InternalError { .. } => StatusCode::INTERNAL_SERVER_ERROR,
};
(status_code, Json(self)).into_response()
}
}

View File

@@ -54,3 +54,13 @@ impl LowerHex for Sha256 {
Ok(())
}
}
impl PartialEq<String> for Sha256 {
fn eq(&self, other: &String) -> bool {
if let Ok(other_sha256) = Sha256::from_hex_string(other) {
self.0 == other_sha256.0
} else {
false
}
}
}

78
src/shard/fn_migrate.rs Normal file
View File

@@ -0,0 +1,78 @@
use super::*;
impl Shard {
pub(super) async fn migrate(&self) -> Result<(), Box<dyn Error>> {
let shard_id = self.id();
// create tables, indexes, etc
self.conn
.call(move |conn| {
ensure_schema_versions_table(conn)?;
let schema_rows = load_schema_rows(conn)?;
if let Some((version, date_time)) = schema_rows.first() {
debug!(
"shard {}: latest schema version: {} @ {}",
shard_id, version, date_time
);
if *version == 1 {
// no-op
} else {
return Err(tokio_rusqlite::Error::Other(Box::new(ShardError::new(
format!("shard {}: unsupported schema version {}", shard_id, version),
))));
}
} else {
debug!("shard {}: no schema version found, initializing", shard_id);
migrate_to_version_1(conn)?;
}
Ok(())
})
.await?;
Ok(())
}
}
fn ensure_schema_versions_table(conn: &rusqlite::Connection) -> Result<usize, rusqlite::Error> {
conn.execute(
"CREATE TABLE IF NOT EXISTS schema_version (
version INTEGER PRIMARY KEY,
created_at TEXT NOT NULL
)",
[],
)
}
fn load_schema_rows(conn: &rusqlite::Connection) -> Result<Vec<(i64, String)>, rusqlite::Error> {
let mut stmt = conn
.prepare("SELECT version, created_at FROM schema_version ORDER BY version DESC LIMIT 1")?;
let rows = stmt.query_map([], |row| {
let version = row.get(0)?;
let created_at = row.get(1)?;
Ok((version, created_at))
})?;
rows.collect()
}
fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
debug!("migrating to version 1");
conn.execute(
"CREATE TABLE IF NOT EXISTS entries (
sha256 BLOB PRIMARY KEY,
content_type TEXT NOT NULL,
compression INTEGER NOT NULL,
uncompressed_size INTEGER NOT NULL,
compressed_size INTEGER NOT NULL,
data BLOB NOT NULL,
created_at TEXT NOT NULL
)",
[],
)?;
conn.execute(
"INSERT INTO schema_version (version, created_at) VALUES (1, ?)",
[chrono::Utc::now().to_rfc3339()],
)?;
Ok(())
}

View File

@@ -1,145 +1,124 @@
mod fn_migrate;
pub mod shard_error;
use crate::{
axum_result_type::AxumJsonResultOf,
request_response::store_request::{StoreRequestWithSha256, StoreResponse},
sha256::Sha256,
shard::shard_error::ShardError,
request_response::store_request::StoreResult, sha256::Sha256, shard::shard_error::ShardError,
};
use axum::{http::StatusCode, Json};
use axum::body::Bytes;
use rusqlite::{params, OptionalExtension};
use std::{
error::Error,
path::{Path, PathBuf},
};
use rusqlite::{params, types::FromSql, OptionalExtension};
use std::error::Error;
use tokio_rusqlite::Connection;
use tracing::{debug, error, info};
use tracing::{debug, error};
type UtcDateTime = chrono::DateTime<chrono::Utc>;
#[derive(Clone)]
pub struct Shards(Vec<Shard>);
impl Shards {
pub fn new(shards: Vec<Shard>) -> Self {
Self(shards)
}
pub fn shard_for(&self, sha256: &Sha256) -> &Shard {
let shard_id = sha256.modulo(self.0.len());
&self.0[shard_id]
}
pub async fn close_all(self) -> Result<(), Box<dyn Error>> {
for shard in self.0 {
shard.close().await?;
}
Ok(())
}
pub fn iter(&self) -> std::slice::Iter<'_, Shard> {
self.0.iter()
}
pub fn len(&self) -> usize {
self.0.len()
}
}
pub type UtcDateTime = chrono::DateTime<chrono::Utc>;
#[derive(Clone)]
pub struct Shard {
id: usize,
sqlite: Connection,
file_path: PathBuf,
conn: Connection,
}
pub struct GetResult {
pub content_type: String,
pub stored_size: usize,
pub created_at: UtcDateTime,
pub data: Vec<u8>,
}
impl Shard {
pub async fn open(id: usize, file_path: &Path) -> Result<Self, Box<dyn Error>> {
let sqlite = Connection::open(file_path).await?;
let shard = Self {
id,
sqlite,
file_path: file_path.to_owned(),
};
pub async fn open(id: usize, conn: Connection) -> Result<Self, Box<dyn Error>> {
let shard = Self { id, conn };
shard.migrate().await?;
Ok(shard)
}
pub async fn close(self) -> Result<(), Box<dyn Error>> {
self.sqlite.close().await?;
Ok(())
self.conn.close().await.map_err(|e| e.into())
}
pub fn id(&self) -> usize {
self.id
}
pub async fn size_bytes(&self) -> Result<u64, Box<dyn Error>> {
// stat the file to get its size in bytes
let metadata = tokio::fs::metadata(&self.file_path).await?;
Ok(metadata.len())
pub async fn db_size_bytes(&self) -> Result<usize, Box<dyn Error>> {
self.query_single_row(
"SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()",
)
.await
}
pub async fn store(
async fn query_single_row<T: FromSql + Send + 'static>(
&self,
store_request: StoreRequestWithSha256,
) -> AxumJsonResultOf<StoreResponse> {
let sha256 = store_request.sha256.clone();
query: &'static str,
) -> Result<T, Box<dyn Error>> {
self.conn
.call(move |conn| {
let value: T = conn.query_row(query, [], |row| row.get(0))?;
Ok(value)
})
.await
.map_err(|e| e.into())
}
// let data = &store_request.data;
pub async fn store(&self, sha256: Sha256, content_type: String, data: Bytes) -> StoreResult {
let sha256 = sha256.hex_string();
self.conn.call(move |conn| {
// check for existing entry
let maybe_existing: Option<StoreResult> = conn
.query_row(
"SELECT uncompressed_size, compressed_size, created_at FROM entries WHERE sha256 = ?",
params![sha256],
|row| Ok(StoreResult::Exists {
stored_size: row.get(0)?,
data_size: row.get(1)?,
})
)
.optional()?;
if let Some(existing) = maybe_existing {
return Ok(existing);
}
self.sqlite.call(move |conn| {
let created_at = chrono::Utc::now().to_rfc3339();
let maybe_error = conn.execute(
let data_size = data.len();
conn.execute(
"INSERT INTO entries (sha256, content_type, compression, uncompressed_size, compressed_size, data, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
params![
store_request.sha256,
store_request.content_type,
sha256,
content_type,
0,
store_request.data.len() as i64,
store_request.data.len() as i64,
&store_request.data[..],
data_size,
data_size,
&data[..],
created_at,
],
);
)?;
if let Err(e) = &maybe_error {
if is_duplicate_entry_err(e) {
info!("entry {} already exists", store_request.sha256);
return Ok((StatusCode::OK, Json(StoreResponse::Ok{
sha256: store_request.sha256,
message: "exists",
})));
}
}
maybe_error?;
Ok((StatusCode::CREATED, Json(StoreResponse::Ok { sha256: store_request.sha256, message: "created" })))
Ok(StoreResult::Created { stored_size: data_size, data_size })
})
.await.map_err(|e| {
.await.unwrap_or_else(|e| {
error!("store failed: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, Json(StoreResponse::Error { sha256: Some(sha256), message: e.to_string() }))
StoreResult::InternalError { message: e.to_string() }
})
}
pub async fn get(&self, sha256: Sha256) -> Result<Option<GetResult>, Box<dyn Error>> {
self.sqlite
self.conn
.call(move |conn| {
let get_result = conn
.query_row(
"SELECT content_type, created_at, data FROM entries WHERE sha256 = ?",
"SELECT content_type, compressed_size, created_at, data FROM entries WHERE sha256 = ?",
params![sha256.hex_string()],
|row| {
let content_type = row.get(0)?;
let created_at = parse_created_at_str(row.get(1)?)?;
let data = row.get(2)?;
let stored_size = row.get(1)?;
let created_at = parse_created_at_str(row.get(2)?)?;
let data = row.get(3)?;
Ok(GetResult {
content_type,
stored_size,
created_at,
data,
})
@@ -156,38 +135,7 @@ impl Shard {
}
pub async fn num_entries(&self) -> Result<usize, Box<dyn Error>> {
get_num_entries(&self.sqlite).await.map_err(|e| e.into())
}
async fn migrate(&self) -> Result<(), Box<dyn Error>> {
let shard_id = self.id();
// create tables, indexes, etc
self.sqlite
.call(move |conn| {
ensure_schema_versions_table(conn)?;
let schema_rows = load_schema_rows(conn)?;
if let Some((version, date_time)) = schema_rows.first() {
debug!(
"shard {}: latest schema version: {} @ {}",
shard_id, version, date_time
);
if *version == 1 {
// no-op
} else {
return Err(tokio_rusqlite::Error::Other(Box::new(ShardError::new(
format!("shard {}: unsupported schema version {}", shard_id, version),
))));
}
} else {
debug!("shard {}: no schema version found, initializing", shard_id);
migrate_to_version_1(conn)?;
}
Ok(())
})
.await?;
Ok(())
get_num_entries(&self.conn).await.map_err(|e| e.into())
}
}
@@ -197,68 +145,6 @@ fn parse_created_at_str(created_at_str: String) -> Result<UtcDateTime, rusqlite:
Ok(parsed.to_utc())
}
fn is_duplicate_entry_err(error: &rusqlite::Error) -> bool {
use rusqlite::*;
if let Error::SqliteFailure(
ffi::Error {
code: ffi::ErrorCode::ConstraintViolation,
..
},
Some(err_str),
) = error
{
if err_str.contains("UNIQUE constraint failed: entries.sha256") {
return true;
}
}
false
}
fn ensure_schema_versions_table(conn: &rusqlite::Connection) -> Result<usize, rusqlite::Error> {
conn.execute(
"CREATE TABLE IF NOT EXISTS schema_version (
version INTEGER PRIMARY KEY,
created_at TEXT NOT NULL
)",
[],
)
}
fn load_schema_rows(conn: &rusqlite::Connection) -> Result<Vec<(i64, String)>, rusqlite::Error> {
let mut stmt = conn
.prepare("SELECT version, created_at FROM schema_version ORDER BY version DESC LIMIT 1")?;
let rows = stmt.query_map([], |row| {
let version = row.get(0)?;
let created_at = row.get(1)?;
Ok((version, created_at))
})?;
rows.collect()
}
fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
debug!("migrating to version 1");
conn.execute(
"CREATE TABLE IF NOT EXISTS entries (
sha256 BLOB PRIMARY KEY,
content_type TEXT NOT NULL,
compression INTEGER NOT NULL,
uncompressed_size INTEGER NOT NULL,
compressed_size INTEGER NOT NULL,
data BLOB NOT NULL,
created_at TEXT NOT NULL
)",
[],
)?;
conn.execute(
"INSERT INTO schema_version (version, created_at) VALUES (1, ?)",
[chrono::Utc::now().to_rfc3339()],
)?;
Ok(())
}
async fn get_num_entries(conn: &Connection) -> Result<usize, tokio_rusqlite::Error> {
conn.call(|conn| {
let count: usize = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?;
@@ -266,3 +152,88 @@ async fn get_num_entries(conn: &Connection) -> Result<usize, tokio_rusqlite::Err
})
.await
}
#[cfg(test)]
mod test {
use crate::sha256::Sha256;
async fn make_shard() -> super::Shard {
let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap();
super::Shard::open(0, conn).await.unwrap()
}
#[tokio::test]
async fn test_num_entries() {
let shard = make_shard().await;
let num_entries = shard.num_entries().await.unwrap();
assert_eq!(num_entries, 0);
}
#[tokio::test]
async fn test_db_size_bytes() {
let shard = make_shard().await;
let db_size = shard.db_size_bytes().await.unwrap();
assert!(db_size > 0);
}
#[tokio::test]
async fn test_not_found_get() {
let shard = make_shard().await;
let sha256 = Sha256::from_bytes("hello, world!".as_bytes());
let get_result = shard.get(sha256).await.unwrap();
assert!(get_result.is_none());
}
#[tokio::test]
async fn test_store_and_get() {
let shard = make_shard().await;
let data = "hello, world!".as_bytes();
let sha256 = Sha256::from_bytes(data);
let store_result = shard
.store(sha256, "text/plain".to_string(), data.into())
.await;
assert_eq!(
store_result,
super::StoreResult::Created {
data_size: data.len(),
stored_size: data.len()
}
);
assert_eq!(shard.num_entries().await.unwrap(), 1);
let get_result = shard.get(sha256).await.unwrap().unwrap();
assert_eq!(get_result.content_type, "text/plain");
assert_eq!(get_result.data, data);
assert_eq!(get_result.stored_size, data.len());
}
#[tokio::test]
async fn test_store_duplicate() {
let shard = make_shard().await;
let data = "hello, world!".as_bytes();
let sha256 = Sha256::from_bytes(data);
let store_result = shard
.store(sha256, "text/plain".to_string(), data.into())
.await;
assert_eq!(
store_result,
super::StoreResult::Created {
data_size: data.len(),
stored_size: data.len()
}
);
let store_result = shard
.store(sha256, "text/plain".to_string(), data.into())
.await;
assert_eq!(
store_result,
super::StoreResult::Exists {
data_size: data.len(),
stored_size: data.len()
}
);
assert_eq!(shard.num_entries().await.unwrap(), 1);
}
}

31
src/shards.rs Normal file
View File

@@ -0,0 +1,31 @@
use std::error::Error;
use crate::{sha256::Sha256, shard::Shard};
#[derive(Clone)]
pub struct Shards(Vec<Shard>);
impl Shards {
pub fn new(shards: Vec<Shard>) -> Self {
Self(shards)
}
pub fn shard_for(&self, sha256: &Sha256) -> &Shard {
let shard_id = sha256.modulo(self.0.len());
&self.0[shard_id]
}
pub async fn close_all(self) -> Result<(), Box<dyn Error>> {
for shard in self.0 {
shard.close().await?;
}
Ok(())
}
pub fn iter(&self) -> std::slice::Iter<'_, Shard> {
self.0.iter()
}
pub fn len(&self) -> usize {
self.0.len()
}
}