split common visual search logic out

This commit is contained in:
Dylan Knutson
2025-08-14 19:11:13 +00:00
parent 90d2cce076
commit 2a8d631b29
4 changed files with 173 additions and 112 deletions

View File

@@ -147,21 +147,33 @@ class Domain::PostsController < DomainController
authorize Domain::Post
# Process the uploaded image or URL
image_result = process_image_input
return unless image_result
image_path, content_type = image_result
file_result = process_image_input
return unless file_result
file_path, content_type = file_result
# Create thumbnail for the view if possible
@uploaded_image_data_uri = create_thumbnail(image_path, content_type)
@uploaded_hash_value = generate_fingerprint(image_path)
@uploaded_detail_hash_value = generate_detail_fingerprint(image_path)
tmp_dir = Dir.mktmpdir("visual-search")
thumbs_and_fingerprints =
helpers.generate_fingerprints(file_path, content_type, tmp_dir)
first_thumb_and_fingerprint = thumbs_and_fingerprints&.first
if thumbs_and_fingerprints.nil? || first_thumb_and_fingerprint.nil?
flash.now[:error] = "Error generating fingerprints"
render :visual_search
return
end
logger.info("generated #{thumbs_and_fingerprints.length} thumbs")
@uploaded_image_data_uri =
helpers.create_image_thumbnail_data_uri(
first_thumb_and_fingerprint.thumb_path,
"image/jpeg",
)
@uploaded_detail_hash_value = first_thumb_and_fingerprint.detail_fingerprint
before = Time.now
similar_fingerprints =
helpers.find_similar_fingerprints(
fingerprint_value: @uploaded_hash_value,
fingerprint_detail_value: @uploaded_detail_hash_value,
).take(10)
helpers.find_similar_fingerprints(thumbs_and_fingerprints).take(10)
@time_taken = Time.now - before
@matches = similar_fingerprints
@@ -173,10 +185,7 @@ class Domain::PostsController < DomainController
@matches = @good_matches if @good_matches.any?
ensure
# Clean up any temporary files
if @temp_file
@temp_file.unlink
@temp_file = nil
end
FileUtils.rm_rf(tmp_dir) if tmp_dir
end
private
@@ -240,27 +249,6 @@ class Domain::PostsController < DomainController
nil
end
# Create a thumbnail from the image and return the data URI
sig do
params(image_path: String, content_type: String).returns(T.nilable(String))
end
def create_thumbnail(image_path, content_type)
helpers.create_image_thumbnail_data_uri(image_path, content_type)
end
# Generate a fingerprint from the image path
sig { params(image_path: String).returns(String) }
def generate_fingerprint(image_path)
# Use the new from_file_path method to create a fingerprint
Domain::PostFile::BitFingerprint.from_file_path(image_path)
end
# Generate a detail fingerprint from the image path
sig { params(image_path: String).returns(String) }
def generate_detail_fingerprint(image_path)
Domain::PostFile::BitFingerprint.detail_from_file_path(image_path)
end
sig { override.returns(DomainController::DomainParamConfig) }
def self.param_config
DomainController::DomainParamConfig.new(

View File

@@ -76,42 +76,97 @@ module Domain
# Find similar images based on the fingerprint
sig do
params(
fingerprint_value: String,
fingerprint_detail_value: String,
fingerprints: T::Array[GenerateFingerprintsResult],
limit: Integer,
oversearch: Integer,
includes: T.untyped,
).returns(T::Array[SimilarFingerprintResult])
end
def find_similar_fingerprints(
fingerprint_value:,
fingerprint_detail_value:,
fingerprints,
limit: 32,
oversearch: 2,
includes: {}
)
ActiveRecord::Base.connection.execute("SET ivfflat.probes = 20")
Domain::PostFile::BitFingerprint
.order(
Arel.sql "(fingerprint_value <~> '#{ActiveRecord::Base.connection.quote_string(fingerprint_value)}')"
)
.limit(limit * oversearch)
.includes(includes)
.to_a
.uniq(&:post_file_id)
.map do |other_fingerprint|
SimilarFingerprintResult.new(
fingerprint: other_fingerprint,
similarity_percentage:
calculate_similarity_percentage(
fingerprint_detail_value,
T.must(other_fingerprint.fingerprint_detail_value),
),
)
results =
fingerprints.flat_map do |f|
Domain::PostFile::BitFingerprint
.order(
Arel.sql "(fingerprint_value <~> '#{ActiveRecord::Base.connection.quote_string(f.fingerprint)}')"
)
.limit(limit * oversearch)
.includes(includes)
.to_a
.uniq(&:post_file_id)
.map do |other_fingerprint|
SimilarFingerprintResult.new(
fingerprint: other_fingerprint,
similarity_percentage:
calculate_similarity_percentage(
f.detail_fingerprint,
T.must(other_fingerprint.fingerprint_detail_value),
),
)
end
.sort { |a, b| b.similarity_percentage <=> a.similarity_percentage }
.take(limit)
end
.sort { |a, b| b.similarity_percentage <=> a.similarity_percentage }
.take(limit)
results
.group_by { |s| T.must(s.fingerprint.post_file_id) }
.map do |post_file_id, similar_fingerprints|
T.must(similar_fingerprints.max_by(&:similarity_percentage))
end
.sort_by(&:similarity_percentage)
.reverse
end
class GenerateFingerprintsResult < T::Struct
const :thumb_path, String
const :fingerprint, String
const :detail_fingerprint, String
end
# Generate a fingerprint from the image path
sig do
params(image_path: String, content_type: String, tmp_dir: String).returns(
T.nilable(T::Array[GenerateFingerprintsResult]),
)
end
def generate_fingerprints(image_path, content_type, tmp_dir)
# Use the new from_file_path method to create a fingerprint
media = LoadedMedia.from_file(content_type, image_path)
return nil unless media
thumbnail_options =
LoadedMedia::ThumbnailOptions.new(
width: 128,
height: 128,
quality: 95,
size: :force,
interlace: false,
for_frames: [0.0, 0.1, 0.5, 0.9, 1.0],
)
frame_nums =
thumbnail_options
.for_frames
.map do |frame_fraction|
(frame_fraction * (media.num_frames - 1)).to_i
end
.uniq
.sort
frame_nums.map do |frame_num|
tmp_file = File.join(tmp_dir, "frame-#{frame_num}.jpg")
media.write_frame_thumbnail(frame_num, tmp_file, thumbnail_options)
GenerateFingerprintsResult.new(
thumb_path: tmp_file,
fingerprint:
Domain::PostFile::BitFingerprint.from_file_path(tmp_file),
detail_fingerprint:
Domain::PostFile::BitFingerprint.detail_from_file_path(tmp_file),
)
end
end
end
end

View File

@@ -75,7 +75,7 @@ module Tasks
).void
end
def handle_message(bot, message)
return unless message.photo || message.document
return unless message.photo || message.document || message.video
# Start timing the total request
total_request_timer = Stopwatch.start
@@ -89,14 +89,14 @@ module Tasks
response_message =
bot.api.send_message(
chat_id: chat_id,
text: "🔍 Analyzing image... Please wait...",
text: "🔍 Analyzing... Please wait...",
reply_to_message_id: message.message_id,
)
begin
# Process the image and perform visual search
search_result, processed_blob =
process_image_message_with_logging(bot, message, telegram_log)
process_media_message_with_logging(bot, message, telegram_log)
if search_result
if search_result.empty?
@@ -114,7 +114,7 @@ module Tasks
)
else
result_text =
"❌ Could not process the image. Please make sure it's a valid image file."
"❌ Could not process the file. Please make sure it's a valid image or video file."
# Update log with invalid image
update_telegram_log_invalid_image(telegram_log, result_text)
@@ -133,7 +133,7 @@ module Tasks
telegram_log.update!(total_request_time: total_request_time)
log("⏱️ Total request completed in #{total_request_timer.elapsed_s}")
rescue StandardError => e
log("Error processing image: #{e.message}")
log("Error processing file: #{e.message}")
# Update log with error
update_telegram_log_error(telegram_log, e)
@@ -143,7 +143,7 @@ module Tasks
chat_id: chat_id,
message_id: response_message.message_id,
text:
"❌ An error occurred while processing your image. Please try again.",
"❌ An error occurred while processing your file. Please try again.",
)
# Record total request time even for errors
@@ -169,21 +169,21 @@ module Tasks
],
)
end
def process_image_message_with_logging(bot, message, telegram_log)
log("📥 Received image message from chat #{message.chat.id}")
def process_media_message_with_logging(bot, message, telegram_log)
log("📥 Received message from chat #{message.chat.id}")
# Get the largest photo or document
image_file = get_image_file_from_message(message)
return nil, nil unless image_file
media_file = get_media_file_from_message(message)
return nil, nil unless media_file
# Download the image to a temporary file
download_stopwatch = Stopwatch.start
temp_file = download_telegram_image(bot, image_file)
temp_file = download_telegram_file(bot, media_file)
download_time = download_stopwatch.elapsed
return nil, nil unless temp_file
log("📥 Downloaded image in #{download_stopwatch.elapsed_s}")
log("📥 Downloaded file in #{download_stopwatch.elapsed_s}")
processed_blob = nil
@@ -196,11 +196,13 @@ module Tasks
# Create BlobFile for the processed image
content_type =
case image_file
case media_file
when Telegram::Bot::Types::Document
image_file.mime_type || "application/octet-stream"
media_file.mime_type || "application/octet-stream"
when Telegram::Bot::Types::PhotoSize
"image/jpeg" # Telegram photos are typically JPEG
when Telegram::Bot::Types::Video
media_file.mime_type || "video/mp4"
else
"application/octet-stream"
end
@@ -213,16 +215,19 @@ module Tasks
image_processing_time = image_processing_stopwatch.elapsed
log("🔧 Processed image in #{image_processing_stopwatch.elapsed_s}")
log("🔧 Processed file in #{image_processing_stopwatch.elapsed_s}")
# Time fingerprint generation
fingerprint_stopwatch = Stopwatch.start
fingerprint_value =
Domain::PostFile::BitFingerprint.from_file_path(file_path)
detail_fingerprint_value =
Domain::PostFile::BitFingerprint.detail_from_file_path(file_path)
temp_dir = Dir.mktmpdir("telegram-bot-task-visual-search")
fingerprints = generate_fingerprints(file_path, content_type, temp_dir)
fingerprint_computation_time = fingerprint_stopwatch.elapsed
if fingerprints.nil?
log("❌ Error generating fingerprints")
return nil, nil
end
log(
"🔍 Generated fingerprints in #{fingerprint_stopwatch.elapsed_s}, searching for similar images...",
)
@@ -231,8 +236,7 @@ module Tasks
search_stopwatch = Stopwatch.start
similar_results =
find_similar_fingerprints(
fingerprint_value: fingerprint_value,
fingerprint_detail_value: detail_fingerprint_value,
fingerprints,
limit: 10,
oversearch: 3,
includes: {
@@ -259,10 +263,11 @@ module Tasks
[high_quality_matches, processed_blob]
rescue StandardError => e
log("❌ Error processing image: #{e.message}")
log("❌ Error processing file: #{e.message}")
[nil, processed_blob]
ensure
# Clean up temp file
# Clean up temp files
FileUtils.rm_rf(temp_dir) if temp_dir
temp_file.unlink if temp_file
end
end
@@ -331,10 +336,10 @@ module Tasks
telegram_log.update!(
status: :invalid_image,
search_results_count: 0,
error_message: "Invalid or unsupported image format",
error_message: "Invalid or unsupported file format",
response_data: {
response_text: response_text,
error: "Invalid image format",
error: "Invalid file format",
},
)
end
@@ -408,26 +413,29 @@ module Tasks
T.nilable(
T.any(
Telegram::Bot::Types::PhotoSize,
Telegram::Bot::Types::Video,
Telegram::Bot::Types::Document,
),
),
)
end
def get_image_file_from_message(message)
def get_media_file_from_message(message)
if message.photo && message.photo.any?
# Get the largest photo variant
message.photo.max_by { |photo| photo.file_size || 0 }
elsif message.video
message.video
elsif message.document
# Check if document is an image
content_type = message.document.mime_type
if content_type&.start_with?("image/")
message.document
else
log("❌ Document is not an image: #{content_type}")
log("❌ Document is not an image or video: #{content_type}")
nil
end
else
log("❌ No image found in message")
log("❌ No image or video found in message")
nil
end
end
@@ -439,11 +447,12 @@ module Tasks
file_info:
T.any(
Telegram::Bot::Types::PhotoSize,
Telegram::Bot::Types::Video,
Telegram::Bot::Types::Document,
),
).returns(T.nilable(Tempfile))
end
def download_telegram_image(bot, file_info)
def download_telegram_file(bot, file_info)
bot_token = get_bot_token
return nil unless bot_token
@@ -459,7 +468,7 @@ module Tasks
# Download the file
file_url = "https://api.telegram.org/file/bot#{bot_token}/#{file_path}"
log("📥 Downloading image from: #{file_url}...")
log("📥 Downloading file from: #{file_url}...")
uri = URI(file_url)
downloaded_data = Net::HTTP.get(uri)
@@ -468,15 +477,15 @@ module Tasks
extension = File.extname(file_path)
extension = ".jpg" if extension.empty?
temp_file = Tempfile.new(["telegram_image", extension])
temp_file = Tempfile.new(["telegram_file", extension])
temp_file.binmode
temp_file.write(downloaded_data)
temp_file.close
log("✅ Downloaded image to: #{temp_file.path}")
log("✅ Downloaded file to: #{temp_file.path}")
temp_file
rescue StandardError => e
log("❌ Error downloading image: #{e.message}")
log("❌ Error downloading file: #{e.message}")
nil
end
end