split common visual search logic out
This commit is contained in:
@@ -147,21 +147,33 @@ class Domain::PostsController < DomainController
|
||||
authorize Domain::Post
|
||||
|
||||
# Process the uploaded image or URL
|
||||
image_result = process_image_input
|
||||
return unless image_result
|
||||
|
||||
image_path, content_type = image_result
|
||||
file_result = process_image_input
|
||||
return unless file_result
|
||||
file_path, content_type = file_result
|
||||
|
||||
# Create thumbnail for the view if possible
|
||||
@uploaded_image_data_uri = create_thumbnail(image_path, content_type)
|
||||
@uploaded_hash_value = generate_fingerprint(image_path)
|
||||
@uploaded_detail_hash_value = generate_detail_fingerprint(image_path)
|
||||
tmp_dir = Dir.mktmpdir("visual-search")
|
||||
thumbs_and_fingerprints =
|
||||
helpers.generate_fingerprints(file_path, content_type, tmp_dir)
|
||||
first_thumb_and_fingerprint = thumbs_and_fingerprints&.first
|
||||
if thumbs_and_fingerprints.nil? || first_thumb_and_fingerprint.nil?
|
||||
flash.now[:error] = "Error generating fingerprints"
|
||||
render :visual_search
|
||||
return
|
||||
end
|
||||
logger.info("generated #{thumbs_and_fingerprints.length} thumbs")
|
||||
|
||||
@uploaded_image_data_uri =
|
||||
helpers.create_image_thumbnail_data_uri(
|
||||
first_thumb_and_fingerprint.thumb_path,
|
||||
"image/jpeg",
|
||||
)
|
||||
@uploaded_detail_hash_value = first_thumb_and_fingerprint.detail_fingerprint
|
||||
before = Time.now
|
||||
|
||||
similar_fingerprints =
|
||||
helpers.find_similar_fingerprints(
|
||||
fingerprint_value: @uploaded_hash_value,
|
||||
fingerprint_detail_value: @uploaded_detail_hash_value,
|
||||
).take(10)
|
||||
helpers.find_similar_fingerprints(thumbs_and_fingerprints).take(10)
|
||||
|
||||
@time_taken = Time.now - before
|
||||
|
||||
@matches = similar_fingerprints
|
||||
@@ -173,10 +185,7 @@ class Domain::PostsController < DomainController
|
||||
@matches = @good_matches if @good_matches.any?
|
||||
ensure
|
||||
# Clean up any temporary files
|
||||
if @temp_file
|
||||
@temp_file.unlink
|
||||
@temp_file = nil
|
||||
end
|
||||
FileUtils.rm_rf(tmp_dir) if tmp_dir
|
||||
end
|
||||
|
||||
private
|
||||
@@ -240,27 +249,6 @@ class Domain::PostsController < DomainController
|
||||
nil
|
||||
end
|
||||
|
||||
# Create a thumbnail from the image and return the data URI
|
||||
sig do
|
||||
params(image_path: String, content_type: String).returns(T.nilable(String))
|
||||
end
|
||||
def create_thumbnail(image_path, content_type)
|
||||
helpers.create_image_thumbnail_data_uri(image_path, content_type)
|
||||
end
|
||||
|
||||
# Generate a fingerprint from the image path
|
||||
sig { params(image_path: String).returns(String) }
|
||||
def generate_fingerprint(image_path)
|
||||
# Use the new from_file_path method to create a fingerprint
|
||||
Domain::PostFile::BitFingerprint.from_file_path(image_path)
|
||||
end
|
||||
|
||||
# Generate a detail fingerprint from the image path
|
||||
sig { params(image_path: String).returns(String) }
|
||||
def generate_detail_fingerprint(image_path)
|
||||
Domain::PostFile::BitFingerprint.detail_from_file_path(image_path)
|
||||
end
|
||||
|
||||
sig { override.returns(DomainController::DomainParamConfig) }
|
||||
def self.param_config
|
||||
DomainController::DomainParamConfig.new(
|
||||
|
||||
@@ -76,42 +76,97 @@ module Domain
|
||||
# Find similar images based on the fingerprint
|
||||
sig do
|
||||
params(
|
||||
fingerprint_value: String,
|
||||
fingerprint_detail_value: String,
|
||||
fingerprints: T::Array[GenerateFingerprintsResult],
|
||||
limit: Integer,
|
||||
oversearch: Integer,
|
||||
includes: T.untyped,
|
||||
).returns(T::Array[SimilarFingerprintResult])
|
||||
end
|
||||
def find_similar_fingerprints(
|
||||
fingerprint_value:,
|
||||
fingerprint_detail_value:,
|
||||
fingerprints,
|
||||
limit: 32,
|
||||
oversearch: 2,
|
||||
includes: {}
|
||||
)
|
||||
ActiveRecord::Base.connection.execute("SET ivfflat.probes = 20")
|
||||
|
||||
Domain::PostFile::BitFingerprint
|
||||
.order(
|
||||
Arel.sql "(fingerprint_value <~> '#{ActiveRecord::Base.connection.quote_string(fingerprint_value)}')"
|
||||
)
|
||||
.limit(limit * oversearch)
|
||||
.includes(includes)
|
||||
.to_a
|
||||
.uniq(&:post_file_id)
|
||||
.map do |other_fingerprint|
|
||||
SimilarFingerprintResult.new(
|
||||
fingerprint: other_fingerprint,
|
||||
similarity_percentage:
|
||||
calculate_similarity_percentage(
|
||||
fingerprint_detail_value,
|
||||
T.must(other_fingerprint.fingerprint_detail_value),
|
||||
),
|
||||
)
|
||||
results =
|
||||
fingerprints.flat_map do |f|
|
||||
Domain::PostFile::BitFingerprint
|
||||
.order(
|
||||
Arel.sql "(fingerprint_value <~> '#{ActiveRecord::Base.connection.quote_string(f.fingerprint)}')"
|
||||
)
|
||||
.limit(limit * oversearch)
|
||||
.includes(includes)
|
||||
.to_a
|
||||
.uniq(&:post_file_id)
|
||||
.map do |other_fingerprint|
|
||||
SimilarFingerprintResult.new(
|
||||
fingerprint: other_fingerprint,
|
||||
similarity_percentage:
|
||||
calculate_similarity_percentage(
|
||||
f.detail_fingerprint,
|
||||
T.must(other_fingerprint.fingerprint_detail_value),
|
||||
),
|
||||
)
|
||||
end
|
||||
.sort { |a, b| b.similarity_percentage <=> a.similarity_percentage }
|
||||
.take(limit)
|
||||
end
|
||||
.sort { |a, b| b.similarity_percentage <=> a.similarity_percentage }
|
||||
.take(limit)
|
||||
|
||||
results
|
||||
.group_by { |s| T.must(s.fingerprint.post_file_id) }
|
||||
.map do |post_file_id, similar_fingerprints|
|
||||
T.must(similar_fingerprints.max_by(&:similarity_percentage))
|
||||
end
|
||||
.sort_by(&:similarity_percentage)
|
||||
.reverse
|
||||
end
|
||||
|
||||
class GenerateFingerprintsResult < T::Struct
|
||||
const :thumb_path, String
|
||||
const :fingerprint, String
|
||||
const :detail_fingerprint, String
|
||||
end
|
||||
|
||||
# Generate a fingerprint from the image path
|
||||
sig do
|
||||
params(image_path: String, content_type: String, tmp_dir: String).returns(
|
||||
T.nilable(T::Array[GenerateFingerprintsResult]),
|
||||
)
|
||||
end
|
||||
def generate_fingerprints(image_path, content_type, tmp_dir)
|
||||
# Use the new from_file_path method to create a fingerprint
|
||||
media = LoadedMedia.from_file(content_type, image_path)
|
||||
return nil unless media
|
||||
thumbnail_options =
|
||||
LoadedMedia::ThumbnailOptions.new(
|
||||
width: 128,
|
||||
height: 128,
|
||||
quality: 95,
|
||||
size: :force,
|
||||
interlace: false,
|
||||
for_frames: [0.0, 0.1, 0.5, 0.9, 1.0],
|
||||
)
|
||||
frame_nums =
|
||||
thumbnail_options
|
||||
.for_frames
|
||||
.map do |frame_fraction|
|
||||
(frame_fraction * (media.num_frames - 1)).to_i
|
||||
end
|
||||
.uniq
|
||||
.sort
|
||||
frame_nums.map do |frame_num|
|
||||
tmp_file = File.join(tmp_dir, "frame-#{frame_num}.jpg")
|
||||
media.write_frame_thumbnail(frame_num, tmp_file, thumbnail_options)
|
||||
GenerateFingerprintsResult.new(
|
||||
thumb_path: tmp_file,
|
||||
fingerprint:
|
||||
Domain::PostFile::BitFingerprint.from_file_path(tmp_file),
|
||||
detail_fingerprint:
|
||||
Domain::PostFile::BitFingerprint.detail_from_file_path(tmp_file),
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@@ -75,7 +75,7 @@ module Tasks
|
||||
).void
|
||||
end
|
||||
def handle_message(bot, message)
|
||||
return unless message.photo || message.document
|
||||
return unless message.photo || message.document || message.video
|
||||
|
||||
# Start timing the total request
|
||||
total_request_timer = Stopwatch.start
|
||||
@@ -89,14 +89,14 @@ module Tasks
|
||||
response_message =
|
||||
bot.api.send_message(
|
||||
chat_id: chat_id,
|
||||
text: "🔍 Analyzing image... Please wait...",
|
||||
text: "🔍 Analyzing... Please wait...",
|
||||
reply_to_message_id: message.message_id,
|
||||
)
|
||||
|
||||
begin
|
||||
# Process the image and perform visual search
|
||||
search_result, processed_blob =
|
||||
process_image_message_with_logging(bot, message, telegram_log)
|
||||
process_media_message_with_logging(bot, message, telegram_log)
|
||||
|
||||
if search_result
|
||||
if search_result.empty?
|
||||
@@ -114,7 +114,7 @@ module Tasks
|
||||
)
|
||||
else
|
||||
result_text =
|
||||
"❌ Could not process the image. Please make sure it's a valid image file."
|
||||
"❌ Could not process the file. Please make sure it's a valid image or video file."
|
||||
|
||||
# Update log with invalid image
|
||||
update_telegram_log_invalid_image(telegram_log, result_text)
|
||||
@@ -133,7 +133,7 @@ module Tasks
|
||||
telegram_log.update!(total_request_time: total_request_time)
|
||||
log("⏱️ Total request completed in #{total_request_timer.elapsed_s}")
|
||||
rescue StandardError => e
|
||||
log("Error processing image: #{e.message}")
|
||||
log("Error processing file: #{e.message}")
|
||||
|
||||
# Update log with error
|
||||
update_telegram_log_error(telegram_log, e)
|
||||
@@ -143,7 +143,7 @@ module Tasks
|
||||
chat_id: chat_id,
|
||||
message_id: response_message.message_id,
|
||||
text:
|
||||
"❌ An error occurred while processing your image. Please try again.",
|
||||
"❌ An error occurred while processing your file. Please try again.",
|
||||
)
|
||||
|
||||
# Record total request time even for errors
|
||||
@@ -169,21 +169,21 @@ module Tasks
|
||||
],
|
||||
)
|
||||
end
|
||||
def process_image_message_with_logging(bot, message, telegram_log)
|
||||
log("📥 Received image message from chat #{message.chat.id}")
|
||||
def process_media_message_with_logging(bot, message, telegram_log)
|
||||
log("📥 Received message from chat #{message.chat.id}")
|
||||
|
||||
# Get the largest photo or document
|
||||
image_file = get_image_file_from_message(message)
|
||||
return nil, nil unless image_file
|
||||
media_file = get_media_file_from_message(message)
|
||||
return nil, nil unless media_file
|
||||
|
||||
# Download the image to a temporary file
|
||||
download_stopwatch = Stopwatch.start
|
||||
temp_file = download_telegram_image(bot, image_file)
|
||||
temp_file = download_telegram_file(bot, media_file)
|
||||
download_time = download_stopwatch.elapsed
|
||||
|
||||
return nil, nil unless temp_file
|
||||
|
||||
log("📥 Downloaded image in #{download_stopwatch.elapsed_s}")
|
||||
log("📥 Downloaded file in #{download_stopwatch.elapsed_s}")
|
||||
|
||||
processed_blob = nil
|
||||
|
||||
@@ -196,11 +196,13 @@ module Tasks
|
||||
|
||||
# Create BlobFile for the processed image
|
||||
content_type =
|
||||
case image_file
|
||||
case media_file
|
||||
when Telegram::Bot::Types::Document
|
||||
image_file.mime_type || "application/octet-stream"
|
||||
media_file.mime_type || "application/octet-stream"
|
||||
when Telegram::Bot::Types::PhotoSize
|
||||
"image/jpeg" # Telegram photos are typically JPEG
|
||||
when Telegram::Bot::Types::Video
|
||||
media_file.mime_type || "video/mp4"
|
||||
else
|
||||
"application/octet-stream"
|
||||
end
|
||||
@@ -213,16 +215,19 @@ module Tasks
|
||||
|
||||
image_processing_time = image_processing_stopwatch.elapsed
|
||||
|
||||
log("🔧 Processed image in #{image_processing_stopwatch.elapsed_s}")
|
||||
log("🔧 Processed file in #{image_processing_stopwatch.elapsed_s}")
|
||||
|
||||
# Time fingerprint generation
|
||||
fingerprint_stopwatch = Stopwatch.start
|
||||
fingerprint_value =
|
||||
Domain::PostFile::BitFingerprint.from_file_path(file_path)
|
||||
detail_fingerprint_value =
|
||||
Domain::PostFile::BitFingerprint.detail_from_file_path(file_path)
|
||||
temp_dir = Dir.mktmpdir("telegram-bot-task-visual-search")
|
||||
fingerprints = generate_fingerprints(file_path, content_type, temp_dir)
|
||||
fingerprint_computation_time = fingerprint_stopwatch.elapsed
|
||||
|
||||
if fingerprints.nil?
|
||||
log("❌ Error generating fingerprints")
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
log(
|
||||
"🔍 Generated fingerprints in #{fingerprint_stopwatch.elapsed_s}, searching for similar images...",
|
||||
)
|
||||
@@ -231,8 +236,7 @@ module Tasks
|
||||
search_stopwatch = Stopwatch.start
|
||||
similar_results =
|
||||
find_similar_fingerprints(
|
||||
fingerprint_value: fingerprint_value,
|
||||
fingerprint_detail_value: detail_fingerprint_value,
|
||||
fingerprints,
|
||||
limit: 10,
|
||||
oversearch: 3,
|
||||
includes: {
|
||||
@@ -259,10 +263,11 @@ module Tasks
|
||||
|
||||
[high_quality_matches, processed_blob]
|
||||
rescue StandardError => e
|
||||
log("❌ Error processing image: #{e.message}")
|
||||
log("❌ Error processing file: #{e.message}")
|
||||
[nil, processed_blob]
|
||||
ensure
|
||||
# Clean up temp file
|
||||
# Clean up temp files
|
||||
FileUtils.rm_rf(temp_dir) if temp_dir
|
||||
temp_file.unlink if temp_file
|
||||
end
|
||||
end
|
||||
@@ -331,10 +336,10 @@ module Tasks
|
||||
telegram_log.update!(
|
||||
status: :invalid_image,
|
||||
search_results_count: 0,
|
||||
error_message: "Invalid or unsupported image format",
|
||||
error_message: "Invalid or unsupported file format",
|
||||
response_data: {
|
||||
response_text: response_text,
|
||||
error: "Invalid image format",
|
||||
error: "Invalid file format",
|
||||
},
|
||||
)
|
||||
end
|
||||
@@ -408,26 +413,29 @@ module Tasks
|
||||
T.nilable(
|
||||
T.any(
|
||||
Telegram::Bot::Types::PhotoSize,
|
||||
Telegram::Bot::Types::Video,
|
||||
Telegram::Bot::Types::Document,
|
||||
),
|
||||
),
|
||||
)
|
||||
end
|
||||
def get_image_file_from_message(message)
|
||||
def get_media_file_from_message(message)
|
||||
if message.photo && message.photo.any?
|
||||
# Get the largest photo variant
|
||||
message.photo.max_by { |photo| photo.file_size || 0 }
|
||||
elsif message.video
|
||||
message.video
|
||||
elsif message.document
|
||||
# Check if document is an image
|
||||
content_type = message.document.mime_type
|
||||
if content_type&.start_with?("image/")
|
||||
message.document
|
||||
else
|
||||
log("❌ Document is not an image: #{content_type}")
|
||||
log("❌ Document is not an image or video: #{content_type}")
|
||||
nil
|
||||
end
|
||||
else
|
||||
log("❌ No image found in message")
|
||||
log("❌ No image or video found in message")
|
||||
nil
|
||||
end
|
||||
end
|
||||
@@ -439,11 +447,12 @@ module Tasks
|
||||
file_info:
|
||||
T.any(
|
||||
Telegram::Bot::Types::PhotoSize,
|
||||
Telegram::Bot::Types::Video,
|
||||
Telegram::Bot::Types::Document,
|
||||
),
|
||||
).returns(T.nilable(Tempfile))
|
||||
end
|
||||
def download_telegram_image(bot, file_info)
|
||||
def download_telegram_file(bot, file_info)
|
||||
bot_token = get_bot_token
|
||||
return nil unless bot_token
|
||||
|
||||
@@ -459,7 +468,7 @@ module Tasks
|
||||
|
||||
# Download the file
|
||||
file_url = "https://api.telegram.org/file/bot#{bot_token}/#{file_path}"
|
||||
log("📥 Downloading image from: #{file_url}...")
|
||||
log("📥 Downloading file from: #{file_url}...")
|
||||
|
||||
uri = URI(file_url)
|
||||
downloaded_data = Net::HTTP.get(uri)
|
||||
@@ -468,15 +477,15 @@ module Tasks
|
||||
extension = File.extname(file_path)
|
||||
extension = ".jpg" if extension.empty?
|
||||
|
||||
temp_file = Tempfile.new(["telegram_image", extension])
|
||||
temp_file = Tempfile.new(["telegram_file", extension])
|
||||
temp_file.binmode
|
||||
temp_file.write(downloaded_data)
|
||||
temp_file.close
|
||||
|
||||
log("✅ Downloaded image to: #{temp_file.path}")
|
||||
log("✅ Downloaded file to: #{temp_file.path}")
|
||||
temp_file
|
||||
rescue StandardError => e
|
||||
log("❌ Error downloading image: #{e.message}")
|
||||
log("❌ Error downloading file: #{e.message}")
|
||||
nil
|
||||
end
|
||||
end
|
||||
|
||||
@@ -130,33 +130,46 @@ RSpec.describe Domain::PostsController, type: :controller do
|
||||
context "with an image URL" do
|
||||
let(:mock_hash_value) { "1010101010101010" }
|
||||
let(:mock_detail_hash_value) { "0101010101010101" }
|
||||
let(:mock_fingerprints) { Domain::PostFile::BitFingerprint.none }
|
||||
let(:temp_file_path) { "/tmp/test_image.jpg" }
|
||||
|
||||
# Mock the GenerateFingerprintsResult structure that generate_fingerprints returns
|
||||
let(:mock_generate_fingerprints_results) do
|
||||
[
|
||||
Domain::VisualSearchHelper::GenerateFingerprintsResult.new(
|
||||
thumb_path: "/tmp/thumb1.jpg",
|
||||
fingerprint: mock_hash_value,
|
||||
detail_fingerprint: mock_detail_hash_value,
|
||||
),
|
||||
]
|
||||
end
|
||||
|
||||
# Mock the final similar fingerprints results
|
||||
let(:mock_similar_fingerprints) { Domain::PostFile::BitFingerprint.none }
|
||||
|
||||
it "uses Phash::Fingerprint model methods for fingerprinting and finding similar images" do
|
||||
# We need to mock the image downloading and processing since we can't do that in tests
|
||||
allow(controller).to receive(:process_image_input).and_return(
|
||||
[temp_file_path, "image/jpeg"],
|
||||
)
|
||||
allow(controller).to receive(:create_thumbnail).and_return(
|
||||
"data:image/jpeg;base64,FAKE",
|
||||
)
|
||||
|
||||
# Set up expectations for our model methods - this is what we're really testing
|
||||
expect(Domain::PostFile::BitFingerprint).to receive(
|
||||
:from_file_path,
|
||||
).with(temp_file_path).and_return(mock_hash_value)
|
||||
|
||||
# Add expectation for detail fingerprint
|
||||
expect(Domain::PostFile::BitFingerprint).to receive(
|
||||
:detail_from_file_path,
|
||||
).with(temp_file_path).and_return(mock_detail_hash_value)
|
||||
expect(controller.helpers).to receive(:generate_fingerprints).with(
|
||||
temp_file_path,
|
||||
"image/jpeg",
|
||||
anything,
|
||||
).and_return(mock_generate_fingerprints_results)
|
||||
|
||||
# Mock the similar fingerprints search
|
||||
expect(controller.helpers).to receive(:find_similar_fingerprints).with(
|
||||
fingerprint_value: mock_hash_value,
|
||||
fingerprint_detail_value: mock_detail_hash_value,
|
||||
).and_return(mock_fingerprints)
|
||||
mock_generate_fingerprints_results,
|
||||
).and_return(mock_similar_fingerprints)
|
||||
|
||||
# Mock the thumbnail data URI creation
|
||||
expect(controller.helpers).to receive(
|
||||
:create_image_thumbnail_data_uri,
|
||||
).with("/tmp/thumb1.jpg", "image/jpeg").and_return(
|
||||
"data:image/jpeg;base64,FAKE",
|
||||
)
|
||||
|
||||
post :visual_results,
|
||||
params: {
|
||||
@@ -170,11 +183,7 @@ RSpec.describe Domain::PostsController, type: :controller do
|
||||
expect(assigns(:uploaded_image_data_uri)).to eq(
|
||||
"data:image/jpeg;base64,FAKE",
|
||||
)
|
||||
expect(assigns(:uploaded_hash_value)).to eq(mock_hash_value)
|
||||
expect(assigns(:uploaded_detail_hash_value)).to eq(
|
||||
mock_detail_hash_value,
|
||||
)
|
||||
expect(assigns(:matches)).to eq(mock_fingerprints.to_a)
|
||||
expect(assigns(:matches)).to eq(mock_similar_fingerprints.to_a)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user