initial commit

This commit is contained in:
Dylan Knutson
2024-11-22 00:00:07 -08:00
commit e48651612e
14 changed files with 3099 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
/target
.DS_Store
*.sqlite3

2527
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

31
Cargo.toml Normal file
View File

@@ -0,0 +1,31 @@
[package]
name = "is-this-a-repost-telegram-bot"
version = "0.1.0"
edition = "2021"
[dependencies]
chrono = "0.4.38"
clap = { version = "4.5.21", features = ["derive"] }
humansize = "2.1.3"
image = { version = "0.25.5", default-features = false, features = [
"webp",
"jpeg",
"gif",
"bmp",
"png",
] }
image_hasher = "2.0.0"
log = "0.4.22"
pretty_env_logger = "0.5.0"
rusqlite = { version = "0.32.1", features = [
"vtab",
"modern_sqlite",
"load_extension",
"bundled",
"array",
] }
sqlite-vec = "0.1.5"
teloxide = { version = "0.13.0", features = ["macros"] }
tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread"] }
tokio-rusqlite = "0.6.0"
zerocopy = "0.8.10"

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

BIN
fixtures/1081-536x354.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

BIN
fixtures/237-536x354.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 226 KiB

BIN
fixtures/866-536x354.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

1
src/async_error.rs Normal file
View File

@@ -0,0 +1 @@
pub type AsyncError = Box<dyn std::error::Error + Send + Sync>;

286
src/db.rs Normal file
View File

