query by association values

This commit is contained in:
Dylan Knutson
2025-07-24 04:02:31 +00:00
parent a4c9c597e3
commit 59b11c336f
7 changed files with 313 additions and 229 deletions

View File

@@ -12,6 +12,7 @@ require_relative "has_aux_table/version"
require_relative "has_aux_table/key_type" require_relative "has_aux_table/key_type"
require_relative "has_aux_table/util" require_relative "has_aux_table/util"
require_relative "has_aux_table/relation_extensions" require_relative "has_aux_table/relation_extensions"
require_relative "has_aux_table/model_class_helper"
require_relative "has_aux_table/aux_table_config" require_relative "has_aux_table/aux_table_config"
require_relative "has_aux_table/migration_extensions" require_relative "has_aux_table/migration_extensions"
@@ -76,88 +77,73 @@ module HasAuxTable
config config
end end
sig { params(aux_name: Symbol).returns(T.nilable(AuxTableConfig)) }
def aux_table_config(aux_name)
@aux_table_configs&.[](aux_name)
end
private private
# Generate auxiliary model class dynamically # Generate auxiliary model class dynamically
sig do sig { params(aux_name: Symbol).returns(AuxTableConfig) }
params( def generate_aux_config(aux_name)
aux_name: Symbol,
foreign_key: KeyType,
primary_key: KeyType
).returns(AuxTableConfig)
end
def generate_aux_config(
aux_name,
# The column on the aux table that points to the main table
foreign_key: :base_table_id,
primary_key: self.primary_key
)
base_table = self.table_name
aux_table_name = :"#{base_table}_#{aux_name}_aux"
# Generate class name (e.g., :car_aux => "CarAux")
aux_class_name = aux_table_name.to_s.camelize
aux_association_name = aux_table_name.to_s.singularize.to_sym
# Ensure the class name doesn't conflict with existing constants
if Object.const_defined?(aux_class_name)
Object.send(:remove_const, aux_class_name)
end
# Get the current class for the association
main_class = T.cast(self, T.class_of(ActiveRecord::Base)) main_class = T.cast(self, T.class_of(ActiveRecord::Base))
main_association_name = foreign_key.to_s.delete_suffix("_id").to_sym main_table = main_class.table_name
# Create the auxiliary model class aux_table_name = :"#{main_table}_#{aux_name}_aux"
aux_class_name = aux_table_name.to_s.camelize
aux_association_name = :"#{aux_name}_aux"
aux_class = aux_class =
Class.new(ActiveRecord::Base) do Class.new(ActiveRecord::Base) do
self.table_name = aux_table_name.to_s self.table_name = aux_table_name.to_s
self.primary_key = foreign_key self.primary_key = :base_table_id
# Define the association back to the specific STI subclass
# Foreign key points to base STI table (e.g., vehicle_id)
# But association is to the specific subclass (e.g., Car)
self.belongs_to(
main_association_name,
class_name: main_class.name,
foreign_key:,
primary_key:,
inverse_of: aux_association_name
)
end end
if Object.const_defined?(aux_class_name)
Object.send(:remove_const, aux_class_name)
end
Object.const_set(aux_class_name, aux_class)
aux_table_config =
AuxTableConfig.from_models(
main_class:,
aux_class:,
aux_association_name:
)
# Define the association back to the specific STI subclass
# Foreign key points to base STI table (e.g., vehicle_id)
# But association is to the specific subclass (e.g., Car)
aux_class.belongs_to(
:main,
class_name: main_class.name,
foreign_key: aux_class.primary_key,
primary_key: main_class.primary_key,
inverse_of: aux_association_name
)
# set up has_one association to the auxiliary table # set up has_one association to the auxiliary table
self.has_one( self.has_one(
aux_association_name, aux_association_name,
class_name: aux_class_name, class_name: aux_class_name,
foreign_key:, foreign_key: aux_class.primary_key,
primary_key:, primary_key: main_class.primary_key,
inverse_of: main_association_name, inverse_of: :main,
dependent: :destroy dependent: :destroy
) )
# so the aux table is joined against the main table # so the aux table is joined against the main table
self.default_scope { eager_load(aux_association_name) } self.default_scope { eager_load(aux_association_name) }
# Set the constant to make the class accessible aux_table_config
Object.const_set(aux_class_name, aux_class)
AuxTableConfig.new(
aux_table_name:,
aux_class:,
main_class:,
aux_association_name:,
main_association_name:,
foreign_key:,
primary_key:
)
end end
sig { params(config: AuxTableConfig).void } sig { params(config: AuxTableConfig).void }
def setup_attribute_types_hook!(config) def setup_attribute_types_hook!(config)
original_method = config.main_class.method(:attribute_types) original_method = config.main.klass.method(:attribute_types)
config config
.main_class .main
.klass
.define_singleton_method(:attribute_types) do .define_singleton_method(:attribute_types) do
@aux_config_attribute_types_cache ||= @aux_config_attribute_types_cache ||=
T.let( T.let(
@@ -189,7 +175,7 @@ module HasAuxTable
end end
end end
config.main_class.attributes_for_inspect = config.main.klass.attributes_for_inspect =
Util.attributes_for_inspect(config) Util.attributes_for_inspect(config)
end end
@@ -282,7 +268,7 @@ module HasAuxTable
self.define_method(:initialize) do |*args, **kwargs, &block| self.define_method(:initialize) do |*args, **kwargs, &block|
T.bind(self, ActiveRecord::Base) T.bind(self, ActiveRecord::Base)
if args && args.size == 1 && (arg = args.first).is_a?(Hash) if args && args.size == 1 && (arg = args.first).is_a?(Hash)
main_args, aux_args = config.aux.partition_by_columns(args.first) main_args, aux_args = config.partition_by_columns(args.first)
initialize_method.bind(self).call(main_args, **kwargs, &block) initialize_method.bind(self).call(main_args, **kwargs, &block)
config.aux_model_for(self).assign_attributes(aux_args) config.aux_model_for(self).assign_attributes(aux_args)
else else

View File

@@ -2,159 +2,55 @@
# frozen_string_literal: true # frozen_string_literal: true
module HasAuxTable module HasAuxTable
class ModelClassHelper < T::Struct
extend T::Sig
const :klass, T.class_of(ActiveRecord::Base)
const :rejected_column_names, T::Set[String]
sig { params(name: T.any(String, Symbol)).returns(T::Boolean) }
def is_column?(name)
column_names.include?(name.to_s)
end
sig { returns(T::Array[String]) }
def column_names
@column_names ||=
T.let(
begin
klass
.column_names
.reject { |col| rejected_column_names.include?(col.to_s) }
.map(&:to_s)
end,
T.nilable(T::Array[String])
)
end
sig { returns(T::Hash[String, ActiveRecord::ConnectionAdapters::Column]) }
def columns_hash
@columns_hash ||=
T.let(
slice_by_columns(klass.columns_hash),
T.nilable(T::Hash[String, ActiveRecord::ConnectionAdapters::Column])
)
end
sig { returns(T::Hash[String, ActiveModel::Type]) }
def attribute_types
@attribute_types ||=
T.let(
slice_by_columns(klass.attribute_types),
T.nilable(T::Hash[String, ActiveModel::Type])
)
end
sig { returns(T::Hash[String, ActiveModel::Attribute]) }
def default_attributes
@default_attributes ||=
T.let(
begin
da = klass._default_attributes
da.keys.map { |k, v| [k, da[k]] }.to_h.slice(*self.column_names)
end,
T.nilable(T::Hash[String, ActiveModel::Attribute])
)
end
sig do
params(instance: ActiveRecord::Base).returns(T::Hash[String, T.untyped])
end
def attributes_on(instance)
Util.ensure_is_instance_of!(instance, self.klass)
unless instance.class <= self.klass
raise("#{instance.class.name} not a #{self.klass.name}")
end
slice_by_columns(instance.attributes)
end
sig do
type_parameters(:K, :T)
.params(
hash:
T::Hash[
T.all(T.type_parameter(:K), T.any(String, Symbol)),
T.type_parameter(:T)
]
)
.returns(
[
T::Hash[
T.all(T.type_parameter(:K), T.any(String, Symbol)),
T.type_parameter(:T)
],
T::Hash[
T.all(T.type_parameter(:K), T.any(String, Symbol)),
T.type_parameter(:T)
]
]
)
end
def partition_by_columns(hash)
a, b =
hash
.partition { |k, _| !self.column_names.include?(k.to_s) }
.map(&:to_h)
[T.must(a), T.must(b)]
end
private
sig do
type_parameters(:T)
.params(hash: T::Hash[String, T.type_parameter(:T)])
.returns(T::Hash[String, T.type_parameter(:T)])
end
def slice_by_columns(hash)
T.unsafe(hash).slice(*self.column_names)
end
end
class AuxTableConfig < T::Struct class AuxTableConfig < T::Struct
extend T::Sig extend T::Sig
const :aux_table_name, Symbol const :aux_table_name, Symbol
const :aux_association_name, Symbol const :aux_association_name, Symbol
const :main_association_name, Symbol const :main, ModelClassHelper
const :main_class, T.class_of(ActiveRecord::Base) const :aux, ModelClassHelper
const :aux_class, T.class_of(ActiveRecord::Base)
const :foreign_key, KeyType sig do
const :primary_key, KeyType params(
main_class: T.class_of(ActiveRecord::Base),
aux_class: T.class_of(ActiveRecord::Base),
aux_association_name: Symbol
).returns(AuxTableConfig)
end
def self.from_models(main_class:, aux_class:, aux_association_name:)
primary_key = aux_class.primary_key
aux_rejected_column_names = [
primary_key,
"created_at",
"updated_at"
].flatten.map(&:to_s).to_set
new(
aux_table_name: aux_class.table_name.to_sym,
aux_association_name:,
main:
ModelClassHelper.new(
klass: main_class,
rejected_column_names: Set.new
),
aux:
ModelClassHelper.new(
klass: aux_class,
rejected_column_names: aux_rejected_column_names
)
)
end
sig { returns(T.untyped) } sig { returns(T.untyped) }
def load_aux_schema def load_aux_schema
aux_class.load_schema aux.klass.load_schema
end
sig { returns(ModelClassHelper) }
def aux
@aux ||=
T.let(
ModelClassHelper.new(
klass: self.aux_class,
rejected_column_names: self.aux_rejected_column_names.to_set
),
T.nilable(ModelClassHelper)
)
end
sig { returns(ModelClassHelper) }
def main
@main ||=
T.let(
ModelClassHelper.new(
klass: self.main_class,
rejected_column_names: Set.new
),
T.nilable(ModelClassHelper)
)
end end
sig do sig do
params(main_instance: ActiveRecord::Base).returns(ActiveRecord::Base) params(main_instance: ActiveRecord::Base).returns(ActiveRecord::Base)
end end
def aux_model_for(main_instance) def aux_model_for(main_instance)
Util.ensure_is_instance_of!(main_instance, main_class) Util.ensure_is_instance_of!(main_instance, main.klass)
aux_association = main_instance.association(self.aux_association_name) aux_association = main_instance.association(self.aux_association_name)
aux_association.target ||= aux_association.target ||=
( (
@@ -178,22 +74,30 @@ module HasAuxTable
).returns(Arel::Nodes::Node) ).returns(Arel::Nodes::Node)
end end
def aux_bind_attribute(name, value, &block) def aux_bind_attribute(name, value, &block)
arel_attr = aux_class.arel_table[name] arel_attr = aux.klass.arel_table[name]
aux_bind = aux_bind =
aux_class.predicate_builder.build_bind_attribute(arel_attr.name, value) aux.klass.predicate_builder.build_bind_attribute(arel_attr.name, value)
block.call(arel_attr, aux_bind) block.call(arel_attr, aux_bind)
end end
# Forward method call `method_name` to the aux model
sig { params(method_name: Symbol).void } sig { params(method_name: Symbol).void }
def define_aux_attribute_delegate(method_name) def define_aux_attribute_delegate(method_name)
config = self config = self
main_class.define_method(method_name) do |*args, **kwargs, &block| main
T.bind(self, ActiveRecord::Base) .klass
aux_model = config.aux_model_for(self) .define_method(method_name) do |*args, **kwargs, &block|
ret = T.bind(self, ActiveRecord::Base)
T.unsafe(aux_model).public_send(method_name, *args, **kwargs, &block) aux_model = config.aux_model_for(self)
ret ret =
end T.unsafe(aux_model).public_send(
method_name,
*args,
**kwargs,
&block
)
ret
end
end end
sig do sig do
@@ -202,22 +106,58 @@ module HasAuxTable
) )
end end
def remap_conditions(conditions) def remap_conditions(conditions)
main_conds, aux_conds = aux.partition_by_columns(conditions) main_conds, aux_conds = partition_by_columns(conditions)
main_conds.merge!(aux_association_name => aux_conds) if aux_conds.any? main_conds.merge!(aux_association_name => aux_conds) if aux_conds.any?
main_conds main_conds
end end
private sig do
type_parameters(:K)
sig { returns(T::Set[String]) } .params(
def aux_rejected_column_names hash:
@aux_rejected_column_names ||= T::Hash[
T.let( T.all(T.type_parameter(:K), T.any(String, Symbol)),
[foreign_key, primary_key, "created_at", "updated_at"].flatten T.untyped
.map(&:to_s) ]
.to_set,
T.nilable(T::Set[String])
) )
.returns(
[
T::Hash[
T.all(T.type_parameter(:K), T.any(String, Symbol)),
T.untyped
],
T::Hash[
T.all(T.type_parameter(:K), T.any(String, Symbol)),
T.untyped
]
]
)
end
def partition_by_columns(hash)
main = {}
aux = {}
hash.each do |k, v|
if self.aux.column_names.include?(k.to_s)
# attribute is a column on the aux table
aux[k] = v
elsif assoc = self.main.klass.reflect_on_association(k.to_s)
# attribute is an association on the main class
fk = assoc.association_foreign_key
if self.aux.column_names.include?(fk)
# the association is a column on the aux table, `v` is
# a model, get the primary key of the model
aux[fk] = v && v.send(assoc.association_primary_key)
else
# association is on the main table, `v` is a model,
main[k] = v
end
else
# attribute is not a column on the aux table or an association,
# assume it's a column on the main table
main[k] = v
end
end
[main, aux]
end end
end end
end end

