more refactors, type wrappers
This commit is contained in:
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -1359,9 +1359,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pki-types"
|
||||
version = "1.4.1"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ecd36cc4259e3e4514335c4a138c6b43171a8d61d8f5c9348f9fc7529416f247"
|
||||
checksum = "beb461507cee2c2ff151784c52762cf4d9ff6a61f3e80968600ed24fa837fa54"
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
|
||||
@@ -10,7 +10,7 @@ path = "src/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "load-test"
|
||||
path = "src/load_test.rs"
|
||||
path = "load_test/main.rs"
|
||||
|
||||
[dependencies]
|
||||
axum = { version = "0.7.5", features = ["macros"] }
|
||||
@@ -20,7 +20,6 @@ clap = { version = "4.5.4", features = ["derive"] }
|
||||
futures = "0.3.30"
|
||||
kdam = "0.5.1"
|
||||
rand = "0.8.5"
|
||||
reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] }
|
||||
rusqlite = "0.31.0"
|
||||
serde = { version = "1.0.198", features = ["serde_derive"] }
|
||||
serde_json = "1.0.116"
|
||||
@@ -29,3 +28,4 @@ tokio = { version = "1.37.0", features = ["full", "rt-multi-thread"] }
|
||||
tokio-rusqlite = "0.5.1"
|
||||
tracing = "0.1.40"
|
||||
tracing-subscriber = "0.3.18"
|
||||
reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] }
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
3
src/axum_result_type.rs
Normal file
3
src/axum_result_type.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
use axum::{http::StatusCode, Json};
|
||||
|
||||
pub type AxumJsonResultOf<T> = Result<(StatusCode, Json<T>), (StatusCode, Json<T>)>;
|
||||
203
src/main.rs
203
src/main.rs
@@ -1,20 +1,27 @@
|
||||
mod axum_result_type;
|
||||
mod sha256;
|
||||
mod shard;
|
||||
mod shutdown_signal;
|
||||
mod store_request;
|
||||
|
||||
use axum::Json;
|
||||
use axum::{body::Bytes, http::StatusCode, routing::post, Extension, Router};
|
||||
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||
use axum::{http::StatusCode, routing::post, Extension, Router};
|
||||
use axum_result_type::AxumJsonResultOf;
|
||||
use axum_typed_multipart::TypedMultipart;
|
||||
use clap::Parser;
|
||||
use rusqlite::ffi;
|
||||
use rusqlite::params;
|
||||
use rusqlite::Error::SqliteFailure;
|
||||
use sha2::Digest;
|
||||
|
||||
use shard::Shard;
|
||||
use std::collections::HashMap;
|
||||
use std::{borrow::Borrow, error::Error, path::PathBuf, sync::Arc};
|
||||
use std::{error::Error, path::PathBuf};
|
||||
use store_request::StoreResponse;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use tokio_rusqlite::Connection;
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::sha256::Sha256;
|
||||
use crate::shard::Shards;
|
||||
use crate::store_request::{StoreRequest, StoreRequestParsed};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(version, about, long_about = None)]
|
||||
struct Args {
|
||||
@@ -22,6 +29,14 @@ struct Args {
|
||||
#[arg(short, long)]
|
||||
db_path: String,
|
||||
|
||||
/// Port to listen on
|
||||
#[arg(short, long, default_value_t = 7692)]
|
||||
port: u16,
|
||||
|
||||
/// Host to bind to
|
||||
#[arg(short, long, default_value = "127.0.0.1")]
|
||||
bind: String,
|
||||
|
||||
/// Number of shards
|
||||
#[arg(short, long)]
|
||||
shards: Option<usize>,
|
||||
@@ -32,22 +47,6 @@ struct ManifestData {
|
||||
shards: usize,
|
||||
}
|
||||
|
||||
#[derive(TryFromMultipart)]
|
||||
struct StoreRequest {
|
||||
sha256: Option<String>,
|
||||
content_type: String,
|
||||
data: FieldData<Bytes>,
|
||||
}
|
||||
|
||||
struct StoreRequestParsed {
|
||||
sha256: String,
|
||||
content_type: String,
|
||||
data: Bytes,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Shards(Vec<Arc<Connection>>);
|
||||
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::DEBUG)
|
||||
@@ -55,7 +54,7 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
|
||||
let args = Args::parse();
|
||||
let db_path = PathBuf::from(&args.db_path);
|
||||
let num_shards = validate_manifest(args)?;
|
||||
let num_shards = validate_manifest(&args)?;
|
||||
|
||||
// max num_shards threads
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
@@ -64,32 +63,30 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
.build()?;
|
||||
|
||||
runtime.block_on(async {
|
||||
let server = TcpListener::bind("127.0.0.1:7692").await?;
|
||||
let server = TcpListener::bind(format!("{}:{}", args.bind, args.port)).await?;
|
||||
info!(
|
||||
"listening on {} with {} shards",
|
||||
server.local_addr()?,
|
||||
num_shards
|
||||
);
|
||||
let mut shards = vec![];
|
||||
for shard in 0..num_shards {
|
||||
let shard_path = db_path.join(format!("shard{}.sqlite", shard));
|
||||
let conn = Connection::open(shard_path).await?;
|
||||
migrate(&conn).await?;
|
||||
shards.push(Arc::new(conn));
|
||||
let mut shards_vec = vec![];
|
||||
for shard_id in 0..num_shards {
|
||||
let shard_path = db_path.join(format!("shard{}.sqlite", shard_id));
|
||||
let shard = Shard::open(shard_id, &shard_path).await?;
|
||||
info!(
|
||||
"shard {} has {} entries",
|
||||
shard.id(),
|
||||
shard.num_entries().await?
|
||||
);
|
||||
shards_vec.push(shard);
|
||||
}
|
||||
|
||||
for (shard, conn) in shards.iter().enumerate() {
|
||||
let count = num_entries_in(conn).await?;
|
||||
info!("shard {} has {} entries", shard, count);
|
||||
}
|
||||
|
||||
server_loop(server, Shards(shards.clone())).await?;
|
||||
let shards = Shards::new(shards_vec);
|
||||
server_loop(server, shards.clone()).await?;
|
||||
info!("shutting down server...");
|
||||
for conn in shards.into_iter() {
|
||||
(*conn).clone().close().await?;
|
||||
}
|
||||
shards.close_all().await?;
|
||||
info!("server closed sqlite connections. bye!");
|
||||
Ok::<(), Box<dyn Error>>(())
|
||||
Ok::<_, Box<dyn Error>>(())
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
@@ -99,138 +96,52 @@ async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box<dyn
|
||||
let app = Router::new()
|
||||
.route("/store", post(store_request_handler))
|
||||
.layer(Extension(shards));
|
||||
|
||||
axum::serve(server, app.into_make_service())
|
||||
.with_graceful_shutdown(crate::shutdown_signal::shutdown_signal())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
type ResultType = Result<
|
||||
(StatusCode, Json<HashMap<&'static str, String>>),
|
||||
(StatusCode, Json<HashMap<&'static str, String>>),
|
||||
>;
|
||||
|
||||
#[axum::debug_handler]
|
||||
async fn store_request_handler(
|
||||
Extension(shards): Extension<Shards>,
|
||||
TypedMultipart(request): TypedMultipart<StoreRequest>,
|
||||
) -> ResultType {
|
||||
) -> AxumJsonResultOf<StoreResponse> {
|
||||
// compute sha256 of data
|
||||
let data_bytes = &request.data.contents;
|
||||
let sha256 = sha2::Sha256::digest(data_bytes);
|
||||
let sha256_str = format!("{:x}", sha256);
|
||||
let num_shards = shards.0.len();
|
||||
// select shard
|
||||
let shard_num = sha256[0] as usize % num_shards;
|
||||
let conn = &shards.0[shard_num];
|
||||
let sha256 = Sha256::from_bytes(&request.data.contents);
|
||||
let sha256_str = sha256.hex_string();
|
||||
let shard = shards.shard_for(&sha256);
|
||||
|
||||
if let Some(req_sha256) = request.sha256 {
|
||||
if req_sha256 != sha256_str {
|
||||
error!("sha256 mismatch: {} != {}", req_sha256, sha256_str);
|
||||
error!(
|
||||
"client sent mismatched sha256: (client) {} != (computed) {}",
|
||||
req_sha256, sha256_str
|
||||
);
|
||||
let mut response = HashMap::new();
|
||||
response.insert("status", "error".to_owned());
|
||||
response.insert("message", "sha256 mismatch".to_owned());
|
||||
return Err((StatusCode::BAD_REQUEST, Json(response)));
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(StoreResponse::Error {
|
||||
sha256: Some(req_sha256),
|
||||
message: "sha256 mismatch".to_owned(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// info!("storing {} on shard {}", sha256_str, shard_num);
|
||||
|
||||
let request_parsed = StoreRequestParsed {
|
||||
sha256: sha256_str,
|
||||
content_type: request.content_type,
|
||||
data: request.data.contents,
|
||||
};
|
||||
|
||||
let conn = conn.borrow();
|
||||
perform_store(conn, request_parsed).await
|
||||
shard.store(request_parsed).await
|
||||
}
|
||||
|
||||
async fn perform_store(conn: &Connection, store_request: StoreRequestParsed) -> ResultType {
|
||||
conn.call(move |conn| {
|
||||
let created_at = chrono::Utc::now().to_rfc3339();
|
||||
let maybe_error = conn.execute(
|
||||
"INSERT INTO entries (sha256, content_type, size, data, created_at) VALUES (?, ?, ?, ?, ?)",
|
||||
params![
|
||||
store_request.sha256,
|
||||
store_request.content_type,
|
||||
store_request.data.len() as i64,
|
||||
store_request.data.to_vec(),
|
||||
created_at,
|
||||
],
|
||||
);
|
||||
|
||||
let mut response = HashMap::new();
|
||||
response.insert("sha256", store_request.sha256.clone());
|
||||
|
||||
if let Err(e) = &maybe_error {
|
||||
if is_duplicate_entry_err(e) {
|
||||
info!("entry {} already exists", store_request.sha256);
|
||||
response.insert("status","ok".to_owned());
|
||||
response.insert("message", "already exists".to_owned());
|
||||
return Ok((StatusCode::OK, Json(response)));
|
||||
}
|
||||
}
|
||||
maybe_error?;
|
||||
let mut response = HashMap::new();
|
||||
response.insert("status", "ok".to_owned());
|
||||
response.insert("message", "created".to_owned());
|
||||
Ok((StatusCode::CREATED, Json(response)))
|
||||
})
|
||||
.await.map_err(|e| {
|
||||
error!("store failed: {}", e);
|
||||
let mut response = HashMap::new();
|
||||
response.insert("status", "error".to_owned());
|
||||
response.insert("message", e.to_string());
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(response))
|
||||
})
|
||||
}
|
||||
|
||||
fn is_duplicate_entry_err(error: &rusqlite::Error) -> bool {
|
||||
if let SqliteFailure(
|
||||
ffi::Error {
|
||||
code: ffi::ErrorCode::ConstraintViolation,
|
||||
..
|
||||
},
|
||||
Some(err_str),
|
||||
) = error
|
||||
{
|
||||
if err_str.contains("UNIQUE constraint failed: entries.sha256") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
async fn migrate(conn: &Connection) -> Result<(), Box<dyn Error>> {
|
||||
// create tables, indexes, etc
|
||||
conn.call(|conn| {
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS entries (
|
||||
sha256 BLOB PRIMARY KEY,
|
||||
content_type TEXT NOT NULL,
|
||||
size INTEGER NOT NULL,
|
||||
data BLOB NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn num_entries_in(conn: &Connection) -> Result<i64, Box<dyn Error>> {
|
||||
conn.call(|conn| {
|
||||
let count: i64 = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?;
|
||||
Ok(count)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
fn validate_manifest(args: Args) -> Result<usize, Box<dyn Error>> {
|
||||
fn validate_manifest(args: &Args) -> Result<usize, Box<dyn Error>> {
|
||||
let manifest_path = PathBuf::from(&args.db_path).join("manifest.json");
|
||||
if manifest_path.exists() {
|
||||
let file_content = std::fs::read_to_string(manifest_path)?;
|
||||
|
||||
27
src/sha256.rs
Normal file
27
src/sha256.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use std::fmt::LowerHex;
|
||||
|
||||
use sha2::Digest;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Sha256([u8; 32]);
|
||||
impl Sha256 {
|
||||
pub fn from_bytes(bytes: &[u8]) -> Self {
|
||||
let hash = sha2::Sha256::digest(bytes);
|
||||
Self(hash.into())
|
||||
}
|
||||
pub fn hex_string(&self) -> String {
|
||||
format!("{:x}", self)
|
||||
}
|
||||
pub fn modulo(&self, num: usize) -> usize {
|
||||
self.0[0] as usize % num
|
||||
}
|
||||
}
|
||||
|
||||
impl LowerHex for Sha256 {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
for byte in self.0.iter() {
|
||||
write!(f, "{:02x}", byte)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
142
src/shard/mod.rs
Normal file
142
src/shard/mod.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
use std::{error::Error, path::Path};
|
||||
|
||||
use rusqlite::params;
|
||||
use tokio_rusqlite::Connection;
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::{
|
||||
axum_result_type::AxumJsonResultOf,
|
||||
sha256::Sha256,
|
||||
store_request::{StoreRequestParsed, StoreResponse},
|
||||
};
|
||||
use axum::{http::StatusCode, Json};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Shards(Vec<Shard>);
|
||||
impl Shards {
|
||||
pub fn new(shards: Vec<Shard>) -> Self {
|
||||
Self(shards)
|
||||
}
|
||||
|
||||
pub fn shard_for(&self, sha256: &Sha256) -> &Shard {
|
||||
let shard_id = sha256.modulo(self.0.len());
|
||||
&self.0[shard_id]
|
||||
}
|
||||
|
||||
pub async fn close_all(self) -> Result<(), Box<dyn Error>> {
|
||||
for shard in self.0 {
|
||||
shard.close().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Shard {
|
||||
id: usize,
|
||||
sqlite: Connection,
|
||||
}
|
||||
|
||||
impl Shard {
|
||||
pub async fn open(id: usize, db_path: &Path) -> Result<Self, Box<dyn Error>> {
|
||||
let sqlite = Connection::open(db_path).await?;
|
||||
migrate(&sqlite).await?;
|
||||
Ok(Self { id, sqlite })
|
||||
}
|
||||
|
||||
pub async fn close(self) -> Result<(), Box<dyn Error>> {
|
||||
self.sqlite.close().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn id(&self) -> usize {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub async fn store(
|
||||
&self,
|
||||
store_request: StoreRequestParsed,
|
||||
) -> AxumJsonResultOf<StoreResponse> {
|
||||
let sha256 = store_request.sha256.clone();
|
||||
|
||||
self.sqlite.call(move |conn| {
|
||||
let created_at = chrono::Utc::now().to_rfc3339();
|
||||
let maybe_error = conn.execute(
|
||||
"INSERT INTO entries (sha256, content_type, size, data, created_at) VALUES (?, ?, ?, ?, ?)",
|
||||
params![
|
||||
store_request.sha256,
|
||||
store_request.content_type,
|
||||
store_request.data.len() as i64,
|
||||
store_request.data.to_vec(),
|
||||
created_at,
|
||||
],
|
||||
);
|
||||
|
||||
if let Err(e) = &maybe_error {
|
||||
if is_duplicate_entry_err(e) {
|
||||
info!("entry {} already exists", store_request.sha256);
|
||||
return Ok((StatusCode::OK, Json(StoreResponse::Ok{
|
||||
sha256: store_request.sha256,
|
||||
message: "exists",
|
||||
})));
|
||||
}
|
||||
}
|
||||
maybe_error?;
|
||||
Ok((StatusCode::CREATED, Json(StoreResponse::Ok { sha256: store_request.sha256, message: "created" })))
|
||||
})
|
||||
.await.map_err(|e| {
|
||||
error!("store failed: {}", e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(StoreResponse::Error { sha256: Some(sha256), message: e.to_string() }))
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn num_entries(&self) -> Result<usize, Box<dyn Error>> {
|
||||
get_num_entries(&self.sqlite).await
|
||||
}
|
||||
}
|
||||
|
||||
fn is_duplicate_entry_err(error: &rusqlite::Error) -> bool {
|
||||
use rusqlite::*;
|
||||
|
||||
if let Error::SqliteFailure(
|
||||
ffi::Error {
|
||||
code: ffi::ErrorCode::ConstraintViolation,
|
||||
..
|
||||
},
|
||||
Some(err_str),
|
||||
) = error
|
||||
{
|
||||
if err_str.contains("UNIQUE constraint failed: entries.sha256") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
async fn migrate(conn: &Connection) -> Result<(), Box<dyn Error>> {
|
||||
// create tables, indexes, etc
|
||||
conn.call(|conn| {
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS entries (
|
||||
sha256 BLOB PRIMARY KEY,
|
||||
content_type TEXT NOT NULL,
|
||||
size INTEGER NOT NULL,
|
||||
data BLOB NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_num_entries(conn: &Connection) -> Result<usize, Box<dyn Error>> {
|
||||
conn.call(|conn| {
|
||||
let count: usize = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?;
|
||||
Ok(count)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
28
src/store_request.rs
Normal file
28
src/store_request.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
use axum::body::Bytes;
|
||||
use axum_typed_multipart::{FieldData, TryFromMultipart};
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(TryFromMultipart)]
|
||||
pub struct StoreRequest {
|
||||
pub sha256: Option<String>,
|
||||
pub content_type: String,
|
||||
pub data: FieldData<Bytes>,
|
||||
}
|
||||
|
||||
pub struct StoreRequestParsed {
|
||||
pub sha256: String,
|
||||
pub content_type: String,
|
||||
pub data: Bytes,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub enum StoreResponse {
|
||||
Ok {
|
||||
sha256: String,
|
||||
message: &'static str,
|
||||
},
|
||||
Error {
|
||||
sha256: Option<String>,
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
Reference in New Issue
Block a user