more refactors, type wrappers

This commit is contained in:
Dylan Knutson
2024-04-23 11:13:41 -07:00
parent b5c367d26c
commit e9c36cb805
8 changed files with 261 additions and 152 deletions

4
Cargo.lock generated
View File

@@ -1359,9 +1359,9 @@ dependencies = [
[[package]]
name = "rustls-pki-types"
version = "1.4.1"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ecd36cc4259e3e4514335c4a138c6b43171a8d61d8f5c9348f9fc7529416f247"
checksum = "beb461507cee2c2ff151784c52762cf4d9ff6a61f3e80968600ed24fa837fa54"
[[package]]
name = "rustversion"

View File

@@ -10,7 +10,7 @@ path = "src/main.rs"
[[bin]]
name = "load-test"
path = "src/load_test.rs"
path = "load_test/main.rs"
[dependencies]
axum = { version = "0.7.5", features = ["macros"] }
@@ -20,7 +20,6 @@ clap = { version = "4.5.4", features = ["derive"] }
futures = "0.3.30"
kdam = "0.5.1"
rand = "0.8.5"
reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] }
rusqlite = "0.31.0"
serde = { version = "1.0.198", features = ["serde_derive"] }
serde_json = "1.0.116"
@@ -29,3 +28,4 @@ tokio = { version = "1.37.0", features = ["full", "rt-multi-thread"] }
tokio-rusqlite = "0.5.1"
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] }

View File

@@ -1,5 +1,3 @@
use std::sync::Arc;
use std::sync::Mutex;

3
src/axum_result_type.rs Normal file
View File

@@ -0,0 +1,3 @@
use axum::{http::StatusCode, Json};
pub type AxumJsonResultOf<T> = Result<(StatusCode, Json<T>), (StatusCode, Json<T>)>;

View File