View File

@@ -0,0 +1,91 @@
# typed: strict
# frozen_string_literal: true
module HasAuxTable
class ModelClassHelper < T::Struct
extend T::Sig
const :klass, T.class_of(ActiveRecord::Base)
const :rejected_column_names, T::Set[String]
sig { params(name: T.any(String, Symbol)).returns(T::Boolean) }
def is_column?(name)
column_names.include?(name.to_s)
end
sig { returns(T::Array[Symbol]) }
def primary_keys
@primary_keys ||=
T.let(
[klass.primary_key].flatten.map(&:to_sym),
T.nilable(T::Array[Symbol])
)
end
sig { returns(T::Array[String]) }
def column_names
@column_names ||=
T.let(
begin
klass
.column_names
.reject { |col| rejected_column_names.include?(col.to_s) }
.map(&:to_s)
end,
T.nilable(T::Array[String])
)
end
sig { returns(T::Hash[String, ActiveRecord::ConnectionAdapters::Column]) }
def columns_hash
@columns_hash ||=
T.let(
slice_by_columns(klass.columns_hash),
T.nilable(T::Hash[String, ActiveRecord::ConnectionAdapters::Column])
)
end
sig { returns(T::Hash[String, ActiveModel::Type]) }
def attribute_types
@attribute_types ||=
T.let(
slice_by_columns(klass.attribute_types),
T.nilable(T::Hash[String, ActiveModel::Type])
)
end
sig { returns(T::Hash[String, ActiveModel::Attribute]) }
def default_attributes
@default_attributes ||=
T.let(
begin
da = klass._default_attributes
da.keys.map { |k, v| [k, da[k]] }.to_h.slice(*self.column_names)
end,
T.nilable(T::Hash[String, ActiveModel::Attribute])
)
end
sig do
params(instance: ActiveRecord::Base).returns(T::Hash[String, T.untyped])
end
def attributes_on(instance)
Util.ensure_is_instance_of!(instance, self.klass)
unless instance.class <= self.klass
raise("#{instance.class.name} not a #{self.klass.name}")
end
slice_by_columns(instance.attributes)
end
private
sig do
type_parameters(:T)
.params(hash: T::Hash[String, T.type_parameter(:T)])
.returns(T::Hash[String, T.type_parameter(:T)])
end
def slice_by_columns(hash)
T.unsafe(hash).slice(*self.column_names)
end
end
end

