schema in database

This commit is contained in:
Dylan Knutson
2024-04-23 12:08:33 -07:00
parent e9c36cb805
commit 2045fcb89b
6 changed files with 144 additions and 87 deletions

1
src/handlers/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub mod store_handler;

View File

@@ -0,0 +1,48 @@
use crate::{
axum_result_type::AxumJsonResultOf,
request_response::store_request::{StoreRequest, StoreRequestWithSha256, StoreResponse},
sha256::Sha256,
shard::Shards,
};
use axum::{http::StatusCode, Extension, Json};
use axum_typed_multipart::TypedMultipart;
use std::collections::HashMap;
use tracing::error;
#[axum::debug_handler]
pub async fn store_handler(
Extension(shards): Extension<Shards>,
TypedMultipart(request): TypedMultipart<StoreRequest>,
) -> AxumJsonResultOf<StoreResponse> {
// compute sha256 of data
let sha256 = Sha256::from_bytes(&request.data.contents);
let sha256_str = sha256.hex_string();
let shard = shards.shard_for(&sha256);
if let Some(req_sha256) = request.sha256 {
if req_sha256 != sha256_str {
error!(
"client sent mismatched sha256: (client) {} != (computed) {}",
req_sha256, sha256_str
);
let mut response = HashMap::new();
response.insert("status", "error".to_owned());
response.insert("message", "sha256 mismatch".to_owned());
return Err((
StatusCode::BAD_REQUEST,
Json(StoreResponse::Error {
sha256: Some(req_sha256),
message: "sha256 mismatch".to_owned(),
}),
));
}
}
let request_parsed = StoreRequestWithSha256 {
sha256: sha256_str,
content_type: request.content_type,
data: request.data.contents,
};
shard.store(request_parsed).await
}

View File