@@ -0,0 +1,286 @@
use rusqlite::{ffi::sqlite3_auto_extension, params};
use sqlite_vec::sqlite3_vec_init;
use teloxide::types::{ChatId, MessageId, UserId};
use tokio_rusqlite::Connection;
use crate::{
async_error::AsyncError,
media_hash::{MediaHash, HASH_SIZE},
sql_types::MediaId,
};
#[derive(Clone)]
pub struct Db {
connection: Connection,
}
#[derive(Debug)]
pub struct SearchChatMessagesResult {
pub chat_id: ChatId,
pub message_id: MessageId,
pub media_id: MediaId,
pub distance: f32,
}
impl Db {
pub async fn insert_chat_message(
&self,
chat_id: ChatId,
message_id: MessageId,
media_hash: MediaHash,
) -> Result<MediaId, AsyncError> {
Ok(self
.connection
.call(move |conn| {
let now = chrono::Utc::now().timestamp();
conn.execute(
"INSERT INTO medias (media_hash) VALUES (vec_normalize(?))",
params![media_hash],
)?;
let media_id = conn.query_row("SELECT last_insert_rowid()", [], |row| {
Ok(MediaId(row.get(0)?))
})?;
conn.execute(
"INSERT INTO chat_posts (chat_id, message_id, media_id, created_at) VALUES (?, ?, ?, ?)",
params![chat_id.0, message_id.0, media_id.0, now],
)?;
Ok(media_id)
})
.await?)
}
pub async fn insert_channel_ownership(
&self,
owner_user_id: UserId,
chat_id: ChatId,
) -> Result<(), AsyncError> {
self.connection
.call(move |conn| {
Ok(conn.execute(
"INSERT INTO channel_ownerships (owner_user_id, chat_id) VALUES (?, ?)",
params![owner_user_id.0, chat_id.0],
)?)
})
.await?;
Ok(())
}
pub async fn delete_channel_ownership(
&self,
owner_user_id: UserId,
chat_id: ChatId,
) -> Result<(), AsyncError> {
self.connection
.call(move |conn| {
Ok(conn.execute(
"DELETE FROM channel_ownerships WHERE owner_user_id = ? AND chat_id = ?",
params![owner_user_id.0, chat_id.0],
)?)
})
.await?;
Ok(())
}
pub async fn print_info(&self) -> Result<(), AsyncError> {
let vec_version = self
.connection
.call(|conn| {
let version =
conn.query_row("select vec_version()", [], |row| row.get::<_, String>(0))?;
Ok(version)
})
.await?;
log::info!("sqlite_vec version: {}", vec_version);
Ok(())
}
pub async fn search_chat_messages(
&self,
media_hash: MediaHash,
chat_owner_id: UserId,
limit: u32,
max_distance: f32,
) -> Result<Vec<SearchChatMessagesResult>, AsyncError> {
Ok(self
.connection
.call(move |conn| {
let mut posts_stmt = conn.prepare(
"
WITH m AS (
SELECT id, distance FROM medias
WHERE media_hash match vec_normalize(?)
ORDER BY distance
LIMIT ?
)
SELECT DISTINCT
cp.chat_id,
cp.message_id,
cp.media_id,
m.distance
FROM chat_posts cp
LEFT JOIN channel_ownerships co ON cp.chat_id = co.chat_id
LEFT JOIN m ON cp.media_id = m.id
WHERE co.owner_user_id = ?
AND m.distance < ?
ORDER BY m.distance
",
)?;
let rows = posts_stmt.query_map(
params![media_hash, limit, chat_owner_id.0, max_distance],
|row| {
Ok(SearchChatMessagesResult {
chat_id: ChatId(row.get(0)?),
message_id: MessageId(row.get(1)?),
media_id: MediaId(row.get(2)?),
distance: row.get(3)?, // distance
})
},
)?;
Ok(rows.collect::<Result<Vec<_>, _>>()?)
})
.await?)
}
}
pub async fn get_db(db_path: &str) -> Result<Db, AsyncError> {
// initialize sqlite_vec extension
unsafe {
sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ())));
}
let connection = Connection::open(db_path).await?;
migrate(&connection).await?;
Ok(Db { connection })
}
async fn migrate(connection: &Connection) -> Result<(), AsyncError> {
connection
.call(|conn| {
conn.execute(
&format!(
"CREATE VIRTUAL TABLE IF NOT EXISTS medias USING vec0 (
id INTEGER PRIMARY KEY,
media_hash float32[{}]
)",
HASH_SIZE
),
[],
)?;
conn.execute(
concat!(
"CREATE TABLE IF NOT EXISTS chat_posts (
chat_id INTEGER NOT NULL,
message_id INTEGER NOT NULL,
media_id INTEGER NOT NULL,
created_at INTEGER NOT NULL
)"
),
[],
)?;
conn.execute(
concat!(
"CREATE INDEX IF NOT EXISTS channel_posts_idx_1 ",
"ON chat_posts (chat_id, media_id)"
),
[],
)?;
conn.execute(
concat!(
"CREATE INDEX IF NOT EXISTS channel_posts_idx_2 ",
"ON chat_posts (media_id, chat_id)"
),
[],
)?;
conn.execute(
"CREATE TABLE IF NOT EXISTS channels (
id INTEGER PRIMARY KEY AUTOINCREMENT,
chat_id INTEGER NOT NULL
)",
[],
)?;
conn.execute(
"CREATE TABLE IF NOT EXISTS channel_ownerships (
owner_user_id INTEGER NOT NULL,
chat_id INTEGER NOT NULL
)",
[],
)?;
conn.execute(
concat!(
"CREATE INDEX IF NOT EXISTS channel_ownerships_idx_1 ",
"ON channel_ownerships (owner_user_id, chat_id)"
),
[],
)?;
Ok(())
})
.await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::Path;
#[tokio::test]
async fn test_insert_media_hash() -> Result<(), AsyncError> {
let db = get_db(":memory:").await?;
let media_hash =
MediaHash::from_file(Path::new("fixtures/237-536x354_low_quality.jpg")).await?;
let media_id = db
.insert_chat_message(ChatId(1), MessageId(1), media_hash)
.await?;
assert_eq!(media_id, MediaId(1));
Ok(())
}
#[tokio::test]
async fn test_search_media_hashes() -> Result<(), AsyncError> {
let db = get_db(":memory:").await?;
db.insert_channel_ownership(UserId(1), ChatId(1)).await?;
db.insert_channel_ownership(UserId(1), ChatId(2)).await?;
let hash1 = MediaHash::from_file(Path::new("fixtures/237-536x354.jpg")).await?;
let hash2 = MediaHash::from_file(Path::new("fixtures/237-536x354_low_quality.jpg")).await?;
let hash3 =
MediaHash::from_file(Path::new("fixtures/237-536x354_missing_chunk.png")).await?;
let hash4 = MediaHash::from_file(Path::new("fixtures/866-536x354.jpg")).await?;
let media_ids = [
db.insert_chat_message(ChatId(1), MessageId(1), hash1.clone())
.await?,
db.insert_chat_message(ChatId(4), MessageId(1), hash1.clone())
.await?,
db.insert_chat_message(ChatId(2), MessageId(1), hash2.clone())
.await?,
db.insert_chat_message(ChatId(2), MessageId(2), hash3.clone())
.await?,
db.insert_chat_message(ChatId(2), MessageId(3), hash4.clone())
.await?,
];
let results = db
.search_chat_messages(hash1.clone(), UserId(1), 10, 1.0)
.await?;
println!("{:#?}", results);
assert_eq!(
results.iter().map(|r| r.media_id).collect::<Vec<_>>(),
vec![media_ids[0], media_ids[2], media_ids[3]]
);
Ok(())
}
}

