refactoring regression model code

This commit is contained in:
Dylan Knutson
2025-07-11 01:45:39 +00:00
parent 9f1fc93267
commit acc2f9a240
16 changed files with 120 additions and 62 deletions

View File

@@ -127,6 +127,8 @@ class Domain::Fa::Parser::Page < Domain::Fa::Parser::Base
ActiveSupport::TimeZone.new("America/Los_Angeles")
when "ddwhatnow", "vipvillageworker"
ActiveSupport::TimeZone.new("America/New_York")
when "blazeandwish"
ActiveSupport::TimeZone.new("America/Chicago")
else
# server default?
raise("unknown logged in user #{logged_in_user}")

View File

@@ -44,7 +44,7 @@ module Stats::Helpers
records_array.map do |record|
Stats::DataPoint.new(
x: record.fav_fa_id.to_f,
y: T.cast(record.date&.to_time&.to_i&.to_f, Float),
y: T.cast(record.date&.to_f, Float),
)
end
end

View File

@@ -20,10 +20,10 @@ module Stats
end
# Denormalize linear regression coefficients back to original scale
sig do
params(norm_intercept: Float, norm_slope: Float).returns(T::Array[Float])
end
def denormalize_coefficients(norm_intercept, norm_slope)
sig { params(coefficients: [Float, Float]).returns([Float, Float]) }
def denormalize_coefficients(coefficients)
norm_slope = coefficients[0]
norm_intercept = coefficients[1]
slope_orig = norm_slope * @y.scale / @x.scale
intercept_orig =
(norm_intercept * @y.scale + @y.min) - slope_orig * @x.min

View File

@@ -10,9 +10,9 @@ module Stats
.returns(T::Array[Float])
end
def denormalize_regression(regression_x, weight_vec)
norm_a = T.cast(weight_vec[2], Float)
norm_b = T.cast(weight_vec[1], Float)
norm_c = T.cast(weight_vec[0], Float)
norm_c = T.must(weight_vec[0])
norm_b = T.must(weight_vec[1])
norm_a = T.must(weight_vec[2])
regression_x.map do |x|
x_norm = @x.normalize(x)
y_norm = norm_a * x_norm * x_norm + norm_b * x_norm + norm_c
@@ -22,11 +22,12 @@ module Stats
# Denormalize quadratic regression coefficients back to original scale
sig do
params(norm_c: Float, norm_b: Float, norm_a: Float).returns(
T::Array[Float],
)
params(coefficients: [Float, Float, Float]).returns([Float, Float, Float])
end
def denormalize_coefficients(norm_c, norm_b, norm_a)
def denormalize_coefficients(coefficients)
norm_c = coefficients[0]
norm_b = coefficients[1]
norm_a = coefficients[2]
a_orig = norm_a * @y.scale / (@x.scale * @x.scale)
b_orig = norm_b * @y.scale / @x.scale - 2 * a_orig * @x.min
c_orig =

View File

@@ -165,31 +165,26 @@ module Stats
end
def create_equation(equation_class, normalizer, weight_vec)
if equation_class == Stats::PolynomialEquation
case normalizer
when Stats::LinearNormalizer
coefficients =
coefficients =
case normalizer
when Stats::LinearNormalizer
normalizer.denormalize_coefficients(
T.cast(weight_vec[0], Float),
T.cast(weight_vec[1], Float),
T.cast(weight_vec, [Float, Float]),
)
when Stats::QuadraticNormalizer
coefficients =
when Stats::QuadraticNormalizer
normalizer.denormalize_coefficients(
T.cast(weight_vec[0], Float),
T.cast(weight_vec[1], Float),
T.cast(weight_vec[2], Float),
T.cast(weight_vec, [Float, Float, Float]),
)
else
raise "Unsupported normalizer for PolynomialEquation: #{normalizer.class}"
end
else
raise "Unsupported normalizer for PolynomialEquation: #{normalizer.class}"
end
Stats::PolynomialEquation.new(normalizer.x, normalizer.y, coefficients)
elsif equation_class == Stats::LogarithmicEquation ||
equation_class == Stats::SquareRootEquation
equation_class.new(
normalizer.x,
normalizer.y,
T.cast(weight_vec[0], Float),
T.cast(weight_vec[1], Float),
T.cast(weight_vec, [Float, Float]),
)
else
raise "Unsupported equation class: #{equation_class}"

View File

