initial commit
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
/target
|
||||||
|
.DS_Store
|
||||||
|
*.sqlite3
|
||||||
2527
Cargo.lock
generated
Normal file
2527
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
31
Cargo.toml
Normal file
31
Cargo.toml
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
[package]
|
||||||
|
name = "is-this-a-repost-telegram-bot"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
chrono = "0.4.38"
|
||||||
|
clap = { version = "4.5.21", features = ["derive"] }
|
||||||
|
humansize = "2.1.3"
|
||||||
|
image = { version = "0.25.5", default-features = false, features = [
|
||||||
|
"webp",
|
||||||
|
"jpeg",
|
||||||
|
"gif",
|
||||||
|
"bmp",
|
||||||
|
"png",
|
||||||
|
] }
|
||||||
|
image_hasher = "2.0.0"
|
||||||
|
log = "0.4.22"
|
||||||
|
pretty_env_logger = "0.5.0"
|
||||||
|
rusqlite = { version = "0.32.1", features = [
|
||||||
|
"vtab",
|
||||||
|
"modern_sqlite",
|
||||||
|
"load_extension",
|
||||||
|
"bundled",
|
||||||
|
"array",
|
||||||
|
] }
|
||||||
|
sqlite-vec = "0.1.5"
|
||||||
|
teloxide = { version = "0.13.0", features = ["macros"] }
|
||||||
|
tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread"] }
|
||||||
|
tokio-rusqlite = "0.6.0"
|
||||||
|
zerocopy = "0.8.10"
|
||||||
BIN
fixtures/1060-536x354-blur_2.jpg
Normal file
BIN
fixtures/1060-536x354-blur_2.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 17 KiB |
BIN
fixtures/1081-536x354.jpg
Normal file
BIN
fixtures/1081-536x354.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 36 KiB |
BIN
fixtures/237-536x354.jpg
Normal file
BIN
fixtures/237-536x354.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 35 KiB |
BIN
fixtures/237-536x354_low_quality.jpg
Normal file
BIN
fixtures/237-536x354_low_quality.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 20 KiB |
BIN
fixtures/237-536x354_missing_chunk.png
Normal file
BIN
fixtures/237-536x354_missing_chunk.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 226 KiB |
BIN
fixtures/866-536x354.jpg
Normal file
BIN
fixtures/866-536x354.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 14 KiB |
1
src/async_error.rs
Normal file
1
src/async_error.rs
Normal file
@@ -0,0 +1 @@
|
|||||||
|
pub type AsyncError = Box<dyn std::error::Error + Send + Sync>;
|
||||||
286
src/db.rs
Normal file
286
src/db.rs
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
use rusqlite::{ffi::sqlite3_auto_extension, params};
|
||||||
|
use sqlite_vec::sqlite3_vec_init;
|
||||||
|
use teloxide::types::{ChatId, MessageId, UserId};
|
||||||
|
use tokio_rusqlite::Connection;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
async_error::AsyncError,
|
||||||
|
media_hash::{MediaHash, HASH_SIZE},
|
||||||
|
sql_types::MediaId,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Db {
|
||||||
|
connection: Connection,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct SearchChatMessagesResult {
|
||||||
|
pub chat_id: ChatId,
|
||||||
|
pub message_id: MessageId,
|
||||||
|
pub media_id: MediaId,
|
||||||
|
pub distance: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Db {
|
||||||
|
pub async fn insert_chat_message(
|
||||||
|
&self,
|
||||||
|
chat_id: ChatId,
|
||||||
|
message_id: MessageId,
|
||||||
|
media_hash: MediaHash,
|
||||||
|
) -> Result<MediaId, AsyncError> {
|
||||||
|
Ok(self
|
||||||
|
.connection
|
||||||
|
.call(move |conn| {
|
||||||
|
let now = chrono::Utc::now().timestamp();
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO medias (media_hash) VALUES (vec_normalize(?))",
|
||||||
|
params![media_hash],
|
||||||
|
)?;
|
||||||
|
let media_id = conn.query_row("SELECT last_insert_rowid()", [], |row| {
|
||||||
|
Ok(MediaId(row.get(0)?))
|
||||||
|
})?;
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO chat_posts (chat_id, message_id, media_id, created_at) VALUES (?, ?, ?, ?)",
|
||||||
|
params![chat_id.0, message_id.0, media_id.0, now],
|
||||||
|
)?;
|
||||||
|
Ok(media_id)
|
||||||
|
})
|
||||||
|
.await?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn insert_channel_ownership(
|
||||||
|
&self,
|
||||||
|
owner_user_id: UserId,
|
||||||
|
chat_id: ChatId,
|
||||||
|
) -> Result<(), AsyncError> {
|
||||||
|
self.connection
|
||||||
|
.call(move |conn| {
|
||||||
|
Ok(conn.execute(
|
||||||
|
"INSERT INTO channel_ownerships (owner_user_id, chat_id) VALUES (?, ?)",
|
||||||
|
params![owner_user_id.0, chat_id.0],
|
||||||
|
)?)
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete_channel_ownership(
|
||||||
|
&self,
|
||||||
|
owner_user_id: UserId,
|
||||||
|
chat_id: ChatId,
|
||||||
|
) -> Result<(), AsyncError> {
|
||||||
|
self.connection
|
||||||
|
.call(move |conn| {
|
||||||
|
Ok(conn.execute(
|
||||||
|
"DELETE FROM channel_ownerships WHERE owner_user_id = ? AND chat_id = ?",
|
||||||
|
params![owner_user_id.0, chat_id.0],
|
||||||
|
)?)
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn print_info(&self) -> Result<(), AsyncError> {
|
||||||
|
let vec_version = self
|
||||||
|
.connection
|
||||||
|
.call(|conn| {
|
||||||
|
let version =
|
||||||
|
conn.query_row("select vec_version()", [], |row| row.get::<_, String>(0))?;
|
||||||
|
Ok(version)
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
log::info!("sqlite_vec version: {}", vec_version);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn search_chat_messages(
|
||||||
|
&self,
|
||||||
|
media_hash: MediaHash,
|
||||||
|
chat_owner_id: UserId,
|
||||||
|
limit: u32,
|
||||||
|
max_distance: f32,
|
||||||
|
) -> Result<Vec<SearchChatMessagesResult>, AsyncError> {
|
||||||
|
Ok(self
|
||||||
|
.connection
|
||||||
|
.call(move |conn| {
|
||||||
|
let mut posts_stmt = conn.prepare(
|
||||||
|
"
|
||||||
|
WITH m AS (
|
||||||
|
SELECT id, distance FROM medias
|
||||||
|
WHERE media_hash match vec_normalize(?)
|
||||||
|
ORDER BY distance
|
||||||
|
LIMIT ?
|
||||||
|
)
|
||||||
|
SELECT DISTINCT
|
||||||
|
cp.chat_id,
|
||||||
|
cp.message_id,
|
||||||
|
cp.media_id,
|
||||||
|
m.distance
|
||||||
|
FROM chat_posts cp
|
||||||
|
LEFT JOIN channel_ownerships co ON cp.chat_id = co.chat_id
|
||||||
|
LEFT JOIN m ON cp.media_id = m.id
|
||||||
|
WHERE co.owner_user_id = ?
|
||||||
|
AND m.distance < ?
|
||||||
|
ORDER BY m.distance
|
||||||
|
",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let rows = posts_stmt.query_map(
|
||||||
|
params![media_hash, limit, chat_owner_id.0, max_distance],
|
||||||
|
|row| {
|
||||||
|
Ok(SearchChatMessagesResult {
|
||||||
|
chat_id: ChatId(row.get(0)?),
|
||||||
|
message_id: MessageId(row.get(1)?),
|
||||||
|
media_id: MediaId(row.get(2)?),
|
||||||
|
distance: row.get(3)?, // distance
|
||||||
|
})
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(rows.collect::<Result<Vec<_>, _>>()?)
|
||||||
|
})
|
||||||
|
.await?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_db(db_path: &str) -> Result<Db, AsyncError> {
|
||||||
|
// initialize sqlite_vec extension
|
||||||
|
unsafe {
|
||||||
|
sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ())));
|
||||||
|
}
|
||||||
|
|
||||||
|
let connection = Connection::open(db_path).await?;
|
||||||
|
migrate(&connection).await?;
|
||||||
|
Ok(Db { connection })
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn migrate(connection: &Connection) -> Result<(), AsyncError> {
|
||||||
|
connection
|
||||||
|
.call(|conn| {
|
||||||
|
conn.execute(
|
||||||
|
&format!(
|
||||||
|
"CREATE VIRTUAL TABLE IF NOT EXISTS medias USING vec0 (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
media_hash float32[{}]
|
||||||
|
)",
|
||||||
|
HASH_SIZE
|
||||||
|
),
|
||||||
|
[],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
concat!(
|
||||||
|
"CREATE TABLE IF NOT EXISTS chat_posts (
|
||||||
|
chat_id INTEGER NOT NULL,
|
||||||
|
message_id INTEGER NOT NULL,
|
||||||
|
media_id INTEGER NOT NULL,
|
||||||
|
created_at INTEGER NOT NULL
|
||||||
|
)"
|
||||||
|
),
|
||||||
|
[],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
concat!(
|
||||||
|
"CREATE INDEX IF NOT EXISTS channel_posts_idx_1 ",
|
||||||
|
"ON chat_posts (chat_id, media_id)"
|
||||||
|
),
|
||||||
|
[],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
concat!(
|
||||||
|
"CREATE INDEX IF NOT EXISTS channel_posts_idx_2 ",
|
||||||
|
"ON chat_posts (media_id, chat_id)"
|
||||||
|
),
|
||||||
|
[],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS channels (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
chat_id INTEGER NOT NULL
|
||||||
|
)",
|
||||||
|
[],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS channel_ownerships (
|
||||||
|
owner_user_id INTEGER NOT NULL,
|
||||||
|
chat_id INTEGER NOT NULL
|
||||||
|
)",
|
||||||
|
[],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
concat!(
|
||||||
|
"CREATE INDEX IF NOT EXISTS channel_ownerships_idx_1 ",
|
||||||
|
"ON channel_ownerships (owner_user_id, chat_id)"
|
||||||
|
),
|
||||||
|
[],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_insert_media_hash() -> Result<(), AsyncError> {
|
||||||
|
let db = get_db(":memory:").await?;
|
||||||
|
let media_hash =
|
||||||
|
MediaHash::from_file(Path::new("fixtures/237-536x354_low_quality.jpg")).await?;
|
||||||
|
let media_id = db
|
||||||
|
.insert_chat_message(ChatId(1), MessageId(1), media_hash)
|
||||||
|
.await?;
|
||||||
|
assert_eq!(media_id, MediaId(1));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_search_media_hashes() -> Result<(), AsyncError> {
|
||||||
|
let db = get_db(":memory:").await?;
|
||||||
|
db.insert_channel_ownership(UserId(1), ChatId(1)).await?;
|
||||||
|
db.insert_channel_ownership(UserId(1), ChatId(2)).await?;
|
||||||
|
|
||||||
|
let hash1 = MediaHash::from_file(Path::new("fixtures/237-536x354.jpg")).await?;
|
||||||
|
let hash2 = MediaHash::from_file(Path::new("fixtures/237-536x354_low_quality.jpg")).await?;
|
||||||
|
let hash3 =
|
||||||
|
MediaHash::from_file(Path::new("fixtures/237-536x354_missing_chunk.png")).await?;
|
||||||
|
let hash4 = MediaHash::from_file(Path::new("fixtures/866-536x354.jpg")).await?;
|
||||||
|
|
||||||
|
let media_ids = [
|
||||||
|
db.insert_chat_message(ChatId(1), MessageId(1), hash1.clone())
|
||||||
|
.await?,
|
||||||
|
db.insert_chat_message(ChatId(4), MessageId(1), hash1.clone())
|
||||||
|
.await?,
|
||||||
|
db.insert_chat_message(ChatId(2), MessageId(1), hash2.clone())
|
||||||
|
.await?,
|
||||||
|
db.insert_chat_message(ChatId(2), MessageId(2), hash3.clone())
|
||||||
|
.await?,
|
||||||
|
db.insert_chat_message(ChatId(2), MessageId(3), hash4.clone())
|
||||||
|
.await?,
|
||||||
|
];
|
||||||
|
|
||||||
|
let results = db
|
||||||
|
.search_chat_messages(hash1.clone(), UserId(1), 10, 1.0)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
println!("{:#?}", results);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
results.iter().map(|r| r.media_id).collect::<Vec<_>>(),
|
||||||
|
vec![media_ids[0], media_ids[2], media_ids[3]]
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
169
src/main.rs
Normal file
169
src/main.rs
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
mod async_error;
|
||||||
|
mod db;
|
||||||
|
mod media_hash;
|
||||||
|
mod sql_types;
|
||||||
|
|
||||||
|
use async_error::AsyncError;
|
||||||
|
use clap::Parser as _;
|
||||||
|
use db::{get_db, Db};
|
||||||
|
use media_hash::MediaHash;
|
||||||
|
use teloxide::{
|
||||||
|
dispatching::UpdateFilterExt,
|
||||||
|
net::Download as _,
|
||||||
|
prelude::*,
|
||||||
|
types::{Chat, ChatMemberStatus, MediaKind, MessageCommon, MessageKind},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(clap::Parser)]
|
||||||
|
struct CommandLineArgs {
|
||||||
|
/// Path to the database file
|
||||||
|
#[clap(short, long)]
|
||||||
|
db_path: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), AsyncError> {
|
||||||
|
let args = CommandLineArgs::parse();
|
||||||
|
|
||||||
|
pretty_env_logger::init();
|
||||||
|
log::info!("Starting bot");
|
||||||
|
|
||||||
|
// open up the database
|
||||||
|
let db = get_db(&args.db_path).await?;
|
||||||
|
db.print_info().await?;
|
||||||
|
|
||||||
|
let bot = Bot::from_env();
|
||||||
|
|
||||||
|
let handler = dptree::entry()
|
||||||
|
.branch(Update::filter_my_chat_member().endpoint(handle_bot_status_change))
|
||||||
|
.branch(Update::filter_channel_post().endpoint(handle_channel_message))
|
||||||
|
.branch(Update::filter_message().endpoint(handle_direct_message));
|
||||||
|
|
||||||
|
log::info!("Starting dispatcher");
|
||||||
|
Dispatcher::builder(bot, handler)
|
||||||
|
.enable_ctrlc_handler()
|
||||||
|
.dependencies(dptree::deps![db])
|
||||||
|
.build()
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn chat_name(chat: &Chat) -> &str {
|
||||||
|
chat.username().or(chat.title()).unwrap_or("<no name>")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_bot_status_change(updated: ChatMemberUpdated, db: Db) -> Result<(), AsyncError> {
|
||||||
|
let status = updated.new_chat_member.status();
|
||||||
|
let by_user = updated.from.id;
|
||||||
|
let chat_id = updated.chat.id;
|
||||||
|
log::info!(
|
||||||
|
"Bot status in Chat({}) changed by User({}) to {:?}",
|
||||||
|
chat_id,
|
||||||
|
by_user,
|
||||||
|
status
|
||||||
|
);
|
||||||
|
match status {
|
||||||
|
ChatMemberStatus::Administrator => {
|
||||||
|
log::info!(
|
||||||
|
"Bot is now an administrator of Chat({}, {})",
|
||||||
|
chat_id,
|
||||||
|
chat_name(&updated.chat)
|
||||||
|
);
|
||||||
|
db.insert_channel_ownership(by_user, chat_id).await?;
|
||||||
|
}
|
||||||
|
ChatMemberStatus::Left => {
|
||||||
|
log::info!(
|
||||||
|
"Bot is no longer in Chat({}, {})",
|
||||||
|
chat_id,
|
||||||
|
chat_name(&updated.chat)
|
||||||
|
);
|
||||||
|
db.delete_channel_ownership(by_user, chat_id).await?;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_direct_message(bot: Bot, message: Message, db: Db) -> Result<(), AsyncError> {
|
||||||
|
let from_user = message.from.clone().ok_or("No sender in direct message")?;
|
||||||
|
log::info!(
|
||||||
|
"Got direct message from User({}, {}): {}",
|
||||||
|
from_user.id,
|
||||||
|
from_user.username.as_deref().unwrap_or("<no username>"),
|
||||||
|
message.text().unwrap_or("<no text>")
|
||||||
|
);
|
||||||
|
|
||||||
|
let media_hash = download_message_photo(&bot, &message).await?;
|
||||||
|
if let Some(media_hash) = media_hash {
|
||||||
|
let results = db
|
||||||
|
.search_chat_messages(media_hash, from_user.id, 1, 1.0)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
log::info!("Found {} similar medias", results.len());
|
||||||
|
|
||||||
|
bot.send_message(
|
||||||
|
message.chat.id,
|
||||||
|
format!("Found {} similar medias", results.len()),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
for result in results {
|
||||||
|
log::info!(" {:?}", result);
|
||||||
|
|
||||||
|
bot.forward_message(message.chat.id, result.chat_id, result.message_id)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_channel_message(bot: Bot, message: Message, db: Db) -> Result<(), AsyncError> {
|
||||||
|
let chat_id = message.chat.id;
|
||||||
|
let media_hash = match download_message_photo(&bot, &message).await? {
|
||||||
|
Some(media_hash) => media_hash,
|
||||||
|
None => return Ok(()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let media_id = db
|
||||||
|
.insert_chat_message(chat_id, message.id, media_hash)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
log::info!("Inserted message: {:?}", media_id);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn download_message_photo(
|
||||||
|
bot: &Bot,
|
||||||
|
message: &Message,
|
||||||
|
) -> Result<Option<MediaHash>, AsyncError> {
|
||||||
|
let photo = match &message.kind {
|
||||||
|
MessageKind::Common(MessageCommon {
|
||||||
|
media_kind: MediaKind::Photo(photo),
|
||||||
|
..
|
||||||
|
}) => {
|
||||||
|
let photo = match photo.photo.iter().max_by_key(|p| p.file.size) {
|
||||||
|
Some(photo) => photo,
|
||||||
|
None => {
|
||||||
|
log::info!("No photo found in message");
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
photo
|
||||||
|
}
|
||||||
|
_ => return Ok(None),
|
||||||
|
};
|
||||||
|
|
||||||
|
log::info!("Downloading photo {}...", photo.file.id,);
|
||||||
|
let file_path = bot.get_file(&photo.file.id).await?.path;
|
||||||
|
let mut dst = Vec::new();
|
||||||
|
bot.download_file(&file_path, &mut dst).await?;
|
||||||
|
log::info!(
|
||||||
|
"Downloaded {}.",
|
||||||
|
humansize::format_size(dst.len(), humansize::BINARY)
|
||||||
|
);
|
||||||
|
Ok(Some(MediaHash::from_bytes(&dst)?))
|
||||||
|
}
|
||||||
80
src/media_hash.rs
Normal file
80
src/media_hash.rs
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
use std::{path::Path, rc::Rc};
|
||||||
|
|
||||||
|
use rusqlite::{
|
||||||
|
types::{FromSql, Value, ValueRef},
|
||||||
|
vtab::array::Array,
|
||||||
|
ToSql,
|
||||||
|
};
|
||||||
|
use tokio_rusqlite::types::ToSqlOutput;
|
||||||
|
use zerocopy::IntoBytes;
|
||||||
|
|
||||||
|
use crate::async_error::AsyncError;
|
||||||
|
|
||||||
|
pub const HASH_SIZE: usize = 32;
|
||||||
|
type ImageHash = image_hasher::ImageHash<[u8; HASH_SIZE]>;
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
|
pub struct MediaHash {
|
||||||
|
hash: ImageHash,
|
||||||
|
bytes: Vec<f32>,
|
||||||
|
}
|
||||||
|
impl MediaHash {
|
||||||
|
pub async fn from_file(path: &Path) -> Result<Self, AsyncError> {
|
||||||
|
let image_bytes = tokio::fs::read(path).await?;
|
||||||
|
Self::from_bytes(&image_bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_bytes(image_bytes: &[u8]) -> Result<Self, AsyncError> {
|
||||||
|
let image = image::load_from_memory(&image_bytes)?;
|
||||||
|
let num_bits = HASH_SIZE * 8;
|
||||||
|
let hash_size = (num_bits as f32).sqrt() as u32;
|
||||||
|
let hasher = image_hasher::HasherConfig::with_bytes_type::<[u8; HASH_SIZE]>()
|
||||||
|
.hash_size(hash_size, hash_size)
|
||||||
|
.preproc_dct()
|
||||||
|
.hash_alg(image_hasher::HashAlg::Blockhash)
|
||||||
|
.to_hasher();
|
||||||
|
let media_hash = hasher.hash_image(&image).try_into()?;
|
||||||
|
Ok(media_hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn distance(&self, other: &Self) -> u32 {
|
||||||
|
self.hash.dist(&other.hash)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToSql for MediaHash {
|
||||||
|
fn to_sql(&self) -> Result<ToSqlOutput<'_>, rusqlite::Error> {
|
||||||
|
Ok(ToSqlOutput::from(self.bytes.as_bytes()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<ImageHash> for MediaHash {
|
||||||
|
type Error = AsyncError;
|
||||||
|
|
||||||
|
fn try_from(hash: ImageHash) -> Result<Self, Self::Error> {
|
||||||
|
if hash.as_bytes().len() != HASH_SIZE {
|
||||||
|
return Err("Invalid hash size".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let bytes = hash
|
||||||
|
.as_bytes()
|
||||||
|
.iter()
|
||||||
|
.map(|b| *b as f32)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
Ok(MediaHash { hash, bytes })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_hash_media() -> Result<(), AsyncError> {
|
||||||
|
let hash1 = MediaHash::from_file(Path::new("fixtures/237-536x354_low_quality.jpg")).await?;
|
||||||
|
let hash2 = MediaHash::from_file(Path::new("fixtures/237-536x354.jpg")).await?;
|
||||||
|
assert_eq!(hash1, hash2);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
2
src/sql_types.rs
Normal file
2
src/sql_types.rs
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||||
|
pub struct MediaId(pub i64);
|
||||||
Reference in New Issue
Block a user