initial commit
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
/target
|
||||
.DS_Store
|
||||
*.sqlite3
|
||||
2527
Cargo.lock
generated
Normal file
2527
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
31
Cargo.toml
Normal file
31
Cargo.toml
Normal file
@@ -0,0 +1,31 @@
|
||||
[package]
|
||||
name = "is-this-a-repost-telegram-bot"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
chrono = "0.4.38"
|
||||
clap = { version = "4.5.21", features = ["derive"] }
|
||||
humansize = "2.1.3"
|
||||
image = { version = "0.25.5", default-features = false, features = [
|
||||
"webp",
|
||||
"jpeg",
|
||||
"gif",
|
||||
"bmp",
|
||||
"png",
|
||||
] }
|
||||
image_hasher = "2.0.0"
|
||||
log = "0.4.22"
|
||||
pretty_env_logger = "0.5.0"
|
||||
rusqlite = { version = "0.32.1", features = [
|
||||
"vtab",
|
||||
"modern_sqlite",
|
||||
"load_extension",
|
||||
"bundled",
|
||||
"array",
|
||||
] }
|
||||
sqlite-vec = "0.1.5"
|
||||
teloxide = { version = "0.13.0", features = ["macros"] }
|
||||
tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread"] }
|
||||
tokio-rusqlite = "0.6.0"
|
||||
zerocopy = "0.8.10"
|
||||
BIN
fixtures/1060-536x354-blur_2.jpg
Normal file
BIN
fixtures/1060-536x354-blur_2.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 17 KiB |
BIN
fixtures/1081-536x354.jpg
Normal file
BIN
fixtures/1081-536x354.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 36 KiB |
BIN
fixtures/237-536x354.jpg
Normal file
BIN
fixtures/237-536x354.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 35 KiB |
BIN
fixtures/237-536x354_low_quality.jpg
Normal file
BIN
fixtures/237-536x354_low_quality.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 20 KiB |
BIN
fixtures/237-536x354_missing_chunk.png
Normal file
BIN
fixtures/237-536x354_missing_chunk.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 226 KiB |
BIN
fixtures/866-536x354.jpg
Normal file
BIN
fixtures/866-536x354.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 14 KiB |
1
src/async_error.rs
Normal file
1
src/async_error.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub type AsyncError = Box<dyn std::error::Error + Send + Sync>;
|
||||
286
src/db.rs
Normal file
286
src/db.rs
Normal file
@@ -0,0 +1,286 @@
|
||||
use rusqlite::{ffi::sqlite3_auto_extension, params};
|
||||
use sqlite_vec::sqlite3_vec_init;
|
||||
use teloxide::types::{ChatId, MessageId, UserId};
|
||||
use tokio_rusqlite::Connection;
|
||||
|
||||
use crate::{
|
||||
async_error::AsyncError,
|
||||
media_hash::{MediaHash, HASH_SIZE},
|
||||
sql_types::MediaId,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Db {
|
||||
connection: Connection,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SearchChatMessagesResult {
|
||||
pub chat_id: ChatId,
|
||||
pub message_id: MessageId,
|
||||
pub media_id: MediaId,
|
||||
pub distance: f32,
|
||||
}
|
||||
|
||||
impl Db {
|
||||
pub async fn insert_chat_message(
|
||||
&self,
|
||||
chat_id: ChatId,
|
||||
message_id: MessageId,
|
||||
media_hash: MediaHash,
|
||||
) -> Result<MediaId, AsyncError> {
|
||||
Ok(self
|
||||
.connection
|
||||
.call(move |conn| {
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
conn.execute(
|
||||
"INSERT INTO medias (media_hash) VALUES (vec_normalize(?))",
|
||||
params![media_hash],
|
||||
)?;
|
||||
let media_id = conn.query_row("SELECT last_insert_rowid()", [], |row| {
|
||||
Ok(MediaId(row.get(0)?))
|
||||
})?;
|
||||
conn.execute(
|
||||
"INSERT INTO chat_posts (chat_id, message_id, media_id, created_at) VALUES (?, ?, ?, ?)",
|
||||
params![chat_id.0, message_id.0, media_id.0, now],
|
||||
)?;
|
||||
Ok(media_id)
|
||||
})
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub async fn insert_channel_ownership(
|
||||
&self,
|
||||
owner_user_id: UserId,
|
||||
chat_id: ChatId,
|
||||
) -> Result<(), AsyncError> {
|
||||
self.connection
|
||||
.call(move |conn| {
|
||||
Ok(conn.execute(
|
||||
"INSERT INTO channel_ownerships (owner_user_id, chat_id) VALUES (?, ?)",
|
||||
params![owner_user_id.0, chat_id.0],
|
||||
)?)
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete_channel_ownership(
|
||||
&self,
|
||||
owner_user_id: UserId,
|
||||
chat_id: ChatId,
|
||||
) -> Result<(), AsyncError> {
|
||||
self.connection
|
||||
.call(move |conn| {
|
||||
Ok(conn.execute(
|
||||
"DELETE FROM channel_ownerships WHERE owner_user_id = ? AND chat_id = ?",
|
||||
params![owner_user_id.0, chat_id.0],
|
||||
)?)
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn print_info(&self) -> Result<(), AsyncError> {
|
||||
let vec_version = self
|
||||
.connection
|
||||
.call(|conn| {
|
||||
let version =
|
||||
conn.query_row("select vec_version()", [], |row| row.get::<_, String>(0))?;
|
||||
Ok(version)
|
||||
})
|
||||
.await?;
|
||||
|
||||
log::info!("sqlite_vec version: {}", vec_version);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn search_chat_messages(
|
||||
&self,
|
||||
media_hash: MediaHash,
|
||||
chat_owner_id: UserId,
|
||||
limit: u32,
|
||||
max_distance: f32,
|
||||
) -> Result<Vec<SearchChatMessagesResult>, AsyncError> {
|
||||
Ok(self
|
||||
.connection
|
||||
.call(move |conn| {
|
||||
let mut posts_stmt = conn.prepare(
|
||||
"
|
||||
WITH m AS (
|
||||
SELECT id, distance FROM medias
|
||||
WHERE media_hash match vec_normalize(?)
|
||||
ORDER BY distance
|
||||
LIMIT ?
|
||||
)
|
||||
SELECT DISTINCT
|
||||
cp.chat_id,
|
||||
cp.message_id,
|
||||
cp.media_id,
|
||||
m.distance
|
||||
FROM chat_posts cp
|
||||
LEFT JOIN channel_ownerships co ON cp.chat_id = co.chat_id
|
||||
LEFT JOIN m ON cp.media_id = m.id
|
||||
WHERE co.owner_user_id = ?
|
||||
AND m.distance < ?
|
||||
ORDER BY m.distance
|
||||
",
|
||||
)?;
|
||||
|
||||
let rows = posts_stmt.query_map(
|
||||
params![media_hash, limit, chat_owner_id.0, max_distance],
|
||||
|row| {
|
||||
Ok(SearchChatMessagesResult {
|
||||
chat_id: ChatId(row.get(0)?),
|
||||
message_id: MessageId(row.get(1)?),
|
||||
media_id: MediaId(row.get(2)?),
|
||||
distance: row.get(3)?, // distance
|
||||
})
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(rows.collect::<Result<Vec<_>, _>>()?)
|
||||
})
|
||||
.await?)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_db(db_path: &str) -> Result<Db, AsyncError> {
|
||||
// initialize sqlite_vec extension
|
||||
unsafe {
|
||||
sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ())));
|
||||
}
|
||||
|
||||
let connection = Connection::open(db_path).await?;
|
||||
migrate(&connection).await?;
|
||||
Ok(Db { connection })
|
||||
}
|
||||
|
||||
async fn migrate(connection: &Connection) -> Result<(), AsyncError> {
|
||||
connection
|
||||
.call(|conn| {
|
||||
conn.execute(
|
||||
&format!(
|
||||
"CREATE VIRTUAL TABLE IF NOT EXISTS medias USING vec0 (
|
||||
id INTEGER PRIMARY KEY,
|
||||
media_hash float32[{}]
|
||||
)",
|
||||
HASH_SIZE
|
||||
),
|
||||
[],
|
||||
)?;
|
||||
|
||||
conn.execute(
|
||||
concat!(
|
||||
"CREATE TABLE IF NOT EXISTS chat_posts (
|
||||
chat_id INTEGER NOT NULL,
|
||||
message_id INTEGER NOT NULL,
|
||||
media_id INTEGER NOT NULL,
|
||||
created_at INTEGER NOT NULL
|
||||
)"
|
||||
),
|
||||
[],
|
||||
)?;
|
||||
|
||||
conn.execute(
|
||||
concat!(
|
||||
"CREATE INDEX IF NOT EXISTS channel_posts_idx_1 ",
|
||||
"ON chat_posts (chat_id, media_id)"
|
||||
),
|
||||
[],
|
||||
)?;
|
||||
|
||||
conn.execute(
|
||||
concat!(
|
||||
"CREATE INDEX IF NOT EXISTS channel_posts_idx_2 ",
|
||||
"ON chat_posts (media_id, chat_id)"
|
||||
),
|
||||
[],
|
||||
)?;
|
||||
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS channels (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
chat_id INTEGER NOT NULL
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS channel_ownerships (
|
||||
owner_user_id INTEGER NOT NULL,
|
||||
chat_id INTEGER NOT NULL
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
conn.execute(
|
||||
concat!(
|
||||
"CREATE INDEX IF NOT EXISTS channel_ownerships_idx_1 ",
|
||||
"ON channel_ownerships (owner_user_id, chat_id)"
|
||||
),
|
||||
[],
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::Path;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_insert_media_hash() -> Result<(), AsyncError> {
|
||||
let db = get_db(":memory:").await?;
|
||||
let media_hash =
|
||||
MediaHash::from_file(Path::new("fixtures/237-536x354_low_quality.jpg")).await?;
|
||||
let media_id = db
|
||||
.insert_chat_message(ChatId(1), MessageId(1), media_hash)
|
||||
.await?;
|
||||
assert_eq!(media_id, MediaId(1));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search_media_hashes() -> Result<(), AsyncError> {
|
||||
let db = get_db(":memory:").await?;
|
||||
db.insert_channel_ownership(UserId(1), ChatId(1)).await?;
|
||||
db.insert_channel_ownership(UserId(1), ChatId(2)).await?;
|
||||
|
||||
let hash1 = MediaHash::from_file(Path::new("fixtures/237-536x354.jpg")).await?;
|
||||
let hash2 = MediaHash::from_file(Path::new("fixtures/237-536x354_low_quality.jpg")).await?;
|
||||
let hash3 =
|
||||
MediaHash::from_file(Path::new("fixtures/237-536x354_missing_chunk.png")).await?;
|
||||
let hash4 = MediaHash::from_file(Path::new("fixtures/866-536x354.jpg")).await?;
|
||||
|
||||
let media_ids = [
|
||||
db.insert_chat_message(ChatId(1), MessageId(1), hash1.clone())
|
||||
.await?,
|
||||
db.insert_chat_message(ChatId(4), MessageId(1), hash1.clone())
|
||||
.await?,
|
||||
db.insert_chat_message(ChatId(2), MessageId(1), hash2.clone())
|
||||
.await?,
|
||||
db.insert_chat_message(ChatId(2), MessageId(2), hash3.clone())
|
||||
.await?,
|
||||
db.insert_chat_message(ChatId(2), MessageId(3), hash4.clone())
|
||||
.await?,
|
||||
];
|
||||
|
||||
let results = db
|
||||
.search_chat_messages(hash1.clone(), UserId(1), 10, 1.0)
|
||||
.await?;
|
||||
|
||||
println!("{:#?}", results);
|
||||
|
||||
assert_eq!(
|
||||
results.iter().map(|r| r.media_id).collect::<Vec<_>>(),
|
||||
vec![media_ids[0], media_ids[2], media_ids[3]]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
169
src/main.rs
Normal file
169
src/main.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
mod async_error;
|
||||
mod db;
|
||||
mod media_hash;
|
||||
mod sql_types;
|
||||
|
||||
use async_error::AsyncError;
|
||||
use clap::Parser as _;
|
||||
use db::{get_db, Db};
|
||||
use media_hash::MediaHash;
|
||||
use teloxide::{
|
||||
dispatching::UpdateFilterExt,
|
||||
net::Download as _,
|
||||
prelude::*,
|
||||
types::{Chat, ChatMemberStatus, MediaKind, MessageCommon, MessageKind},
|
||||
};
|
||||
|
||||
#[derive(clap::Parser)]
|
||||
struct CommandLineArgs {
|
||||
/// Path to the database file
|
||||
#[clap(short, long)]
|
||||
db_path: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), AsyncError> {
|
||||
let args = CommandLineArgs::parse();
|
||||
|
||||
pretty_env_logger::init();
|
||||
log::info!("Starting bot");
|
||||
|
||||
// open up the database
|
||||
let db = get_db(&args.db_path).await?;
|
||||
db.print_info().await?;
|
||||
|
||||
let bot = Bot::from_env();
|
||||
|
||||
let handler = dptree::entry()
|
||||
.branch(Update::filter_my_chat_member().endpoint(handle_bot_status_change))
|
||||
.branch(Update::filter_channel_post().endpoint(handle_channel_message))
|
||||
.branch(Update::filter_message().endpoint(handle_direct_message));
|
||||
|
||||
log::info!("Starting dispatcher");
|
||||
Dispatcher::builder(bot, handler)
|
||||
.enable_ctrlc_handler()
|
||||
.dependencies(dptree::deps![db])
|
||||
.build()
|
||||
.dispatch()
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn chat_name(chat: &Chat) -> &str {
|
||||
chat.username().or(chat.title()).unwrap_or("<no name>")
|
||||
}
|
||||
|
||||
async fn handle_bot_status_change(updated: ChatMemberUpdated, db: Db) -> Result<(), AsyncError> {
|
||||
let status = updated.new_chat_member.status();
|
||||
let by_user = updated.from.id;
|
||||
let chat_id = updated.chat.id;
|
||||
log::info!(
|
||||
"Bot status in Chat({}) changed by User({}) to {:?}",
|
||||
chat_id,
|
||||
by_user,
|
||||
status
|
||||
);
|
||||
match status {
|
||||
ChatMemberStatus::Administrator => {
|
||||
log::info!(
|
||||
"Bot is now an administrator of Chat({}, {})",
|
||||
chat_id,
|
||||
chat_name(&updated.chat)
|
||||
);
|
||||
db.insert_channel_ownership(by_user, chat_id).await?;
|
||||
}
|
||||
ChatMemberStatus::Left => {
|
||||
log::info!(
|
||||
"Bot is no longer in Chat({}, {})",
|
||||
chat_id,
|
||||
chat_name(&updated.chat)
|
||||
);
|
||||
db.delete_channel_ownership(by_user, chat_id).await?;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_direct_message(bot: Bot, message: Message, db: Db) -> Result<(), AsyncError> {
|
||||
let from_user = message.from.clone().ok_or("No sender in direct message")?;
|
||||
log::info!(
|
||||
"Got direct message from User({}, {}): {}",
|
||||
from_user.id,
|
||||
from_user.username.as_deref().unwrap_or("<no username>"),
|
||||
message.text().unwrap_or("<no text>")
|
||||
);
|
||||
|
||||
let media_hash = download_message_photo(&bot, &message).await?;
|
||||
if let Some(media_hash) = media_hash {
|
||||
let results = db
|
||||
.search_chat_messages(media_hash, from_user.id, 1, 1.0)
|
||||
.await?;
|
||||
|
||||
log::info!("Found {} similar medias", results.len());
|
||||
|
||||
bot.send_message(
|
||||
message.chat.id,
|
||||
format!("Found {} similar medias", results.len()),
|
||||
)
|
||||
.await?;
|
||||
|
||||
for result in results {
|
||||
log::info!(" {:?}", result);
|
||||
|
||||
bot.forward_message(message.chat.id, result.chat_id, result.message_id)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_channel_message(bot: Bot, message: Message, db: Db) -> Result<(), AsyncError> {
|
||||
let chat_id = message.chat.id;
|
||||
let media_hash = match download_message_photo(&bot, &message).await? {
|
||||
Some(media_hash) => media_hash,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
let media_id = db
|
||||
.insert_chat_message(chat_id, message.id, media_hash)
|
||||
.await?;
|
||||
|
||||
log::info!("Inserted message: {:?}", media_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn download_message_photo(
|
||||
bot: &Bot,
|
||||
message: &Message,
|
||||
) -> Result<Option<MediaHash>, AsyncError> {
|
||||
let photo = match &message.kind {
|
||||
MessageKind::Common(MessageCommon {
|
||||
media_kind: MediaKind::Photo(photo),
|
||||
..
|
||||
}) => {
|
||||
let photo = match photo.photo.iter().max_by_key(|p| p.file.size) {
|
||||
Some(photo) => photo,
|
||||
None => {
|
||||
log::info!("No photo found in message");
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
photo
|
||||
}
|
||||
_ => return Ok(None),
|
||||
};
|
||||
|
||||
log::info!("Downloading photo {}...", photo.file.id,);
|
||||
let file_path = bot.get_file(&photo.file.id).await?.path;
|
||||
let mut dst = Vec::new();
|
||||
bot.download_file(&file_path, &mut dst).await?;
|
||||
log::info!(
|
||||
"Downloaded {}.",
|
||||
humansize::format_size(dst.len(), humansize::BINARY)
|
||||
);
|
||||
Ok(Some(MediaHash::from_bytes(&dst)?))
|
||||
}
|
||||
80
src/media_hash.rs
Normal file
80
src/media_hash.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
use std::{path::Path, rc::Rc};
|
||||
|
||||
use rusqlite::{
|
||||
types::{FromSql, Value, ValueRef},
|
||||
vtab::array::Array,
|
||||
ToSql,
|
||||
};
|
||||
use tokio_rusqlite::types::ToSqlOutput;
|
||||
use zerocopy::IntoBytes;
|
||||
|
||||
use crate::async_error::AsyncError;
|
||||
|
||||
pub const HASH_SIZE: usize = 32;
|
||||
type ImageHash = image_hasher::ImageHash<[u8; HASH_SIZE]>;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct MediaHash {
|
||||
hash: ImageHash,
|
||||
bytes: Vec<f32>,
|
||||
}
|
||||
impl MediaHash {
|
||||
pub async fn from_file(path: &Path) -> Result<Self, AsyncError> {
|
||||
let image_bytes = tokio::fs::read(path).await?;
|
||||
Self::from_bytes(&image_bytes)
|
||||
}
|
||||
|
||||
pub fn from_bytes(image_bytes: &[u8]) -> Result<Self, AsyncError> {
|
||||
let image = image::load_from_memory(&image_bytes)?;
|
||||
let num_bits = HASH_SIZE * 8;
|
||||
let hash_size = (num_bits as f32).sqrt() as u32;
|
||||
let hasher = image_hasher::HasherConfig::with_bytes_type::<[u8; HASH_SIZE]>()
|
||||
.hash_size(hash_size, hash_size)
|
||||
.preproc_dct()
|
||||
.hash_alg(image_hasher::HashAlg::Blockhash)
|
||||
.to_hasher();
|
||||
let media_hash = hasher.hash_image(&image).try_into()?;
|
||||
Ok(media_hash)
|
||||
}
|
||||
|
||||
pub fn distance(&self, other: &Self) -> u32 {
|
||||
self.hash.dist(&other.hash)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for MediaHash {
|
||||
fn to_sql(&self) -> Result<ToSqlOutput<'_>, rusqlite::Error> {
|
||||
Ok(ToSqlOutput::from(self.bytes.as_bytes()))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ImageHash> for MediaHash {
|
||||
type Error = AsyncError;
|
||||
|
||||
fn try_from(hash: ImageHash) -> Result<Self, Self::Error> {
|
||||
if hash.as_bytes().len() != HASH_SIZE {
|
||||
return Err("Invalid hash size".into());
|
||||
}
|
||||
|
||||
let bytes = hash
|
||||
.as_bytes()
|
||||
.iter()
|
||||
.map(|b| *b as f32)
|
||||
.collect::<Vec<_>>();
|
||||
Ok(MediaHash { hash, bytes })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hash_media() -> Result<(), AsyncError> {
|
||||
let hash1 = MediaHash::from_file(Path::new("fixtures/237-536x354_low_quality.jpg")).await?;
|
||||
let hash2 = MediaHash::from_file(Path::new("fixtures/237-536x354.jpg")).await?;
|
||||
assert_eq!(hash1, hash2);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
2
src/sql_types.rs
Normal file
2
src/sql_types.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
pub struct MediaId(pub i64);
|
||||
Reference in New Issue
Block a user