This commit is contained in:
Dylan Knutson
2024-04-23 13:46:06 -07:00
parent 2045fcb89b
commit 013de9c446
11 changed files with 326 additions and 39 deletions

7
Cargo.lock generated
View File

@@ -265,6 +265,7 @@ dependencies = [
"chrono",
"clap",
"futures",
"hex",
"kdam",
"rand",
"reqwest",
@@ -711,6 +712,12 @@ version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
[[package]]
name = "hex"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "http"
version = "1.1.0"

View File

@@ -29,3 +29,4 @@ tokio-rusqlite = "0.5.1"
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
reqwest = { version = "0.12.4", features = ["json", "multipart", "blocking"] }
hex = "0.4.3"

View File

@@ -1,20 +1,33 @@
use std::sync::Arc;
use std::sync::Mutex;
use clap::Parser;
use kdam::tqdm;
use kdam::Bar;
use kdam::BarExt;
use rand::Rng;
use reqwest::StatusCode;
#[derive(Parser, Debug, Clone)]
#[command(version, about, long_about = None)]
struct Args {
#[arg(long)]
file_size: usize,
#[arg(long)]
num_threads: usize,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
let pb = Arc::new(Mutex::new(tqdm!()));
let mut handles = vec![];
let num_shards = 8;
for _ in 0..num_shards {
for _ in 0..args.num_threads {
let pb = pb.clone();
let args = args.clone();
handles.push(std::thread::spawn(move || {
run_loop(pb).unwrap();
run_loop(pb, args).unwrap();
}));
}
@@ -25,10 +38,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
fn run_loop(pb: Arc<Mutex<Bar>>) -> Result<(), Box<dyn std::error::Error>> {
fn run_loop(pb: Arc<Mutex<Bar>>, args: Args) -> Result<(), Box<dyn std::error::Error>> {
let client = reqwest::blocking::Client::new();
let mut rng = rand::thread_rng();
let mut rand_data = vec![0u8; 1024 * 1024];
let mut rand_data = vec![0u8; args.file_size];
rng.fill(&mut rand_data[..]);
loop {
@@ -50,7 +63,8 @@ fn run_loop(pb: Arc<Mutex<Bar>>) -> Result<(), Box<dyn std::error::Error>> {
// update progress bar
let mut pb = pb.lock().unwrap();
pb.update(1)?;
if resp.status() != 200 && resp.status() != 201 {}
if resp.status() == StatusCode::CREATED {
pb.update(1)?;
}
}
}

View File

@@ -0,0 +1,89 @@
use std::collections::HashMap;
use axum::{
extract::Path,
http::{header, HeaderMap, HeaderName, HeaderValue, StatusCode},
Extension, Json,
};
use crate::shard::Shards;
#[derive(Debug, serde::Serialize)]
pub struct GetError {
sha256: Option<String>,
message: String,
}
#[axum::debug_handler]
pub async fn get_handler(
Path(params): Path<HashMap<String, String>>,
Extension(shards): Extension<Shards>,
) -> Result<(StatusCode, HeaderMap, Vec<u8>), (StatusCode, Json<GetError>)> {
let sha256_str = match params.get("sha256") {
Some(sha256_str) => sha256_str.clone(),
None => {
return Err((
StatusCode::BAD_REQUEST,
Json(GetError {
sha256: None,
message: "missing sha256 parameter".to_owned(),
}),
));
}
};
let sha256 = crate::sha256::Sha256::from_hex_string(&sha256_str).map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(GetError {
sha256: Some(sha256_str),
message: e.to_string(),
}),
)
})?;
let internal_error = |message| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(GetError {
sha256: Some(sha256.hex_string()),
message,
}),
)
};
let shard = shards.shard_for(&sha256);
let response = shard
.get(sha256)
.await
.map_err(|e| internal_error(e.to_string()))?;
let sha256_str = sha256.hex_string();
match response {
Some(response) => {
let content_type = HeaderValue::from_str(&response.content_type)
.map_err(|e| internal_error(e.to_string()))?;
let created_at = HeaderValue::from_str(&response.created_at.to_rfc3339())
.map_err(|e| internal_error(e.to_string()))?;
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, content_type);
headers.insert(
header::CACHE_CONTROL,
HeaderValue::from_static("public, max-age=31536000"),
);
headers.insert(header::ETAG, HeaderValue::from_str(&sha256_str).unwrap());
headers.insert(HeaderName::from_static("x-stored-at"), created_at);
Ok((StatusCode::OK, headers, response.data))
}
None => Err((
StatusCode::NOT_FOUND,
GetError {
sha256: Some(sha256_str),
message: "not found".to_owned(),
}
.into(),
)),
}
}

View File

