allow redefinition of methods

This commit is contained in:
Dylan Knutson
2025-07-26 00:32:12 +00:00
parent 6df1fe8053
commit 8f610b8fa7
3 changed files with 85 additions and 11 deletions

View File

@@ -51,18 +51,26 @@ module HasAuxTable
end
# Main DSL method for defining auxiliary tables
sig { params(aux_name: T.any(String, Symbol)).returns(AuxTableConfig) }
def aux_table(aux_name)
sig do
params(
aux_name: T.any(String, Symbol),
allow_redefining: T.nilable(T.any(Symbol, T::Array[Symbol]))
).returns(AuxTableConfig)
end
def aux_table(aux_name, allow_redefining: nil)
@aux_table_configs ||=
T.let({}, T.nilable(T::Hash[Symbol, AuxTableConfig]))
allow_redefining = [allow_redefining].flatten.compact
aux_name = aux_name.to_sym
if @aux_table_configs.key?(aux_name)
Kernel.raise ArgumentError,
"Auxiliary '#{aux_name}' on #{self.name} (table '#{self.table_name}') already exists"
end
@aux_table_configs[aux_name] = config = generate_aux_config(aux_name)
@aux_table_configs[aux_name] = config =
generate_aux_config(aux_name, allow_redefining)
setup_attribute_types_hook!(config)
setup_load_schema_hook!(config)
setup_initialize_hook!(config)
@@ -85,8 +93,12 @@ module HasAuxTable
private
# Generate auxiliary model class dynamically
sig { params(aux_name: Symbol).returns(AuxTableConfig) }
def generate_aux_config(aux_name)
sig do
params(aux_name: Symbol, allow_redefining: T::Array[Symbol]).returns(
AuxTableConfig
)
end
def generate_aux_config(aux_name, allow_redefining)
main_class = T.cast(self, T.class_of(ActiveRecord::Base))
main_table = main_class.table_name
@@ -108,7 +120,8 @@ module HasAuxTable
AuxTableConfig.from_models(
main_class:,
aux_class:,
aux_association_name:
aux_association_name:,
allow_redefining:
)
# Define the association back to the specific STI subclass
@@ -217,7 +230,8 @@ module HasAuxTable
# Generate attribute accessors for each auxiliary column
config.aux.columns_hash.each do |column_name, column|
column_name = column_name.to_sym
if self.method_defined?(column_name.to_sym)
if self.method_defined?(column_name.to_sym) &&
!config.allow_redefining.include?(column_name.to_sym)
raise "invariant: method #{column_name} already defined"
end
[

View File

@@ -9,15 +9,21 @@ module HasAuxTable
const :aux_association_name, Symbol
const :main, ModelClassHelper
const :aux, ModelClassHelper
const :allow_redefining, T::Array[Symbol]
sig do
params(
main_class: T.class_of(ActiveRecord::Base),
aux_class: T.class_of(ActiveRecord::Base),
aux_association_name: Symbol
aux_association_name: Symbol,
allow_redefining: T::Array[Symbol]
).returns(AuxTableConfig)
end
def self.from_models(main_class:, aux_class:, aux_association_name:)
def self.from_models(
main_class:,
aux_class:,
aux_association_name:,
allow_redefining: []
)
primary_key = aux_class.primary_key
aux_rejected_column_names = [
primary_key,
@@ -37,7 +43,8 @@ module HasAuxTable
ModelClassHelper.new(
klass: aux_class,
rejected_column_names: aux_rejected_column_names
)
),
allow_redefining:
)
end

View File

@@ -1182,4 +1182,57 @@ RSpec.describe HasAuxTable do
expect(patient.doctors.count).to eq(1)
end
end
describe "allowing redefining of methods" do
it "allows method redefining with `allow_method_redefinition`" do
ActiveRecord::Schema.define do
create_base_table :test_model2s do |t|
t.string :on_base
t.create_aux :specific do |t|
t.string :on_aux
end
end
end
class TestModel2 < ActiveRecord::Base
include HasAuxTable
def on_base
"on_base #{super} #{id}"
end
end
expect {
class TestModel2A < TestModel2
aux_table :specific, allow_redefining: :on_base
def on_base
"2a_on_base_override #{super}"
end
def on_aux
"2a_on_aux_override #{super}"
end
end
}.not_to raise_error
expect {
class TestModel2B < TestModel2
aux_table :specific, allow_redefining: :on_base
end
}.not_to raise_error
base_model = TestModel2.create!(on_base: "base")
expect(base_model.on_base).to eq("on_base base #{base_model.id}")
specific_a = TestModel2A.create!(on_base: "2a_base", on_aux: "2a_aux")
expect(specific_a.on_base).to eq(
"2a_on_base_override on_base 2a_base #{specific_a.id}"
)
expect(specific_a.on_aux).to eq("2a_on_aux_override 2a_aux")
specific_b = TestModel2B.create!(on_base: "2b_base", on_aux: "2b_aux")
expect(specific_b.on_base).to eq("on_base 2b_base #{specific_b.id}")
expect(specific_b.on_aux).to eq("2b_aux")
end
end
end