refactor more logic into AuxTableConfig

This commit is contained in:
Dylan Knutson
2025-07-18 05:27:38 +00:00
parent d112d8b72d
commit 3a80c2b8dd
5 changed files with 298 additions and 167 deletions

View File

@@ -21,6 +21,14 @@ module HasAuxTable
VERSION = "0.1.0"
included do
T.bind(self, T.class_of(ActiveRecord::Base))
before_create do
T.bind(self, ActiveRecord::Base)
T.unsafe(self).type ||= self.class.name
end
end
module ClassMethods
extend T::Sig
extend T::Helpers
@@ -45,17 +53,17 @@ module HasAuxTable
"Auxiliary '#{aux_name}' on #{self.name} (table '#{self.table_name}') already exists"
end
@aux_table_configs[aux_name] = aux_config = generate_aux_config(aux_name)
setup_attribute_types_hook!(aux_config)
setup_schema_loading_hook!(aux_config)
setup_initialize_hook!(aux_config)
setup_save_hook!(aux_config)
setup_reload_hook!(aux_config)
setup_attributes_hook!(aux_config)
setup_relation_extensions!(aux_config)
setup_attribute_getter_setter_hooks!(aux_config)
@aux_table_configs[aux_name] = config = generate_aux_config(aux_name)
setup_attribute_types_hook!(config)
setup_load_schema_hook!(config)
setup_initialize_hook!(config)
setup_save_hook!(config)
setup_reload_hook!(config)
setup_attributes_hook!(config)
setup_relation_extensions!(config)
setup_attribute_getter_setter_hooks!(config)
aux_config
config
end
private
@@ -126,7 +134,7 @@ module HasAuxTable
AuxTableConfig.new(
aux_table_name:,
model_class: aux_class,
aux_class:,
main_class:,
aux_association_name:,
main_association_name:,
@@ -135,43 +143,49 @@ module HasAuxTable
)
end
sig { params(aux_config: AuxTableConfig).void }
def setup_attribute_types_hook!(aux_config)
original_method = aux_config.main_class.method(:attribute_types)
aux_config
sig { params(config: AuxTableConfig).void }
def setup_attribute_types_hook!(config)
original_method = config.main_class.method(:attribute_types)
config
.main_class
.define_singleton_method(:attribute_types) do
@aux_config_attribute_types_cache ||= {}
@aux_config_attribute_types_cache[aux_config.aux_table_name] ||= begin
original_types = original_method.call.dup
@aux_config_attribute_types_cache ||=
T.let(
{},
T.nilable(
T::Hash[Symbol, T::Hash[String, ActiveModel::Type::Value]]
)
)
aux_types =
aux_config.model_class.attribute_types.filter do |k, _|
aux_config.is_aux_column?(k)
end
@aux_config_attribute_types_cache[config.aux_table_name] ||= begin
original_types =
T.let(
original_method.call,
T::Hash[String, ActiveModel::Type::Value]
)
# move 'created_at', 'updated_at' etc to the end of the list
at_types = {}
original_types.each do |k, v|
timestamp_types = {}
original_types.reject! do |k, v|
if k.end_with?("_at") && v.type == :datetime
at_types[k] = v
timestamp_types[k] = v
original_types.delete(k)
end
end
original_types.merge!(aux_types)
original_types.merge!(at_types)
original_types.merge!(config.aux.attribute_types)
original_types.merge!(timestamp_types)
original_types
end
end
aux_config.main_class.attributes_for_inspect =
Util.attributes_for_inspect(aux_config)
config.main_class.attributes_for_inspect =
Util.attributes_for_inspect(config)
end
# Hook into schema loading to generate attribute accessors when schema is loaded
sig { params(aux_config: AuxTableConfig).void }
def setup_schema_loading_hook!(aux_config)
sig { params(config: AuxTableConfig).void }
def setup_load_schema_hook!(config)
# Override load_schema to also generate auxiliary attribute accessors when schema is loaded
load_schema_method = self.method(:load_schema!)
self.define_singleton_method(:load_schema!) do
@@ -180,66 +194,48 @@ module HasAuxTable
T.all(T.class_of(ActiveRecord::Base), HasAuxTable::ClassMethods)
)
aux_config_load_schema!(load_schema_method, aux_config)
aux_config_load_schema!(load_schema_method, config)
end
self.load_schema! if self.schema_loaded?
end
sig { params(load_schema_method: Method, aux_config: AuxTableConfig).void }
def aux_config_load_schema!(load_schema_method, aux_config)
sig { params(load_schema_method: Method, config: AuxTableConfig).void }
def aux_config_load_schema!(load_schema_method, config)
# first, load the main and aux table schemas like normal
result = load_schema_method.call
aux_config.load_aux_schema
config.load_aux_schema
aux_table_name = aux_config.aux_table_name
main_columns_hash = self.columns_hash
aux_columns_hash =
aux_config.model_class.columns_hash.select do |col|
aux_config.is_aux_column?(col)
end
main_column_names = main_columns_hash.keys
aux_column_names = aux_columns_hash.keys
aux_table_name = config.aux_table_name
check_for_overlapping_columns!(
aux_table_name,
main_column_names,
aux_column_names
config.main.column_names,
config.aux.column_names
)
aux_attributes = aux_config.model_class._default_attributes
aux_table_filtered_attributes =
aux_attributes
.keys
.filter_map do |k|
[k, aux_attributes[k]] if aux_column_names.include?(k)
end
.to_h
# set attributes that exist on the aux table to also exist on this table
aux_table_filtered_attributes.each do |name, attr|
config.aux.default_attributes.each do |name, attr|
@default_attributes[name] = attr
end
# Generate attribute accessors for each auxiliary column
aux_columns_hash.each do |column_name, column|
config.aux.columns_hash.each do |column_name, column|
column_name = column_name.to_sym
if self.method_defined?(column_name.to_sym)
raise "invariant: method #{column_name} already defined"
end
aux_config.define_aux_attribute_delegate(column_name)
aux_config.define_aux_attribute_delegate(:"#{column_name}?")
aux_config.define_aux_attribute_delegate(:"#{column_name}=")
config.define_aux_attribute_delegate(column_name)
config.define_aux_attribute_delegate(:"#{column_name}?")
config.define_aux_attribute_delegate(:"#{column_name}=")
end
result
end
sig { params(aux_config: AuxTableConfig).void }
def setup_attribute_getter_setter_hooks!(aux_config)
sig { params(config: AuxTableConfig).void }
def setup_attribute_getter_setter_hooks!(config)
%i[
_read_attribute
read_attribute
@@ -250,8 +246,8 @@ module HasAuxTable
method = self.instance_method(method_name)
self.define_method(method_name) do |name, *args, **kwargs, &block|
T.bind(self, ActiveRecord::Base)
if aux_config.is_aux_column?(name)
target = aux_config.ensure_aux_target(self)
if config.aux.column_names.include?(name)
target = config.aux_model_for(self)
T.unsafe(target).send(method_name, name, *args, **kwargs, &block)
else
T.unsafe(method).bind(self).call(name, *args, **kwargs, &block)
@@ -260,46 +256,45 @@ module HasAuxTable
end
end
sig { params(aux_config: AuxTableConfig).void }
def setup_initialize_hook!(aux_config)
sig { params(config: AuxTableConfig).void }
def setup_initialize_hook!(config)
initialize_method = self.instance_method(:initialize)
self.define_method(:initialize) do |args, **kwargs, &block|
T.bind(self, ActiveRecord::Base)
aux_args, main_args =
args.partition { |k, _| aux_config.is_aux_column?(k) }.map(&:to_h)
main_args, aux_args = config.aux.partition_by_columns(args)
initialize_method.bind(self).call(main_args, **kwargs, &block)
aux_config.assign_aux_attributes(self, aux_args)
config.aux_model_for(self).assign_attributes(aux_args)
end
end
sig { params(aux_config: AuxTableConfig).void }
def setup_save_hook!(aux_config)
sig { params(config: AuxTableConfig).void }
def setup_save_hook!(config)
%i[save save!].each do |method_name|
save_method = self.instance_method(method_name)
self.define_method(method_name) do |*args, **kwargs|
self.define_method(method_name) do |*args, **kwargs, &block|
T.bind(self, ActiveRecord::Base)
result = save_method.bind(self).call(*args, **kwargs)
result = save_method.bind(self).call(*args, **kwargs, &block)
result &&=
self
.association(aux_config.aux_association_name)
.association(config.aux_association_name)
.target
.send(method_name, *args, **kwargs)
.send(method_name, *args, **kwargs, &block)
result
end
end
end
sig { params(aux_config: AuxTableConfig).void }
def setup_reload_hook!(aux_config)
sig { params(config: AuxTableConfig).void }
def setup_reload_hook!(config)
self.define_method(:reload) do |*args|
T.bind(self, ActiveRecord::Base)
aux_model = aux_config.ensure_aux_target(self)
aux_model = config.aux_model_for(self)
fresh_model = self.class.find(id)
@attributes = fresh_model.instance_variable_get(:@attributes)
aux_model.instance_variable_set(
:@attributes,
fresh_model
.association(aux_config.aux_association_name)
.association(config.aux_association_name)
.target
.instance_variable_get(:@attributes)
)
@@ -307,13 +302,14 @@ module HasAuxTable
end
end
sig { params(aux_config: AuxTableConfig).void }
def setup_attributes_hook!(aux_config)
sig { params(config: AuxTableConfig).void }
def setup_attributes_hook!(config)
attributes_method = self.instance_method(:attributes)
self.define_method(:attributes) do |*args|
T.bind(self, ActiveRecord::Base)
ret = attributes_method.bind(self).call(*args)
ret.merge!(aux_config.aux_attributes(self))
target = config.aux_model_for(self)
ret.merge!(config.aux.attributes_on(target))
ret
end
end

View File

@@ -2,6 +2,114 @@
# frozen_string_literal: true
module HasAuxTable
class ModelClassHelper < T::Struct
extend T::Sig
const :klass, T.class_of(ActiveRecord::Base)
const :rejected_column_names, T::Set[String]
sig { params(name: T.any(String, Symbol)).returns(T::Boolean) }
def is_column?(name)
column_names.include?(name.to_s)
end
sig { returns(T::Array[String]) }
def column_names
@column_names ||=
T.let(
begin
klass
.column_names
.reject { |col| rejected_column_names.include?(col.to_s) }
.map(&:to_s)
end,
T.nilable(T::Array[String])
)
end
sig { returns(T::Hash[String, ActiveRecord::ConnectionAdapters::Column]) }
def columns_hash
@columns_hash ||=
T.let(
slice_by_columns(klass.columns_hash),
T.nilable(T::Hash[String, ActiveRecord::ConnectionAdapters::Column])
)
end
sig { returns(T::Hash[String, ActiveModel::Type]) }
def attribute_types
@attribute_types ||=
T.let(
slice_by_columns(klass.attribute_types),
T.nilable(T::Hash[String, ActiveModel::Type])
)
end
sig { returns(T::Hash[String, ActiveModel::Attribute]) }
def default_attributes
@default_attributes ||=
T.let(
begin
da = klass._default_attributes
da.keys.map { |k, v| [k, da[k]] }.to_h.slice(*self.column_names)
end,
T.nilable(T::Hash[String, ActiveModel::Attribute])
)
end
sig do
params(instance: ActiveRecord::Base).returns(T::Hash[String, T.untyped])
end
def attributes_on(instance)
Util.ensure_is_instance_of!(instance, self.klass)
unless instance.class <= self.klass
raise("#{instance.class.name} not a #{self.klass.name}")
end
slice_by_columns(instance.attributes)
end
sig do
type_parameters(:K, :T)
.params(
hash:
T::Hash[
T.all(T.type_parameter(:K), T.any(String, Symbol)),
T.type_parameter(:T)
]
)
.returns(
[
T::Hash[
T.all(T.type_parameter(:K), T.any(String, Symbol)),
T.type_parameter(:T)
],
T::Hash[
T.all(T.type_parameter(:K), T.any(String, Symbol)),
T.type_parameter(:T)
]
]
)
end
def partition_by_columns(hash)
a, b =
hash
.partition { |k, _| !self.column_names.include?(k.to_s) }
.map(&:to_h)
[T.must(a), T.must(b)]
end
private
sig do
type_parameters(:T)
.params(hash: T::Hash[String, T.type_parameter(:T)])
.returns(T::Hash[String, T.type_parameter(:T)])
end
def slice_by_columns(hash)
T.unsafe(hash).slice(*self.column_names)
end
end
class AuxTableConfig < T::Struct
extend T::Sig
@@ -9,21 +117,48 @@ module HasAuxTable
const :aux_association_name, Symbol
const :main_association_name, Symbol
const :main_class, T.class_of(ActiveRecord::Base)
const :model_class, T.class_of(ActiveRecord::Base)
const :aux_class, T.class_of(ActiveRecord::Base)
const :foreign_key, KeyType
const :primary_key, KeyType
sig { void }
def load_aux_schema
model_class.load_schema
aux_class.load_schema
end
sig { params(main_model: ActiveRecord::Base).returns(ActiveRecord::Base) }
def ensure_aux_target(main_model)
aux_association = main_model.association(self.aux_association_name)
sig { returns(ModelClassHelper) }
def aux
@aux ||=
T.let(
ModelClassHelper.new(
klass: self.aux_class,
rejected_column_names: self.aux_rejected_column_names.to_set
),
T.nilable(ModelClassHelper)
)
end
sig { returns(ModelClassHelper) }
def main
@main ||=
T.let(
ModelClassHelper.new(
klass: self.main_class,
rejected_column_names: Set.new
),
T.nilable(ModelClassHelper)
)
end
sig do
params(main_instance: ActiveRecord::Base).returns(ActiveRecord::Base)
end
def aux_model_for(main_instance)
Util.ensure_is_instance_of!(main_instance, main_class)
aux_association = main_instance.association(self.aux_association_name)
aux_association.target ||=
(
if main_model.persisted?
if main_instance.persisted?
aux_association.load_target || aux_association.build
else
aux_association.build
@@ -43,25 +178,20 @@ module HasAuxTable
).returns(Arel::Nodes::Node)
end
def aux_bind_attribute(name, value, &block)
arel_attr = model_class.arel_table[name]
arel_attr = aux_class.arel_table[name]
aux_bind =
model_class.predicate_builder.build_bind_attribute(
arel_attr.name,
value
)
aux_class.predicate_builder.build_bind_attribute(arel_attr.name, value)
block.call(arel_attr, aux_bind)
end
sig { params(method_name: Symbol).void }
def define_aux_attribute_delegate(method_name)
aux_config = self
aux_config
.main_class
.define_method(method_name) do |*args, **kwargs|
T.bind(self, ActiveRecord::Base)
aux_model = aux_config.ensure_aux_target(self)
T.unsafe(aux_model).public_send(method_name, *args, **kwargs)
end
config = self
main_class.define_method(method_name) do |*args, **kwargs, &block|
T.bind(self, ActiveRecord::Base)
aux_model = config.aux_model_for(self)
T.unsafe(aux_model).public_send(method_name, *args, **kwargs, &block)
end
end
sig do
@@ -72,7 +202,7 @@ module HasAuxTable
end
def apply_split_conditions!(relation, conditions)
main_conditions, aux_conditions =
self.partition_by_aux_columns(conditions)
self.aux.partition_by_columns(conditions)
relation = relation.where(main_conditions) if main_conditions.any?
if aux_conditions.any?
relation = relation.where(aux_association_name => aux_conditions)
@@ -86,66 +216,22 @@ module HasAuxTable
)
end
def remap_conditions(conditions)
main, aux = partition_by_aux_columns(conditions)
main.merge!(aux_association_name => aux) if aux.any?
main
end
sig do
params(
main_model: ActiveRecord::Base,
aux_args: T::Hash[Symbol, T.untyped]
).void
end
def assign_aux_attributes(main_model, aux_args)
aux_model = self.ensure_aux_target(main_model)
aux_model.assign_attributes(aux_args)
end
sig do
params(main_model: ActiveRecord::Base).returns(T::Hash[Symbol, T.untyped])
end
def aux_attributes(main_model)
aux_model = self.ensure_aux_target(main_model)
aux_model.attributes.slice(*self.aux_column_names)
end
sig { returns(T::Array[String]) }
def aux_column_names
@aux_column_names ||=
T.let(
begin
rejected_columns = [
self.foreign_key,
self.primary_key,
"created_at",
"updated_at"
].flatten.map(&:to_s)
model_class
.column_names
.reject { |col| rejected_columns.include?(col.to_s) }
.map(&:to_s)
end,
T.nilable(T::Array[String])
)
end
sig { params(name: T.any(Symbol, String)).returns(T::Boolean) }
def is_aux_column?(name)
aux_column_names.include?(name.to_s)
main_conds, aux_conds = aux.partition_by_columns(conditions)
main_conds.merge!(aux_association_name => aux_conds) if aux_conds.any?
main_conds
end
private
sig do
params(hash: T::Hash[String, T.untyped]).returns(
[T::Hash[String, T.untyped], T::Hash[String, T.untyped]]
)
end
def partition_by_aux_columns(hash)
a, b, _ = hash.partition { |k, _| !self.is_aux_column?(k) }.map(&:to_h)
[T.must(a), T.must(b)]
sig { returns(T::Set[String]) }
def aux_rejected_column_names
@aux_rejected_column_names ||=
T.let(
[foreign_key, primary_key, "created_at", "updated_at"].flatten
.map(&:to_s)
.to_set,
T.nilable(T::Set[String])
)
end
end
end

View File

@@ -73,7 +73,7 @@ module HasAuxTable
:bind_attribute,
true
) do |original, name, value, &block|
if aux_config.is_aux_column?(name)
if aux_config.aux.is_column?(name)
aux_config.aux_bind_attribute(name, value, &block)
else
original.call(name, value, &block)

View File

@@ -51,5 +51,19 @@ module HasAuxTable
main_class_attributes
end
sig do
type_parameters(:T)
.params(
instance: T.all(T.type_parameter(:T), Object),
klass: T::Class[T.type_parameter(:T)]
)
.void
end
def self.ensure_is_instance_of!(instance, klass)
unless instance.class <= klass
Kernel.raise("#{instance.class.name} not a #{klass.name}")
end
end
end
end