View File

@@ -53,9 +53,16 @@ class ActiveRecord::Associations::AssociationScope
next unless aux_config = klass.aux_table_for(refl.join_primary_key) next unless aux_config = klass.aux_table_for(refl.join_primary_key)
aux_table = aux_config.aux.klass.table_name aux_table = aux_config.aux.klass.table_name
main_table = aux_config.main.klass.table_name main_table = aux_config.main.klass.table_name
fkey = "'#{aux_table}'.'#{aux_config.foreign_key}'" main_keys =
pkey = "'#{main_table}'.'#{aux_config.primary_key}'" aux_config.main.primary_keys.map { |key| "'#{main_table}'.'#{key}'" }
scope.joins!("INNER JOIN '#{main_table}' ON #{fkey} = #{pkey}") aux_keys =
aux_config.aux.primary_keys.map { |key| "'#{aux_table}'.'#{key}'" }
join_clause =
main_keys
.zip(aux_keys)
.map { |(main_key, aux_key)| "#{main_key} = #{aux_key}" }
.join(" AND ")
scope.joins!("INNER JOIN '#{main_table}' ON (#{join_clause})")
end if association.is_a?( end if association.is_a?(
ActiveRecord::Associations::HasManyThroughAssociation ActiveRecord::Associations::HasManyThroughAssociation
) )
@@ -70,12 +77,7 @@ module HasAuxTable
sig { params(aux_config: AuxTableConfig).void } sig { params(aux_config: AuxTableConfig).void }
def setup_relation_extensions!(aux_config) def setup_relation_extensions!(aux_config)
setup_main_class_extensions!(aux_config) main_class = aux_config.main.klass
end
sig { params(aux_config: AuxTableConfig).void }
def setup_main_class_extensions!(aux_config)
main_class = aux_config.main_class
Util.hook_method(main_class, :where, false) do |original, *args| Util.hook_method(main_class, :where, false) do |original, *args|
if args.length == 1 && args.first.is_a?(Hash) if args.length == 1 && args.first.is_a?(Hash)

