initial commit

This commit is contained in:
Dylan Knutson
2024-04-23 09:47:36 -07:00
commit 37cc74bfd1
7 changed files with 2592 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
.DS_Store
target

2191
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

31
Cargo.toml Normal file
View File

@@ -0,0 +1,31 @@
[package]
name = "blob-store-app"
version = "0.1.0"
edition = "2021"
default-run = "blob-store-app"
[[bin]]
name = "blob-store-app"
path = "src/main.rs"
[[bin]]
name = "load-test"
path = "src/load_test.rs"
[dependencies]
axum = { version = "0.7.5", features = ["macros"] }
axum_typed_multipart = "0.11.1"
chrono = "0.4.38"
clap = { version = "4.5.4", features = ["derive"] }
futures = "0.3.30"
kdam = "0.5.1"
rand = "0.8.5"
reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] }
rusqlite = "0.31.0"
serde = { version = "1.0.198", features = ["serde_derive"] }
serde_json = "1.0.116"
sha2 = "0.10.8"
tokio = { version = "1.37.0", features = ["full", "rt-multi-thread"] }
tokio-rusqlite = "0.5.1"
tracing = "0.1.40"
tracing-subscriber = "0.3.18"

59
src/load_test.rs Normal file
View File

@@ -0,0 +1,59 @@
use std::borrow::Borrow;
use std::borrow::BorrowMut;
use std::sync::Arc;
use std::sync::Mutex;
use kdam::tqdm;
use kdam::Bar;
use kdam::BarExt;
use rand::Rng;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let pb = Arc::new(Mutex::new(tqdm!()));
let mut handles = vec![];
let num_shards = 8;
for _ in (0..num_shards).into_iter() {
let pb = pb.clone();
handles.push(std::thread::spawn(move || {
run_loop(pb).unwrap();
}));
}
for handle in handles {
handle.join().unwrap();
}
Ok(())
}
fn run_loop(mut pb: Arc<Mutex<Bar>>) -> Result<(), Box<dyn std::error::Error>> {
let client = reqwest::blocking::Client::new();
let mut rng = rand::thread_rng();
let mut rand_data = vec![0u8; 1024 * 1024];
rng.fill(&mut rand_data[..]);
loop {
// tweak a byte in the data
let idx = rng.gen_range(0..rand_data.len());
rand_data[idx] = rng.gen();
let form = reqwest::blocking::multipart::Form::new()
.text("content_type", "text/plain")
.part(
"data",
reqwest::blocking::multipart::Part::bytes(rand_data.clone()),
);
let resp = client
.post("http://localhost:7692/store")
.multipart(form)
.send()?;
// update progress bar
let mut pb = pb.lock().unwrap();
pb.update(1)?;
if resp.status() != 200 && resp.status() != 201 {}
}
Ok(())
}

280
src/main.rs Normal file
View File