@@ -5,12 +5,19 @@ module Stats
class SquareRootAxis < Axis
sig { override.params(value: Float).returns(Float) }
def normalize(value)
Math.sqrt(value)
min_sqrt = Math.sqrt(min)
max_sqrt = Math.sqrt(max)
value_sqrt = Math.sqrt(value)
value_normalized = (value_sqrt - min_sqrt) / (max_sqrt - min_sqrt)
value_normalized
end
sig { override.params(value: Float).returns(Float) }
def denormalize(value)
value * value
min_sqrt = Math.sqrt(min)
max_sqrt = Math.sqrt(max)
value_denormalized = value * (max_sqrt - min_sqrt) + min_sqrt
value_denormalized * value_denormalized
end
end
end

View File

@@ -9,20 +9,14 @@ module Stats
abstract!
sig(:final) do
params(
x: Stats::Axis,
y: Stats::Axis,
norm_slope: Float,
norm_intercept: Float,
).void
params(x: Stats::Axis, y: Stats::Axis, coefficients: [Float, Float]).void
end
def initialize(x, y, norm_slope, norm_intercept)
def initialize(x, y, coefficients)
super(x, y)
@norm_slope = norm_slope
@norm_intercept = norm_intercept
@norm_intercept = T.let(coefficients[0], Float)
@norm_slope = T.let(coefficients[1], Float)
end
# Public method to get coefficients (intercept, slope)
sig(:final) { override.returns(T::Array[Float]) }
def coefficients
[@norm_intercept, @norm_slope]

View File

@@ -8,12 +8,15 @@ class Tasks::Fa::BackfillFavsAndDatesTask < Tasks::InterruptableTask
Both = new("both")
OnlyFavs = new("favs")
OnlyUserPages = new("profiles")
ForUser = new("for-user")
end
end
sig { override.returns(String) }
def progress_key
"task-fa-backfill-favs-and-dates-#{@mode.serialize}"
tag = "task-fa-backfill-favs-and-dates-#{@mode.serialize}"
tag += "-#{@user&.url_name}" if @mode == Mode::ForUser
tag
end
sig do
@@ -21,12 +24,22 @@ class Tasks::Fa::BackfillFavsAndDatesTask < Tasks::InterruptableTask
mode: Mode,
start_at: T.nilable(String),
log_sink: T.any(IO, StringIO),
user_url_name: T.nilable(String),
).void
end
def initialize(mode:, start_at:, log_sink: $stderr)
def initialize(mode:, start_at:, log_sink: $stderr, user_url_name: nil)
super(log_sink:)
@mode = mode
@start_at = T.let(get_progress(start_at&.to_s)&.to_i, T.nilable(Integer))
if @mode == Mode::ForUser && user_url_name.present?
@user =
T.let(
Domain::User::FaUser.find_by(url_name: user_url_name),
T.nilable(Domain::User::FaUser),
)
raise "user not found for #{user_url_name}" unless @user
end
end
class Stats < T::ImmutableStruct
@@ -66,6 +79,8 @@ class Tasks::Fa::BackfillFavsAndDatesTask < Tasks::InterruptableTask
"uri_path like '/favorites/%'"
when Mode::OnlyUserPages
"uri_path like '/user/%'"
when Mode::ForUser
"uri_path like '/user/#{@user&.url_name}/%' or uri_path like '/favorites/#{@user&.url_name}/%'"
end
query =
@@ -74,9 +89,13 @@ class Tasks::Fa::BackfillFavsAndDatesTask < Tasks::InterruptableTask
.where(query_string)
.where(status_code: 200)
log("counting relevant log entries...")
total = query.where(id: @start_at..).count
pb = create_progress_bar(total)
if @mode != Mode::ForUser
log("counting relevant log entries...")
total = query.where(id: @start_at..).count
pb = create_progress_bar(total)
else
pb = create_progress_bar(nil)
end
query
.includes(:response)

View File

@@ -13,4 +13,26 @@ class Domain::FaFavIdAndDate < ReduxApplicationRecord
primary_key: :fa_id,
optional: true
validates :post_fa_id, presence: true
sig { returns(T.nilable(ActiveSupport::TimeWithZone)) }
def infer_date
return date if date.present?
fa_fav_id = self.fav_fa_id
return nil if fa_fav_id.nil?
@infer_date ||=
T.let(
begin
regression_model =
TrainedRegressionModel.find_by(
name: "fa_fav_id_and_date",
model_type: "square_root",
)
return nil if regression_model.nil?
date_i = regression_model.predict(fa_fav_id.to_f).to_i
Time.at(date_i).in_time_zone
end,
T.nilable(ActiveSupport::TimeWithZone),
)
end
end

View File

