diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs new file mode 100644 index 0000000..db8f00c --- /dev/null +++ b/src/handlers/mod.rs @@ -0,0 +1 @@ +pub mod store_handler; diff --git a/src/handlers/store_handler.rs b/src/handlers/store_handler.rs new file mode 100644 index 0000000..85f8688 --- /dev/null +++ b/src/handlers/store_handler.rs @@ -0,0 +1,48 @@ +use crate::{ + axum_result_type::AxumJsonResultOf, + request_response::store_request::{StoreRequest, StoreRequestWithSha256, StoreResponse}, + sha256::Sha256, + shard::Shards, +}; +use axum::{http::StatusCode, Extension, Json}; +use axum_typed_multipart::TypedMultipart; +use std::collections::HashMap; +use tracing::error; + +#[axum::debug_handler] +pub async fn store_handler( + Extension(shards): Extension, + TypedMultipart(request): TypedMultipart, +) -> AxumJsonResultOf { + // compute sha256 of data + let sha256 = Sha256::from_bytes(&request.data.contents); + let sha256_str = sha256.hex_string(); + let shard = shards.shard_for(&sha256); + + if let Some(req_sha256) = request.sha256 { + if req_sha256 != sha256_str { + error!( + "client sent mismatched sha256: (client) {} != (computed) {}", + req_sha256, sha256_str + ); + let mut response = HashMap::new(); + response.insert("status", "error".to_owned()); + response.insert("message", "sha256 mismatch".to_owned()); + return Err(( + StatusCode::BAD_REQUEST, + Json(StoreResponse::Error { + sha256: Some(req_sha256), + message: "sha256 mismatch".to_owned(), + }), + )); + } + } + + let request_parsed = StoreRequestWithSha256 { + sha256: sha256_str, + content_type: request.content_type, + data: request.data.contents, + }; + + shard.store(request_parsed).await +} diff --git a/src/main.rs b/src/main.rs index 61ec2e3..b1930b5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,26 +1,16 @@ mod axum_result_type; +mod handlers; +mod request_response; mod sha256; mod shard; mod shutdown_signal; -mod store_request; - -use axum::Json; -use axum::{http::StatusCode, routing::post, Extension, Router}; -use axum_result_type::AxumJsonResultOf; -use axum_typed_multipart::TypedMultipart; -use clap::Parser; - -use shard::Shard; -use std::collections::HashMap; -use std::{error::Error, path::PathBuf}; -use store_request::StoreResponse; -use tokio::net::TcpListener; - -use tracing::{error, info}; - -use crate::sha256::Sha256; use crate::shard::Shards; -use crate::store_request::{StoreRequest, StoreRequestParsed}; +use axum::{routing::post, Extension, Router}; +use clap::Parser; +use shard::Shard; +use std::{error::Error, path::PathBuf}; +use tokio::net::TcpListener; +use tracing::info; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] @@ -94,7 +84,7 @@ fn main() -> Result<(), Box> { async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box> { let app = Router::new() - .route("/store", post(store_request_handler)) + .route("/store", post(handlers::store_handler::store_handler)) .layer(Extension(shards)); axum::serve(server, app.into_make_service()) @@ -103,49 +93,12 @@ async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box, - TypedMultipart(request): TypedMultipart, -) -> AxumJsonResultOf { - // compute sha256 of data - let sha256 = Sha256::from_bytes(&request.data.contents); - let sha256_str = sha256.hex_string(); - let shard = shards.shard_for(&sha256); - - if let Some(req_sha256) = request.sha256 { - if req_sha256 != sha256_str { - error!( - "client sent mismatched sha256: (client) {} != (computed) {}", - req_sha256, sha256_str - ); - let mut response = HashMap::new(); - response.insert("status", "error".to_owned()); - response.insert("message", "sha256 mismatch".to_owned()); - return Err(( - StatusCode::BAD_REQUEST, - Json(StoreResponse::Error { - sha256: Some(req_sha256), - message: "sha256 mismatch".to_owned(), - }), - )); - } - } - - let request_parsed = StoreRequestParsed { - sha256: sha256_str, - content_type: request.content_type, - data: request.data.contents, - }; - - shard.store(request_parsed).await -} - fn validate_manifest(args: &Args) -> Result> { let manifest_path = PathBuf::from(&args.db_path).join("manifest.json"); if manifest_path.exists() { let file_content = std::fs::read_to_string(manifest_path)?; let manifest: ManifestData = serde_json::from_str(&file_content)?; + info!("loading existing database with {} shards", manifest.shards); if let Some(shards) = args.shards { if shards != manifest.shards { return Err(format!( @@ -157,6 +110,7 @@ fn validate_manifest(args: &Args) -> Result> { } Ok(manifest.shards) } else if let Some(shards) = args.shards { + info!("creating new database with {} shards", shards); std::fs::create_dir_all(&args.db_path)?; let manifest = ManifestData { shards }; let manifest_json = serde_json::to_string(&manifest)?; diff --git a/src/request_response/mod.rs b/src/request_response/mod.rs new file mode 100644 index 0000000..4065c30 --- /dev/null +++ b/src/request_response/mod.rs @@ -0,0 +1 @@ +pub mod store_request; diff --git a/src/store_request.rs b/src/request_response/store_request.rs similarity index 71% rename from src/store_request.rs rename to src/request_response/store_request.rs index 9ca0fe8..559afb8 100644 --- a/src/store_request.rs +++ b/src/request_response/store_request.rs @@ -9,18 +9,23 @@ pub struct StoreRequest { pub data: FieldData, } -pub struct StoreRequestParsed { +pub struct StoreRequestWithSha256 { pub sha256: String, pub content_type: String, pub data: Bytes, } +// serializes to: +// {"status": "ok", "sha256": ..., "message": ...} +// {"status": "error", ["sha256": ...], "message": ...} #[derive(Serialize)] +#[serde(tag = "status", rename_all = "snake_case")] pub enum StoreResponse { Ok { sha256: String, message: &'static str, }, + Error { sha256: Option, message: String, diff --git a/src/shard/mod.rs b/src/shard/mod.rs index 799ca97..005d1d0 100644 --- a/src/shard/mod.rs +++ b/src/shard/mod.rs @@ -1,16 +1,17 @@ -use std::{error::Error, path::Path}; - -use rusqlite::params; -use tokio_rusqlite::Connection; -use tracing::{error, info}; - use crate::{ axum_result_type::AxumJsonResultOf, + request_response::store_request::{StoreRequestWithSha256, StoreResponse}, sha256::Sha256, - store_request::{StoreRequestParsed, StoreResponse}, }; use axum::{http::StatusCode, Json}; +use rusqlite::{params, OptionalExtension}; +use std::{error::Error, path::Path}; +use tokio_rusqlite::Connection; +use tracing::{debug, error, info}; + +type UtcDateTime = chrono::DateTime; + #[derive(Clone)] pub struct Shards(Vec); impl Shards { @@ -40,8 +41,9 @@ pub struct Shard { impl Shard { pub async fn open(id: usize, db_path: &Path) -> Result> { let sqlite = Connection::open(db_path).await?; - migrate(&sqlite).await?; - Ok(Self { id, sqlite }) + let shard = Self { id, sqlite }; + shard.migrate().await?; + Ok(shard) } pub async fn close(self) -> Result<(), Box> { @@ -55,7 +57,7 @@ impl Shard { pub async fn store( &self, - store_request: StoreRequestParsed, + store_request: StoreRequestWithSha256, ) -> AxumJsonResultOf { let sha256 = store_request.sha256.clone(); @@ -91,7 +93,53 @@ impl Shard { } pub async fn num_entries(&self) -> Result> { - get_num_entries(&self.sqlite).await + get_num_entries(&self.sqlite).await.map_err(|e| e.into()) + } + + async fn migrate(&self) -> Result<(), Box> { + let shard_id = self.id(); + // create tables, indexes, etc + self.sqlite + .call(move |conn| { + conn.execute( + "CREATE TABLE IF NOT EXISTS schema_version ( + version INTEGER PRIMARY KEY, + created_at TEXT NOT NULL + )", + [], + )?; + + let schema_row: Option<(i64, UtcDateTime)> = conn.query_row( + "SELECT version, created_at FROM schema_version ORDER BY version DESC LIMIT 1", + [], + |row| { + let ver = row.get(0)?; + let created_at_str: String = row.get(1)?; + let created_at = chrono::DateTime::parse_from_rfc3339(&created_at_str).map_err(|e| { + rusqlite::Error::ToSqlConversionFailure(e.into()) + })?.to_utc(); + Ok((ver, created_at)) + } + ).optional()?; + + if let Some((version, date_time)) = schema_row { + debug!( + "shard {}: latest schema version: {} @ {}", + shard_id, + version, date_time + ); + + if version < 1 { + migrate_to_version_1(conn)?; + } + } else { + debug!("shard {}: no schema version found, initializing", shard_id); + migrate_to_version_1(conn)?; + } + Ok(()) + }) + .await?; + Ok(()) } } @@ -113,30 +161,30 @@ fn is_duplicate_entry_err(error: &rusqlite::Error) -> bool { false } -async fn migrate(conn: &Connection) -> Result<(), Box> { - // create tables, indexes, etc - conn.call(|conn| { - conn.execute( - "CREATE TABLE IF NOT EXISTS entries ( - sha256 BLOB PRIMARY KEY, - content_type TEXT NOT NULL, - size INTEGER NOT NULL, - data BLOB NOT NULL, - created_at TEXT NOT NULL - )", - [], - )?; - Ok(()) - }) - .await?; +fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> { + conn.execute( + "INSERT INTO schema_version (version, created_at) VALUES (1, ?)", + [chrono::Utc::now().to_rfc3339()], + )?; + + conn.execute( + "CREATE TABLE IF NOT EXISTS entries ( + sha256 BLOB PRIMARY KEY, + content_type TEXT NOT NULL, + size INTEGER NOT NULL, + data BLOB NOT NULL, + created_at TEXT NOT NULL + )", + [], + )?; + Ok(()) } -async fn get_num_entries(conn: &Connection) -> Result> { +async fn get_num_entries(conn: &Connection) -> Result { conn.call(|conn| { let count: usize = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?; Ok(count) }) .await - .map_err(|e| e.into()) }