diff --git a/lib/liquid/errors.rb b/lib/liquid/errors.rb index 762828c..6dcd05e 100644 --- a/lib/liquid/errors.rb +++ b/lib/liquid/errors.rb @@ -59,4 +59,5 @@ module Liquid UndefinedVariable = Class.new(Error) UndefinedDropMethod = Class.new(Error) UndefinedFilter = Class.new(Error) + MethodOverrideError = Class.new(Error) end diff --git a/lib/liquid/strainer.rb b/lib/liquid/strainer.rb index c79a586..7ecb31a 100644 --- a/lib/liquid/strainer.rb +++ b/lib/liquid/strainer.rb @@ -28,8 +28,13 @@ module Liquid def self.add_filter(filter) raise ArgumentError, "Expected module but got: #{filter.class}" unless filter.is_a?(Module) unless self.class.include?(filter) - send(:include, filter) - @filter_methods.merge(filter.public_instance_methods.map(&:to_s)) + invokable_private_methods = filter.private_instance_methods.select { |m| invokable?(m) } + if invokable_private_methods.any? + raise MethodOverrideError, "Filter overrides registered public methods as private: #{invokable_private_methods.join(', ')}" + else + send(:include, filter) + @filter_methods.merge(filter.public_instance_methods.map(&:to_s)) + end end end diff --git a/test/unit/strainer_unit_test.rb b/test/unit/strainer_unit_test.rb index d06d30a..ef32c5f 100644 --- a/test/unit/strainer_unit_test.rb +++ b/test/unit/strainer_unit_test.rb @@ -87,4 +87,33 @@ class StrainerUnitTest < Minitest::Test s.class.add_filter(wrong_filter) end end + + module PrivateMethodOverrideFilter + private + + def public_filter + "overriden as private" + end + end + + def test_add_filter_raises_when_module_privately_overrides_registered_public_methods + strainer = Context.new.strainer + + error = assert_raises(Liquid::MethodOverrideError) do + strainer.class.add_filter(PrivateMethodOverrideFilter) + end + assert_equal 'Liquid error: Filter overrides registered public methods as private: public_filter', error.message + end + + module PublicMethodOverrideFilter + def public_filter + "public" + end + end + + def test_add_filter_does_not_raise_when_module_overrides_previously_registered_method + strainer = Context.new.strainer + strainer.class.add_filter(PublicMethodOverrideFilter) + assert strainer.class.filter_methods.include?('public_filter') + end end # StrainerTest