refactoring regression model code
This commit is contained in:
@@ -127,6 +127,8 @@ class Domain::Fa::Parser::Page < Domain::Fa::Parser::Base
|
||||
ActiveSupport::TimeZone.new("America/Los_Angeles")
|
||||
when "ddwhatnow", "vipvillageworker"
|
||||
ActiveSupport::TimeZone.new("America/New_York")
|
||||
when "blazeandwish"
|
||||
ActiveSupport::TimeZone.new("America/Chicago")
|
||||
else
|
||||
# server default?
|
||||
raise("unknown logged in user #{logged_in_user}")
|
||||
|
||||
@@ -44,7 +44,7 @@ module Stats::Helpers
|
||||
records_array.map do |record|
|
||||
Stats::DataPoint.new(
|
||||
x: record.fav_fa_id.to_f,
|
||||
y: T.cast(record.date&.to_time&.to_i&.to_f, Float),
|
||||
y: T.cast(record.date&.to_f, Float),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
@@ -20,10 +20,10 @@ module Stats
|
||||
end
|
||||
|
||||
# Denormalize linear regression coefficients back to original scale
|
||||
sig do
|
||||
params(norm_intercept: Float, norm_slope: Float).returns(T::Array[Float])
|
||||
end
|
||||
def denormalize_coefficients(norm_intercept, norm_slope)
|
||||
sig { params(coefficients: [Float, Float]).returns([Float, Float]) }
|
||||
def denormalize_coefficients(coefficients)
|
||||
norm_slope = coefficients[0]
|
||||
norm_intercept = coefficients[1]
|
||||
slope_orig = norm_slope * @y.scale / @x.scale
|
||||
intercept_orig =
|
||||
(norm_intercept * @y.scale + @y.min) - slope_orig * @x.min
|
||||
|
||||
@@ -10,9 +10,9 @@ module Stats
|
||||
.returns(T::Array[Float])
|
||||
end
|
||||
def denormalize_regression(regression_x, weight_vec)
|
||||
norm_a = T.cast(weight_vec[2], Float)
|
||||
norm_b = T.cast(weight_vec[1], Float)
|
||||
norm_c = T.cast(weight_vec[0], Float)
|
||||
norm_c = T.must(weight_vec[0])
|
||||
norm_b = T.must(weight_vec[1])
|
||||
norm_a = T.must(weight_vec[2])
|
||||
regression_x.map do |x|
|
||||
x_norm = @x.normalize(x)
|
||||
y_norm = norm_a * x_norm * x_norm + norm_b * x_norm + norm_c
|
||||
@@ -22,11 +22,12 @@ module Stats
|
||||
|
||||
# Denormalize quadratic regression coefficients back to original scale
|
||||
sig do
|
||||
params(norm_c: Float, norm_b: Float, norm_a: Float).returns(
|
||||
T::Array[Float],
|
||||
)
|
||||
params(coefficients: [Float, Float, Float]).returns([Float, Float, Float])
|
||||
end
|
||||
def denormalize_coefficients(norm_c, norm_b, norm_a)
|
||||
def denormalize_coefficients(coefficients)
|
||||
norm_c = coefficients[0]
|
||||
norm_b = coefficients[1]
|
||||
norm_a = coefficients[2]
|
||||
a_orig = norm_a * @y.scale / (@x.scale * @x.scale)
|
||||
b_orig = norm_b * @y.scale / @x.scale - 2 * a_orig * @x.min
|
||||
c_orig =
|
||||
|
||||
@@ -165,31 +165,26 @@ module Stats
|
||||
end
|
||||
def create_equation(equation_class, normalizer, weight_vec)
|
||||
if equation_class == Stats::PolynomialEquation
|
||||
case normalizer
|
||||
when Stats::LinearNormalizer
|
||||
coefficients =
|
||||
coefficients =
|
||||
case normalizer
|
||||
when Stats::LinearNormalizer
|
||||
normalizer.denormalize_coefficients(
|
||||
T.cast(weight_vec[0], Float),
|
||||
T.cast(weight_vec[1], Float),
|
||||
T.cast(weight_vec, [Float, Float]),
|
||||
)
|
||||
when Stats::QuadraticNormalizer
|
||||
coefficients =
|
||||
when Stats::QuadraticNormalizer
|
||||
normalizer.denormalize_coefficients(
|
||||
T.cast(weight_vec[0], Float),
|
||||
T.cast(weight_vec[1], Float),
|
||||
T.cast(weight_vec[2], Float),
|
||||
T.cast(weight_vec, [Float, Float, Float]),
|
||||
)
|
||||
else
|
||||
raise "Unsupported normalizer for PolynomialEquation: #{normalizer.class}"
|
||||
end
|
||||
else
|
||||
raise "Unsupported normalizer for PolynomialEquation: #{normalizer.class}"
|
||||
end
|
||||
Stats::PolynomialEquation.new(normalizer.x, normalizer.y, coefficients)
|
||||
elsif equation_class == Stats::LogarithmicEquation ||
|
||||
equation_class == Stats::SquareRootEquation
|
||||
equation_class.new(
|
||||
normalizer.x,
|
||||
normalizer.y,
|
||||
T.cast(weight_vec[0], Float),
|
||||
T.cast(weight_vec[1], Float),
|
||||
T.cast(weight_vec, [Float, Float]),
|
||||
)
|
||||
else
|
||||
raise "Unsupported equation class: #{equation_class}"
|
||||
|
||||
@@ -5,12 +5,19 @@ module Stats
|
||||
class SquareRootAxis < Axis
|
||||
sig { override.params(value: Float).returns(Float) }
|
||||
def normalize(value)
|
||||
Math.sqrt(value)
|
||||
min_sqrt = Math.sqrt(min)
|
||||
max_sqrt = Math.sqrt(max)
|
||||
value_sqrt = Math.sqrt(value)
|
||||
value_normalized = (value_sqrt - min_sqrt) / (max_sqrt - min_sqrt)
|
||||
value_normalized
|
||||
end
|
||||
|
||||
sig { override.params(value: Float).returns(Float) }
|
||||
def denormalize(value)
|
||||
value * value
|
||||
min_sqrt = Math.sqrt(min)
|
||||
max_sqrt = Math.sqrt(max)
|
||||
value_denormalized = value * (max_sqrt - min_sqrt) + min_sqrt
|
||||
value_denormalized * value_denormalized
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@@ -9,20 +9,14 @@ module Stats
|
||||
abstract!
|
||||
|
||||
sig(:final) do
|
||||
params(
|
||||
x: Stats::Axis,
|
||||
y: Stats::Axis,
|
||||
norm_slope: Float,
|
||||
norm_intercept: Float,
|
||||
).void
|
||||
params(x: Stats::Axis, y: Stats::Axis, coefficients: [Float, Float]).void
|
||||
end
|
||||
def initialize(x, y, norm_slope, norm_intercept)
|
||||
def initialize(x, y, coefficients)
|
||||
super(x, y)
|
||||
@norm_slope = norm_slope
|
||||
@norm_intercept = norm_intercept
|
||||
@norm_intercept = T.let(coefficients[0], Float)
|
||||
@norm_slope = T.let(coefficients[1], Float)
|
||||
end
|
||||
|
||||
# Public method to get coefficients (intercept, slope)
|
||||
sig(:final) { override.returns(T::Array[Float]) }
|
||||
def coefficients
|
||||
[@norm_intercept, @norm_slope]
|
||||
|
||||
@@ -8,12 +8,15 @@ class Tasks::Fa::BackfillFavsAndDatesTask < Tasks::InterruptableTask
|
||||
Both = new("both")
|
||||
OnlyFavs = new("favs")
|
||||
OnlyUserPages = new("profiles")
|
||||
ForUser = new("for-user")
|
||||
end
|
||||
end
|
||||
|
||||
sig { override.returns(String) }
|
||||
def progress_key
|
||||
"task-fa-backfill-favs-and-dates-#{@mode.serialize}"
|
||||
tag = "task-fa-backfill-favs-and-dates-#{@mode.serialize}"
|
||||
tag += "-#{@user&.url_name}" if @mode == Mode::ForUser
|
||||
tag
|
||||
end
|
||||
|
||||
sig do
|
||||
@@ -21,12 +24,22 @@ class Tasks::Fa::BackfillFavsAndDatesTask < Tasks::InterruptableTask
|
||||
mode: Mode,
|
||||
start_at: T.nilable(String),
|
||||
log_sink: T.any(IO, StringIO),
|
||||
user_url_name: T.nilable(String),
|
||||
).void
|
||||
end
|
||||
def initialize(mode:, start_at:, log_sink: $stderr)
|
||||
def initialize(mode:, start_at:, log_sink: $stderr, user_url_name: nil)
|
||||
super(log_sink:)
|
||||
@mode = mode
|
||||
@start_at = T.let(get_progress(start_at&.to_s)&.to_i, T.nilable(Integer))
|
||||
|
||||
if @mode == Mode::ForUser && user_url_name.present?
|
||||
@user =
|
||||
T.let(
|
||||
Domain::User::FaUser.find_by(url_name: user_url_name),
|
||||
T.nilable(Domain::User::FaUser),
|
||||
)
|
||||
raise "user not found for #{user_url_name}" unless @user
|
||||
end
|
||||
end
|
||||
|
||||
class Stats < T::ImmutableStruct
|
||||
@@ -66,6 +79,8 @@ class Tasks::Fa::BackfillFavsAndDatesTask < Tasks::InterruptableTask
|
||||
"uri_path like '/favorites/%'"
|
||||
when Mode::OnlyUserPages
|
||||
"uri_path like '/user/%'"
|
||||
when Mode::ForUser
|
||||
"uri_path like '/user/#{@user&.url_name}/%' or uri_path like '/favorites/#{@user&.url_name}/%'"
|
||||
end
|
||||
|
||||
query =
|
||||
@@ -74,9 +89,13 @@ class Tasks::Fa::BackfillFavsAndDatesTask < Tasks::InterruptableTask
|
||||
.where(query_string)
|
||||
.where(status_code: 200)
|
||||
|
||||
log("counting relevant log entries...")
|
||||
total = query.where(id: @start_at..).count
|
||||
pb = create_progress_bar(total)
|
||||
if @mode != Mode::ForUser
|
||||
log("counting relevant log entries...")
|
||||
total = query.where(id: @start_at..).count
|
||||
pb = create_progress_bar(total)
|
||||
else
|
||||
pb = create_progress_bar(nil)
|
||||
end
|
||||
|
||||
query
|
||||
.includes(:response)
|
||||
|
||||
@@ -13,4 +13,26 @@ class Domain::FaFavIdAndDate < ReduxApplicationRecord
|
||||
primary_key: :fa_id,
|
||||
optional: true
|
||||
validates :post_fa_id, presence: true
|
||||
|
||||
sig { returns(T.nilable(ActiveSupport::TimeWithZone)) }
|
||||
def infer_date
|
||||
return date if date.present?
|
||||
fa_fav_id = self.fav_fa_id
|
||||
return nil if fa_fav_id.nil?
|
||||
|
||||
@infer_date ||=
|
||||
T.let(
|
||||
begin
|
||||
regression_model =
|
||||
TrainedRegressionModel.find_by(
|
||||
name: "fa_fav_id_and_date",
|
||||
model_type: "square_root",
|
||||
)
|
||||
return nil if regression_model.nil?
|
||||
date_i = regression_model.predict(fa_fav_id.to_f).to_i
|
||||
Time.at(date_i).in_time_zone
|
||||
end,
|
||||
T.nilable(ActiveSupport::TimeWithZone),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
@@ -16,4 +16,19 @@ class Domain::UserPostFav < ReduxApplicationRecord
|
||||
post_klass = T.cast(post_klass, T.class_of(Domain::Post))
|
||||
joins(:post).where(post: { type: post_klass.name })
|
||||
end
|
||||
|
||||
sig { returns(T.nilable(Time)) }
|
||||
def faved_at
|
||||
post = self.post
|
||||
return nil if post.nil?
|
||||
|
||||
case post
|
||||
when Domain::Post::FaPost
|
||||
fav_model =
|
||||
Domain::FaFavIdAndDate.find_by(post_fa_id: post.fa_id, user_id: user_id)
|
||||
fav_model&.infer_date || post.posted_at&.to_time
|
||||
else
|
||||
post.posted_at&.to_time
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@@ -5,7 +5,7 @@ class TrainedRegressionModel < ReduxApplicationRecord
|
||||
extend T::Sig
|
||||
|
||||
# Validations
|
||||
validates :name, presence: true, uniqueness: true
|
||||
validates :name, presence: true
|
||||
validates :model_type,
|
||||
presence: true,
|
||||
inclusion: {
|
||||
@@ -159,24 +159,20 @@ class TrainedRegressionModel < ReduxApplicationRecord
|
||||
coefficients_array,
|
||||
)
|
||||
when "logarithmic"
|
||||
# For transformed equations: slope, intercept order in constructor
|
||||
slope = T.cast(coefficients_array[1], Float)
|
||||
intercept = T.cast(coefficients_array[0], Float)
|
||||
Stats::LogarithmicEquation.new(
|
||||
Stats::LogarithmicAxis.new(min: x_min, max: x_max),
|
||||
Stats::LogarithmicAxis.new(min: y_min, max: y_max),
|
||||
slope,
|
||||
intercept,
|
||||
T.cast(coefficients_array, [Float, Float]),
|
||||
)
|
||||
when "square_root"
|
||||
# For transformed equations: slope, intercept order in constructor
|
||||
slope = T.cast(coefficients_array[1], Float)
|
||||
intercept = T.cast(coefficients_array[0], Float)
|
||||
# For square root regression y = a*√x + b:
|
||||
# - x-axis uses SquareRootAxis to transform x -> √x
|
||||
# - y-axis uses LinearAxis to keep y linear
|
||||
# Stored coefficients are [intercept, slope], constructor expects (slope, intercept)
|
||||
Stats::SquareRootEquation.new(
|
||||
Stats::SquareRootAxis.new(min: x_min, max: x_max),
|
||||
Stats::SquareRootAxis.new(min: y_min, max: y_max),
|
||||
slope,
|
||||
intercept,
|
||||
T.cast(coefficients_array, [Float, Float]),
|
||||
)
|
||||
else
|
||||
raise "Unsupported model type: #{model_type}"
|
||||
|
||||
@@ -9,7 +9,7 @@ Rails.application.configure do
|
||||
# it changes. This slows down response time but is perfect for development
|
||||
# since you don't have to restart the web server when you make code changes.
|
||||
config.cache_classes = false
|
||||
config.action_view.cache_template_loading = true
|
||||
config.action_view.cache_template_loading = false
|
||||
|
||||
# Do not eager load code on boot.
|
||||
config.eager_load = false
|
||||
|
||||
@@ -63,11 +63,13 @@ namespace :stats do
|
||||
|
||||
puts "\n✅ Graph generation completed!"
|
||||
|
||||
# remove old regressions for this model
|
||||
model_name = "fa_fav_id_and_date"
|
||||
TrainedRegressionModel.where(name: model_name).destroy_all
|
||||
|
||||
# Save each regression model to the database
|
||||
regressions.each do |name, result|
|
||||
equation = result.equation
|
||||
model_name = "fa_fav_id_and_date_#{name.downcase}"
|
||||
TrainedRegressionModel.find_by(name: model_name)&.destroy
|
||||
TrainedRegressionModel.create!(
|
||||
name: model_name,
|
||||
model_type: name.downcase.tr(" ", "_"),
|
||||
|
||||
@@ -136,7 +136,12 @@ namespace :fa do
|
||||
start_at = ENV["start_at"]
|
||||
mode = ENV["mode"] || "both"
|
||||
mode = Tasks::Fa::BackfillFavsAndDatesTask::Mode.deserialize(mode)
|
||||
Tasks::Fa::BackfillFavsAndDatesTask.new(mode:, start_at:).run
|
||||
user_url_name = ENV["user_url_name"]
|
||||
Tasks::Fa::BackfillFavsAndDatesTask.new(
|
||||
mode:,
|
||||
start_at:,
|
||||
user_url_name:,
|
||||
).run
|
||||
end
|
||||
|
||||
# task export_to_sqlite: %i[environment set_logger_stdout] do
|
||||
|
||||
@@ -638,7 +638,7 @@ RSpec.describe Stats::Equation do
|
||||
end
|
||||
|
||||
it "stores normalized slope and intercept" do
|
||||
expect(equation.coefficients).to eq([1.0, 2.0])
|
||||
expect(equation.coefficients).to eq([2.0, 1.0]) # [slope, intercept]
|
||||
end
|
||||
|
||||
it "evaluates using the normalized transformation" do
|
||||
|
||||
@@ -177,7 +177,7 @@ RSpec.describe TrainedRegressionModel, type: :model do
|
||||
|
||||
equation = model.equation
|
||||
expect(equation).to be_a(Stats::LogarithmicEquation)
|
||||
expect(equation.coefficients).to eq([0.1, 0.5])
|
||||
expect(equation.coefficients).to eq([0.5, 0.1]) # [slope, intercept]
|
||||
end
|
||||
|
||||
it "constructs and caches a Stats::SquareRootEquation for square root models" do
|
||||
@@ -202,7 +202,7 @@ RSpec.describe TrainedRegressionModel, type: :model do
|
||||
|
||||
equation = model.equation
|
||||
expect(equation).to be_a(Stats::SquareRootEquation)
|
||||
expect(equation.coefficients).to eq([0.2, 0.8])
|
||||
expect(equation.coefficients).to eq([0.8, 0.2]) # [slope, intercept]
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
Reference in New Issue
Block a user