@@ -1,26 +1,16 @@
mod axum_result_type; mod axum_result_type;
mod handlers;
mod request_response;
mod sha256; mod sha256;
mod shard; mod shard;
mod shutdown_signal; mod shutdown_signal;
mod store_request;
use axum::Json;
use axum::{http::StatusCode, routing::post, Extension, Router};
use axum_result_type::AxumJsonResultOf;
use axum_typed_multipart::TypedMultipart;
use clap::Parser;
use shard::Shard;
use std::collections::HashMap;
use std::{error::Error, path::PathBuf};
use store_request::StoreResponse;
use tokio::net::TcpListener;
use tracing::{error, info};
use crate::sha256::Sha256;
use crate::shard::Shards; use crate::shard::Shards;
use crate::store_request::{StoreRequest, StoreRequestParsed}; use axum::{routing::post, Extension, Router};
use clap::Parser;
use shard::Shard;
use std::{error::Error, path::PathBuf};
use tokio::net::TcpListener;
use tracing::info;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(version, about, long_about = None)] #[command(version, about, long_about = None)]
@@ -94,7 +84,7 @@ fn main() -> Result<(), Box<dyn Error>> {
async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box<dyn Error>> { async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box<dyn Error>> {
let app = Router::new() let app = Router::new()
.route("/store", post(store_request_handler)) .route("/store", post(handlers::store_handler::store_handler))
.layer(Extension(shards)); .layer(Extension(shards));
axum::serve(server, app.into_make_service()) axum::serve(server, app.into_make_service())
@@ -103,49 +93,12 @@ async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box<dyn
Ok(()) Ok(())
} }
#[axum::debug_handler]
async fn store_request_handler(
Extension(shards): Extension<Shards>,
TypedMultipart(request): TypedMultipart<StoreRequest>,
) -> AxumJsonResultOf<StoreResponse> {
// compute sha256 of data
let sha256 = Sha256::from_bytes(&request.data.contents);
let sha256_str = sha256.hex_string();
let shard = shards.shard_for(&sha256);
if let Some(req_sha256) = request.sha256 {
if req_sha256 != sha256_str {
error!(
"client sent mismatched sha256: (client) {} != (computed) {}",
req_sha256, sha256_str
);
let mut response = HashMap::new();
response.insert("status", "error".to_owned());
response.insert("message", "sha256 mismatch".to_owned());
return Err((
StatusCode::BAD_REQUEST,
Json(StoreResponse::Error {
sha256: Some(req_sha256),
message: "sha256 mismatch".to_owned(),
}),
));
}
}
let request_parsed = StoreRequestParsed {
sha256: sha256_str,
content_type: request.content_type,
data: request.data.contents,
};
shard.store(request_parsed).await
}
fn validate_manifest(args: &Args) -> Result<usize, Box<dyn Error>> { fn validate_manifest(args: &Args) -> Result<usize, Box<dyn Error>> {
let manifest_path = PathBuf::from(&args.db_path).join("manifest.json"); let manifest_path = PathBuf::from(&args.db_path).join("manifest.json");
if manifest_path.exists() { if manifest_path.exists() {
let file_content = std::fs::read_to_string(manifest_path)?; let file_content = std::fs::read_to_string(manifest_path)?;
let manifest: ManifestData = serde_json::from_str(&file_content)?; let manifest: ManifestData = serde_json::from_str(&file_content)?;
info!("loading existing database with {} shards", manifest.shards);
if let Some(shards) = args.shards { if let Some(shards) = args.shards {
if shards != manifest.shards { if shards != manifest.shards {
return Err(format!( return Err(format!(
@@ -157,6 +110,7 @@ fn validate_manifest(args: &Args) -> Result<usize, Box<dyn Error>> {
} }
Ok(manifest.shards) Ok(manifest.shards)
} else if let Some(shards) = args.shards { } else if let Some(shards) = args.shards {
info!("creating new database with {} shards", shards);
std::fs::create_dir_all(&args.db_path)?; std::fs::create_dir_all(&args.db_path)?;
let manifest = ManifestData { shards }; let manifest = ManifestData { shards };
let manifest_json = serde_json::to_string(&manifest)?; let manifest_json = serde_json::to_string(&manifest)?;

View File

@@ -0,0 +1 @@
pub mod store_request;

View File

@@ -9,18 +9,23 @@ pub struct StoreRequest {
pub data: FieldData<Bytes>, pub data: FieldData<Bytes>,
} }
pub struct StoreRequestParsed { pub struct StoreRequestWithSha256 {
pub sha256: String, pub sha256: String,
pub content_type: String, pub content_type: String,
pub data: Bytes, pub data: Bytes,
} }
// serializes to:
// {"status": "ok", "sha256": ..., "message": ...}
// {"status": "error", ["sha256": ...], "message": ...}
#[derive(Serialize)] #[derive(Serialize)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum StoreResponse { pub enum StoreResponse {
Ok { Ok {
sha256: String, sha256: String,
message: &'static str, message: &'static str,
}, },
Error { Error {
sha256: Option<String>, sha256: Option<String>,
message: String, message: String,

View File

@@ -1,16 +1,17 @@
use std::{error::Error, path::Path};
use rusqlite::params;
use tokio_rusqlite::Connection;
use tracing::{error, info};
use crate::{ use crate::{
axum_result_type::AxumJsonResultOf, axum_result_type::AxumJsonResultOf,
request_response::store_request::{StoreRequestWithSha256, StoreResponse},
sha256::Sha256, sha256::Sha256,
store_request::{StoreRequestParsed, StoreResponse},
}; };
use axum::{http::StatusCode, Json}; use axum::{http::StatusCode, Json};
use rusqlite::{params, OptionalExtension};
use std::{error::Error, path::Path};
use tokio_rusqlite::Connection;
use tracing::{debug, error, info};
type UtcDateTime = chrono::DateTime<chrono::Utc>;
#[derive(Clone)] #[derive(Clone)]
pub struct Shards(Vec<Shard>); pub struct Shards(Vec<Shard>);
impl Shards { impl Shards {
@@ -40,8 +41,9 @@ pub struct Shard {
impl Shard { impl Shard {
pub async fn open(id: usize, db_path: &Path) -> Result<Self, Box<dyn Error>> { pub async fn open(id: usize, db_path: &Path) -> Result<Self, Box<dyn Error>> {
let sqlite = Connection::open(db_path).await?; let sqlite = Connection::open(db_path).await?;
migrate(&sqlite).await?; let shard = Self { id, sqlite };
Ok(Self { id, sqlite }) shard.migrate().await?;
Ok(shard)
} }
pub async fn close(self) -> Result<(), Box<dyn Error>> { pub async fn close(self) -> Result<(), Box<dyn Error>> {
@@ -55,7 +57,7 @@ impl Shard {
pub async fn store( pub async fn store(
&self, &self,
store_request: StoreRequestParsed, store_request: StoreRequestWithSha256,
) -> AxumJsonResultOf<StoreResponse> { ) -> AxumJsonResultOf<StoreResponse> {
let sha256 = store_request.sha256.clone(); let sha256 = store_request.sha256.clone();
@@ -91,7 +93,53 @@ impl Shard {
} }
pub async fn num_entries(&self) -> Result<usize, Box<dyn Error>> { pub async fn num_entries(&self) -> Result<usize, Box<dyn Error>> {
get_num_entries(&self.sqlite).await get_num_entries(&self.sqlite).await.map_err(|e| e.into())
}
async fn migrate(&self) -> Result<(), Box<dyn Error>> {
let shard_id = self.id();
// create tables, indexes, etc
self.sqlite
.call(move |conn| {
conn.execute(
"CREATE TABLE IF NOT EXISTS schema_version (
version INTEGER PRIMARY KEY,
created_at TEXT NOT NULL
)",
[],
)?;
let schema_row: Option<(i64, UtcDateTime)> = conn.query_row(
"SELECT version, created_at FROM schema_version ORDER BY version DESC LIMIT 1",
[],
|row| {
let ver = row.get(0)?;
let created_at_str: String = row.get(1)?;
let created_at = chrono::DateTime::parse_from_rfc3339(&created_at_str).map_err(|e| {
rusqlite::Error::ToSqlConversionFailure(e.into())
})?.to_utc();
Ok((ver, created_at))
}
).optional()?;
if let Some((version, date_time)) = schema_row {
debug!(
"shard {}: latest schema version: {} @ {}",
shard_id,
version, date_time
);
if version < 1 {
migrate_to_version_1(conn)?;
}
} else {
debug!("shard {}: no schema version found, initializing", shard_id);
migrate_to_version_1(conn)?;
}
Ok(())
})
.await?;
Ok(())
} }
} }
@@ -113,30 +161,30 @@ fn is_duplicate_entry_err(error: &rusqlite::Error) -> bool {
false false
} }
async fn migrate(conn: &Connection) -> Result<(), Box<dyn Error>> { fn migrate_to_version_1(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
// create tables, indexes, etc conn.execute(
conn.call(|conn| { "INSERT INTO schema_version (version, created_at) VALUES (1, ?)",
conn.execute( [chrono::Utc::now().to_rfc3339()],
"CREATE TABLE IF NOT EXISTS entries ( )?;
sha256 BLOB PRIMARY KEY,
content_type TEXT NOT NULL, conn.execute(
size INTEGER NOT NULL, "CREATE TABLE IF NOT EXISTS entries (
data BLOB NOT NULL, sha256 BLOB PRIMARY KEY,
created_at TEXT NOT NULL content_type TEXT NOT NULL,
)", size INTEGER NOT NULL,
[], data BLOB NOT NULL,
)?; created_at TEXT NOT NULL
Ok(()) )",
}) [],
.await?; )?;
Ok(()) Ok(())
} }
async fn get_num_entries(conn: &Connection) -> Result<usize, Box<dyn Error>> { async fn get_num_entries(conn: &Connection) -> Result<usize, tokio_rusqlite::Error> {
conn.call(|conn| { conn.call(|conn| {
let count: usize = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?; let count: usize = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?;
Ok(count) Ok(count)
}) })
.await .await
.map_err(|e| e.into())
} }