train/eval separately

This commit is contained in:
Dylan Knutson
2025-07-10 20:43:30 +00:00
parent 2a8d618a84
commit aad0cb045d

View File

@@ -25,6 +25,11 @@ namespace :stats do
puts "📈 X-axis range (fav_fa_id): #{base_normalizer.x_range}"
puts "📈 Y-axis range (date): #{base_normalizer.y_range}"
# Split data for plotting
split = StatsHelpers.split_train_test(records_array)
train_normalizer = LinearNormalizer.new(split.training_records)
eval_normalizer = LinearNormalizer.new(split.evaluation_records)
# Run regressions using specialized normalizers
regressions = RegressionAnalyzer.new(records_array).analyze
@@ -32,17 +37,19 @@ namespace :stats do
regressions.each do |name, result|
puts "\n📊 #{name} Regression Results:"
puts " #{result.equation_string}"
puts " R² = #{StatsHelpers.format_r_squared(result.r_squared)}"
puts " #{result.score_summary}"
end
# Generate visualizations
puts "\n🎨 Generating visualizations with UnicodePlot..."
plotter = StatsPlotter.new
plotter.plot_scatter(
"Original Data",
base_normalizer.x_values,
base_normalizer.y_values,
plotter.plot_train_eval_scatter(
"Original Data (Train/Eval)",
train_normalizer.x_values,
train_normalizer.y_values,
eval_normalizer.x_values,
eval_normalizer.y_values,
)
# Plot individual regression results
@@ -105,12 +112,76 @@ module StatsHelpers
records_array
end
sig do
params(
records: T::Array[Domain::FaFavIdAndDate],
eval_ratio: Float,
).returns(TrainTestSplit)
end
def self.split_train_test(records, eval_ratio = 0.2)
# Set random seed for reproducibility
srand(42)
# Shuffle the records
shuffled_records = records.shuffle
# Calculate split point
split_index = (records.length * (1.0 - eval_ratio)).round
training_records =
T.cast(
shuffled_records[0...split_index],
T::Array[Domain::FaFavIdAndDate],
)
evaluation_records =
T.cast(
shuffled_records[split_index..-1],
T::Array[Domain::FaFavIdAndDate],
)
split =
TrainTestSplit.new(
training_records: training_records,
evaluation_records: evaluation_records,
)
split
end
sig { params(value: Float).returns(Float) }
def self.format_r_squared(value)
value.round(3).to_f
end
end
# Immutable struct representing training and evaluation data split
class TrainTestSplit < T::ImmutableStruct
extend T::Sig
const :training_records, T::Array[Domain::FaFavIdAndDate]
const :evaluation_records, T::Array[Domain::FaFavIdAndDate]
sig { returns(Integer) }
def training_count
training_records.length
end
sig { returns(Integer) }
def evaluation_count
evaluation_records.length
end
sig { returns(Integer) }
def total_count
training_count + evaluation_count
end
sig { returns(String) }
def summary
"📊 Split data: #{training_count} training, #{evaluation_count} evaluation records"
end
end
class AxisRange < T::ImmutableStruct
extend T::Sig
@@ -538,7 +609,8 @@ class RegressionResult < T::ImmutableStruct
extend T::Sig
const :equation, Equation
const :r_squared, Float
const :training_r_squared, Float
const :evaluation_r_squared, Float
const :x_values, T::Array[Float]
const :y_values, T::Array[Float]
@@ -546,6 +618,11 @@ class RegressionResult < T::ImmutableStruct
def equation_string
equation.to_s
end
sig { returns(String) }
def score_summary
"Training R² = #{StatsHelpers.format_r_squared(training_r_squared)}, Evaluation R² = #{StatsHelpers.format_r_squared(evaluation_r_squared)}"
end
end
# Immutable struct representing the complete analysis results
@@ -567,41 +644,21 @@ class RegressionAnalyzer
@records = records
end
sig { returns(T::Array[[String, RegressionResult]]) }
def analyze
[
[
"Linear",
analyze_regression(LinearNormalizer, PolynomialEquation, degree: 1),
],
[
"Quadratic",
analyze_regression(QuadraticNormalizer, PolynomialEquation, degree: 2),
],
[
"Logarithmic",
analyze_regression(LogarithmicNormalizer, LogarithmicEquation),
],
[
"Square Root",
analyze_regression(SquareRootNormalizer, SquareRootEquation),
],
]
end
private
# Generic regression analysis method to eliminate duplication
sig do
params(
normalizer_class: T.class_of(DataNormalizer),
equation_class: T.class_of(Equation),
split: TrainTestSplit,
degree: Integer,
).returns(RegressionResult)
end
def analyze_regression(normalizer_class, equation_class, degree: 1)
normalizer = normalizer_class.new(@records)
regression_x = normalizer.regression_x_range
def analyze_regression(normalizer_class, equation_class, split, degree: 1)
# Create normalizers for training and evaluation data
training_normalizer = normalizer_class.new(split.training_records)
evaluation_normalizer = normalizer_class.new(split.evaluation_records)
regression_x = training_normalizer.regression_x_range
poly_features = Rumale::Preprocessing::PolynomialFeatures.new(degree:)
regressor = Rumale::LinearModel::LinearRegression.new(fit_bias: true)
pipeline =
@@ -612,28 +669,73 @@ class RegressionAnalyzer
},
)
# Fit the pipeline
x_matrix = normalizer.transformed_x_matrix
y_vector = normalizer.normalized_y_vector
pipeline.fit(x_matrix, y_vector)
r_squared = pipeline.score(x_matrix, y_vector)
# Fit the pipeline on training data
training_x_matrix = training_normalizer.transformed_x_matrix
training_y_vector = training_normalizer.normalized_y_vector
pipeline.fit(training_x_matrix, training_y_vector)
# Score on training data
training_r_squared = pipeline.score(training_x_matrix, training_y_vector)
# Score on evaluation data
evaluation_x_matrix = evaluation_normalizer.transformed_x_matrix
evaluation_y_vector = evaluation_normalizer.normalized_y_vector
evaluation_r_squared =
pipeline.score(evaluation_x_matrix, evaluation_y_vector)
weight_vec = pipeline.steps[:estimator].weight_vec.to_a
# Generate regression line data in original scale
regression_y =
generate_regression_line(normalizer, regression_x, weight_vec)
generate_regression_line(training_normalizer, regression_x, weight_vec)
# Create equation object
equation = create_equation(equation_class, normalizer, weight_vec)
equation = create_equation(equation_class, training_normalizer, weight_vec)
RegressionResult.new(
equation: equation,
r_squared: r_squared,
training_r_squared: training_r_squared,
evaluation_r_squared: evaluation_r_squared,
x_values: regression_x,
y_values: regression_y,
)
end
sig { returns(T::Array[[String, RegressionResult]]) }
def analyze
# Split data into training and evaluation sets
split = StatsHelpers.split_train_test(@records)
[
[
"Linear",
analyze_regression(
LinearNormalizer,
PolynomialEquation,
split,
degree: 1,
),
],
[
"Quadratic",
analyze_regression(
QuadraticNormalizer,
PolynomialEquation,
split,
degree: 2,
),
],
[
"Logarithmic",
analyze_regression(LogarithmicNormalizer, LogarithmicEquation, split),
],
[
"Square Root",
analyze_regression(SquareRootNormalizer, SquareRootEquation, split),
],
]
end
# Generate regression line using appropriate denormalization method
sig do
params(
@@ -691,6 +793,33 @@ end
class StatsPlotter
extend T::Sig
sig do
params(
title: String,
train_x: T::Array[Float],
train_y: T::Array[Float],
eval_x: T::Array[Float],
eval_y: T::Array[Float],
).void
end
def plot_train_eval_scatter(title, train_x, train_y, eval_x, eval_y)
plot_with_error_handling(title) do
plot =
UnicodePlot.scatterplot(
train_x,
train_y,
title: title,
name: "Training Data",
width: 80,
height: 20,
xlabel: "fav_fa_id",
ylabel: date_axis_label(train_y + eval_y),
)
UnicodePlot.scatterplot!(plot, eval_x, eval_y, name: "Evaluation Data")
plot
end
end
sig do
params(
title: String,
@@ -714,7 +843,8 @@ class StatsPlotter
sig { params(title: String, result: RegressionResult).void }
def plot_regression(title, result)
subtitle = "#{title.split.first} fit (R² = #{result.r_squared.round(3)})"
subtitle =
"#{title.split.first} fit (Training R² = #{result.training_r_squared.round(3)}, Evaluation R² = #{result.evaluation_r_squared.round(3)})"
plot_with_error_handling("#{title} - #{subtitle}") do
UnicodePlot.lineplot(
result.x_values,
@@ -756,7 +886,8 @@ class StatsPlotter
plot,
result.x_values,
result.y_values,
name: "#{name} (R²=#{result.r_squared.round(3)})",
name:
"#{name} (Train R²=#{result.training_r_squared.round(3)}, Eval R²=#{result.evaluation_r_squared.round(3)})",
)
end
plot
@@ -770,7 +901,7 @@ class StatsPlotter
y_min, y_max = y_values.minmax
start_date = Time.at(y_min).strftime("%Y-%m-%d")
end_date = Time.at(y_max).strftime("%Y-%m-%d")
"Date (#{start_date} to #{end_date})"
"#{start_date} to #{end_date}"
end
sig { params(title: String, block: T.proc.returns(T.untyped)).void }