169
src/main.rs Normal file
View File

@@ -0,0 +1,169 @@
mod async_error;
mod db;
mod media_hash;
mod sql_types;
use async_error::AsyncError;
use clap::Parser as _;
use db::{get_db, Db};
use media_hash::MediaHash;
use teloxide::{
dispatching::UpdateFilterExt,
net::Download as _,
prelude::*,
types::{Chat, ChatMemberStatus, MediaKind, MessageCommon, MessageKind},
};
#[derive(clap::Parser)]
struct CommandLineArgs {
/// Path to the database file
#[clap(short, long)]
db_path: String,
}
#[tokio::main]
async fn main() -> Result<(), AsyncError> {
let args = CommandLineArgs::parse();
pretty_env_logger::init();
log::info!("Starting bot");
// open up the database
let db = get_db(&args.db_path).await?;
db.print_info().await?;
let bot = Bot::from_env();
let handler = dptree::entry()
.branch(Update::filter_my_chat_member().endpoint(handle_bot_status_change))
.branch(Update::filter_channel_post().endpoint(handle_channel_message))
.branch(Update::filter_message().endpoint(handle_direct_message));
log::info!("Starting dispatcher");
Dispatcher::builder(bot, handler)
.enable_ctrlc_handler()
.dependencies(dptree::deps![db])
.build()
.dispatch()
.await;
Ok(())
}
fn chat_name(chat: &Chat) -> &str {
chat.username().or(chat.title()).unwrap_or("<no name>")
}
async fn handle_bot_status_change(updated: ChatMemberUpdated, db: Db) -> Result<(), AsyncError> {
let status = updated.new_chat_member.status();
let by_user = updated.from.id;
let chat_id = updated.chat.id;
log::info!(
"Bot status in Chat({}) changed by User({}) to {:?}",
chat_id,
by_user,
status
);
match status {
ChatMemberStatus::Administrator => {
log::info!(
"Bot is now an administrator of Chat({}, {})",
chat_id,
chat_name(&updated.chat)
);
db.insert_channel_ownership(by_user, chat_id).await?;
}
ChatMemberStatus::Left => {
log::info!(
"Bot is no longer in Chat({}, {})",
chat_id,
chat_name(&updated.chat)
);
db.delete_channel_ownership(by_user, chat_id).await?;
}
_ => {}
}
Ok(())
}
async fn handle_direct_message(bot: Bot, message: Message, db: Db) -> Result<(), AsyncError> {
let from_user = message.from.clone().ok_or("No sender in direct message")?;
log::info!(
"Got direct message from User({}, {}): {}",
from_user.id,
from_user.username.as_deref().unwrap_or("<no username>"),
message.text().unwrap_or("<no text>")
);
let media_hash = download_message_photo(&bot, &message).await?;
if let Some(media_hash) = media_hash {
let results = db
.search_chat_messages(media_hash, from_user.id, 1, 1.0)
.await?;
log::info!("Found {} similar medias", results.len());
bot.send_message(
message.chat.id,
format!("Found {} similar medias", results.len()),
)
.await?;
for result in results {
log::info!(" {:?}", result);
bot.forward_message(message.chat.id, result.chat_id, result.message_id)
.await?;
}
}
Ok(())
}
async fn handle_channel_message(bot: Bot, message: Message, db: Db) -> Result<(), AsyncError> {
let chat_id = message.chat.id;
let media_hash = match download_message_photo(&bot, &message).await? {
Some(media_hash) => media_hash,
None => return Ok(()),
};
let media_id = db
.insert_chat_message(chat_id, message.id, media_hash)
.await?;
log::info!("Inserted message: {:?}", media_id);
Ok(())
}
async fn download_message_photo(
bot: &Bot,
message: &Message,
) -> Result<Option<MediaHash>, AsyncError> {
let photo = match &message.kind {
MessageKind::Common(MessageCommon {
media_kind: MediaKind::Photo(photo),
..
}) => {
let photo = match photo.photo.iter().max_by_key(|p| p.file.size) {
Some(photo) => photo,
None => {
log::info!("No photo found in message");
return Ok(None);
}
};
photo
}
_ => return Ok(None),
};
log::info!("Downloading photo {}...", photo.file.id,);
let file_path = bot.get_file(&photo.file.id).await?.path;
let mut dst = Vec::new();
bot.download_file(&file_path, &mut dst).await?;
log::info!(
"Downloaded {}.",
humansize::format_size(dst.len(), humansize::BINARY)
);
Ok(Some(MediaHash::from_bytes(&dst)?))
}

