109 lines
3.0 KiB
Ruby
109 lines
3.0 KiB
Ruby
# typed: strict
|
|
# frozen_string_literal: true
|
|
|
|
module Stats::Helpers
|
|
extend T::Sig
|
|
|
|
sig do
|
|
params(max_points: T.nilable(Integer)).returns(T::Array[Stats::DataPoint])
|
|
end
|
|
def self.sample_records(max_points)
|
|
records = Domain::UserPostFav::FaUserPostFav.with_explicit_time_and_id
|
|
|
|
if records.empty?
|
|
puts "❌ No complete FaUserPostFav records found"
|
|
exit 1
|
|
end
|
|
|
|
total_records = records.count
|
|
puts "📊 Found #{total_records} complete records"
|
|
|
|
records_array = records.to_a
|
|
if max_points && total_records > max_points
|
|
puts "🎲 Randomly sampling #{max_points} points from #{total_records} total records"
|
|
srand(42) # Fixed seed for reproducibility
|
|
records_array =
|
|
T.cast(
|
|
records_array.sample(max_points),
|
|
T::Array[Domain::UserPostFav::FaUserPostFav],
|
|
)
|
|
puts "📊 Using #{records_array.length} sampled records for analysis"
|
|
else
|
|
message =
|
|
(
|
|
if max_points
|
|
"within max_points limit of #{max_points}"
|
|
else
|
|
"no sampling limit specified"
|
|
end
|
|
)
|
|
puts "📊 Using all #{records_array.length} records (#{message})"
|
|
end
|
|
|
|
records_array.map do |record|
|
|
Stats::DataPoint.new(
|
|
x: record.fa_fav_id.to_f,
|
|
y: T.must(record.explicit_time).to_f,
|
|
)
|
|
end
|
|
end
|
|
|
|
sig do
|
|
params(records: T::Array[Stats::DataPoint], eval_ratio: Float).returns(
|
|
Stats::TrainTestSplit,
|
|
)
|
|
end
|
|
def self.split_train_test(records, eval_ratio = 0.2)
|
|
# Set random seed for reproducibility
|
|
srand(42)
|
|
|
|
# Shuffle the records
|
|
shuffled_records = records.shuffle
|
|
|
|
# Calculate split point
|
|
split_index = (records.length * (1.0 - eval_ratio)).round
|
|
|
|
training_records =
|
|
T.cast(shuffled_records[0...split_index], T::Array[Stats::DataPoint])
|
|
evaluation_records =
|
|
T.cast(shuffled_records[split_index..-1], T::Array[Stats::DataPoint])
|
|
|
|
split =
|
|
Stats::TrainTestSplit.new(
|
|
training_records: training_records,
|
|
evaluation_records: evaluation_records,
|
|
)
|
|
|
|
split
|
|
end
|
|
|
|
sig { params(value: Float).returns(Float) }
|
|
def self.format_r_squared(value)
|
|
value.round(3).to_f
|
|
end
|
|
|
|
# Format a number with significant figures and scientific notation when needed
|
|
sig { params(num: Float, sig_figs: Integer).returns(String) }
|
|
def self.format_number(num, sig_figs = 3)
|
|
# Handle zero case
|
|
return "0.0" if num.zero?
|
|
|
|
# Get order of scale
|
|
order = Math.log10(num.abs).floor
|
|
|
|
# Use scientific notation for very large or small numbers
|
|
if order >= 6 || order <= -3
|
|
# Scale number between 1 and 10
|
|
scaled = num / (10.0**order)
|
|
# Round to sig figs
|
|
rounded = scaled.round(sig_figs - 1)
|
|
"#{rounded}e#{order}"
|
|
else
|
|
# For normal range numbers, just round to appropriate decimal places
|
|
decimal_places = sig_figs - (order + 1)
|
|
decimal_places = 0 if decimal_places < 0
|
|
num.round(decimal_places).to_s
|
|
end
|
|
end
|
|
end
|