From db635bf81176b986c644a71bd237b2c36052cd4f Mon Sep 17 00:00:00 2001 From: Danielle Tomlinson Date: Tue, 11 Dec 2018 17:35:51 +0100 Subject: [PATCH] fixup: Correctly sort based on distance, use iradix for ordering --- acl/acl.go | 36 ++++++++++++++++++++++-------------- acl/acl_test.go | 37 ++++++++++++++++++++++++++++++++++++- 2 files changed, 58 insertions(+), 15 deletions(-) diff --git a/acl/acl.go b/acl/acl.go index fcb46d230..25f111152 100644 --- a/acl/acl.go +++ b/acl/acl.go @@ -48,7 +48,8 @@ type ACL struct { namespaces *iradix.Tree // wildcardNamespaces maps a glob pattern of a namespace to a capabilitySet - wildcardNamespaces map[string]capabilitySet + // We use an iradix for the purposes of ordered iteration. + wildcardNamespaces *iradix.Tree agent string node string @@ -81,7 +82,7 @@ func NewACL(management bool, policies []*Policy) (*ACL, error) { // Create the ACL object acl := &ACL{} nsTxn := iradix.New().Txn() - wns := make(map[string]capabilitySet) + wnsTxn := iradix.New().Txn() for _, policy := range policies { NAMESPACES: @@ -93,12 +94,12 @@ func NewACL(management bool, policies []*Policy) (*ACL, error) { var capabilities capabilitySet if globDefinition { - raw, ok := wns[ns.Name] + raw, ok := wnsTxn.Get([]byte(ns.Name)) if ok { - capabilities = raw + capabilities = raw.(capabilitySet) } else { capabilities = make(capabilitySet) - wns[ns.Name] = capabilities + wnsTxn.Insert([]byte(ns.Name), capabilities) } } else { raw, ok := nsTxn.Get([]byte(ns.Name)) @@ -144,7 +145,7 @@ func NewACL(management bool, policies []*Policy) (*ACL, error) { // Finalize the namespaces acl.namespaces = nsTxn.Commit() - acl.wildcardNamespaces = wns + acl.wildcardNamespaces = wnsTxn.Commit() return acl, nil } @@ -211,7 +212,7 @@ func (a *ACL) matchingCapabilitySet(ns string) (capabilitySet, bool) { type matchingGlob struct { ns string - nsLen int + difference int capabilitySet capabilitySet } @@ -229,13 +230,11 @@ func (a *ACL) findClosestMatchingGlob(ns string) (capabilitySet, bool) { return matchingGlobs[0].capabilitySet, true } - nsLen := len(ns) - // Stable sort the matched globs, based on the character difference between // the glob definition and the requested namespace. This allows us to be // more consistent about results based on the policy definition. sort.SliceStable(matchingGlobs, func(i, j int) bool { - return (matchingGlobs[i].nsLen - nsLen) >= (matchingGlobs[j].nsLen - nsLen) + return matchingGlobs[i].difference <= matchingGlobs[j].difference }) return matchingGlobs[0].capabilitySet, true @@ -244,17 +243,26 @@ func (a *ACL) findClosestMatchingGlob(ns string) (capabilitySet, bool) { func (a *ACL) findAllMatchingWildcards(ns string) []matchingGlob { var matches []matchingGlob - for k, v := range a.wildcardNamespaces { - isMatch := glob.Glob(string(k), ns) + nsLen := len(ns) + + a.wildcardNamespaces.Root().Walk(func(bk []byte, iv interface{}) bool { + k := string(bk) + v := iv.(capabilitySet) + + isMatch := glob.Glob(k, ns) if isMatch { + globLen := len(strings.Replace(k, glob.GLOB, "", -1)) pair := matchingGlob{ ns: k, - nsLen: len(k), + difference: nsLen - globLen, capabilitySet: v, } matches = append(matches, pair) } - } + + // We always want to walk the entire tree, never terminate early. + return false + }) return matches } diff --git a/acl/acl_test.go b/acl/acl_test.go index 9a6c5904f..e1cad58ae 100644 --- a/acl/acl_test.go +++ b/acl/acl_test.go @@ -309,7 +309,7 @@ func TestWildcardNamespaceMatching(t *testing.T) { } } -func TestACL_matchingCapabilitySet(t *testing.T) { +func TestACL_matchingCapabilitySet_returnsAllMatches(t *testing.T) { tests := []struct { Policy string NS string @@ -353,5 +353,40 @@ func TestACL_matchingCapabilitySet(t *testing.T) { assert.Equal(tc.MatchingGlobs, namespaces) }) } +} + +func TestACL_matchingCapabilitySet_difference(t *testing.T) { + tests := []struct { + Policy string + NS string + Difference int + }{ + { + Policy: `namespace "production-*" { policy = "write" }`, + NS: "production-api", + Difference: 3, + }, + { + Policy: `namespace "production-*" { policy = "write" }`, + NS: "production-admin-api", + Difference: 9, + }, + } + + for _, tc := range tests { + t.Run(tc.Policy, func(t *testing.T) { + assert := assert.New(t) + + policy, err := Parse(tc.Policy) + assert.NoError(err) + assert.NotNil(policy.Namespaces) + + acl, err := NewACL(false, []*Policy{policy}) + assert.Nil(err) + + matches := acl.findAllMatchingWildcards(tc.NS) + assert.Equal(tc.Difference, matches[0].difference) + }) + } }