View File

@@ -40,7 +40,7 @@ module HasAuxTable
params(aux_config: HasAuxTable::AuxTableConfig).returns(T::Array[String]) params(aux_config: HasAuxTable::AuxTableConfig).returns(T::Array[String])
end end
def self.attributes_for_inspect(aux_config) def self.attributes_for_inspect(aux_config)
main_class = aux_config.main_class main_class = aux_config.main.klass
main_class_attributes = main_class_attributes =
if main_class.attributes_for_inspect == :all if main_class.attributes_for_inspect == :all

View File

@@ -0,0 +1,45 @@
# typed: false
# frozen_string_literal: true
RSpec.describe HasAuxTable::AuxTableConfig do
# const :aux_table_name, Symbol
# const :aux_association_name, Symbol
# const :main_association_name, Symbol
# const :main_class, T.class_of(ActiveRecord::Base)
# const :aux_class, T.class_of(ActiveRecord::Base)
# const :foreign_key, KeyType
# const :primary_key, KeyType
it "identifies columns on the aux table" do
driver_aux_config = Driver.aux_table_config(:driver)
expect(driver_aux_config.aux.column_names).to contain_exactly(
"license_number",
"car_id"
)
expect(driver_aux_config.main.column_names).to contain_exactly(
"type",
"id",
"name",
"created_at",
"updated_at"
)
end
describe "#remap_conditions" do
it "works with simple conditions" do
driver_aux_config = Driver.aux_table_config(:driver)
conditions = { name: "John Doe", car_id: 1 }
conditions = driver_aux_config.remap_conditions(conditions)
expect(conditions).to eq({ name: "John Doe", driver_aux: { car_id: 1 } })
end
it "partitions columns referring to associations" do
car = Car.create!(name: "Toyota Prius")
driver_aux_config = Driver.aux_table_config(:driver)
conditions = { car: car }
conditions = driver_aux_config.remap_conditions(conditions)
expect(conditions).to eq({ driver_aux: { "car_id" => car.id } })
end
end
end