@@ -0,0 +1,46 @@
use crate::shard::Shards;
use axum::{http::StatusCode, Extension, Json};
use tracing::error;
#[derive(serde::Serialize)]
pub struct InfoResponse {
num_shards: usize,
shards: Vec<ShardInfo>,
}
#[derive(serde::Serialize)]
pub struct ShardInfo {
id: usize,
num_entries: usize,
size_bytes: u64,
}
#[axum::debug_handler]
pub async fn info_handler(
Extension(shards): Extension<Shards>,
) -> Result<(StatusCode, Json<InfoResponse>), StatusCode> {
let mut shard_infos = vec![];
for shard in shards.iter() {
let num_entries = shard.num_entries().await.map_err(|e| {
error!("error getting num entries: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let size_bytes = shard.size_bytes().await.map_err(|e| {
error!("error getting size bytes: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
shard_infos.push(ShardInfo {
id: shard.id(),
num_entries,
size_bytes,
});
}
Ok((
StatusCode::OK,
Json(InfoResponse {
num_shards: shards.len(),
shards: shard_infos,
}),
))
}

View File

@@ -1 +1,3 @@
pub mod get_handler;
pub mod info_handler;
pub mod store_handler;

View File

@@ -14,7 +14,6 @@ pub async fn store_handler(
Extension(shards): Extension<Shards>,
TypedMultipart(request): TypedMultipart<StoreRequest>,
) -> AxumJsonResultOf<StoreResponse> {
// compute sha256 of data
let sha256 = Sha256::from_bytes(&request.data.contents);
let sha256_str = sha256.hex_string();
let shard = shards.shard_for(&sha256);

View File

@@ -5,7 +5,10 @@ mod sha256;
mod shard;
mod shutdown_signal;
use crate::shard::Shards;
use axum::{routing::post, Extension, Router};
use axum::{
routing::{get, post},
Extension, Router,
};
use clap::Parser;
use shard::Shard;
use std::{error::Error, path::PathBuf};
@@ -85,6 +88,8 @@ fn main() -> Result<(), Box<dyn Error>> {
async fn server_loop(server: TcpListener, shards: Shards) -> Result<(), Box<dyn Error>> {
let app = Router::new()
.route("/store", post(handlers::store_handler::store_handler))
.route("/get/:sha256", get(handlers::get_handler::get_handler))
.route("/info", get(handlers::info_handler::info_handler))
.layer(Extension(shards));
axum::serve(server, app.into_make_service())

View File

@@ -1,10 +1,39 @@
use std::fmt::LowerHex;
use std::{
error::Error,
fmt::{Display, LowerHex},
};
use sha2::Digest;
#[derive(Debug)]
struct Sha256Error {
message: String,
}
impl Display for Sha256Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl Error for Sha256Error {}
#[derive(Clone, Copy)]
pub struct Sha256([u8; 32]);
impl Sha256 {
pub fn from_hex_string(hex: &str) -> Result<Self, Box<dyn Error>> {
if hex.len() != 64 {
return Err(Box::new(Sha256Error {
message: "sha256 wrong length".to_owned(),
}));
}
let mut hash = [0; 32];
hex::decode_to_slice(hex, &mut hash).map_err(|e| Sha256Error {
message: format!("sha256 decode error: {}", e),
})?;
Ok(Self(hash))
}
pub fn from_bytes(bytes: &[u8]) -> Self {
let hash = sha2::Sha256::digest(bytes);
Self(hash.into())

View File

@@ -6,7 +6,10 @@ use crate::{
use axum::{http::StatusCode, Json};
use rusqlite::{params, OptionalExtension};
use std::{error::Error, path::Path};
use std::{
error::Error,
path::{Path, PathBuf},
};
use tokio_rusqlite::Connection;
use tracing::{debug, error, info};
@@ -30,18 +33,37 @@ impl Shards {
}
Ok(())
}
pub fn iter(&self) -> std::slice::Iter<'_, Shard> {
self.0.iter()
}
pub fn len(&self) -> usize {
self.0.len()
}
}
#[derive(Clone)]
pub struct Shard {
id: usize,
sqlite: Connection,
file_path: PathBuf,
}
pub struct GetResult {
pub content_type: String,
pub created_at: UtcDateTime,
pub data: Vec<u8>,
}
impl Shard {
pub async fn open(id: usize, db_path: &Path) -> Result<Self, Box<dyn Error>> {
let sqlite = Connection::open(db_path).await?;
let shard = Self { id, sqlite };
pub async fn open(id: usize, file_path: &Path) -> Result<Self, Box<dyn Error>> {
let sqlite = Connection::open(file_path).await?;
let shard = Self {
id,
sqlite,
file_path: file_path.to_owned(),
};
shard.migrate().await?;
Ok(shard)
}
@@ -55,6 +77,12 @@ impl Shard {
self.id
}
pub async fn size_bytes(&self) -> Result<u64, Box<dyn Error>> {
// stat the file to get its size in bytes
let metadata = tokio::fs::metadata(&self.file_path).await?;
Ok(metadata.len())
}
pub async fn store(
&self,
store_request: StoreRequestWithSha256,
@@ -92,6 +120,34 @@ impl Shard {
})
}
pub async fn get(&self, sha256: Sha256) -> Result<Option<GetResult>, Box<dyn Error>> {
self.sqlite
.call(move |conn| {
let get_result = conn
.query_row(
"SELECT content_type, created_at, data FROM entries WHERE sha256 = ?",
params![sha256.hex_string()],
|row| {
let content_type = row.get(0)?;
let created_at = parse_created_at_str(row.get(1)?)?;
let data = row.get(2)?;
Ok(GetResult {
content_type,
created_at,
data,
})
},
)
.optional()?;
Ok(get_result)
})
.await
.map_err(|e| {
error!("get failed: {}", e);
e.into()
})
}
pub async fn num_entries(&self) -> Result<usize, Box<dyn Error>> {
get_num_entries(&self.sqlite).await.map_err(|e| e.into())
}
@@ -114,10 +170,8 @@ impl Shard {
[],
|row| {
let ver = row.get(0)?;
let created_at_str: String = row.get(1)?;
let created_at = chrono::DateTime::parse_from_rfc3339(&created_at_str).map_err(|e| {
rusqlite::Error::ToSqlConversionFailure(e.into())
})?.to_utc();
// let created_at_str: String = row.get(1)?;
let created_at = parse_created_at_str(row.get(1)?)?;
Ok((ver, created_at))
}
).optional()?;
@@ -125,8 +179,7 @@ impl Shard {
if let Some((version, date_time)) = schema_row {
debug!(
"shard {}: latest schema version: {} @ {}",
shard_id,
version, date_time
shard_id, version, date_time
);
if version < 1 {
@@ -143,6 +196,12 @@ impl Shard {
}
}
fn parse_created_at_str(created_at_str: String) -> Result<UtcDateTime, rusqlite::Error> {
let parsed = chrono::DateTime::parse_from_rfc3339(&created_at_str)
.map_err(|e| rusqlite::Error::ToSqlConversionFailure(e.into()))?;
Ok(parsed.to_utc())
}
fn is_duplicate_entry_err(error: &rusqlite::Error) -> bool {
use rusqlite::*;

View File

@@ -1,29 +1,65 @@
require "rest_client"
require "json"
FILES = {
cat1: -> { File.new("cat1.jpg", mode: "rb") },
}
puts "response without sha256: "
puts RestClient.post("http://localhost:7692/store", {
content_type: "image/jpeg",
data: FILES[:cat1].call,
})
def run
cat1_sha256 = Digest::SHA256.file("cat1.jpg").hexdigest
puts "response with correct sha256:"
puts RestClient.post("http://localhost:7692/store", {
content_type: "image/jpeg",
sha256: "e3705544cbf2fa93e16107d1821b312a7b825fc177fa28180a9c9a9d3ae8af37",
data: FILES[:cat1].call,
})
puts "response with incorrect sha256:"
begin
puts RestClient.post("http://localhost:7692/store", {
puts "store, with sha256"
dump_resp(RestClient.post("http://localhost:7692/store", {
content_type: "image/jpeg",
sha256: "123",
data: FILES[:cat1].call,
})
rescue => e
puts e.response
}))
puts "store, without sha256:"
dump_resp(RestClient.post("http://localhost:7692/store", {
content_type: "image/jpeg",
sha256: cat1_sha256,
data: FILES[:cat1].call,
}))
puts "store, incorrect sha256:"
begin
RestClient.post("http://localhost:7692/store", {
content_type: "image/jpeg",
sha256: "123",
data: FILES[:cat1].call,
})
puts "should have thrown!"
rescue => e
dump_resp(e.response)
end
puts "get, with sha256:"
dump_resp(RestClient.get("http://localhost:7692/get/#{cat1_sha256}"))
puts "get, 404 sha256:"
begin
RestClient.get("http://localhost:7692/get/e3705544cbf2fa93e16107d1821b312a7b825fc177fa28180a9c9a9d3ae8af3c")
raise "should have thrown!"
rescue => e
dump_resp(e.response)
raise "not 404" if e.response.code != 404
end
end
def dump_resp(resp)
puts " -> code: #{resp.code}"
headers = resp.headers
content_type = headers[:content_type]
puts " -> headers: #{headers}"
puts " -> content_type: #{content_type}"
puts " -> size: #{resp.size} bytes"
if content_type == "application/json"
puts " -> body: #{JSON.parse(resp.body)}"
else
body_sha256 = Digest::SHA256.hexdigest(resp.body)
puts " -> body sha256: #{body_sha256}"
end
puts "-" * 80
end
run