diff --git a/demo_functionality.rb b/demo_functionality.rb index dcd6a7b..497348a 100755 --- a/demo_functionality.rb +++ b/demo_functionality.rb @@ -47,6 +47,11 @@ ActiveRecord::Schema.define do end create_aux_table :posts, :e621 do |t| + t.references :creator, + foreign_key: { + to_table: :users_e621_aux, + primary_key: :base_table_id + } t.integer :e621_id, index: true t.string :md5, index: true end @@ -102,19 +107,30 @@ class E621Post < Post belongs_to :creator, class_name: "E621User", inverse_of: :created_posts end -# ActiveRecord::Base.logger = Logger.new(STDOUT) - fa_user = FaUser.create!(username: "Alice", url_name: "alice") +fa_user_id = fa_user.id +raise if fa_user.id.nil? +raise unless fa_user.persisted? +raise unless fa_user.username == "Alice" +raise unless fa_user.url_name == "alice" -# puts "Does the post exist? #{FaPost.exists?(fa_id: 1)}" -attrs = { - # creator_id: fa_user.id, - fa_id: 12_345, - title: "Test Post", - species: "Cat", - posted_at: 1.day.ago -} -fa_post = fa_user.created_posts.create!(attrs) +fa_user_found = FaUser.find_by(username: "Alice") +raise unless fa_user_found.id == fa_user_id +raise unless fa_user_found.username == "Alice" +raise unless fa_user_found.url_name == "alice" + +fa_user_found = FaUser.find(fa_user_id) +raise unless fa_user_found.id == fa_user_id +raise unless fa_user_found.username == "Alice" +raise unless fa_user_found.url_name == "alice" + +fa_post = + fa_user.created_posts.create!( + fa_id: 12_345, + title: "Test Post", + species: "Cat", + posted_at: 1.day.ago + ) raise unless fa_post.persisted? raise unless fa_post.creator == fa_user raise unless fa_post.creator_id == fa_user.id @@ -124,11 +140,41 @@ raise unless fa_posts_all.size == 1 raise unless fa_posts_all.first.creator == fa_user raise unless fa_posts_all.first.creator_id == fa_user.id -# e621_user = E621User.create!(username: "bob", e621_id: 67_890) -# e621_post = -# E621Post.create!( -# e621_id: 102_938, -# md5: "DEADBEEF" * 4, -# posted_at: 2.weeks.ago -# ) -# e621_user.favorite_posts << e621_post +raise unless FaPost.exists?(fa_id: 12_345) +raise if FaPost.exists?(fa_id: 12_346) + +e621_user = E621User.create!(username: "bob", e621_id: 67_890) +raise unless e621_user.persisted? +raise unless e621_user.username == "bob" +raise unless e621_user.e621_id == 67_890 + +e621_user_found = E621User.find_by(username: "bob") +raise unless e621_user_found.id == e621_user.id +raise unless e621_user_found.username == "bob" +raise unless e621_user_found.e621_id == 67_890 + +e621_post = + e621_user.created_posts.create!( + e621_id: 102_938, + md5: "DEADBEEF" * 4, + posted_at: 2.weeks.ago + ) +raise unless e621_post.persisted? +raise unless e621_post.creator == e621_user +raise unless e621_post.creator_id == e621_user.id + +e621_user.favorite_posts << e621_post +raise unless e621_user.favorite_posts.size == 1 +raise unless e621_user.favorite_posts.first == e621_post +raise unless e621_user.favorite_posts.first.id == e621_post.id + +e621_fav_joins = e621_user.user_post_fav_joins +raise unless e621_fav_joins.size == 1 +raise unless e621_fav_joins.first.user == e621_user +raise unless e621_fav_joins.first.post == e621_post +raise unless e621_fav_joins.first.post_id == e621_post.id + +e621_posts_all = E621Post.all.to_a +raise unless e621_posts_all.size == 1 +raise unless e621_posts_all.first.creator == e621_user +raise unless e621_posts_all.first.creator_id == e621_user.id diff --git a/lib/has_aux_table.rb b/lib/has_aux_table.rb index aeae965..749cc68 100644 --- a/lib/has_aux_table.rb +++ b/lib/has_aux_table.rb @@ -8,6 +8,7 @@ require "active_support" require "active_support/concern" require "active_model/attribute_set" +require_relative "has_aux_table/relation_extensions" require_relative "has_aux_table/aux_table_config" require_relative "has_aux_table/query_extensions" require_relative "has_aux_table/migration_extensions" @@ -22,6 +23,7 @@ module HasAuxTable module ClassMethods extend T::Sig include QueryExtensions + include RelationExtensions # Main DSL method for defining auxiliary tables sig { params(aux_name: T.any(String, Symbol)).returns(AuxTableConfig) } @@ -39,8 +41,9 @@ module HasAuxTable @aux_table_configs[aux_table_name] = aux_config = generate_aux_config(aux_table_name) - setup_schema_loading_hook!(aux_table_name) - setup_query_extensions!(self, aux_config, with_bind_attribute: false) + setup_schema_loading_hook!(aux_config) + # setup_query_extensions!(self, aux_config, with_bind_attribute: false) + setup_relation_extensions!(aux_config) aux_config end @@ -58,11 +61,9 @@ module HasAuxTable private # Hook into schema loading to generate attribute accessors when schema is loaded - sig { params(aux_table_name: Symbol).void } - def setup_schema_loading_hook!(aux_table_name) - aux_config = - @aux_table_configs[aux_table_name] || - raise("no aux_config for #{aux_table_name}") + sig { params(aux_config: AuxTableConfig).void } + def setup_schema_loading_hook!(aux_config) + aux_table_name = aux_config.table_name # Override load_schema to also generate auxiliary attribute accessors when schema is loaded load_schema_method = self.method(:load_schema!) @@ -127,7 +128,12 @@ module HasAuxTable end end - %i[_read_attribute read_attribute write_attribute].each do |method_name| + %i[ + _read_attribute + read_attribute + _write_attribute + write_attribute + ].each do |method_name| read_attribute_method = self.instance_method(method_name) self.define_method(method_name) do |name, *args, **kwargs| if aux_config.is_aux_column?(name) @@ -255,6 +261,7 @@ module HasAuxTable AuxTableConfig.new( table_name:, model_class: aux_class, + main_class:, aux_association_name:, main_association_name:, foreign_key:, diff --git a/lib/has_aux_table/aux_table_config.rb b/lib/has_aux_table/aux_table_config.rb index 816b579..4d2d876 100644 --- a/lib/has_aux_table/aux_table_config.rb +++ b/lib/has_aux_table/aux_table_config.rb @@ -9,6 +9,7 @@ module HasAuxTable const :table_name, Symbol const :aux_association_name, Symbol const :main_association_name, Symbol + const :main_class, T.class_of(ActiveRecord::Base) const :model_class, T.class_of(ActiveRecord::Base) const :foreign_key, T.any(Symbol, T::Array[Symbol]) const :primary_key, T.any(Symbol, T::Array[Symbol]) @@ -65,7 +66,6 @@ module HasAuxTable if aux_conditions.any? relation = relation.where(aux_association_name => aux_conditions) end - puts "conditions: #{main_conditions} / #{aux_conditions}" relation end @@ -78,6 +78,17 @@ module HasAuxTable conditions.partition { |k, _| !self.is_aux_column?(k) }.map(&:to_h) end + sig do + params(conditions: T::Hash[String, T.untyped]).returns( + T::Hash[String, T.untyped] + ) + end + def remap_conditions(conditions) + main, aux = split_conditions(conditions) + main.merge!(aux_association_name => aux) if aux.any? + main + end + sig do params( main_model: ActiveRecord::Base, diff --git a/lib/has_aux_table/query_extensions.rb b/lib/has_aux_table/query_extensions.rb index 3061f2c..452f0db 100644 --- a/lib/has_aux_table/query_extensions.rb +++ b/lib/has_aux_table/query_extensions.rb @@ -4,23 +4,6 @@ module HasAuxTable module QueryExtensions extend T::Sig - - # Split conditions into main table and aux table conditions - def split_conditions(conditions, aux_config) - main_conditions = {} - aux_conditions = {} - - conditions.each do |key, value| - if aux_config.is_aux_column?(key) - aux_conditions[key] = value - else - main_conditions[key] = value - end - end - - [main_conditions, aux_conditions] - end - sig do params( on: T.any(ActiveRecord::Relation, T.class_of(ActiveRecord::Base)), diff --git a/lib/has_aux_table/relation_extensions.rb b/lib/has_aux_table/relation_extensions.rb new file mode 100644 index 0000000..a0cf258 --- /dev/null +++ b/lib/has_aux_table/relation_extensions.rb @@ -0,0 +1,90 @@ +# typed: false +# frozen_string_literal: true + +module HasAuxTable + module RelationExtensions + extend T::Sig + + sig { params(aux_config: AuxTableConfig).void } + def setup_relation_extensions!(aux_config) + setup_main_class_extensions!(aux_config) + end + + def hook_method(target, method_name, is_instance_method, &hook_block) + define_method = + is_instance_method ? :define_method : :define_singleton_method + + target_method = + ( + if is_instance_method + target.instance_method(method_name) + else + target.method(method_name) + end + ) + + target.send(define_method, method_name) do |*args, **kwargs, &block| + method = is_instance_method ? target_method.bind(self) : target_method + hook_block.call(method, *args, **kwargs, &block) + end + end + + sig { params(aux_config: AuxTableConfig).void } + def setup_main_class_extensions!(aux_config) + main_class = aux_config.main_class + + hook_method(main_class, :where, false) do |original, *args| + if args.length == 1 && args.first.is_a?(Hash) + opts_remapped = aux_config.remap_conditions(args.first) + original.call(opts_remapped) + else + original.call(*args) + end + end + + hook_method( + main_class, + :all, + false + ) do |original, *args, **kwargs, &block| + original.call(*args, **kwargs, &block).eager_load( + aux_config.aux_association_name + ) + end + + hook_method(main_class, :unscoped, false) do |original, *args, **kwargs| + original.call(*args, **kwargs).eager_load( + aux_config.aux_association_name + ) + end + + hook_method(main_class, :find, false) do |original, arg| + original.call(arg) + end + + relation_class = + main_class.relation_delegate_class(ActiveRecord::Relation) + + hook_method(relation_class, :where!, true) do |original, opts, *rest| + if opts.is_a?(Hash) + opts_remapped = aux_config.remap_conditions(opts) + original.call(opts_remapped, *rest) + else + original.call(opts, *rest) + end + end + + hook_method( + relation_class, + :bind_attribute, + true + ) do |original, name, value, &block| + if aux_config.is_aux_column?(name) + aux_config.aux_bind_attribute(name, value, &block) + else + original.call(name, value, &block) + end + end + end + end +end diff --git a/spec/active_record/aux_table_spec.rb b/spec/active_record/aux_table_spec.rb index 1d45c90..bf1929a 100644 --- a/spec/active_record/aux_table_spec.rb +++ b/spec/active_record/aux_table_spec.rb @@ -293,15 +293,14 @@ RSpec.describe HasAuxTable do expect(car_names).to eq(["Tesla Model 3", "Toyota Prius"]) end - it "doesn't add joins for queries without auxiliary columns", - skip: true do + it "doesn't add joins for queries without auxiliary columns" do toyota_cars = Car.where(name: "Toyota Prius") expect(toyota_cars.length).to eq(1) expect(toyota_cars.first.name).to eq("Toyota Prius") expect(toyota_cars.first.fuel_type).to eq("hybrid") end - it "works with chained where clauses", skip: true do + it "works with chained where clauses" do # Chain where clauses with auxiliary columns efficient_cars = Car.where(fuel_type: "hybrid").where(engine_size: 1.8) @@ -529,7 +528,7 @@ RSpec.describe HasAuxTable do ) end - it "works with empty where conditions", skip: true do + it "works with empty where conditions" do # Empty where should not cause issues cars = Car.where({}) expect(cars.length).to eq(3) @@ -629,7 +628,7 @@ RSpec.describe HasAuxTable do end describe "nested associations" do - it "can create a driver through the association", skip: true do + it "can create a driver through the association" do driver = @car.drivers.create!(name: "John Doe") expect(driver.car).to eq(@car) expect(driver.car_id).to eq(@car.id) @@ -637,7 +636,7 @@ RSpec.describe HasAuxTable do expect(driver.car.engine_size).to eq(1.5) end - it "can create a driver directly", skip: true do + it "can create a driver directly" do driver = Driver.create!(car: @car, name: "John Doe") expect(driver.car).to eq(@car) expect(driver.car_id).to eq(@car.id) @@ -645,13 +644,13 @@ RSpec.describe HasAuxTable do expect(driver.car.engine_size).to eq(1.5) end - it "can be accessed through the association", skip: true do + it "can be accessed through the association" do driver = @car.drivers.create!(name: "John Doe") expect(@car.drivers).to eq([driver]) end - it "can be destroyed through the association", skip: true do - driver = @car.drivers.build(name: "John Doe") + it "can be destroyed through the association" do + driver = @car.drivers.create!(name: "John Doe") expect { driver.destroy }.to change { @car.drivers.count }.by(-1) end end