80
src/media_hash.rs Normal file
View File

@@ -0,0 +1,80 @@
use std::{path::Path, rc::Rc};
use rusqlite::{
types::{FromSql, Value, ValueRef},
vtab::array::Array,
ToSql,
};
use tokio_rusqlite::types::ToSqlOutput;
use zerocopy::IntoBytes;
use crate::async_error::AsyncError;
pub const HASH_SIZE: usize = 32;
type ImageHash = image_hasher::ImageHash<[u8; HASH_SIZE]>;
#[derive(Debug, PartialEq, Clone)]
pub struct MediaHash {
hash: ImageHash,
bytes: Vec<f32>,
}
impl MediaHash {
pub async fn from_file(path: &Path) -> Result<Self, AsyncError> {
let image_bytes = tokio::fs::read(path).await?;
Self::from_bytes(&image_bytes)
}
pub fn from_bytes(image_bytes: &[u8]) -> Result<Self, AsyncError> {
let image = image::load_from_memory(&image_bytes)?;
let num_bits = HASH_SIZE * 8;
let hash_size = (num_bits as f32).sqrt() as u32;
let hasher = image_hasher::HasherConfig::with_bytes_type::<[u8; HASH_SIZE]>()
.hash_size(hash_size, hash_size)
.preproc_dct()
.hash_alg(image_hasher::HashAlg::Blockhash)
.to_hasher();
let media_hash = hasher.hash_image(&image).try_into()?;
Ok(media_hash)
}
pub fn distance(&self, other: &Self) -> u32 {
self.hash.dist(&other.hash)
}
}
impl ToSql for MediaHash {
fn to_sql(&self) -> Result<ToSqlOutput<'_>, rusqlite::Error> {
Ok(ToSqlOutput::from(self.bytes.as_bytes()))
}
}
impl TryFrom<ImageHash> for MediaHash {
type Error = AsyncError;
fn try_from(hash: ImageHash) -> Result<Self, Self::Error> {
if hash.as_bytes().len() != HASH_SIZE {
return Err("Invalid hash size".into());
}
let bytes = hash
.as_bytes()
.iter()
.map(|b| *b as f32)
.collect::<Vec<_>>();
Ok(MediaHash { hash, bytes })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_hash_media() -> Result<(), AsyncError> {
let hash1 = MediaHash::from_file(Path::new("fixtures/237-536x354_low_quality.jpg")).await?;
let hash2 = MediaHash::from_file(Path::new("fixtures/237-536x354.jpg")).await?;
assert_eq!(hash1, hash2);
Ok(())
}
}

2
src/sql_types.rs Normal file
View File

@@ -0,0 +1,2 @@
#[derive(Debug, PartialEq, Copy, Clone)]
pub struct MediaId(pub i64);