more ar hacking

This commit is contained in:
Dylan Knutson
2025-07-14 05:44:01 +00:00
parent bb5c22b070
commit 30b017906f
9 changed files with 371 additions and 579 deletions

View File

@@ -1,9 +1,12 @@
# typed: strict
# typed: false
# frozen_string_literal: true
require "sorbet-runtime"
require "active_record"
require "active_record/base"
require "active_support"
require "active_support/concern"
require "active_model/attribute_set"
require_relative "aux_table/auto_join_queries"
module ActiveRecord
@@ -15,57 +18,44 @@ module ActiveRecord
extend ActiveSupport::Concern
# Configuration class to store auxiliary table definition
class Configuration
# AuxTable class to store auxiliary table definition
class AuxTableConfig < T::Struct
extend T::Sig
sig { returns(Symbol) }
attr_reader :table_name
const :table_name, Symbol
const :aux_association_name, Symbol
const :main_association_name, Symbol
const :model_class, T.class_of(ActiveRecord::Base)
const :foreign_key, T.any(Symbol, T::Array[Symbol])
const :primary_key, T.any(Symbol, T::Array[Symbol])
sig { returns(T.nilable(Proc)) }
attr_reader :block
sig { returns(T::Array[T.untyped]) }
attr_reader :columns
sig { returns(T::Array[T.untyped]) }
attr_reader :indexes
sig { returns(T.untyped) }
attr_reader :model_class
sig do
params(table_name: T.any(String, Symbol), block: T.nilable(Proc)).void
end
def initialize(table_name, block = nil)
@table_name = T.let(table_name.to_sym, Symbol)
@block = T.let(block, T.nilable(Proc))
@columns = T.let([], T::Array[T.untyped])
@indexes = T.let([], T::Array[T.untyped])
@model_class = T.let(nil, T.untyped)
def load_aux_schema
model_class.load_schema
end
sig { params(model_class: T.untyped).void }
def model_class=(model_class)
@model_class = model_class
def ensure_aux_target(main_model)
aux_association = main_model.association(self.aux_association_name)
aux_association.target ||= aux_association.build
end
sig { returns(T::Hash[Symbol, T.untyped]) }
def to_hash
{
table_name: table_name,
block: block,
columns: columns,
indexes: indexes,
model_class: model_class
}
def define_aux_attribute_delegate(main_model, method_name)
aux_config = self
main_model.define_method(method_name) do |*args|
aux_model = aux_config.ensure_aux_target(self)
aux_model.public_send(method_name, *args)
end
end
def assign_aux_attributes(main_model, aux_args)
aux_model = self.ensure_aux_target(main_model)
aux_model.assign_attributes(aux_args)
end
end
included do
# Initialize aux table configurations for this class
@aux_table_configurations =
T.let({}, T.nilable(T::Hash[Symbol, Configuration]))
T.let({}, T.nilable(T::Hash[Symbol, AuxTableConfig]))
end
module ClassMethods
@@ -73,49 +63,31 @@ module ActiveRecord
include AutoJoinQueries
# Accessor methods for aux table configurations
sig { returns(T::Hash[Symbol, Configuration]) }
sig { returns(T::Hash[Symbol, AuxTableConfig]) }
def aux_table_configurations
@aux_table_configurations ||=
T.let({}, T.nilable(T::Hash[Symbol, Configuration]))
T.let({}, T.nilable(T::Hash[Symbol, AuxTableConfig]))
end
sig { params(value: T::Hash[Symbol, Configuration]).void }
sig { params(value: T::Hash[Symbol, AuxTableConfig]).void }
def aux_table_configurations=(value)
@aux_table_configurations = value
end
# Main DSL method for defining auxiliary tables
sig do
params(
table_name: T.any(String, Symbol),
block: T.nilable(T.proc.void)
).returns(Configuration)
end
def aux_table(table_name, &block)
table_name_sym = table_name.to_sym
sig { params(table_name: T.any(String, Symbol)).returns(AuxTableConfig) }
def aux_table(table_name)
table_name = table_name.to_sym
# Check for duplicate table definitions
if aux_table_configurations.key?(table_name_sym)
if aux_table_configurations.key?(table_name)
Kernel.raise ArgumentError,
"Auxiliary table '#{table_name}' is already defined"
end
# Store the configuration
config = Configuration.new(table_name, block)
aux_table_configurations[table_name_sym] = config
# Generate the auxiliary model class
model_class = generate_aux_model_class(table_name_sym)
config.model_class = model_class
# Hook into schema loading to generate attribute accessors
setup_schema_loading_hook(table_name_sym)
# Set up automatic auxiliary record creation and loading
setup_automatic_aux_record_handling(table_name_sym)
# Set up query extensions for automatic joins
setup_auto_join_queries(table_name_sym)
aux_table_configurations[table_name] = config =
generate_aux_model_class(table_name)
setup_schema_loading_hook(table_name)
setup_auto_join_queries(table_name)
config
end
@@ -123,145 +95,153 @@ module ActiveRecord
# Helper method to get auxiliary table configuration
sig do
params(table_name: T.any(String, Symbol)).returns(
T.nilable(Configuration)
T.nilable(AuxTableConfig)
)
end
def aux_table_configuration(table_name)
aux_table_configurations[table_name.to_sym]
end
# Check if class has auxiliary tables configured
sig { returns(T::Boolean) }
def has_aux_tables?
aux_table_configurations.any?
end
private
# Hook into schema loading to generate attribute accessors when schema is loaded
sig { params(table_name: Symbol).void }
def setup_schema_loading_hook(table_name)
association_name = table_name.to_s.singularize.to_sym
sig { params(aux_table_name: Symbol).void }
def setup_schema_loading_hook(aux_table_name)
aux_config =
aux_table_configurations[aux_table_name] ||
raise("no aux_config for #{aux_table_name}")
# Override load_schema to also generate auxiliary attribute accessors when schema is loaded
original_load_schema = T.unsafe(self).method(:load_schema)
load_schema_method = self.method(:load_schema!)
self.define_singleton_method(:load_schema!) do
# first, load the main and aux table schemas like normal
result = load_schema_method.call
aux_config.load_aux_schema
T
.unsafe(self)
.define_singleton_method(:load_schema) do
# Call the original load_schema method
result = original_load_schema.call
# After schema is loaded, generate auxiliary attribute accessors
aux_config = aux_table_configurations[table_name]
if aux_config && aux_config.model_class
# Force the auxiliary model to load its schema too
aux_config.model_class.load_schema
# Validate no column overlaps between main table and auxiliary table
main_columns = T.unsafe(self).column_names
aux_columns = aux_config.model_class.column_names
# Find overlapping columns (excluding system columns and foreign keys)
overlapping_columns =
aux_columns.select do |col|
main_columns.include?(col) &&
!%w[id created_at updated_at].include?(col) &&
!col.to_s.end_with?("_id")
end
if overlapping_columns.any?
column_list =
overlapping_columns.map { |col| "'#{col}'" }.join(", ")
Kernel.raise ArgumentError,
"Auxiliary table '#{aux_config.model_class.table_name}' defines column(s) #{column_list} " \
"that already exist(s) in main table '#{T.unsafe(self).table_name}'. " \
"Auxiliary table columns must not overlap with main table columns."
end
# Get auxiliary columns (excluding system columns and foreign keys)
aux_columns =
aux_config.model_class.column_names.reject do |col|
%w[id created_at updated_at].include?(col) ||
col.to_s.end_with?("_id")
end
# Generate attribute accessors for each auxiliary column
aux_columns.each do |column_name|
unless T.unsafe(self).method_defined?(column_name)
define_aux_attribute_getter(column_name, association_name)
define_aux_attribute_setter(column_name, association_name)
define_aux_attribute_presence_check(
column_name,
association_name
)
end
end
# `columns_hash` is populated by `load_schema!` so we can use it to
# validate no column overlaps between main table and auxiliary table
main_columns_hash = self.columns_hash
aux_columns_hash =
aux_config.model_class.columns_hash.reject do |col|
%w[id created_at updated_at].include?(col) ||
col == aux_config.foreign_key.to_s
end
result
main_column_names = main_columns_hash.keys
aux_column_names = aux_columns_hash.keys
check_for_overlapping_columns!(
aux_table_name,
main_column_names,
aux_column_names
)
aux_attributes = aux_config.model_class._default_attributes
aux_table_filtered_attributes =
aux_attributes
.keys
.filter_map do |k|
[k, aux_attributes[k]] if aux_column_names.include?(k)
end
.to_h
aux_table_filtered_attributes.each do |name, attr|
@default_attributes[name] = attr
end
end
# Set up automatic auxiliary record creation and loading
sig { params(table_name: Symbol).void }
def setup_automatic_aux_record_handling(table_name)
association_name = table_name.to_s.singularize.to_sym
# Use after_save to ensure aux record is persisted when main record is saved
T
.unsafe(self)
.after_save do
aux_record = T.unsafe(self).send(association_name)
aux_record.save! if aux_record && aux_record.changed?
end
end
# Define getter method for auxiliary attribute
sig { params(column_name: String, association_name: Symbol).void }
def define_aux_attribute_getter(column_name, association_name)
T
.unsafe(self)
.define_method(column_name) do
aux_record = T.unsafe(self).send(association_name)
aux_record&.send(column_name)
end
end
# Define setter method for auxiliary attribute
sig { params(column_name: String, association_name: Symbol).void }
def define_aux_attribute_setter(column_name, association_name)
T
.unsafe(self)
.define_method("#{column_name}=") do |value|
# Ensure auxiliary record exists (should exist due to automatic creation)
aux_record = T.unsafe(self).send(association_name)
unless aux_record
aux_record = T.unsafe(self).send("build_#{association_name}")
# Generate attribute accessors for each auxiliary column
aux_columns_hash.each do |column_name, column|
if self.method_defined?(column_name)
raise "invariant: method #{column_name} already defined"
end
aux_record.send("#{column_name}=", value)
# Save the auxiliary record if the main record is persisted
aux_record.save! if T.unsafe(self).persisted? && aux_record.changed?
aux_config.define_aux_attribute_delegate(self, column_name)
aux_config.define_aux_attribute_delegate(self, "#{column_name}?")
aux_config.define_aux_attribute_delegate(self, "#{column_name}=")
end
%i[_read_attribute read_attribute].each do |method_name|
# override _read_attribute to delegate auxiliary columns to the auxiliary table
read_attribute_method = self.instance_method(method_name)
self.define_method(method_name) do |name|
if aux_columns_hash.include?(name.to_s)
target = aux_config.ensure_aux_target(self)
target.send(method_name, name)
else
read_attribute_method.bind(self).call(name)
end
end
end
initialize_method = self.instance_method(:initialize)
self.define_method(:initialize) do |args|
aux_args, main_args =
args
.partition { |k, _| aux_columns_hash.key?(k.to_s) }
.map(&:to_h)
initialize_method.bind(self).call(main_args)
aux_config.assign_aux_attributes(self, aux_args)
end
# reload_method = self.instance_method(:reload)
self.define_method(:reload) do |*args|
result = nil
aux_model = aux_config.ensure_aux_target(self)
fresh_model = self.class.find(id)
@attributes = fresh_model.instance_variable_get(:@attributes)
aux_model.instance_variable_set(
:@attributes,
fresh_model
.association(aux_config.aux_association_name)
.target
.instance_variable_get(:@attributes)
)
# ActiveRecord::Base.transaction do
# aux_model = aux_config.ensure_aux_target(self)
# result = reload_method.bind(self).call(*args)
# self.send(:"#{aux_config.aux_association_name}=", aux_model)
# end
# fresh_model =
# result
self
end
result
end
end
# Define presence check method for auxiliary attribute
sig { params(column_name: String, association_name: Symbol).void }
def define_aux_attribute_presence_check(column_name, association_name)
T
.unsafe(self)
.define_method("#{column_name}?") do
aux_record = T.unsafe(self).send(association_name)
aux_record&.send(column_name).present?
end
sig do
params(
aux_table_name: Symbol,
main_columns: T::Array[String],
aux_columns: T::Array[String]
).void
end
def check_for_overlapping_columns!(
aux_table_name,
main_columns,
aux_columns
)
# Find overlapping columns (excluding system columns and foreign keys)
overlapping_columns =
aux_columns.select { |col| main_columns.include?(col) }
if overlapping_columns.any?
column_list = overlapping_columns.map { |col| "'#{col}'" }.join(", ")
Kernel.raise ArgumentError,
"Auxiliary table '#{aux_table_name}' defines column(s) #{column_list} " \
"that already exist(s) in main table '#{self.table_name}'. " \
"Auxiliary table columns must not overlap with main table columns."
end
end
# Generate auxiliary model class dynamically
sig { params(table_name: Symbol).returns(T.untyped) }
sig { params(table_name: Symbol).returns(AuxTableConfig) }
def generate_aux_model_class(table_name)
# Generate class name (e.g., :car_aux => "CarAux")
class_name = table_name.to_s.camelize
aux_association_name = table_name.to_s.singularize.to_sym
# Ensure the class name doesn't conflict with existing constants
if Object.const_defined?(class_name)
@@ -270,66 +250,56 @@ module ActiveRecord
# Get the base class name for the foreign key (e.g., Vehicle -> vehicle_id)
# In STI, all subclasses share the same table, so we need the base class
base_class = T.unsafe(self).base_class
base_class = self.base_class
base_class_name = base_class.name.underscore
foreign_key = "#{base_class_name}_id".to_sym
# Get the current class for the association
current_class = T.unsafe(self)
current_class_name = current_class.name.underscore
main_class = self
main_association_name = main_class.name.underscore.to_sym
primary_key = :id
# Create the auxiliary model class
aux_model_class =
model_class =
Class.new(ActiveRecord::Base) do
# Set the table name
T.unsafe(self).table_name = table_name.to_s
self.table_name = table_name.to_s
# Define the association back to the specific STI subclass
# Foreign key points to base STI table (e.g., vehicle_id)
# But association is to the specific subclass (e.g., Car)
T.unsafe(self).belongs_to(
current_class_name.to_sym,
class_name: current_class.name,
foreign_key: "#{base_class_name}_id"
self.belongs_to(
main_association_name,
class_name: main_class.name,
foreign_key:,
primary_key:,
inverse_of: aux_association_name
)
end
# set up has_one association to the auxiliary table
T.unsafe(self).has_one(
table_name.to_s.singularize.to_sym,
self.has_one(
aux_association_name,
class_name: class_name,
foreign_key: "#{base_class_name}_id"
foreign_key:,
primary_key:,
inverse_of: main_association_name,
autosave: true
)
# Set the constant to make the class accessible
Object.const_set(class_name, aux_model_class)
Object.const_set(class_name, model_class)
aux_model_class
AuxTableConfig.new(
table_name:,
model_class:,
aux_association_name:,
main_association_name:,
foreign_key:,
primary_key:
)
end
end
mixes_in_class_methods(ClassMethods)
# Instance methods for working with auxiliary tables
sig { params(table_name: T.any(String, Symbol)).returns(T.untyped) }
def aux_table_record(table_name)
association_name = table_name.to_s.singularize.to_sym
# Get the existing auxiliary record
aux_record = T.unsafe(self).send(association_name)
# If it doesn't exist and this is the correct class and the record is persisted, create it lazily
if aux_record.nil? && T.unsafe(self).persisted? &&
T.unsafe(self).class.aux_table_configurations.key?(table_name.to_sym)
aux_record = T.unsafe(self).send("build_#{association_name}")
aux_record.save!
end
aux_record
end
sig { params(table_name: T.any(String, Symbol)).returns(T.untyped) }
def build_aux_table_record(table_name)
association_name = table_name.to_s.singularize.to_sym
T.unsafe(self).send("build_#{association_name}")
end
end
end

View File

@@ -4,32 +4,18 @@
module ActiveRecord
module AuxTable
module AutoJoinQueries
# Since users explicitly opt-in with `aux_table`, we can provide automatic behavior:
# 1. after_initialize to load aux data (prevents N+1)
# 2. Automatic query handling for aux columns (including chained where)
# 3. Transparent attribute access
def setup_auto_join_queries(table_name)
association_name = table_name.to_s.singularize.to_sym
# Set up automatic loading to prevent N+1 queries
setup_auto_loading(association_name)
# Set up automatic query extensions
setup_query_extensions(association_name)
end
# Helper method to check if model has aux tables
def has_aux_tables?
self.respond_to?(:aux_table_configurations) &&
self.aux_table_configurations.any?
def setup_auto_join_queries(aux_table_name)
association_name = aux_table_name.to_s.singularize.to_sym
ActiveRecord::AuxTable::AutoJoinQueries.setup_query_extensions!(
self,
association_name
)
self
end
# Get all aux column names for this model
def get_aux_column_names
return [] unless has_aux_tables?
config = self.aux_table_configurations.values.first
def get_aux_column_names(aux_table_name)
config = self.aux_table_configurations[aux_table_name]
return [] unless config&.model_class
config.model_class.column_names.reject do |col|
@@ -41,16 +27,11 @@ module ActiveRecord
def split_conditions(conditions, aux_columns)
main_conditions = {}
aux_conditions = {}
main_columns = self.column_names
conditions.each do |key, value|
key_str = key.to_s
if aux_columns.include?(key_str)
if aux_columns.include?(key.to_s)
aux_conditions[key] = value
elsif main_columns.include?(key_str)
main_conditions[key] = value
else
# Unknown column - let ActiveRecord handle the error by putting in main
main_conditions[key] = value
end
end
@@ -58,158 +39,40 @@ module ActiveRecord
[main_conditions, aux_conditions]
end
private
def setup_auto_loading(association_name)
# Auto-loading is now handled via includes in query methods
# No need for after_initialize callback since we use includes
# which properly loads associations
end
def setup_query_extensions(association_name)
# Override class methods to handle aux column queries
model_class = self
# Store original methods
original_where = model_class.method(:where)
original_find_by = model_class.method(:find_by)
original_find = model_class.method(:find)
# Define a helper method for processing where conditions
model_class.define_singleton_method(:process_aux_where) do |*args|
# Only handle hash conditions that might contain aux columns
if args.first.is_a?(Hash) && has_aux_tables?
conditions = args.first
aux_columns = get_aux_column_names
# Check if any aux columns are referenced
if conditions.keys.any? { |key| aux_columns.include?(key.to_s) }
# Split conditions and build query with eager_load
main_conditions, aux_conditions =
split_conditions(conditions, aux_columns)
relation = self.eager_load(association_name)
relation = relation.where(main_conditions) if main_conditions.any?
relation =
relation.where(
association_name => aux_conditions
) if aux_conditions.any?
relation
else
# No aux columns, check for unknown columns and raise error
check_unknown_columns(conditions, original_where)
original_where.call(*args)
end
def self.setup_query_extensions!(on, association_name)
on.define_singleton_method(:where) do |*args|
if args.first.is_a?(Hash)
relation = self.eager_load(association_name)
self.apply_split_conditions!(relation, args)
ActiveRecord::AuxTable::AutoJoinQueries.setup_query_extensions!(
relation,
association_name
)
relation
else
# Not a hash or no aux tables, use original method
original_where.call(*args)
super(*args)
end
end
model_class.define_singleton_method(:where) do |*args|
process_aux_where(*args)
on.define_singleton_method(:find_by) do |*args|
relation = self.eager_load(association_name)
self.apply_split_conditions!(relation, args)
relation.first
end
model_class.define_singleton_method(:find_by) do |*args|
# Handle hash conditions for aux columns
if args.first.is_a?(Hash) && has_aux_tables?
conditions = args.first
aux_columns = get_aux_column_names
if conditions.keys.any? { |key| aux_columns.include?(key.to_s) }
# Use the enhanced where method with eager_load and get first result
self.eager_load(association_name).where(conditions).first
else
# Use eager_load for non-aux queries to preload auxiliary data
self.eager_load(association_name).find_by(*args)
end
else
original_find_by.call(*args)
on.define_singleton_method(:apply_split_conditions!) do |relation, args|
conditions = args.first
aux_columns = self.get_aux_column_names(association_name)
main_conditions, aux_conditions =
self.split_conditions(conditions, aux_columns)
relation.where!(main_conditions) if main_conditions.any?
if aux_conditions.any?
relation.where!(association_name => aux_conditions)
end
end
model_class.define_singleton_method(:find) do |*args|
# Override find to automatically include aux table joins
if has_aux_tables?
# Use eager_load to get both main and auxiliary data in single query
# and properly populate the association
self.eager_load(association_name).find(*args)
else
original_find.call(*args)
end
end
# Also override where on ActiveRecord::Relation for chained calls
setup_relation_where_override(association_name)
end
def setup_relation_where_override(association_name)
# Prepend to ActiveRecord::Relation to handle chained where calls
relation_extension =
Module.new do
define_method(:where) do |*args|
# Check if this relation's model has aux tables and needs processing
if args.first.is_a?(Hash) &&
klass.respond_to?(:has_aux_tables?) && klass.has_aux_tables?
conditions = args.first
aux_columns = klass.get_aux_column_names
# Check if any aux columns are referenced
if conditions.keys.any? { |key| aux_columns.include?(key.to_s) }
# Split conditions and ensure aux includes exists
main_conditions, aux_conditions =
klass.split_conditions(conditions, aux_columns)
# Ensure we have the aux eager_load
relation = self
unless relation.eager_load_values.any? { |eager_load_val|
eager_load_val == association_name
}
relation = relation.eager_load(association_name)
end
# Apply conditions
relation =
relation.where(main_conditions) if main_conditions.any?
relation =
relation.where(
association_name => aux_conditions
) if aux_conditions.any?
relation
else
# No aux columns, use original method
super(*args)
end
else
# Not applicable, use original method
super(*args)
end
end
end
ActiveRecord::Relation.prepend(relation_extension)
end
# Check for unknown columns and raise error
def check_unknown_columns(conditions, original_where = nil)
return unless has_aux_tables?
aux_columns = get_aux_column_names
main_columns = self.column_names
all_valid_columns = (main_columns + aux_columns).to_set
conditions.keys.each do |key|
key_str = key.to_s
unless all_valid_columns.include?(key_str)
# Force execution to trigger error by calling original where with bad column
if original_where
original_where.call(key => conditions[key]).load
else
# Fallback - this will trigger the error when executed
raise ActiveRecord::StatementInvalid, "unknown column: #{key}"
end
return
end
on.define_singleton_method(:find) do |*args|
self.eager_load(association_name).find(*args)
end
end
end