Files
redux-scraper/app/models/trained_regression_model.rb
2025-07-12 08:53:49 +00:00

157 lines
4.1 KiB
Ruby

# typed: strict
# frozen_string_literal: true
class TrainedRegressionModel < ReduxApplicationRecord
extend T::Sig
class ModelType < T::Enum
enums do
Linear = new("linear")
Quadratic = new("quadratic")
Logarithmic = new("logarithmic")
SquareRoot = new("square_root")
end
end
# Validations
validates :name, presence: true
validates :model_type,
presence: true,
inclusion: {
in: ModelType.values.map(&:serialize),
}
validates :total_records_count,
presence: true,
numericality: {
greater_than: 0,
}
validates :training_records_count,
presence: true,
numericality: {
greater_than: 0,
}
validates :evaluation_records_count,
presence: true,
numericality: {
greater_than_or_equal_to: 0,
}
validates :train_test_split_ratio,
presence: true,
numericality: {
greater_than: 0,
less_than: 1,
}
validates :random_seed, presence: true
validates :x_min, presence: true
validates :x_max, presence: true
validates :y_min, presence: true
validates :y_max, presence: true
validates :coefficients, presence: true
validates :training_r_squared,
presence: true,
numericality: {
greater_than_or_equal_to: 0,
less_than_or_equal_to: 1,
}
validates :evaluation_r_squared,
presence: true,
numericality: {
greater_than_or_equal_to: 0,
less_than_or_equal_to: 1,
}
validates :equation_string, presence: true
validates :model_type,
presence: true,
inclusion: {
in: ModelType.values.map(&:serialize),
}
# Enums
enum :model_type,
ModelType.values.map(&:serialize).map { |v| [v, v] }.to_h,
prefix: true
sig { params(x_value: Float).returns(Float) }
def predict(x_value)
equation.evaluate(x_value)
end
sig { returns(Stats::Equation) }
def equation
@equation ||= T.let(build_equation, T.nilable(Stats::Equation))
end
sig { returns(String) }
def performance_summary
"Training R² = #{T.must(training_r_squared).round(3)}, Evaluation R² = #{T.must(evaluation_r_squared).round(3)}"
end
sig { returns(String) }
def data_summary
"Total: #{total_records_count}, Training: #{training_records_count}, Evaluation: #{evaluation_records_count}"
end
private
sig { returns(T::Array[Float]) }
def coefficients_array
coefficients || []
end
sig { returns(Float) }
def x_min
T.must(super)
end
sig { returns(Float) }
def x_max
T.must(super)
end
sig { returns(Float) }
def y_min
T.must(super)
end
sig { returns(Float) }
def y_max
T.must(super)
end
sig { returns(Stats::Equation) }
def build_equation
case model_type
when "linear"
Stats::PolynomialEquation.new(
Stats::LinearAxis.new(min: x_min, max: x_max),
Stats::LinearAxis.new(min: y_min, max: y_max),
coefficients_array,
)
when "quadratic"
Stats::PolynomialEquation.new(
Stats::LinearAxis.new(min: x_min, max: x_max),
Stats::LinearAxis.new(min: y_min, max: y_max),
coefficients_array,
)
when "logarithmic"
Stats::LogarithmicEquation.new(
Stats::LogarithmicAxis.new(min: x_min, max: x_max),
Stats::LogarithmicAxis.new(min: y_min, max: y_max),
T.cast(coefficients_array, [Float, Float]),
)
when "square_root"
# For square root regression y = a*√x + b:
# - x-axis uses SquareRootAxis to transform x -> √x
# - y-axis uses LinearAxis to keep y linear
# Stored coefficients are [intercept, slope], constructor expects (slope, intercept)
Stats::SquareRootEquation.new(
Stats::SquareRootAxis.new(min: x_min, max: x_max),
Stats::SquareRootAxis.new(min: y_min, max: y_max),
T.cast(coefficients_array, [Float, Float]),
)
else
raise "Unsupported model type: `#{model_type}`"
end
end
end