@@ -16,4 +16,19 @@ class Domain::UserPostFav < ReduxApplicationRecord
post_klass = T.cast(post_klass, T.class_of(Domain::Post))
joins(:post).where(post: { type: post_klass.name })
end
sig { returns(T.nilable(Time)) }
def faved_at
post = self.post
return nil if post.nil?
case post
when Domain::Post::FaPost
fav_model =
Domain::FaFavIdAndDate.find_by(post_fa_id: post.fa_id, user_id: user_id)
fav_model&.infer_date || post.posted_at&.to_time
else
post.posted_at&.to_time
end
end
end

View File

@@ -5,7 +5,7 @@ class TrainedRegressionModel < ReduxApplicationRecord
extend T::Sig
# Validations
validates :name, presence: true, uniqueness: true
validates :name, presence: true
validates :model_type,
presence: true,
inclusion: {
@@ -159,24 +159,20 @@ class TrainedRegressionModel < ReduxApplicationRecord
coefficients_array,
)
when "logarithmic"
# For transformed equations: slope, intercept order in constructor
slope = T.cast(coefficients_array[1], Float)
intercept = T.cast(coefficients_array[0], Float)
Stats::LogarithmicEquation.new(
Stats::LogarithmicAxis.new(min: x_min, max: x_max),
Stats::LogarithmicAxis.new(min: y_min, max: y_max),
slope,
intercept,
T.cast(coefficients_array, [Float, Float]),
)
when "square_root"
# For transformed equations: slope, intercept order in constructor
slope = T.cast(coefficients_array[1], Float)
intercept = T.cast(coefficients_array[0], Float)
# For square root regression y = a*√x + b:
# - x-axis uses SquareRootAxis to transform x -> √x
# - y-axis uses LinearAxis to keep y linear
# Stored coefficients are [intercept, slope], constructor expects (slope, intercept)
Stats::SquareRootEquation.new(
Stats::SquareRootAxis.new(min: x_min, max: x_max),
Stats::SquareRootAxis.new(min: y_min, max: y_max),
slope,
intercept,
T.cast(coefficients_array, [Float, Float]),
)
else
raise "Unsupported model type: #{model_type}"

View File

@@ -9,7 +9,7 @@ Rails.application.configure do
# it changes. This slows down response time but is perfect for development
# since you don't have to restart the web server when you make code changes.
config.cache_classes = false
config.action_view.cache_template_loading = true
config.action_view.cache_template_loading = false
# Do not eager load code on boot.
config.eager_load = false

View File

@@ -63,11 +63,13 @@ namespace :stats do
puts "\n✅ Graph generation completed!"
# remove old regressions for this model
model_name = "fa_fav_id_and_date"
TrainedRegressionModel.where(name: model_name).destroy_all
# Save each regression model to the database
regressions.each do |name, result|
equation = result.equation
model_name = "fa_fav_id_and_date_#{name.downcase}"
TrainedRegressionModel.find_by(name: model_name)&.destroy
TrainedRegressionModel.create!(
name: model_name,
model_type: name.downcase.tr(" ", "_"),

View File

@@ -136,7 +136,12 @@ namespace :fa do
start_at = ENV["start_at"]
mode = ENV["mode"] || "both"
mode = Tasks::Fa::BackfillFavsAndDatesTask::Mode.deserialize(mode)
Tasks::Fa::BackfillFavsAndDatesTask.new(mode:, start_at:).run
user_url_name = ENV["user_url_name"]
Tasks::Fa::BackfillFavsAndDatesTask.new(
mode:,
start_at:,
user_url_name:,
).run
end
# task export_to_sqlite: %i[environment set_logger_stdout] do

View File

@@ -638,7 +638,7 @@ RSpec.describe Stats::Equation do
end
it "stores normalized slope and intercept" do
expect(equation.coefficients).to eq([1.0, 2.0])
expect(equation.coefficients).to eq([2.0, 1.0]) # [slope, intercept]
end
it "evaluates using the normalized transformation" do

View File

@@ -177,7 +177,7 @@ RSpec.describe TrainedRegressionModel, type: :model do
equation = model.equation
expect(equation).to be_a(Stats::LogarithmicEquation)
expect(equation.coefficients).to eq([0.1, 0.5])
expect(equation.coefficients).to eq([0.5, 0.1]) # [slope, intercept]
end
it "constructs and caches a Stats::SquareRootEquation for square root models" do
@@ -202,7 +202,7 @@ RSpec.describe TrainedRegressionModel, type: :model do
equation = model.equation
expect(equation).to be_a(Stats::SquareRootEquation)
expect(equation.coefficients).to eq([0.2, 0.8])
expect(equation.coefficients).to eq([0.8, 0.2]) # [slope, intercept]
end
end