diff --git a/lib/has_aux_table.rb b/lib/has_aux_table.rb index ba66f39..4ae06e4 100644 --- a/lib/has_aux_table.rb +++ b/lib/has_aux_table.rb @@ -12,6 +12,7 @@ require_relative "has_aux_table/version" require_relative "has_aux_table/key_type" require_relative "has_aux_table/util" require_relative "has_aux_table/relation_extensions" +require_relative "has_aux_table/model_class_helper" require_relative "has_aux_table/aux_table_config" require_relative "has_aux_table/migration_extensions" @@ -76,88 +77,73 @@ module HasAuxTable config end + sig { params(aux_name: Symbol).returns(T.nilable(AuxTableConfig)) } + def aux_table_config(aux_name) + @aux_table_configs&.[](aux_name) + end + private # Generate auxiliary model class dynamically - sig do - params( - aux_name: Symbol, - foreign_key: KeyType, - primary_key: KeyType - ).returns(AuxTableConfig) - end - def generate_aux_config( - aux_name, - # The column on the aux table that points to the main table - foreign_key: :base_table_id, - primary_key: self.primary_key - ) - base_table = self.table_name - aux_table_name = :"#{base_table}_#{aux_name}_aux" - - # Generate class name (e.g., :car_aux => "CarAux") - aux_class_name = aux_table_name.to_s.camelize - aux_association_name = aux_table_name.to_s.singularize.to_sym - - # Ensure the class name doesn't conflict with existing constants - if Object.const_defined?(aux_class_name) - Object.send(:remove_const, aux_class_name) - end - - # Get the current class for the association + sig { params(aux_name: Symbol).returns(AuxTableConfig) } + def generate_aux_config(aux_name) main_class = T.cast(self, T.class_of(ActiveRecord::Base)) - main_association_name = foreign_key.to_s.delete_suffix("_id").to_sym + main_table = main_class.table_name - # Create the auxiliary model class + aux_table_name = :"#{main_table}_#{aux_name}_aux" + aux_class_name = aux_table_name.to_s.camelize + aux_association_name = :"#{aux_name}_aux" aux_class = Class.new(ActiveRecord::Base) do self.table_name = aux_table_name.to_s - self.primary_key = foreign_key - - # Define the association back to the specific STI subclass - # Foreign key points to base STI table (e.g., vehicle_id) - # But association is to the specific subclass (e.g., Car) - self.belongs_to( - main_association_name, - class_name: main_class.name, - foreign_key:, - primary_key:, - inverse_of: aux_association_name - ) + self.primary_key = :base_table_id end + if Object.const_defined?(aux_class_name) + Object.send(:remove_const, aux_class_name) + end + Object.const_set(aux_class_name, aux_class) + + aux_table_config = + AuxTableConfig.from_models( + main_class:, + aux_class:, + aux_association_name: + ) + + # Define the association back to the specific STI subclass + # Foreign key points to base STI table (e.g., vehicle_id) + # But association is to the specific subclass (e.g., Car) + aux_class.belongs_to( + :main, + class_name: main_class.name, + foreign_key: aux_class.primary_key, + primary_key: main_class.primary_key, + inverse_of: aux_association_name + ) + # set up has_one association to the auxiliary table self.has_one( aux_association_name, class_name: aux_class_name, - foreign_key:, - primary_key:, - inverse_of: main_association_name, + foreign_key: aux_class.primary_key, + primary_key: main_class.primary_key, + inverse_of: :main, dependent: :destroy ) # so the aux table is joined against the main table self.default_scope { eager_load(aux_association_name) } - # Set the constant to make the class accessible - Object.const_set(aux_class_name, aux_class) - - AuxTableConfig.new( - aux_table_name:, - aux_class:, - main_class:, - aux_association_name:, - main_association_name:, - foreign_key:, - primary_key: - ) + aux_table_config end sig { params(config: AuxTableConfig).void } def setup_attribute_types_hook!(config) - original_method = config.main_class.method(:attribute_types) + original_method = config.main.klass.method(:attribute_types) config - .main_class + .main + .klass .define_singleton_method(:attribute_types) do @aux_config_attribute_types_cache ||= T.let( @@ -189,7 +175,7 @@ module HasAuxTable end end - config.main_class.attributes_for_inspect = + config.main.klass.attributes_for_inspect = Util.attributes_for_inspect(config) end @@ -282,7 +268,7 @@ module HasAuxTable self.define_method(:initialize) do |*args, **kwargs, &block| T.bind(self, ActiveRecord::Base) if args && args.size == 1 && (arg = args.first).is_a?(Hash) - main_args, aux_args = config.aux.partition_by_columns(args.first) + main_args, aux_args = config.partition_by_columns(args.first) initialize_method.bind(self).call(main_args, **kwargs, &block) config.aux_model_for(self).assign_attributes(aux_args) else diff --git a/lib/has_aux_table/aux_table_config.rb b/lib/has_aux_table/aux_table_config.rb index 17c70d5..37d492d 100644 --- a/lib/has_aux_table/aux_table_config.rb +++ b/lib/has_aux_table/aux_table_config.rb @@ -2,159 +2,55 @@ # frozen_string_literal: true module HasAuxTable - class ModelClassHelper < T::Struct - extend T::Sig - - const :klass, T.class_of(ActiveRecord::Base) - const :rejected_column_names, T::Set[String] - - sig { params(name: T.any(String, Symbol)).returns(T::Boolean) } - def is_column?(name) - column_names.include?(name.to_s) - end - - sig { returns(T::Array[String]) } - def column_names - @column_names ||= - T.let( - begin - klass - .column_names - .reject { |col| rejected_column_names.include?(col.to_s) } - .map(&:to_s) - end, - T.nilable(T::Array[String]) - ) - end - - sig { returns(T::Hash[String, ActiveRecord::ConnectionAdapters::Column]) } - def columns_hash - @columns_hash ||= - T.let( - slice_by_columns(klass.columns_hash), - T.nilable(T::Hash[String, ActiveRecord::ConnectionAdapters::Column]) - ) - end - - sig { returns(T::Hash[String, ActiveModel::Type]) } - def attribute_types - @attribute_types ||= - T.let( - slice_by_columns(klass.attribute_types), - T.nilable(T::Hash[String, ActiveModel::Type]) - ) - end - - sig { returns(T::Hash[String, ActiveModel::Attribute]) } - def default_attributes - @default_attributes ||= - T.let( - begin - da = klass._default_attributes - da.keys.map { |k, v| [k, da[k]] }.to_h.slice(*self.column_names) - end, - T.nilable(T::Hash[String, ActiveModel::Attribute]) - ) - end - - sig do - params(instance: ActiveRecord::Base).returns(T::Hash[String, T.untyped]) - end - def attributes_on(instance) - Util.ensure_is_instance_of!(instance, self.klass) - unless instance.class <= self.klass - raise("#{instance.class.name} not a #{self.klass.name}") - end - slice_by_columns(instance.attributes) - end - - sig do - type_parameters(:K, :T) - .params( - hash: - T::Hash[ - T.all(T.type_parameter(:K), T.any(String, Symbol)), - T.type_parameter(:T) - ] - ) - .returns( - [ - T::Hash[ - T.all(T.type_parameter(:K), T.any(String, Symbol)), - T.type_parameter(:T) - ], - T::Hash[ - T.all(T.type_parameter(:K), T.any(String, Symbol)), - T.type_parameter(:T) - ] - ] - ) - end - def partition_by_columns(hash) - a, b = - hash - .partition { |k, _| !self.column_names.include?(k.to_s) } - .map(&:to_h) - [T.must(a), T.must(b)] - end - - private - - sig do - type_parameters(:T) - .params(hash: T::Hash[String, T.type_parameter(:T)]) - .returns(T::Hash[String, T.type_parameter(:T)]) - end - def slice_by_columns(hash) - T.unsafe(hash).slice(*self.column_names) - end - end - class AuxTableConfig < T::Struct extend T::Sig const :aux_table_name, Symbol const :aux_association_name, Symbol - const :main_association_name, Symbol - const :main_class, T.class_of(ActiveRecord::Base) - const :aux_class, T.class_of(ActiveRecord::Base) - const :foreign_key, KeyType - const :primary_key, KeyType + const :main, ModelClassHelper + const :aux, ModelClassHelper + + sig do + params( + main_class: T.class_of(ActiveRecord::Base), + aux_class: T.class_of(ActiveRecord::Base), + aux_association_name: Symbol + ).returns(AuxTableConfig) + end + def self.from_models(main_class:, aux_class:, aux_association_name:) + primary_key = aux_class.primary_key + aux_rejected_column_names = [ + primary_key, + "created_at", + "updated_at" + ].flatten.map(&:to_s).to_set + + new( + aux_table_name: aux_class.table_name.to_sym, + aux_association_name:, + main: + ModelClassHelper.new( + klass: main_class, + rejected_column_names: Set.new + ), + aux: + ModelClassHelper.new( + klass: aux_class, + rejected_column_names: aux_rejected_column_names + ) + ) + end sig { returns(T.untyped) } def load_aux_schema - aux_class.load_schema - end - - sig { returns(ModelClassHelper) } - def aux - @aux ||= - T.let( - ModelClassHelper.new( - klass: self.aux_class, - rejected_column_names: self.aux_rejected_column_names.to_set - ), - T.nilable(ModelClassHelper) - ) - end - - sig { returns(ModelClassHelper) } - def main - @main ||= - T.let( - ModelClassHelper.new( - klass: self.main_class, - rejected_column_names: Set.new - ), - T.nilable(ModelClassHelper) - ) + aux.klass.load_schema end sig do params(main_instance: ActiveRecord::Base).returns(ActiveRecord::Base) end def aux_model_for(main_instance) - Util.ensure_is_instance_of!(main_instance, main_class) + Util.ensure_is_instance_of!(main_instance, main.klass) aux_association = main_instance.association(self.aux_association_name) aux_association.target ||= ( @@ -178,22 +74,30 @@ module HasAuxTable ).returns(Arel::Nodes::Node) end def aux_bind_attribute(name, value, &block) - arel_attr = aux_class.arel_table[name] + arel_attr = aux.klass.arel_table[name] aux_bind = - aux_class.predicate_builder.build_bind_attribute(arel_attr.name, value) + aux.klass.predicate_builder.build_bind_attribute(arel_attr.name, value) block.call(arel_attr, aux_bind) end + # Forward method call `method_name` to the aux model sig { params(method_name: Symbol).void } def define_aux_attribute_delegate(method_name) config = self - main_class.define_method(method_name) do |*args, **kwargs, &block| - T.bind(self, ActiveRecord::Base) - aux_model = config.aux_model_for(self) - ret = - T.unsafe(aux_model).public_send(method_name, *args, **kwargs, &block) - ret - end + main + .klass + .define_method(method_name) do |*args, **kwargs, &block| + T.bind(self, ActiveRecord::Base) + aux_model = config.aux_model_for(self) + ret = + T.unsafe(aux_model).public_send( + method_name, + *args, + **kwargs, + &block + ) + ret + end end sig do @@ -202,22 +106,58 @@ module HasAuxTable ) end def remap_conditions(conditions) - main_conds, aux_conds = aux.partition_by_columns(conditions) + main_conds, aux_conds = partition_by_columns(conditions) main_conds.merge!(aux_association_name => aux_conds) if aux_conds.any? main_conds end - private - - sig { returns(T::Set[String]) } - def aux_rejected_column_names - @aux_rejected_column_names ||= - T.let( - [foreign_key, primary_key, "created_at", "updated_at"].flatten - .map(&:to_s) - .to_set, - T.nilable(T::Set[String]) + sig do + type_parameters(:K) + .params( + hash: + T::Hash[ + T.all(T.type_parameter(:K), T.any(String, Symbol)), + T.untyped + ] ) + .returns( + [ + T::Hash[ + T.all(T.type_parameter(:K), T.any(String, Symbol)), + T.untyped + ], + T::Hash[ + T.all(T.type_parameter(:K), T.any(String, Symbol)), + T.untyped + ] + ] + ) + end + def partition_by_columns(hash) + main = {} + aux = {} + hash.each do |k, v| + if self.aux.column_names.include?(k.to_s) + # attribute is a column on the aux table + aux[k] = v + elsif assoc = self.main.klass.reflect_on_association(k.to_s) + # attribute is an association on the main class + fk = assoc.association_foreign_key + if self.aux.column_names.include?(fk) + # the association is a column on the aux table, `v` is + # a model, get the primary key of the model + aux[fk] = v && v.send(assoc.association_primary_key) + else + # association is on the main table, `v` is a model, + main[k] = v + end + else + # attribute is not a column on the aux table or an association, + # assume it's a column on the main table + main[k] = v + end + end + [main, aux] end end end diff --git a/lib/has_aux_table/model_class_helper.rb b/lib/has_aux_table/model_class_helper.rb new file mode 100644 index 0000000..711eeef --- /dev/null +++ b/lib/has_aux_table/model_class_helper.rb @@ -0,0 +1,91 @@ +# typed: strict +# frozen_string_literal: true + +module HasAuxTable + class ModelClassHelper < T::Struct + extend T::Sig + + const :klass, T.class_of(ActiveRecord::Base) + const :rejected_column_names, T::Set[String] + + sig { params(name: T.any(String, Symbol)).returns(T::Boolean) } + def is_column?(name) + column_names.include?(name.to_s) + end + + sig { returns(T::Array[Symbol]) } + def primary_keys + @primary_keys ||= + T.let( + [klass.primary_key].flatten.map(&:to_sym), + T.nilable(T::Array[Symbol]) + ) + end + + sig { returns(T::Array[String]) } + def column_names + @column_names ||= + T.let( + begin + klass + .column_names + .reject { |col| rejected_column_names.include?(col.to_s) } + .map(&:to_s) + end, + T.nilable(T::Array[String]) + ) + end + + sig { returns(T::Hash[String, ActiveRecord::ConnectionAdapters::Column]) } + def columns_hash + @columns_hash ||= + T.let( + slice_by_columns(klass.columns_hash), + T.nilable(T::Hash[String, ActiveRecord::ConnectionAdapters::Column]) + ) + end + + sig { returns(T::Hash[String, ActiveModel::Type]) } + def attribute_types + @attribute_types ||= + T.let( + slice_by_columns(klass.attribute_types), + T.nilable(T::Hash[String, ActiveModel::Type]) + ) + end + + sig { returns(T::Hash[String, ActiveModel::Attribute]) } + def default_attributes + @default_attributes ||= + T.let( + begin + da = klass._default_attributes + da.keys.map { |k, v| [k, da[k]] }.to_h.slice(*self.column_names) + end, + T.nilable(T::Hash[String, ActiveModel::Attribute]) + ) + end + + sig do + params(instance: ActiveRecord::Base).returns(T::Hash[String, T.untyped]) + end + def attributes_on(instance) + Util.ensure_is_instance_of!(instance, self.klass) + unless instance.class <= self.klass + raise("#{instance.class.name} not a #{self.klass.name}") + end + slice_by_columns(instance.attributes) + end + + private + + sig do + type_parameters(:T) + .params(hash: T::Hash[String, T.type_parameter(:T)]) + .returns(T::Hash[String, T.type_parameter(:T)]) + end + def slice_by_columns(hash) + T.unsafe(hash).slice(*self.column_names) + end + end +end diff --git a/lib/has_aux_table/relation_extensions.rb b/lib/has_aux_table/relation_extensions.rb index 80f5c49..c7b08cc 100644 --- a/lib/has_aux_table/relation_extensions.rb +++ b/lib/has_aux_table/relation_extensions.rb @@ -53,9 +53,16 @@ class ActiveRecord::Associations::AssociationScope next unless aux_config = klass.aux_table_for(refl.join_primary_key) aux_table = aux_config.aux.klass.table_name main_table = aux_config.main.klass.table_name - fkey = "'#{aux_table}'.'#{aux_config.foreign_key}'" - pkey = "'#{main_table}'.'#{aux_config.primary_key}'" - scope.joins!("INNER JOIN '#{main_table}' ON #{fkey} = #{pkey}") + main_keys = + aux_config.main.primary_keys.map { |key| "'#{main_table}'.'#{key}'" } + aux_keys = + aux_config.aux.primary_keys.map { |key| "'#{aux_table}'.'#{key}'" } + join_clause = + main_keys + .zip(aux_keys) + .map { |(main_key, aux_key)| "#{main_key} = #{aux_key}" } + .join(" AND ") + scope.joins!("INNER JOIN '#{main_table}' ON (#{join_clause})") end if association.is_a?( ActiveRecord::Associations::HasManyThroughAssociation ) @@ -70,12 +77,7 @@ module HasAuxTable sig { params(aux_config: AuxTableConfig).void } def setup_relation_extensions!(aux_config) - setup_main_class_extensions!(aux_config) - end - - sig { params(aux_config: AuxTableConfig).void } - def setup_main_class_extensions!(aux_config) - main_class = aux_config.main_class + main_class = aux_config.main.klass Util.hook_method(main_class, :where, false) do |original, *args| if args.length == 1 && args.first.is_a?(Hash) diff --git a/lib/has_aux_table/util.rb b/lib/has_aux_table/util.rb index db97a3d..d60598b 100644 --- a/lib/has_aux_table/util.rb +++ b/lib/has_aux_table/util.rb @@ -40,7 +40,7 @@ module HasAuxTable params(aux_config: HasAuxTable::AuxTableConfig).returns(T::Array[String]) end def self.attributes_for_inspect(aux_config) - main_class = aux_config.main_class + main_class = aux_config.main.klass main_class_attributes = if main_class.attributes_for_inspect == :all diff --git a/spec/aux_table_config_spec.rb b/spec/aux_table_config_spec.rb new file mode 100644 index 0000000..bbb9454 --- /dev/null +++ b/spec/aux_table_config_spec.rb @@ -0,0 +1,45 @@ +# typed: false +# frozen_string_literal: true + +RSpec.describe HasAuxTable::AuxTableConfig do + # const :aux_table_name, Symbol + # const :aux_association_name, Symbol + # const :main_association_name, Symbol + # const :main_class, T.class_of(ActiveRecord::Base) + # const :aux_class, T.class_of(ActiveRecord::Base) + # const :foreign_key, KeyType + # const :primary_key, KeyType + + it "identifies columns on the aux table" do + driver_aux_config = Driver.aux_table_config(:driver) + expect(driver_aux_config.aux.column_names).to contain_exactly( + "license_number", + "car_id" + ) + + expect(driver_aux_config.main.column_names).to contain_exactly( + "type", + "id", + "name", + "created_at", + "updated_at" + ) + end + + describe "#remap_conditions" do + it "works with simple conditions" do + driver_aux_config = Driver.aux_table_config(:driver) + conditions = { name: "John Doe", car_id: 1 } + conditions = driver_aux_config.remap_conditions(conditions) + expect(conditions).to eq({ name: "John Doe", driver_aux: { car_id: 1 } }) + end + + it "partitions columns referring to associations" do + car = Car.create!(name: "Toyota Prius") + driver_aux_config = Driver.aux_table_config(:driver) + conditions = { car: car } + conditions = driver_aux_config.remap_conditions(conditions) + expect(conditions).to eq({ driver_aux: { "car_id" => car.id } }) + end + end +end diff --git a/spec/active_record/has_aux_table_spec.rb b/spec/has_aux_table_spec.rb similarity index 97% rename from spec/active_record/has_aux_table_spec.rb rename to spec/has_aux_table_spec.rb index 98b7fd7..c06def9 100644 --- a/spec/active_record/has_aux_table_spec.rb +++ b/spec/has_aux_table_spec.rb @@ -714,7 +714,7 @@ RSpec.describe HasAuxTable do end end - describe "nested associations" do + describe "associations" do it "can create a driver through the association" do driver = @car.drivers.create!(name: "John Doe", license_number: 123_456) expect(driver.car).to eq(@car) @@ -764,6 +764,26 @@ RSpec.describe HasAuxTable do d = drivers.find_by(license_number: 123_456) expect(d.id).to eq(driver.id) end + + it "can have the association queried when fk is on the main table" do + lot = VehicleLot.create!(name: "Lot 1") + nolot_car = @car + lot_car = Car.create!(name: "Car 1", vehicle_lot: lot) + + expect(Car.where(vehicle_lot: lot)).to eq([lot_car]) + expect(Car.where(vehicle_lot: nil)).to eq([nolot_car]) + end + + it "can have the association queried when fk is on the aux table" do + driver1 = + Driver.create!(name: "John Doe", license_number: 123, car: @car) + driver2 = Driver.create!(name: "Jane Goodall", license_number: 456) + nodriver_car = Car.create!(name: "No Driver Car") + + expect(Driver.where(car: @car)).to eq([driver1]) + expect(Driver.where(car: nil)).to eq([driver2]) + expect(Driver.where(car: nodriver_car)).to eq([]) + end end describe "#reload" do