@@ -1,20 +1,27 @@
mod axum_result_type;
mod sha256;
mod shard;
mod shutdown_signal;
mod store_request;
use axum::Json;
use axum::{body::Bytes, http::StatusCode, routing::post, Extension, Router};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use axum::{http::StatusCode, routing::post, Extension, Router};
use axum_result_type::AxumJsonResultOf;
use axum_typed_multipart::TypedMultipart;
use clap::Parser;
use rusqlite::ffi;
use rusqlite::params;
use rusqlite::Error::SqliteFailure;
use sha2::Digest;
use shard::Shard;
use std::collections::HashMap;
use std::{borrow::Borrow, error::Error, path::PathBuf, sync::Arc};
use std::{error::Error, path::PathBuf};
use store_request::StoreResponse;
use tokio::net::TcpListener;
use tokio_rusqlite::Connection;
use tracing::{error, info};
use crate::sha256::Sha256;
use crate::shard::Shards;
use crate::store_request::{StoreRequest, StoreRequestParsed};
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
@@ -22,6 +29,14 @@ struct Args {
#[arg(short, long)]
db_path: String,
/// Port to listen on
#[arg(short, long, default_value_t = 7692)]
port: u16,
/// Host to bind to
#[arg(short, long, default_value = "127.0.0.1")]
bind: String,
/// Number of shards
#[arg(short, long)]
shards: Option<usize>,
@@ -32,22 +47,6 @@ struct ManifestData {
shards: usize,
}
#[derive(TryFromMultipart)]
struct StoreRequest {
sha256: Option<String>,
content_type: String,
data: FieldData<Bytes>,
}
struct StoreRequestParsed {
sha256: String,
content_type: String,
data: Bytes,
}
#[derive(Clone)]
struct Shards(Vec<Arc<Connection>>);
fn main() -> Result<(), Box<dyn Error>> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
@@ -55,7 +54,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let args = Args::parse();
let db_path = PathBuf::from(&args.db_path);
let num_shards = validate_manifest(args)?;
let num_shards = validate_manifest(&args)?;
// max num_shards threads
let runtime = tokio::runtime::Builder::new_multi_thread()
@@ -64,32 +63,30 @@ fn main() -> Result<(), Box<dyn Error>> {
.build()?;
runtime.block_on(async {
let server = TcpListener::bind("127.0.0.1:7692").await?;
let server = TcpListener::bind(format!("{}:{}", args.bind, args.port)).await?;
info!(
"listening on {} with {} shards",
server.local_addr()?,
num_shards
);
let mut shards = vec![];
for shard in 0..num_shards {
let shard_path = db_path.join(format!("shard{}.sqlite", shard));
let conn = Connection::open(shard_path).await?;
migrate(&conn).await?;
shards.push(Arc::new(conn));
let mut shards_vec = vec![];
for shard_id in 0..num_shards {
let shard_path = db_path.join(format!("shard{}.sqlite", shard_id));
let shard = Shard::open(shard_id, &shard_path).await?;
info!(
"shard {} has {} entries",
shard.id(),
shard.num_entries().await?
);
shards_vec.push(shard);
}
for (shard, conn) in shards.iter().enumerate() {
let count = num_entries_in(conn).await?;
info!("shard {} has {} entries", shard, count);
}
server_loop(server, Shards(shards.clone())).await?;
let shards = Shards::new(shards_vec);
server_loop(server, shards.clone()).await?;
info!("shutting down server...");
for conn in shards.into_iter() {
(*conn).clone().close().await?;
}
shards.close_all().await?;
info!("server closed sqlite connections. bye!");
Ok::<(), Box<dyn Error>>(())
Ok::<_, Box<dyn Error>>(())
})?;
Ok(())
@@ -99,138 +96,52 @@ async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box<dyn
let app = Router::new()
.route("/store", post(store_request_handler))
.layer(Extension(shards));
axum::serve(server, app.into_make_service())
.with_graceful_shutdown(crate::shutdown_signal::shutdown_signal())
.await?;
Ok(())
}
type ResultType = Result<
(StatusCode, Json<HashMap<&'static str, String>>),
(StatusCode, Json<HashMap<&'static str, String>>),
>;
#[axum::debug_handler]
async fn store_request_handler(
Extension(shards): Extension<Shards>,
TypedMultipart(request): TypedMultipart<StoreRequest>,
) -> ResultType {
) -> AxumJsonResultOf<StoreResponse> {
// compute sha256 of data
let data_bytes = &request.data.contents;
let sha256 = sha2::Sha256::digest(data_bytes);
let sha256_str = format!("{:x}", sha256);
let num_shards = shards.0.len();
// select shard
let shard_num = sha256[0] as usize % num_shards;
let conn = &shards.0[shard_num];
let sha256 = Sha256::from_bytes(&request.data.contents);
let sha256_str = sha256.hex_string();
let shard = shards.shard_for(&sha256);
if let Some(req_sha256) = request.sha256 {
if req_sha256 != sha256_str {
error!("sha256 mismatch: {} != {}", req_sha256, sha256_str);
error!(
"client sent mismatched sha256: (client) {} != (computed) {}",
req_sha256, sha256_str
);
let mut response = HashMap::new();
response.insert("status", "error".to_owned());
response.insert("message", "sha256 mismatch".to_owned());
return Err((StatusCode::BAD_REQUEST, Json(response)));
return Err((
StatusCode::BAD_REQUEST,
Json(StoreResponse::Error {
sha256: Some(req_sha256),
message: "sha256 mismatch".to_owned(),
}),
));
}
}
// info!("storing {} on shard {}", sha256_str, shard_num);
let request_parsed = StoreRequestParsed {
sha256: sha256_str,
content_type: request.content_type,
data: request.data.contents,
};
let conn = conn.borrow();
perform_store(conn, request_parsed).await
shard.store(request_parsed).await
}
async fn perform_store(conn: &Connection, store_request: StoreRequestParsed) -> ResultType {
conn.call(move |conn| {
let created_at = chrono::Utc::now().to_rfc3339();
let maybe_error = conn.execute(
"INSERT INTO entries (sha256, content_type, size, data, created_at) VALUES (?, ?, ?, ?, ?)",
params![
store_request.sha256,
store_request.content_type,
store_request.data.len() as i64,
store_request.data.to_vec(),
created_at,
],
);
let mut response = HashMap::new();
response.insert("sha256", store_request.sha256.clone());
if let Err(e) = &maybe_error {
if is_duplicate_entry_err(e) {
info!("entry {} already exists", store_request.sha256);
response.insert("status","ok".to_owned());
response.insert("message", "already exists".to_owned());
return Ok((StatusCode::OK, Json(response)));
}
}
maybe_error?;
let mut response = HashMap::new();
response.insert("status", "ok".to_owned());
response.insert("message", "created".to_owned());
Ok((StatusCode::CREATED, Json(response)))
})
.await.map_err(|e| {
error!("store failed: {}", e);
let mut response = HashMap::new();
response.insert("status", "error".to_owned());
response.insert("message", e.to_string());
(StatusCode::INTERNAL_SERVER_ERROR, Json(response))
})
}
fn is_duplicate_entry_err(error: &rusqlite::Error) -> bool {
if let SqliteFailure(
ffi::Error {
code: ffi::ErrorCode::ConstraintViolation,
..
},
Some(err_str),
) = error
{
if err_str.contains("UNIQUE constraint failed: entries.sha256") {
return true;
}
}
false
}
async fn migrate(conn: &Connection) -> Result<(), Box<dyn Error>> {
// create tables, indexes, etc
conn.call(|conn| {
conn.execute(
"CREATE TABLE IF NOT EXISTS entries (
sha256 BLOB PRIMARY KEY,
content_type TEXT NOT NULL,
size INTEGER NOT NULL,
data BLOB NOT NULL,
created_at TEXT NOT NULL
)",
[],
)?;
Ok(())
})
.await?;
Ok(())
}
async fn num_entries_in(conn: &Connection) -> Result<i64, Box<dyn Error>> {
conn.call(|conn| {
let count: i64 = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?;
Ok(count)
})
.await
.map_err(|e| e.into())
}
fn validate_manifest(args: Args) -> Result<usize, Box<dyn Error>> {
fn validate_manifest(args: &Args) -> Result<usize, Box<dyn Error>> {
let manifest_path = PathBuf::from(&args.db_path).join("manifest.json");
if manifest_path.exists() {
let file_content = std::fs::read_to_string(manifest_path)?;

27
src/sha256.rs Normal file
View File

@@ -0,0 +1,27 @@
use std::fmt::LowerHex;
use sha2::Digest;
#[derive(Clone, Copy)]
pub struct Sha256([u8; 32]);
impl Sha256 {
pub fn from_bytes(bytes: &[u8]) -> Self {
let hash = sha2::Sha256::digest(bytes);
Self(hash.into())
}
pub fn hex_string(&self) -> String {
format!("{:x}", self)
}
pub fn modulo(&self, num: usize) -> usize {
self.0[0] as usize % num
}
}
impl LowerHex for Sha256 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for byte in self.0.iter() {
write!(f, "{:02x}", byte)?;
}
Ok(())
}
}

142
src/shard/mod.rs Normal file
View File

@@ -0,0 +1,142 @@
use std::{error::Error, path::Path};
use rusqlite::params;
use tokio_rusqlite::Connection;
use tracing::{error, info};
use crate::{
axum_result_type::AxumJsonResultOf,
sha256::Sha256,
store_request::{StoreRequestParsed, StoreResponse},
};
use axum::{http::StatusCode, Json};
#[derive(Clone)]
pub struct Shards(Vec<Shard>);
impl Shards {
pub fn new(shards: Vec<Shard>) -> Self {
Self(shards)
}
pub fn shard_for(&self, sha256: &Sha256) -> &Shard {
let shard_id = sha256.modulo(self.0.len());
&self.0[shard_id]
}
pub async fn close_all(self) -> Result<(), Box<dyn Error>> {
for shard in self.0 {
shard.close().await?;
}
Ok(())
}
}
#[derive(Clone)]
pub struct Shard {
id: usize,
sqlite: Connection,
}
impl Shard {
pub async fn open(id: usize, db_path: &Path) -> Result<Self, Box<dyn Error>> {
let sqlite = Connection::open(db_path).await?;
migrate(&sqlite).await?;
Ok(Self { id, sqlite })
}
pub async fn close(self) -> Result<(), Box<dyn Error>> {
self.sqlite.close().await?;
Ok(())
}
pub fn id(&self) -> usize {
self.id
}
pub async fn store(
&self,
store_request: StoreRequestParsed,
) -> AxumJsonResultOf<StoreResponse> {
let sha256 = store_request.sha256.clone();
self.sqlite.call(move |conn| {
let created_at = chrono::Utc::now().to_rfc3339();
let maybe_error = conn.execute(
"INSERT INTO entries (sha256, content_type, size, data, created_at) VALUES (?, ?, ?, ?, ?)",
params![
store_request.sha256,
store_request.content_type,
store_request.data.len() as i64,
store_request.data.to_vec(),
created_at,
],
);
if let Err(e) = &maybe_error {
if is_duplicate_entry_err(e) {
info!("entry {} already exists", store_request.sha256);
return Ok((StatusCode::OK, Json(StoreResponse::Ok{
sha256: store_request.sha256,
message: "exists",
})));
}
}
maybe_error?;
Ok((StatusCode::CREATED, Json(StoreResponse::Ok { sha256: store_request.sha256, message: "created" })))
})
.await.map_err(|e| {
error!("store failed: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, Json(StoreResponse::Error { sha256: Some(sha256), message: e.to_string() }))
})
}
pub async fn num_entries(&self) -> Result<usize, Box<dyn Error>> {
get_num_entries(&self.sqlite).await
}
}
fn is_duplicate_entry_err(error: &rusqlite::Error) -> bool {
use rusqlite::*;
if let Error::SqliteFailure(
ffi::Error {
code: ffi::ErrorCode::ConstraintViolation,
..
},
Some(err_str),
) = error
{
if err_str.contains("UNIQUE constraint failed: entries.sha256") {
return true;
}
}
false
}
async fn migrate(conn: &Connection) -> Result<(), Box<dyn Error>> {
// create tables, indexes, etc
conn.call(|conn| {
conn.execute(
"CREATE TABLE IF NOT EXISTS entries (
sha256 BLOB PRIMARY KEY,
content_type TEXT NOT NULL,
size INTEGER NOT NULL,
data BLOB NOT NULL,
created_at TEXT NOT NULL
)",
[],
)?;
Ok(())
})
.await?;
Ok(())
}
async fn get_num_entries(conn: &Connection) -> Result<usize, Box<dyn Error>> {
conn.call(|conn| {
let count: usize = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?;
Ok(count)
})
.await
.map_err(|e| e.into())
}

28
src/store_request.rs Normal file
View File

@@ -0,0 +1,28 @@
use axum::body::Bytes;
use axum_typed_multipart::{FieldData, TryFromMultipart};
use serde::Serialize;
#[derive(TryFromMultipart)]
pub struct StoreRequest {
pub sha256: Option<String>,
pub content_type: String,
pub data: FieldData<Bytes>,
}
pub struct StoreRequestParsed {
pub sha256: String,
pub content_type: String,
pub data: Bytes,
}
#[derive(Serialize)]
pub enum StoreResponse {
Ok {
sha256: String,
message: &'static str,
},
Error {
sha256: Option<String>,
message: String,
},
}