157 lines
4.1 KiB
Ruby
157 lines
4.1 KiB
Ruby
# typed: strict
|
|
# frozen_string_literal: true
|
|
|
|
class TrainedRegressionModel < ReduxApplicationRecord
|
|
extend T::Sig
|
|
|
|
class ModelType < T::Enum
|
|
enums do
|
|
Linear = new("linear")
|
|
Quadratic = new("quadratic")
|
|
Logarithmic = new("logarithmic")
|
|
SquareRoot = new("square_root")
|
|
end
|
|
end
|
|
|
|
# Validations
|
|
validates :name, presence: true
|
|
validates :model_type,
|
|
presence: true,
|
|
inclusion: {
|
|
in: ModelType.values.map(&:serialize),
|
|
}
|
|
validates :total_records_count,
|
|
presence: true,
|
|
numericality: {
|
|
greater_than: 0,
|
|
}
|
|
validates :training_records_count,
|
|
presence: true,
|
|
numericality: {
|
|
greater_than: 0,
|
|
}
|
|
validates :evaluation_records_count,
|
|
presence: true,
|
|
numericality: {
|
|
greater_than_or_equal_to: 0,
|
|
}
|
|
validates :train_test_split_ratio,
|
|
presence: true,
|
|
numericality: {
|
|
greater_than: 0,
|
|
less_than: 1,
|
|
}
|
|
validates :random_seed, presence: true
|
|
validates :x_min, presence: true
|
|
validates :x_max, presence: true
|
|
validates :y_min, presence: true
|
|
validates :y_max, presence: true
|
|
validates :coefficients, presence: true
|
|
validates :training_r_squared,
|
|
presence: true,
|
|
numericality: {
|
|
greater_than_or_equal_to: 0,
|
|
less_than_or_equal_to: 1,
|
|
}
|
|
validates :evaluation_r_squared,
|
|
presence: true,
|
|
numericality: {
|
|
greater_than_or_equal_to: 0,
|
|
less_than_or_equal_to: 1,
|
|
}
|
|
validates :equation_string, presence: true
|
|
validates :model_type,
|
|
presence: true,
|
|
inclusion: {
|
|
in: ModelType.values.map(&:serialize),
|
|
}
|
|
|
|
# Enums
|
|
enum :model_type,
|
|
ModelType.values.map(&:serialize).map { |v| [v, v] }.to_h,
|
|
prefix: true
|
|
|
|
sig { params(x_value: Float).returns(Float) }
|
|
def predict(x_value)
|
|
equation.evaluate(x_value)
|
|
end
|
|
|
|
sig { returns(Stats::Equation) }
|
|
def equation
|
|
@equation ||= T.let(build_equation, T.nilable(Stats::Equation))
|
|
end
|
|
|
|
sig { returns(String) }
|
|
def performance_summary
|
|
"Training R² = #{T.must(training_r_squared).round(3)}, Evaluation R² = #{T.must(evaluation_r_squared).round(3)}"
|
|
end
|
|
|
|
sig { returns(String) }
|
|
def data_summary
|
|
"Total: #{total_records_count}, Training: #{training_records_count}, Evaluation: #{evaluation_records_count}"
|
|
end
|
|
|
|
private
|
|
|
|
sig { returns(T::Array[Float]) }
|
|
def coefficients_array
|
|
coefficients || []
|
|
end
|
|
|
|
sig { returns(Float) }
|
|
def x_min
|
|
T.must(super)
|
|
end
|
|
|
|
sig { returns(Float) }
|
|
def x_max
|
|
T.must(super)
|
|
end
|
|
|
|
sig { returns(Float) }
|
|
def y_min
|
|
T.must(super)
|
|
end
|
|
|
|
sig { returns(Float) }
|
|
def y_max
|
|
T.must(super)
|
|
end
|
|
|
|
sig { returns(Stats::Equation) }
|
|
def build_equation
|
|
case model_type
|
|
when "linear"
|
|
Stats::PolynomialEquation.new(
|
|
Stats::LinearAxis.new(min: x_min, max: x_max),
|
|
Stats::LinearAxis.new(min: y_min, max: y_max),
|
|
coefficients_array,
|
|
)
|
|
when "quadratic"
|
|
Stats::PolynomialEquation.new(
|
|
Stats::LinearAxis.new(min: x_min, max: x_max),
|
|
Stats::LinearAxis.new(min: y_min, max: y_max),
|
|
coefficients_array,
|
|
)
|
|
when "logarithmic"
|
|
Stats::LogarithmicEquation.new(
|
|
Stats::LogarithmicAxis.new(min: x_min, max: x_max),
|
|
Stats::LogarithmicAxis.new(min: y_min, max: y_max),
|
|
T.cast(coefficients_array, [Float, Float]),
|
|
)
|
|
when "square_root"
|
|
# For square root regression y = a*√x + b:
|
|
# - x-axis uses SquareRootAxis to transform x -> √x
|
|
# - y-axis uses LinearAxis to keep y linear
|
|
# Stored coefficients are [intercept, slope], constructor expects (slope, intercept)
|
|
Stats::SquareRootEquation.new(
|
|
Stats::SquareRootAxis.new(min: x_min, max: x_max),
|
|
Stats::SquareRootAxis.new(min: y_min, max: y_max),
|
|
T.cast(coefficients_array, [Float, Float]),
|
|
)
|
|
else
|
|
raise "Unsupported model type: `#{model_type}`"
|
|
end
|
|
end
|
|
end
|