@@ -0,0 +1,280 @@
use axum::Json;
use axum::{body::Bytes, http::StatusCode, routing::post, Extension, Router};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use clap::Parser;
use rusqlite::ffi;
use rusqlite::params;
use rusqlite::Error::SqliteFailure;
use sha2::Digest;
use tokio::signal;
use std::collections::HashMap;
use std::{borrow::Borrow, error::Error, path::PathBuf, sync::Arc};
use tokio::net::TcpListener;
use tokio_rusqlite::Connection;
use tracing::{error, info};
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
/// Directory that holds the backing database files
#[arg(short, long)]
db_path: String,
/// Number of shards
#[arg(short, long)]
shards: Option<usize>,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ManifestData {
shards: usize,
}
#[derive(TryFromMultipart)]
struct StoreRequest {
sha256: Option<String>,
content_type: String,
data: FieldData<Bytes>,
}
struct StoreRequestParsed {
sha256: String,
content_type: String,
data: Bytes,
}
#[derive(Clone)]
struct Shards(Vec<Arc<Connection>>);
fn main() -> Result<(), Box<dyn Error>> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.init();
let args = Args::parse();
let db_path = PathBuf::from(&args.db_path);
let num_shards = validate_manifest(args)?;
// max num_shards threads
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(num_shards as usize)
.enable_all()
.build()?;
runtime.block_on(async {
let server = TcpListener::bind("127.0.0.1:7692").await?;
info!(
"listening on {} with {} shards",
server.local_addr()?,
num_shards
);
let mut shards = vec![];
for shard in 0..num_shards {
let shard_path = db_path.join(format!("shard{}.sqlite", shard));
let conn = Connection::open(shard_path).await?;
migrate(&conn).await?;
shards.push(Arc::new(conn));
}
for (shard, conn) in shards.iter().enumerate() {
let count = num_entries_in(&conn).await?;
info!("shard {} has {} entries", shard, count);
}
server_loop(server, Shards(shards.clone())).await?;
info!("shutting down server...");
for conn in shards.into_iter() {
(*conn).clone().close().await?;
}
info!("server closed sqlite connections. bye!");
Ok::<(), Box<dyn Error>>(())
})?;
Ok(())
}
async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box<dyn Error>> {
let app = Router::new()
.route("/store", post(store_request_handler))
.layer(Extension(shards));
axum::serve(server, app.into_make_service()).with_graceful_shutdown(shutdown_signal()).await?;
Ok(())
}
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
}
type ResultType = Result<
(StatusCode, Json<HashMap<&'static str, String>>),
(StatusCode, Json<HashMap<&'static str, String>>)
>;
#[axum::debug_handler]
async fn store_request_handler(
Extension(shards): Extension<Shards>,
TypedMultipart(request): TypedMultipart<StoreRequest>,
) -> ResultType {
// compute sha256 of data
let data_bytes = &request.data.contents;
let sha256 = sha2::Sha256::digest(&data_bytes);
let sha256_str = format!("{:x}", sha256);
let num_shards = shards.0.len();
// select shard
let shard_num = sha256[0] as usize % num_shards;
let conn = &shards.0[shard_num];
if let Some(req_sha256) = request.sha256 {
if req_sha256 != sha256_str {
error!("sha256 mismatch: {} != {}", req_sha256, sha256_str);
let mut response = HashMap::new();
response.insert("status", "error".to_owned());
response.insert("message", "sha256 mismatch".to_owned());
return Err((StatusCode::BAD_REQUEST, Json(response)));
}
}
// info!("storing {} on shard {}", sha256_str, shard_num);
let request_parsed = StoreRequestParsed {
sha256: sha256_str,
content_type: request.content_type,
data: request.data.contents,
};
let conn = conn.borrow();
perform_store(&conn, request_parsed).await
}
async fn perform_store(
conn: &Connection,
store_request: StoreRequestParsed,
) -> ResultType {
conn.call(move |conn| {
let created_at = chrono::Utc::now().to_rfc3339();
let maybe_error = conn.execute(
"INSERT INTO entries (sha256, content_type, size, data, created_at) VALUES (?, ?, ?, ?, ?)",
params![
store_request.sha256,
store_request.content_type,
store_request.data.len() as i64,
store_request.data.to_vec(),
created_at,
],
);
let mut response = HashMap::new();
response.insert("sha256", store_request.sha256.clone());
if let Err(e) = &maybe_error {
if is_duplicate_entry_err(e) {
info!("entry {} already exists", store_request.sha256);
response.insert("status","ok".to_owned());
response.insert("message", "already exists".to_owned());
return Ok((StatusCode::OK, Json(response)));
}
}
maybe_error?;
let mut response = HashMap::new();
response.insert("status", "ok".to_owned());
response.insert("message", "created".to_owned());
return Ok((StatusCode::CREATED, Json(response)));
})
.await.map_err(|e| {
error!("store failed: {}", e);
let mut response = HashMap::new();
response.insert("status", "error".to_owned());
response.insert("message", e.to_string());
(StatusCode::INTERNAL_SERVER_ERROR, Json(response))
})
}
fn is_duplicate_entry_err(error: &rusqlite::Error) -> bool {
if let SqliteFailure(
ffi::Error {
code: ffi::ErrorCode::ConstraintViolation,
..
},
Some(err_str),
) = error
{
if err_str.contains("UNIQUE constraint failed: entries.sha256") {
return true;
}
}
false
}
async fn migrate(conn: &Connection) -> Result<(), Box<dyn Error>> {
// create tables, indexes, etc
conn.call(|conn| {
conn.execute(
"CREATE TABLE IF NOT EXISTS entries (
sha256 BLOB PRIMARY KEY,
content_type TEXT NOT NULL,
size INTEGER NOT NULL,
data BLOB NOT NULL,
created_at TEXT NOT NULL
)",
[],
)?;
Ok(())
})
.await?;
Ok(())
}
async fn num_entries_in(conn: &Connection) -> Result<i64, Box<dyn Error>> {
conn.call(|conn| {
let count: i64 = conn.query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))?;
Ok(count)
}).await.map_err(|e| e.into())
}
fn validate_manifest(args: Args) -> Result<usize, Box<dyn Error>> {
let manifest_path = PathBuf::from(&args.db_path).join("manifest.json");
if manifest_path.exists() {
let file_content = std::fs::read_to_string(manifest_path)?;
let manifest: ManifestData = serde_json::from_str(&file_content)?;
if let Some(shards) = args.shards {
if shards != manifest.shards {
return Err(format!(
"manifest indicates {} shards, expected {}",
manifest.shards, shards
)
.into());
}
}
Ok(manifest.shards)
} else {
if let Some(shards) = args.shards {
std::fs::create_dir_all(&args.db_path)?;
let manifest = ManifestData { shards };
let manifest_json = serde_json::to_string(&manifest)?;
std::fs::write(manifest_path, manifest_json)?;
return Ok(shards);
} else {
return Err("new database needs --shards argument".into());
}
}
}

BIN
test/cat1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 185 KiB

29
test/test.rb Normal file
View File

@@ -0,0 +1,29 @@
require "rest_client"
FILES = {
cat1: -> { File.new("cat1.jpg", mode: "rb") },
}
puts "response without sha256: "
puts RestClient.post("http://localhost:7692/store", {
content_type: "image/jpeg",
data: FILES[:cat1].call,
})
puts "response with correct sha256:"
puts RestClient.post("http://localhost:7692/store", {
content_type: "image/jpeg",
sha256: "e3705544cbf2fa93e16107d1821b312a7b825fc177fa28180a9c9a9d3ae8af37",
data: FILES[:cat1].call,
})
puts "response with incorrect sha256:"
begin
puts RestClient.post("http://localhost:7692/store", {
content_type: "image/jpeg",
sha256: "123",
data: FILES[:cat1].call,
})
rescue => e
puts e.response
end