View File

@@ -714,7 +714,7 @@ RSpec.describe HasAuxTable do
end end
end end
describe "nested associations" do describe "associations" do
it "can create a driver through the association" do it "can create a driver through the association" do
driver = @car.drivers.create!(name: "John Doe", license_number: 123_456) driver = @car.drivers.create!(name: "John Doe", license_number: 123_456)
expect(driver.car).to eq(@car) expect(driver.car).to eq(@car)
@@ -764,6 +764,26 @@ RSpec.describe HasAuxTable do
d = drivers.find_by(license_number: 123_456) d = drivers.find_by(license_number: 123_456)
expect(d.id).to eq(driver.id) expect(d.id).to eq(driver.id)
end end
it "can have the association queried when fk is on the main table" do
lot = VehicleLot.create!(name: "Lot 1")
nolot_car = @car
lot_car = Car.create!(name: "Car 1", vehicle_lot: lot)
expect(Car.where(vehicle_lot: lot)).to eq([lot_car])
expect(Car.where(vehicle_lot: nil)).to eq([nolot_car])
end
it "can have the association queried when fk is on the aux table" do
driver1 =
Driver.create!(name: "John Doe", license_number: 123, car: @car)
driver2 = Driver.create!(name: "Jane Goodall", license_number: 456)
nodriver_car = Car.create!(name: "No Driver Car")
expect(Driver.where(car: @car)).to eq([driver1])
expect(Driver.where(car: nil)).to eq([driver2])
expect(Driver.where(car: nodriver_car)).to eq([])
end
end end
describe "#reload" do describe "#reload" do