diff --git a/src/axum_result_type.rs b/src/axum_result_type.rs deleted file mode 100644 index 7918a22..0000000 --- a/src/axum_result_type.rs +++ /dev/null @@ -1,3 +0,0 @@ -use axum::{http::StatusCode, Json}; - -pub type AxumJsonResultOf = Result<(StatusCode, Json), (StatusCode, Json)>; diff --git a/src/handlers/get_handler.rs b/src/handlers/get_handler.rs index d485b32..77bf0ec 100644 --- a/src/handlers/get_handler.rs +++ b/src/handlers/get_handler.rs @@ -6,7 +6,7 @@ use axum::{ Extension, Json, }; -use crate::shard::Shards; +use crate::shards::Shards; #[derive(Debug, serde::Serialize)] pub struct GetError { diff --git a/src/handlers/info_handler.rs b/src/handlers/info_handler.rs index f6f727d..aadeb9d 100644 --- a/src/handlers/info_handler.rs +++ b/src/handlers/info_handler.rs @@ -1,4 +1,4 @@ -use crate::shard::Shards; +use crate::shards::Shards; use axum::{http::StatusCode, Extension, Json}; use tracing::error; @@ -7,13 +7,14 @@ use tracing::error; pub struct InfoResponse { num_shards: usize, shards: Vec, + db_size_bytes: usize, } #[derive(serde::Serialize)] pub struct ShardInfo { id: usize, num_entries: usize, - size_bytes: u64, + db_size_bytes: usize, } #[axum::debug_handler] @@ -21,19 +22,22 @@ pub async fn info_handler( Extension(shards): Extension, ) -> Result<(StatusCode, Json), StatusCode> { let mut shard_infos = vec![]; + let mut total_db_size_bytes = 0; + for shard in shards.iter() { let num_entries = shard.num_entries().await.map_err(|e| { error!("error getting num entries: {}", e); StatusCode::INTERNAL_SERVER_ERROR })?; - let size_bytes = shard.size_bytes().await.map_err(|e| { + let db_size_bytes = shard.db_size_bytes().await.map_err(|e| { error!("error getting size bytes: {}", e); StatusCode::INTERNAL_SERVER_ERROR })?; + total_db_size_bytes += db_size_bytes; shard_infos.push(ShardInfo { id: shard.id(), num_entries, - size_bytes, + db_size_bytes, }); } Ok(( @@ -41,6 +45,7 @@ pub async fn info_handler( Json(InfoResponse { num_shards: shards.len(), shards: shard_infos, + db_size_bytes: total_db_size_bytes, }), )) } diff --git a/src/handlers/store_handler.rs b/src/handlers/store_handler.rs index 71d8873..688350e 100644 --- a/src/handlers/store_handler.rs +++ b/src/handlers/store_handler.rs @@ -1,47 +1,32 @@ use crate::{ - axum_result_type::AxumJsonResultOf, - request_response::store_request::{StoreRequest, StoreRequestWithSha256, StoreResponse}, + request_response::store_request::{StoreRequest, StoreResult}, sha256::Sha256, - shard::Shards, + shards::Shards, }; -use axum::{http::StatusCode, Extension, Json}; +use axum::Extension; use axum_typed_multipart::TypedMultipart; -use std::collections::HashMap; + use tracing::error; #[axum::debug_handler] pub async fn store_handler( Extension(shards): Extension, TypedMultipart(request): TypedMultipart, -) -> AxumJsonResultOf { +) -> StoreResult { let sha256 = Sha256::from_bytes(&request.data.contents); - let sha256_str = sha256.hex_string(); let shard = shards.shard_for(&sha256); if let Some(req_sha256) = request.sha256 { - if req_sha256 != sha256_str { + if sha256 != req_sha256 { + let sha256_str = sha256.hex_string(); error!( "client sent mismatched sha256: (client) {} != (computed) {}", req_sha256, sha256_str ); - let mut response = HashMap::new(); - response.insert("status", "error".to_owned()); - response.insert("message", "sha256 mismatch".to_owned()); - return Err(( - StatusCode::BAD_REQUEST, - Json(StoreResponse::Error { - sha256: Some(req_sha256), - message: "sha256 mismatch".to_owned(), - }), - )); + return StoreResult::Sha256Mismatch { sha256: sha256_str }; } } - - let request_parsed = StoreRequestWithSha256 { - sha256: sha256_str, - content_type: request.content_type, - data: request.data.contents, - }; - - shard.store(request_parsed).await + shard + .store(sha256, request.content_type, request.data.contents) + .await } diff --git a/src/main.rs b/src/main.rs index f7a7bfc..5c7e14d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,9 @@ -mod axum_result_type; mod handlers; mod request_response; mod sha256; mod shard; +mod shards; mod shutdown_signal; -use crate::shard::Shards; use axum::{ routing::{get, post}, Extension, Router, @@ -13,8 +12,11 @@ use clap::Parser; use shard::Shard; use std::{error::Error, path::PathBuf}; use tokio::net::TcpListener; +use tokio_rusqlite::Connection; use tracing::info; +use crate::shards::Shards; + #[derive(Parser, Debug)] #[command(version, about, long_about = None)] struct Args { @@ -64,8 +66,9 @@ fn main() -> Result<(), Box> { ); let mut shards_vec = vec![]; for shard_id in 0..num_shards { - let shard_path = db_path.join(format!("shard{}.sqlite", shard_id)); - let shard = Shard::open(shard_id, &shard_path).await?; + let shard_sqlite_path = db_path.join(format!("shard{}.sqlite", shard_id)); + let shard_sqlite_conn = Connection::open(&shard_sqlite_path).await?; + let shard = Shard::open(shard_id, shard_sqlite_conn).await?; info!( "shard {} has {} entries", shard.id(), diff --git a/src/request_response/store_request.rs b/src/request_response/store_request.rs index 559afb8..4208048 100644 --- a/src/request_response/store_request.rs +++ b/src/request_response/store_request.rs @@ -1,4 +1,4 @@ -use axum::body::Bytes; +use axum::{body::Bytes, http::StatusCode, response::IntoResponse, Json}; use axum_typed_multipart::{FieldData, TryFromMultipart}; use serde::Serialize; @@ -9,25 +9,33 @@ pub struct StoreRequest { pub data: FieldData, } -pub struct StoreRequestWithSha256 { - pub sha256: String, - pub content_type: String, - pub data: Bytes, -} - -// serializes to: -// {"status": "ok", "sha256": ..., "message": ...} -// {"status": "error", ["sha256": ...], "message": ...} -#[derive(Serialize)] +#[derive(Serialize, PartialEq, Debug)] #[serde(tag = "status", rename_all = "snake_case")] -pub enum StoreResponse { - Ok { - sha256: String, - message: &'static str, +pub enum StoreResult { + Created { + stored_size: usize, + data_size: usize, }, - - Error { - sha256: Option, + Exists { + stored_size: usize, + data_size: usize, + }, + Sha256Mismatch { + sha256: String, + }, + InternalError { message: String, }, } + +impl IntoResponse for StoreResult { + fn into_response(self) -> axum::response::Response { + let status_code = match &self { + StoreResult::Created { .. } => StatusCode::CREATED, + StoreResult::Exists { .. } => StatusCode::OK, + StoreResult::Sha256Mismatch { .. } => StatusCode::BAD_REQUEST, + StoreResult::InternalError { .. } => StatusCode::INTERNAL_SERVER_ERROR, + }; + (status_code, Json(self)).into_response() + } +} diff --git a/src/sha256.rs b/src/sha256.rs index 070b799..add8195 100644 --- a/src/sha256.rs +++ b/src/sha256.rs @@ -54,3 +54,13 @@ impl LowerHex for Sha256 { Ok(()) } } + +impl PartialEq for Sha256 { + fn eq(&self, other: &String) -> bool { + if let Ok(other_sha256) = Sha256::from_hex_string(other) { + self.0 == other_sha256.0 + } else { + false + } + } +} diff --git a/src/shard/fn_migrate.rs b/src/shard/fn_migrate.rs new file mode 100644 index 0000000..20db125 --- /dev/null +++ b/src/shard/fn_migrate.rs @@ -0,0 +1,78 @@ +use super::*; + +impl Shard { + pub(super) async fn migrate(&self) -> Result<(), Box> { + let shard_id = self.id(); + // create tables, indexes, etc + self.conn + .call(move |conn| { + ensure_schema_versions_table(conn)?; + let schema_rows = load_schema_rows(conn)?; + + if let Some((version, date_time)) = schema_rows.first() { + debug!( + "shard {}: latest schema version: {} @ {}", + shard_id, version, date_time + ); + + if *version == 1 { + // no-op + } else { + return Err(tokio_rusqlite::Error::Other(Box::new(ShardError::new( + format!("shard {}: unsupported schema version {}", shard_id, version), + )))); + } + } else { + debug!("shard {}: no schema version found, initializing", shard_id); + migrate_to_version_1(conn)?; + } + Ok(()) + }) + .await?; + Ok(()) + } +} + +fn ensure_schema_versions_table(conn: &rusqlite::Connection) -> Result { + conn.execute( + "CREATE TABLE IF NOT EXISTS schema_version ( + version INTEGER PRIMARY KEY, + created_at TEXT NOT NULL + )", + [], + ) +} + +fn load_schema_rows(conn: &rusqlite::Connection) -> Result, rusqlite::Error> { + let mut stmt = conn + .prepare("SELECT version, created_at FROM schema_version ORDER BY version DESC LIMIT 1")?; + let rows = stmt.query_map([], |row| { + let version = row.get(0)?; + let created_at = row.get(1)?; + Ok((version, created_at)) + })?; + rows.collect() +} + +fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> { + debug!("migrating to version 1"); + conn.execute( + "CREATE TABLE IF NOT EXISTS entries ( + sha256 BLOB PRIMARY KEY, + content_type TEXT NOT NULL, + compression INTEGER NOT NULL, + uncompressed_size INTEGER NOT NULL, + compressed_size INTEGER NOT NULL, + data BLOB NOT NULL, + created_at TEXT NOT NULL + )", + [], + )?; + + conn.execute( + "INSERT INTO schema_version (version, created_at) VALUES (1, ?)", + [chrono::Utc::now().to_rfc3339()], + )?; + + Ok(()) +} diff --git a/src/shard/mod.rs b/src/shard/mod.rs index 5ac6f90..2716e43 100644 --- a/src/shard/mod.rs +++ b/src/shard/mod.rs @@ -1,145 +1,124 @@ +mod fn_migrate; pub mod shard_error; use crate::{ - axum_result_type::AxumJsonResultOf, - request_response::store_request::{StoreRequestWithSha256, StoreResponse}, - sha256::Sha256, - shard::shard_error::ShardError, + request_response::store_request::StoreResult, sha256::Sha256, shard::shard_error::ShardError, }; -use axum::{http::StatusCode, Json}; +use axum::body::Bytes; -use rusqlite::{params, OptionalExtension}; -use std::{ - error::Error, - path::{Path, PathBuf}, -}; +use rusqlite::{params, types::FromSql, OptionalExtension}; +use std::error::Error; use tokio_rusqlite::Connection; -use tracing::{debug, error, info}; +use tracing::{debug, error}; -type UtcDateTime = chrono::DateTime; - -#[derive(Clone)] -pub struct Shards(Vec); -impl Shards { - pub fn new(shards: Vec) -> Self { - Self(shards) - } - - pub fn shard_for(&self, sha256: &Sha256) -> &Shard { - let shard_id = sha256.modulo(self.0.len()); - &self.0[shard_id] - } - - pub async fn close_all(self) -> Result<(), Box> { - for shard in self.0 { - shard.close().await?; - } - Ok(()) - } - - pub fn iter(&self) -> std::slice::Iter<'_, Shard> { - self.0.iter() - } - - pub fn len(&self) -> usize { - self.0.len() - } -} +pub type UtcDateTime = chrono::DateTime; #[derive(Clone)] pub struct Shard { id: usize, - sqlite: Connection, - file_path: PathBuf, + conn: Connection, } pub struct GetResult { pub content_type: String, + pub stored_size: usize, pub created_at: UtcDateTime, pub data: Vec, } impl Shard { - pub async fn open(id: usize, file_path: &Path) -> Result> { - let sqlite = Connection::open(file_path).await?; - let shard = Self { - id, - sqlite, - file_path: file_path.to_owned(), - }; + pub async fn open(id: usize, conn: Connection) -> Result> { + let shard = Self { id, conn }; shard.migrate().await?; Ok(shard) } pub async fn close(self) -> Result<(), Box> { - self.sqlite.close().await?; - Ok(()) + self.conn.close().await.map_err(|e| e.into()) } pub fn id(&self) -> usize { self.id } - pub async fn size_bytes(&self) -> Result> { - // stat the file to get its size in bytes - let metadata = tokio::fs::metadata(&self.file_path).await?; - Ok(metadata.len()) + pub async fn db_size_bytes(&self) -> Result> { + self.query_single_row( + "SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()", + ) + .await } - pub async fn store( + async fn query_single_row( &self, - store_request: StoreRequestWithSha256, - ) -> AxumJsonResultOf { - let sha256 = store_request.sha256.clone(); + query: &'static str, + ) -> Result> { + self.conn + .call(move |conn| { + let value: T = conn.query_row(query, [], |row| row.get(0))?; + Ok(value) + }) + .await + .map_err(|e| e.into()) + } - // let data = &store_request.data; + pub async fn store(&self, sha256: Sha256, content_type: String, data: Bytes) -> StoreResult { + let sha256 = sha256.hex_string(); + self.conn.call(move |conn| { + // check for existing entry + let maybe_existing: Option = conn + .query_row( + "SELECT uncompressed_size, compressed_size, created_at FROM entries WHERE sha256 = ?", + params![sha256], + |row| Ok(StoreResult::Exists { + stored_size: row.get(0)?, + data_size: row.get(1)?, + }) + ) + .optional()?; + + if let Some(existing) = maybe_existing { + return Ok(existing); + } - self.sqlite.call(move |conn| { let created_at = chrono::Utc::now().to_rfc3339(); - let maybe_error = conn.execute( + let data_size = data.len(); + + conn.execute( "INSERT INTO entries (sha256, content_type, compression, uncompressed_size, compressed_size, data, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)", params![ - store_request.sha256, - store_request.content_type, + sha256, + content_type, 0, - store_request.data.len() as i64, - store_request.data.len() as i64, - &store_request.data[..], + data_size, + data_size, + &data[..], created_at, ], - ); + )?; - if let Err(e) = &maybe_error { - if is_duplicate_entry_err(e) { - info!("entry {} already exists", store_request.sha256); - return Ok((StatusCode::OK, Json(StoreResponse::Ok{ - sha256: store_request.sha256, - message: "exists", - }))); - } - } - maybe_error?; - Ok((StatusCode::CREATED, Json(StoreResponse::Ok { sha256: store_request.sha256, message: "created" }))) + Ok(StoreResult::Created { stored_size: data_size, data_size }) }) - .await.map_err(|e| { + .await.unwrap_or_else(|e| { error!("store failed: {}", e); - (StatusCode::INTERNAL_SERVER_ERROR, Json(StoreResponse::Error { sha256: Some(sha256), message: e.to_string() })) + StoreResult::InternalError { message: e.to_string() } }) } pub async fn get(&self, sha256: Sha256) -> Result, Box> { - self.sqlite + self.conn .call(move |conn| { let get_result = conn .query_row( - "SELECT content_type, created_at, data FROM entries WHERE sha256 = ?", + "SELECT content_type, compressed_size, created_at, data FROM entries WHERE sha256 = ?", params![sha256.hex_string()], |row| { let content_type = row.get(0)?; - let created_at = parse_created_at_str(row.get(1)?)?; - let data = row.get(2)?; + let stored_size = row.get(1)?; + let created_at = parse_created_at_str(row.get(2)?)?; + let data = row.get(3)?; Ok(GetResult { content_type, + stored_size, created_at, data, }) @@ -156,38 +135,7 @@ impl Shard { } pub async fn num_entries(&self) -> Result> { - get_num_entries(&self.sqlite).await.map_err(|e| e.into()) - } - - async fn migrate(&self) -> Result<(), Box> { - let shard_id = self.id(); - // create tables, indexes, etc - self.sqlite - .call(move |conn| { - ensure_schema_versions_table(conn)?; - let schema_rows = load_schema_rows(conn)?; - - if let Some((version, date_time)) = schema_rows.first() { - debug!( - "shard {}: latest schema version: {} @ {}", - shard_id, version, date_time - ); - - if *version == 1 { - // no-op - } else { - return Err(tokio_rusqlite::Error::Other(Box::new(ShardError::new( - format!("shard {}: unsupported schema version {}", shard_id, version), - )))); - } - } else { - debug!("shard {}: no schema version found, initializing", shard_id); - migrate_to_version_1(conn)?; - } - Ok(()) - }) - .await?; - Ok(()) + get_num_entries(&self.conn).await.map_err(|e| e.into()) } } @@ -197,68 +145,6 @@ fn parse_created_at_str(created_at_str: String) -> Result bool { - use rusqlite::*; - - if let Error::SqliteFailure( - ffi::Error { - code: ffi::ErrorCode::ConstraintViolation, - .. - }, - Some(err_str), - ) = error - { - if err_str.contains("UNIQUE constraint failed: entries.sha256") { - return true; - } - } - false -} - -fn ensure_schema_versions_table(conn: &rusqlite::Connection) -> Result { - conn.execute( - "CREATE TABLE IF NOT EXISTS schema_version ( - version INTEGER PRIMARY KEY, - created_at TEXT NOT NULL - )", - [], - ) -} - -fn load_schema_rows(conn: &rusqlite::Connection) -> Result, rusqlite::Error> { - let mut stmt = conn - .prepare("SELECT version, created_at FROM schema_version ORDER BY version DESC LIMIT 1")?; - let rows = stmt.query_map([], |row| { - let version = row.get(0)?; - let created_at = row.get(1)?; - Ok((version, created_at)) - })?; - rows.collect() -} - -fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> { - debug!("migrating to version 1"); - conn.execute( - "CREATE TABLE IF NOT EXISTS entries ( - sha256 BLOB PRIMARY KEY, - content_type TEXT NOT NULL, - compression INTEGER NOT NULL, - uncompressed_size INTEGER NOT NULL, - compressed_size INTEGER NOT NULL, - data BLOB NOT NULL, - created_at TEXT NOT NULL - )", - [], - )?; - - conn.execute( - "INSERT INTO schema_version (version, created_at) VALUES (1, ?)", - [chrono::Utc::now().to_rfc3339()], - )?; - - Ok(()) -} - async fn get_num_entries(conn: &Connection) -> Result { conn.call(|conn| { let count: usize = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?; @@ -266,3 +152,88 @@ async fn get_num_entries(conn: &Connection) -> Result super::Shard { + let conn = tokio_rusqlite::Connection::open_in_memory().await.unwrap(); + super::Shard::open(0, conn).await.unwrap() + } + + #[tokio::test] + async fn test_num_entries() { + let shard = make_shard().await; + let num_entries = shard.num_entries().await.unwrap(); + assert_eq!(num_entries, 0); + } + + #[tokio::test] + async fn test_db_size_bytes() { + let shard = make_shard().await; + let db_size = shard.db_size_bytes().await.unwrap(); + assert!(db_size > 0); + } + + #[tokio::test] + async fn test_not_found_get() { + let shard = make_shard().await; + let sha256 = Sha256::from_bytes("hello, world!".as_bytes()); + let get_result = shard.get(sha256).await.unwrap(); + assert!(get_result.is_none()); + } + + #[tokio::test] + async fn test_store_and_get() { + let shard = make_shard().await; + let data = "hello, world!".as_bytes(); + let sha256 = Sha256::from_bytes(data); + let store_result = shard + .store(sha256, "text/plain".to_string(), data.into()) + .await; + assert_eq!( + store_result, + super::StoreResult::Created { + data_size: data.len(), + stored_size: data.len() + } + ); + assert_eq!(shard.num_entries().await.unwrap(), 1); + + let get_result = shard.get(sha256).await.unwrap().unwrap(); + assert_eq!(get_result.content_type, "text/plain"); + assert_eq!(get_result.data, data); + assert_eq!(get_result.stored_size, data.len()); + } + + #[tokio::test] + async fn test_store_duplicate() { + let shard = make_shard().await; + let data = "hello, world!".as_bytes(); + let sha256 = Sha256::from_bytes(data); + + let store_result = shard + .store(sha256, "text/plain".to_string(), data.into()) + .await; + assert_eq!( + store_result, + super::StoreResult::Created { + data_size: data.len(), + stored_size: data.len() + } + ); + + let store_result = shard + .store(sha256, "text/plain".to_string(), data.into()) + .await; + assert_eq!( + store_result, + super::StoreResult::Exists { + data_size: data.len(), + stored_size: data.len() + } + ); + assert_eq!(shard.num_entries().await.unwrap(), 1); + } +} diff --git a/src/shards.rs b/src/shards.rs new file mode 100644 index 0000000..8b54615 --- /dev/null +++ b/src/shards.rs @@ -0,0 +1,31 @@ +use std::error::Error; + +use crate::{sha256::Sha256, shard::Shard}; + +#[derive(Clone)] +pub struct Shards(Vec); +impl Shards { + pub fn new(shards: Vec) -> Self { + Self(shards) + } + + pub fn shard_for(&self, sha256: &Sha256) -> &Shard { + let shard_id = sha256.modulo(self.0.len()); + &self.0[shard_id] + } + + pub async fn close_all(self) -> Result<(), Box> { + for shard in self.0 { + shard.close().await?; + } + Ok(()) + } + + pub fn iter(&self) -> std::slice::Iter<'_, Shard> { + self.0.iter() + } + + pub fn len(&self) -> usize { + self.0.len() + } +}