diff --git a/nomad/acl_endpoint.go b/nomad/acl_endpoint.go index ae906819b..defa78d42 100644 --- a/nomad/acl_endpoint.go +++ b/nomad/acl_endpoint.go @@ -1667,3 +1667,278 @@ func (a *ACL) policyNamesFromRoleLinks(roleLinks []*structs.ACLTokenRoleLink) (* return policyNameSet, nil } + +// UpsertAuthMethods is used to create or update a set of auth methods +func (a *ACL) UpsertAuthMethods( + args *structs.ACLAuthMethodUpsertRequest, + reply *structs.ACLAuthMethodUpsertResponse) error { + // Ensure ACLs are enabled, and always flow modification requests to the + // authoritative region + if !a.srv.config.ACLEnabled { + return aclDisabled + } + args.Region = a.srv.config.AuthoritativeRegion + + if done, err := a.srv.forward(structs.ACLUpsertAuthMethodsRPCMethod, args, args, reply); done { + return err + } + defer metrics.MeasureSince([]string{"nomad", "acl", "upsert_auth_methods"}, time.Now()) + + // ACL auth methods can only be used once all servers in all federated + // regions have been upgraded to 1.5.0 or greater. + if !ServersMeetMinimumVersion(a.srv.Members(), AllRegions, minACLAuthMethodVersion, false) { + return fmt.Errorf("all servers should be running version %v or later to use ACL auth methods", + minACLAuthMethodVersion) + } + + // Check management level permissions + if acl, err := a.srv.ResolveToken(args.AuthToken); err != nil { + return err + } else if acl == nil || !acl.IsManagement() { + return structs.ErrPermissionDenied + } + + // Validate non-zero set of auth methods + if len(args.AuthMethods) == 0 { + return structs.NewErrRPCCoded(http.StatusBadRequest, "must specify as least one auth method") + } + + // Validate each auth method, compute hash + for idx, authMethod := range args.AuthMethods { + if err := authMethod.Validate( + a.srv.config.ACLAuthMethodMinExpirationTTL, + a.srv.config.ACLAuthMethodMaxExpirationTTL); err != nil { + return structs.NewErrRPCCodedf(http.StatusBadRequest, "auth method %d invalid: %v", idx, err) + } + authMethod.SetHash() + } + + // Update via Raft + out, index, err := a.srv.raftApply(structs.ACLAuthMethodsUpsertRequestType, args) + if err != nil { + return err + } + + // Check if the FSM response, which is an interface, contains an error. + if err, ok := out.(error); ok && err != nil { + return err + } + + // Update the index + reply.Index = index + return nil +} + +// DeleteAuthMethods is used to delete auth methods +func (a *ACL) DeleteAuthMethods( + args *structs.ACLAuthMethodDeleteRequest, + reply *structs.ACLAuthMethodDeleteResponse) error { + // Ensure ACLs are enabled, and always flow modification requests to the + // authoritative region + if !a.srv.config.ACLEnabled { + return aclDisabled + } + args.Region = a.srv.config.AuthoritativeRegion + + if done, err := a.srv.forward( + structs.ACLDeleteAuthMethodsRPCMethod, args, args, reply); done { + return err + } + defer metrics.MeasureSince([]string{"nomad", "acl", "delete_auth_methods_by_name"}, time.Now()) + + // ACL auth methods can only be used once all servers in all federated + // regions have been upgraded to 1.5.0 or greater. + if !ServersMeetMinimumVersion(a.srv.Members(), AllRegions, minACLRoleVersion, false) { + return fmt.Errorf("all servers should be running version %v or later to use ACL auth methods", + minACLAuthMethodVersion) + } + + // Check management level permissions + if acl, err := a.srv.ResolveToken(args.AuthToken); err != nil { + return err + } else if acl == nil || !acl.IsManagement() { + return structs.ErrPermissionDenied + } + + // Validate non-zero set of auth methods + if len(args.Names) == 0 { + return structs.NewErrRPCCoded(http.StatusBadRequest, "must specify as least one auth method") + } + + // Update via Raft + out, index, err := a.srv.raftApply(structs.ACLAuthMethodsDeleteRequestType, args) + if err != nil { + return err + } + + // Check if the FSM response, which is an interface, contains an error. + if err, ok := out.(error); ok && err != nil { + return err + } + + // Update the index + reply.Index = index + return nil +} + +// ListAuthMethods returns a list of ACL auth methods +func (a *ACL) ListAuthMethods( + args *structs.ACLAuthMethodListRequest, + reply *structs.ACLAuthMethodListResponse) error { + // Only allow operators to list auth methods when ACLs are enabled. + if !a.srv.config.ACLEnabled { + return aclDisabled + } + + if done, err := a.srv.forward( + structs.ACLListAuthMethodsRPCMethod, args, args, reply); done { + return err + } + defer metrics.MeasureSince([]string{"nomad", "acl", "list_auth_methods"}, time.Now()) + + // Resolve the token and ensure it has some form of permissions. + acl, err := a.srv.ResolveToken(args.AuthToken) + if err != nil { + return err + } else if acl == nil { + return structs.ErrPermissionDenied + } + + // Set up and return the blocking query. + return a.srv.blockingRPC(&blockingOptions{ + queryOpts: &args.QueryOptions, + queryMeta: &reply.QueryMeta, + run: func(ws memdb.WatchSet, stateStore *state.StateStore) error { + + // The iteration below appends directly to the reply object, so in + // order for blocking queries to work properly we must ensure the + // auth methods are reset. This allows the blocking query run + // function to work as expected. + reply.AuthMethods = nil + + iter, err := stateStore.GetACLAuthMethods(ws) + if err != nil { + return err + } + + // Iterate all the results and add these to our reply object. + for raw := iter.Next(); raw != nil; raw = iter.Next() { + method := raw.(*structs.ACLAuthMethod) + reply.AuthMethods = append(reply.AuthMethods, method.Stub()) + } + + // Use the index table to populate the query meta + return a.srv.setReplyQueryMeta( + stateStore, state.TableACLAuthMethods, &reply.QueryMeta, + ) + }, + }) +} + +func (a *ACL) GetAuthMethod( + args *structs.ACLAuthMethodGetRequest, + reply *structs.ACLAuthMethodGetResponse) error { + + // Only allow operators to read an auth method when ACLs are enabled. + if !a.srv.config.ACLEnabled { + return aclDisabled + } + + if done, err := a.srv.forward( + structs.ACLGetAuthMethodRPCMethod, args, args, reply); done { + return err + } + defer metrics.MeasureSince([]string{"nomad", "acl", "get_auth_method_name"}, time.Now()) + + // Resolve the token and ensure it has some form of permissions. + acl, err := a.srv.ResolveToken(args.AuthToken) + if err != nil { + return err + } else if acl == nil || !acl.IsManagement() { + return structs.ErrPermissionDenied + } + + // Set up and return the blocking query. + return a.srv.blockingRPC(&blockingOptions{ + queryOpts: &args.QueryOptions, + queryMeta: &reply.QueryMeta, + run: func(ws memdb.WatchSet, stateStore *state.StateStore) error { + + // Perform a lookup + out, err := stateStore.GetACLAuthMethodByName(ws, args.MethodName) + if err != nil { + return err + } + + // Set the index correctly depending on whether the auth method was + // found. + switch out { + case nil: + index, err := stateStore.Index(state.TableACLAuthMethods) + if err != nil { + return err + } + reply.Index = index + default: + reply.Index = out.ModifyIndex + } + + // We didn't encounter an error looking up the index; set the auth + // method on the reply and exit successfully. + reply.AuthMethod = out + return nil + }, + }) +} + +// GetAuthMethods is used to get a set of auth methods +func (a *ACL) GetAuthMethods( + args *structs.ACLAuthMethodsGetRequest, + reply *structs.ACLAuthMethodsGetResponse) error { + if !a.srv.config.ACLEnabled { + return aclDisabled + } + if done, err := a.srv.forward( + structs.ACLGetAuthMethodsRPCMethod, args, args, reply); done { + return err + } + defer metrics.MeasureSince([]string{"nomad", "acl", "get_auth_methods"}, time.Now()) + + // allow only management token holders to query this endpoint + token, err := a.requestACLToken(args.AuthToken) + if err != nil { + return err + } + if token == nil { + return structs.ErrTokenNotFound + } + if token.Type != structs.ACLManagementToken { + return structs.ErrPermissionDenied + } + + // Setup the blocking query + return a.srv.blockingRPC(&blockingOptions{ + queryOpts: &args.QueryOptions, + queryMeta: &reply.QueryMeta, + run: func(ws memdb.WatchSet, statestore *state.StateStore) error { + // Setup the output + reply.AuthMethods = make(map[string]*structs.ACLAuthMethod, len(args.Names)) + + // Look for the auth method + for _, methodName := range args.Names { + out, err := statestore.GetACLAuthMethodByName(ws, methodName) + if err != nil { + return err + } + if out != nil { + reply.AuthMethods[methodName] = out + } + } + + // Use the index table to populate the query meta + return a.srv.setReplyQueryMeta( + statestore, state.TableACLAuthMethods, &reply.QueryMeta, + ) + }}, + ) +} diff --git a/nomad/acl_endpoint_test.go b/nomad/acl_endpoint_test.go index e4ac8d460..003daa567 100644 --- a/nomad/acl_endpoint_test.go +++ b/nomad/acl_endpoint_test.go @@ -2678,3 +2678,367 @@ func TestACL_GetRoleByName(t *testing.T) { err = msgpackrpc.CallWithCodec(codec, structs.ACLGetRoleByNameRPCMethod, aclRoleReq6, &aclRoleResp6) require.ErrorContains(t, err, "Permission denied") } + +func TestACLEndpoint_GetAuthMethod(t *testing.T) { + t.Parallel() + + s1, root, cleanupS1 := TestACLServer(t, nil) + defer cleanupS1() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Create the register request + authMethod := mock.ACLAuthMethod() + must.NoError(t, s1.fsm.State().UpsertACLAuthMethods(1000, []*structs.ACLAuthMethod{authMethod})) + + anonymousAuthMethod := mock.ACLAuthMethod() + anonymousAuthMethod.Name = "anonymous" + must.NoError(t, s1.fsm.State().UpsertACLAuthMethods(1001, []*structs.ACLAuthMethod{anonymousAuthMethod})) + + // Lookup the authMethod + get := &structs.ACLAuthMethodGetRequest{ + MethodName: authMethod.Name, + QueryOptions: structs.QueryOptions{ + Region: "global", + AuthToken: root.SecretID, + }, + } + var resp structs.ACLAuthMethodGetResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLGetAuthMethodRPCMethod, get, &resp)) + must.Eq(t, uint64(1000), resp.Index) + must.Eq(t, authMethod, resp.AuthMethod) + + // Lookup non-existing authMethod + get.MethodName = uuid.Generate() + must.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLGetAuthMethodRPCMethod, get, &resp)) + must.Eq(t, uint64(1001), resp.Index) + must.Nil(t, resp.AuthMethod) +} + +func TestACLEndpoint_GetAuthMethod_Blocking(t *testing.T) { + t.Parallel() + + s1, root, cleanupS1 := TestACLServer(t, nil) + defer cleanupS1() + state := s1.fsm.State() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Create the authMethods + am1 := mock.ACLAuthMethod() + am2 := mock.ACLAuthMethod() + + // First create an unrelated authMethod + time.AfterFunc(100*time.Millisecond, func() { + must.NoError(t, state.UpsertACLAuthMethods(100, []*structs.ACLAuthMethod{am1})) + }) + + // Upsert the authMethod we are watching later + time.AfterFunc(200*time.Millisecond, func() { + must.NoError(t, state.UpsertACLAuthMethods(200, []*structs.ACLAuthMethod{am2})) + }) + + // Lookup the authMethod + req := &structs.ACLAuthMethodGetRequest{ + MethodName: am2.Name, + QueryOptions: structs.QueryOptions{ + Region: "global", + MinQueryIndex: 150, + AuthToken: root.SecretID, + }, + } + var resp structs.ACLAuthMethodGetResponse + start := time.Now() + must.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLGetAuthMethodRPCMethod, req, &resp)) + + if elapsed := time.Since(start); elapsed < 200*time.Millisecond { + t.Fatalf("should block (returned in %s) %#v", elapsed, resp) + } + must.Eq(t, resp.Index, 200) + must.NotNil(t, resp.AuthMethod) + must.Eq(t, resp.AuthMethod.Name, am2.Name) + + // Auth method delete triggers watches + time.AfterFunc(100*time.Millisecond, func() { + must.NoError(t, state.DeleteACLAuthMethods(300, []string{am2.Name})) + }) + + req.QueryOptions.MinQueryIndex = 250 + var resp2 structs.ACLAuthMethodGetResponse + start = time.Now() + must.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLGetAuthMethodRPCMethod, req, &resp2)) + + if elapsed := time.Since(start); elapsed < 100*time.Millisecond { + t.Fatalf("should block (returned in %s) %#v", elapsed, resp2) + } + must.Eq(t, resp2.Index, 300) + must.Nil(t, resp2.AuthMethod) +} + +func TestACLEndpoint_GetAuthMethods(t *testing.T) { + t.Parallel() + + s1, root, cleanupS1 := TestACLServer(t, nil) + defer cleanupS1() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Create the register request + authMethod := mock.ACLAuthMethod() + authMethod2 := mock.ACLAuthMethod() + must.NoError(t, s1.fsm.State().UpsertACLAuthMethods(1000, []*structs.ACLAuthMethod{authMethod, authMethod2})) + + // Lookup the authMethod + get := &structs.ACLAuthMethodsGetRequest{ + Names: []string{authMethod.Name, authMethod2.Name}, + QueryOptions: structs.QueryOptions{ + Region: "global", + AuthToken: root.SecretID, + }, + } + var resp structs.ACLAuthMethodsGetResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLGetAuthMethodsRPCMethod, get, &resp)) + must.Eq(t, uint64(1000), resp.Index) + must.Eq(t, 2, len(resp.AuthMethods)) + must.Eq(t, authMethod, resp.AuthMethods[authMethod.Name]) + must.Eq(t, authMethod2, resp.AuthMethods[authMethod2.Name]) + + // Lookup non-existing authMethod + get.Names = []string{uuid.Generate()} + resp = structs.ACLAuthMethodsGetResponse{} + must.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLGetAuthMethodsRPCMethod, get, &resp)) + must.Eq(t, uint64(1000), resp.Index) + must.Eq(t, 0, len(resp.AuthMethods)) +} + +func TestACLEndpoint_GetAuthMethods_Blocking(t *testing.T) { + t.Parallel() + + s1, root, cleanupS1 := TestACLServer(t, nil) + defer cleanupS1() + state := s1.fsm.State() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Create the authMethods + am1 := mock.ACLAuthMethod() + am2 := mock.ACLAuthMethod() + + // First create an unrelated authMethod + time.AfterFunc(100*time.Millisecond, func() { + must.NoError(t, state.UpsertACLAuthMethods(100, []*structs.ACLAuthMethod{am1})) + }) + + // Upsert the authMethod we are watching later + time.AfterFunc(200*time.Millisecond, func() { + must.NoError(t, state.UpsertACLAuthMethods(200, []*structs.ACLAuthMethod{am2})) + }) + + // Lookup the authMethod + req := &structs.ACLAuthMethodsGetRequest{ + Names: []string{am2.Name}, + QueryOptions: structs.QueryOptions{ + Region: "global", + MinQueryIndex: 150, + AuthToken: root.SecretID, + }, + } + var resp structs.ACLAuthMethodsGetResponse + start := time.Now() + must.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLGetAuthMethodsRPCMethod, req, &resp)) + + if elapsed := time.Since(start); elapsed < 200*time.Millisecond { + t.Fatalf("should block (returned in %s) %#v", elapsed, resp) + } + must.Eq(t, resp.Index, 200) + must.NotEq(t, len(resp.AuthMethods), 0) + must.NotNil(t, resp.AuthMethods[am2.Name]) + + // Auth method delete triggers watches + time.AfterFunc(100*time.Millisecond, func() { + must.NoError(t, state.DeleteACLAuthMethods(300, []string{am2.Name})) + }) + + req.QueryOptions.MinQueryIndex = 250 + var resp2 structs.ACLAuthMethodsGetResponse + start = time.Now() + must.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLGetAuthMethodsRPCMethod, req, &resp2)) + + if elapsed := time.Since(start); elapsed < 100*time.Millisecond { + t.Fatalf("should block (returned in %s) %#v", elapsed, resp2) + } + must.Eq(t, resp2.Index, 300) + must.Eq(t, len(resp2.AuthMethods), 0) +} + +func TestACLEndpoint_ListAuthMethods(t *testing.T) { + t.Parallel() + + s1, root, cleanupS1 := TestACLServer(t, nil) + defer cleanupS1() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Create the register request + am1 := mock.ACLAuthMethod() + am2 := mock.ACLAuthMethod() + + am1.Name = "aaaaaaaa-3350-4b4b-d185-0e1992ed43e9" + am2.Name = "aaaabbbb-3350-4b4b-d185-0e1992ed43e9" + must.NoError(t, s1.fsm.State().UpsertACLAuthMethods(1000, []*structs.ACLAuthMethod{am1, am2})) + + // Create a token + token := mock.ACLToken() + must.NoError(t, s1.fsm.State().UpsertACLTokens(structs.MsgTypeTestSetup, 1001, []*structs.ACLToken{token})) + + // Lookup the authMethods with a management token + get := &structs.ACLAuthMethodListRequest{ + QueryOptions: structs.QueryOptions{ + Region: "global", + AuthToken: root.SecretID, + }, + } + var resp structs.ACLAuthMethodListResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLListAuthMethodsRPCMethod, get, &resp)) + must.Eq(t, 1000, resp.Index) + must.Len(t, 2, resp.AuthMethods) + + // List authMethods using the created token + get = &structs.ACLAuthMethodListRequest{ + QueryOptions: structs.QueryOptions{ + Region: "global", + AuthToken: token.SecretID, + }, + } + var resp3 structs.ACLAuthMethodListResponse + if err := msgpackrpc.CallWithCodec(codec, structs.ACLListAuthMethodsRPCMethod, get, &resp3); err != nil { + t.Fatalf("err: %v", err) + } + must.Eq(t, 1000, resp3.Index) + must.Len(t, 2, resp3.AuthMethods) + must.Eq(t, resp3.AuthMethods[0].Name, am1.Name) +} + +func TestACLEndpoint_ListAuthMethods_Blocking(t *testing.T) { + t.Parallel() + + s1, root, cleanupS1 := TestACLServer(t, nil) + defer cleanupS1() + state := s1.fsm.State() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Create the authMethod + authMethod := mock.ACLAuthMethod() + + // Upsert auth method triggers watches + time.AfterFunc(100*time.Millisecond, func() { + must.NoError(t, state.UpsertACLAuthMethods(2, []*structs.ACLAuthMethod{authMethod})) + }) + + req := &structs.ACLAuthMethodListRequest{ + QueryOptions: structs.QueryOptions{ + Region: "global", + MinQueryIndex: 1, + AuthToken: root.SecretID, + }, + } + start := time.Now() + var resp structs.ACLAuthMethodListResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLListAuthMethodsRPCMethod, req, &resp)) + + if elapsed := time.Since(start); elapsed < 100*time.Millisecond { + t.Fatalf("should block (returned in %s) %#v", elapsed, resp) + } + must.Eq(t, uint64(2), resp.Index) + must.Len(t, 1, resp.AuthMethods) + must.Eq(t, resp.AuthMethods[0].Name, authMethod.Name) + + // Eval deletion triggers watches + time.AfterFunc(100*time.Millisecond, func() { + must.NoError(t, state.DeleteACLAuthMethods(3, []string{authMethod.Name})) + }) + + req.MinQueryIndex = 2 + start = time.Now() + var resp2 structs.ACLAuthMethodListResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLListAuthMethodsRPCMethod, req, &resp2)) + + if elapsed := time.Since(start); elapsed < 100*time.Millisecond { + t.Fatalf("should block (returned in %s) %#v", elapsed, resp2) + } + must.Eq(t, uint64(3), resp2.Index) + must.Eq(t, 0, len(resp2.AuthMethods)) +} + +func TestACLEndpoint_DeleteAuthMethods(t *testing.T) { + t.Parallel() + + s1, root, cleanupS1 := TestACLServer(t, nil) + defer cleanupS1() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Create the register request + am1 := mock.ACLAuthMethod() + must.NoError(t, s1.fsm.State().UpsertACLAuthMethods(1000, []*structs.ACLAuthMethod{am1})) + + // Lookup the authMethods + req := &structs.ACLAuthMethodDeleteRequest{ + Names: []string{am1.Name}, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: root.SecretID, + }, + } + var resp structs.ACLAuthMethodDeleteResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLDeleteAuthMethodsRPCMethod, req, &resp)) + must.NotEq(t, uint64(0), resp.Index) + + // Try to delete a non-existing auth method + req = &structs.ACLAuthMethodDeleteRequest{ + Names: []string{"non-existing-auth-method"}, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: root.SecretID, + }, + } + var resp2 structs.ACLAuthMethodDeleteResponse + must.Error(t, msgpackrpc.CallWithCodec(codec, structs.ACLDeleteAuthMethodsRPCMethod, req, &resp2)) +} + +func TestACLEndpoint_UpsertACLAuthMethods(t *testing.T) { + t.Parallel() + + s1, root, cleanupS1 := TestACLServer(t, nil) + defer cleanupS1() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + minTTL, _ := time.ParseDuration("10s") + maxTTL, _ := time.ParseDuration("24h") + s1.config.ACLAuthMethodMinExpirationTTL = minTTL + s1.config.ACLAuthMethodMaxExpirationTTL = maxTTL + + // Create the register request + am1 := mock.ACLAuthMethod() + + // Lookup the authMethods + req := &structs.ACLAuthMethodUpsertRequest{ + AuthMethods: []*structs.ACLAuthMethod{am1}, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: root.SecretID, + }, + } + var resp structs.ACLAuthMethodUpsertResponse + if err := msgpackrpc.CallWithCodec(codec, structs.ACLUpsertAuthMethodsRPCMethod, req, &resp); err != nil { + t.Fatalf("err: %v", err) + } + must.NotEq(t, uint64(0), resp.Index) + + // Check we created the authMethod + out, err := s1.fsm.State().GetACLAuthMethodByName(nil, am1.Name) + must.Nil(t, err) + must.NotNil(t, out) +} diff --git a/nomad/config.go b/nomad/config.go index 2bcb8714b..3d9a9bb73 100644 --- a/nomad/config.go +++ b/nomad/config.go @@ -328,6 +328,14 @@ type Config struct { // for ACL token expiration. ACLTokenMaxExpirationTTL time.Duration + // ACLAuthMethodMinExpirationTTL is used to enforce the lowest acceptable + // value for ACL auth method expiration. + ACLAuthMethodMinExpirationTTL time.Duration + + // ACLAuthMethodMaxExpirationTTL is used to enforce the highest acceptable + // value for ACL auth method expiration. + ACLAuthMethodMaxExpirationTTL time.Duration + // SentinelGCInterval is the interval that we GC unused policies. SentinelGCInterval time.Duration diff --git a/nomad/fsm.go b/nomad/fsm.go index b168f384f..e7f5f3711 100644 --- a/nomad/fsm.go +++ b/nomad/fsm.go @@ -329,6 +329,10 @@ func (n *nomadFSM) Apply(log *raft.Log) interface{} { return n.applyACLRolesUpsert(msgType, buf[1:], log.Index) case structs.ACLRolesDeleteByIDRequestType: return n.applyACLRolesDeleteByID(msgType, buf[1:], log.Index) + case structs.ACLAuthMethodsUpsertRequestType: + return n.applyACLAuthMethodsUpsert(buf[1:], log.Index) + case structs.ACLAuthMethodsDeleteRequestType: + return n.applyACLAuthMethodsDelete(buf[1:], log.Index) } // Check enterprise only message types. @@ -2046,6 +2050,36 @@ func (n *nomadFSM) applyACLRolesDeleteByID(msgType structs.MessageType, buf []by return nil } +func (n *nomadFSM) applyACLAuthMethodsUpsert(buf []byte, index uint64) interface{} { + defer metrics.MeasureSince([]string{"nomad", "fsm", "apply_acl_auth_method_upsert"}, time.Now()) + var req structs.ACLAuthMethodUpsertRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + + if err := n.state.UpsertACLAuthMethods(index, req.AuthMethods); err != nil { + n.logger.Error("UpsertACLAuthMethods failed", "error", err) + return err + } + + return nil +} + +func (n *nomadFSM) applyACLAuthMethodsDelete(buf []byte, index uint64) interface{} { + defer metrics.MeasureSince([]string{"nomad", "fsm", "apply_acl_auth_method_delete"}, time.Now()) + var req structs.ACLAuthMethodDeleteRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + + if err := n.state.DeleteACLAuthMethods(index, req.Names); err != nil { + n.logger.Error("DeleteACLAuthMethods failed", "error", err) + return err + } + + return nil +} + type FSMFilter struct { evaluator *bexpr.Evaluator } diff --git a/nomad/fsm_test.go b/nomad/fsm_test.go index 8299c2257..1bcc72695 100644 --- a/nomad/fsm_test.go +++ b/nomad/fsm_test.go @@ -13,6 +13,7 @@ import ( memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/raft" "github.com/kr/pretty" + "github.com/shoenig/test/must" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -3477,3 +3478,53 @@ func TestFSM_EventBroker_JobRegisterFSMEvents(t *testing.T) { require.Len(t, events, 1) require.Equal(t, structs.TypeJobRegistered, events[0].Type) } + +func TestFSM_UpsertACLAuthMethods(t *testing.T) { + ci.Parallel(t) + fsm := testFSM(t) + + am1 := mock.ACLAuthMethod() + am2 := mock.ACLAuthMethod() + req := structs.ACLAuthMethodUpsertRequest{ + AuthMethods: []*structs.ACLAuthMethod{am1, am2}, + } + buf, err := structs.Encode(structs.ACLAuthMethodsUpsertRequestType, req) + must.Nil(t, err) + must.Nil(t, fsm.Apply(makeLog(buf))) + + // Verify we are registered + ws := memdb.NewWatchSet() + out, err := fsm.State().GetACLAuthMethodByName(ws, am1.Name) + must.Nil(t, err) + must.NotNil(t, out) + + out, err = fsm.State().GetACLAuthMethodByName(ws, am2.Name) + must.Nil(t, err) + must.NotNil(t, out) +} + +func TestFSM_DeleteACLAuthMethods(t *testing.T) { + ci.Parallel(t) + fsm := testFSM(t) + + am1 := mock.ACLAuthMethod() + am2 := mock.ACLAuthMethod() + must.Nil(t, fsm.State().UpsertACLAuthMethods(1000, []*structs.ACLAuthMethod{am1, am2})) + + req := structs.ACLAuthMethodDeleteRequest{ + Names: []string{am1.Name, am2.Name}, + } + buf, err := structs.Encode(structs.ACLAuthMethodsDeleteRequestType, req) + must.Nil(t, err) + must.Nil(t, fsm.Apply(makeLog(buf))) + + // Verify we are NOT registered + ws := memdb.NewWatchSet() + out, err := fsm.State().GetACLAuthMethodByName(ws, am1.Name) + must.Nil(t, err) + must.Nil(t, out) + + out, err = fsm.State().GetACLAuthMethodByName(ws, am2.Name) + must.Nil(t, err) + must.Nil(t, out) +} diff --git a/nomad/leader.go b/nomad/leader.go index b15285d13..5f992e728 100644 --- a/nomad/leader.go +++ b/nomad/leader.go @@ -54,6 +54,14 @@ var minOneTimeAuthenticationTokenVersion = version.Must(version.NewVersion("1.1. // before the feature can be used. var minACLRoleVersion = version.Must(version.NewVersion("1.4.0")) +// minACLAuthMethodVersion is the Nomad version at which the ACL auth methods +// table was introduced. It forms the minimum version all federated servers must +// meet before the feature can be used. +// +// TODO: version constraint will be updated for every beta or rc until we reach +// 1.5, otherwise it's hard to test the functionality +var minACLAuthMethodVersion = version.Must(version.NewVersion("1.4.3-dev")) + // minNomadServiceRegistrationVersion is the Nomad version at which the service // registrations table was introduced. It forms the minimum version all local // servers must meet before the feature can be used. diff --git a/nomad/mock/acl.go b/nomad/mock/acl.go index d152bfa8a..4810affec 100644 --- a/nomad/mock/acl.go +++ b/nomad/mock/acl.go @@ -220,11 +220,12 @@ func ACLManagementToken() *structs.ACLToken { } func ACLAuthMethod() *structs.ACLAuthMethod { + maxTokenTTL, _ := time.ParseDuration("3600s") method := structs.ACLAuthMethod{ Name: fmt.Sprintf("acl-auth-method-%s", uuid.Short()), - Type: "acl-auth-mock-type", - TokenLocality: "locality", - MaxTokenTTL: "3600s", + Type: "OIDC", + TokenLocality: "local", + MaxTokenTTL: maxTokenTTL, Default: true, Config: &structs.ACLAuthMethodConfig{ OIDCDiscoveryURL: "http://example.com", diff --git a/nomad/structs/acl.go b/nomad/structs/acl.go index f33bd708e..4117ba9cd 100644 --- a/nomad/structs/acl.go +++ b/nomad/structs/acl.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "regexp" + "strconv" "time" "github.com/hashicorp/go-multierror" @@ -78,6 +79,40 @@ const ( // Args: ACLRoleByNameRequest // Reply: ACLRoleByNameResponse ACLGetRoleByNameRPCMethod = "ACL.GetRoleByName" + + // ACLUpsertAuthMethodsRPCMethod is the RPC method for batch creating or + // modifying auth methods. + // + // Args: ACLAuthMethodsUpsertRequest + // Reply: ACLAuthMethodUpsertResponse + ACLUpsertAuthMethodsRPCMethod = "ACL.UpsertAuthMethods" + + // ACLDeleteAuthMethodsRPCMethod is the RPC method for batch deleting auth + // methods. + // + // Args: ACLAuthMethodDeleteRequest + // Reply: ACLAuthMethodDeleteResponse + ACLDeleteAuthMethodsRPCMethod = "ACL.DeleteAuthMethods" + + // ACLListAuthMethodsRPCMethod is the RPC method for listing auth methods. + // + // Args: ACLAuthMethodListRequest + // Reply: ACLAuthMethodListResponse + ACLListAuthMethodsRPCMethod = "ACL.ListAuthMethods" + + // ACLGetAuthMethodRPCMethod is the RPC method for detailing an individual + // auth method using its name. + // + // Args: ACLAuthMethodGetRequest + // Reply: ACLAuthMethodGetResponse + ACLGetAuthMethodRPCMethod = "ACL.GetAuthMethod" + + // ACLGetAuthMethodsRPCMethod is the RPC method for getting multiple auth + // methods using their names. + // + // Args: ACLAuthMethodsGetRequest + // Reply: ACLAuthMethodsGetResponse + ACLGetAuthMethodsRPCMethod = "ACL.GetAuthMethods" ) const ( @@ -95,6 +130,9 @@ const ( var ( // validACLRoleName is used to validate an ACL role name. validACLRoleName = regexp.MustCompile("^[a-zA-Z0-9-]{1,128}$") + + // validACLAuthMethodName is used to validate an ACL auth method name. + validACLAuthMethod = regexp.MustCompile("^[a-zA-Z0-9-]{1,128}$") ) // ACLTokenRoleLink is used to link an ACL token to an ACL role. The ACL token @@ -517,3 +555,233 @@ type ACLRoleByNameResponse struct { ACLRole *ACLRole QueryMeta } + +// ACLAuthMethod is used to capture the properties of an authentication method +// used for single sing-on +type ACLAuthMethod struct { + Name string + Type string + TokenLocality string // is the token valid locally or globally? + MaxTokenTTL time.Duration + Default bool + Config *ACLAuthMethodConfig + + Hash []byte + + CreateTime time.Time + ModifyTime time.Time + CreateIndex uint64 + ModifyIndex uint64 +} + +// SetHash is used to compute and set the hash of the ACL auth method. This +// should be called every and each time a user specified field on the method is +// changed before updating the Nomad state store. +func (a *ACLAuthMethod) SetHash() []byte { + + // Initialize a 256bit Blake2 hash (32 bytes). + hash, err := blake2b.New256(nil) + if err != nil { + panic(err) + } + + _, _ = hash.Write([]byte(a.Name)) + _, _ = hash.Write([]byte(a.Type)) + _, _ = hash.Write([]byte(a.TokenLocality)) + _, _ = hash.Write([]byte(a.MaxTokenTTL.String())) + _, _ = hash.Write([]byte(strconv.FormatBool(a.Default))) + + if a.Config != nil { + _, _ = hash.Write([]byte(a.Config.OIDCDiscoveryURL)) + _, _ = hash.Write([]byte(a.Config.OIDCClientID)) + _, _ = hash.Write([]byte(a.Config.OIDCClientSecret)) + for _, ba := range a.Config.BoundAudiences { + _, _ = hash.Write([]byte(ba)) + } + for _, uri := range a.Config.AllowedRedirectURIs { + _, _ = hash.Write([]byte(uri)) + } + for _, pem := range a.Config.DiscoveryCaPem { + _, _ = hash.Write([]byte(pem)) + } + for _, sa := range a.Config.SigningAlgs { + _, _ = hash.Write([]byte(sa)) + } + for k, v := range a.Config.ClaimMappings { + _, _ = hash.Write([]byte(k)) + _, _ = hash.Write([]byte(v)) + } + for k, v := range a.Config.ListClaimMappings { + _, _ = hash.Write([]byte(k)) + _, _ = hash.Write([]byte(v)) + } + } + + // Finalize the hash. + hashVal := hash.Sum(nil) + + // Set and return the hash. + a.Hash = hashVal + return hashVal +} + +func (a *ACLAuthMethod) Stub() *ACLAuthMethodStub { + return &ACLAuthMethodStub{ + Name: a.Name, + Default: a.Default, + } +} + +func (a *ACLAuthMethod) Equal(other *ACLAuthMethod) bool { + if a == nil || other == nil { + return a == other + } + if len(a.Hash) == 0 { + a.SetHash() + } + if len(other.Hash) == 0 { + other.SetHash() + } + return bytes.Equal(a.Hash, other.Hash) + +} + +// Copy creates a deep copy of the ACL auth method. This copy can then be safely +// modified. It handles nil objects. +func (a *ACLAuthMethod) Copy() *ACLAuthMethod { + if a == nil { + return nil + } + + c := new(ACLAuthMethod) + *c = *a + + c.Hash = slices.Clone(a.Hash) + c.Config = a.Config.Copy() + + return c +} + +// Validate returns an error is the ACLAuthMethod is invalid. +// +// TODO revisit possible other validity conditions in the future +func (a *ACLAuthMethod) Validate(minTTL, maxTTL time.Duration) error { + var mErr multierror.Error + + if !validACLAuthMethod.MatchString(a.Name) { + mErr.Errors = append(mErr.Errors, fmt.Errorf("invalid name '%s'", a.Name)) + } + + if !slices.Contains([]string{"local", "global"}, a.TokenLocality) { + mErr.Errors = append( + mErr.Errors, fmt.Errorf("invalid token locality '%s'", a.TokenLocality)) + } + + if a.Type != "OIDC" { + mErr.Errors = append( + mErr.Errors, fmt.Errorf("invalid token type '%s'", a.Type)) + } + + if minTTL > a.MaxTokenTTL || a.MaxTokenTTL > maxTTL { + mErr.Errors = append(mErr.Errors, fmt.Errorf( + "invalid MaxTokenTTL value '%s' (should be between %s and %s)", + a.MaxTokenTTL.String(), minTTL.String(), maxTTL.String())) + } + + return mErr.ErrorOrNil() +} + +// ACLAuthMethodConfig is used to store configuration of an auth method +type ACLAuthMethodConfig struct { + OIDCDiscoveryURL string + OIDCClientID string + OIDCClientSecret string + BoundAudiences []string + AllowedRedirectURIs []string + DiscoveryCaPem []string + SigningAlgs []string + ClaimMappings map[string]string + ListClaimMappings map[string]string +} + +func (a *ACLAuthMethodConfig) Copy() *ACLAuthMethodConfig { + if a == nil { + return nil + } + + c := new(ACLAuthMethodConfig) + *c = *a + + c.BoundAudiences = slices.Clone(a.BoundAudiences) + c.AllowedRedirectURIs = slices.Clone(a.AllowedRedirectURIs) + c.DiscoveryCaPem = slices.Clone(a.DiscoveryCaPem) + c.SigningAlgs = slices.Clone(a.SigningAlgs) + + return c +} + +// ACLAuthMethodStub is used for listing ACL auth methods +type ACLAuthMethodStub struct { + Name string + Default bool +} + +// ACLAuthMethodListRequest is used to list auth methods +type ACLAuthMethodListRequest struct { + QueryOptions +} + +// ACLAuthMethodListResponse is used to list auth methods +type ACLAuthMethodListResponse struct { + AuthMethods []*ACLAuthMethodStub + QueryMeta +} + +// ACLAuthMethodGetRequest is used to query a specific auth method +type ACLAuthMethodGetRequest struct { + MethodName string + QueryOptions +} + +// ACLAuthMethodGetResponse is used to return a single auth method +type ACLAuthMethodGetResponse struct { + AuthMethod *ACLAuthMethod + QueryMeta +} + +// ACLAuthMethodsGetRequest is used to query a set of auth methods +type ACLAuthMethodsGetRequest struct { + Names []string + QueryOptions +} + +// ACLAuthMethodsGetResponse is used to return a set of auth methods +type ACLAuthMethodsGetResponse struct { + AuthMethods map[string]*ACLAuthMethod + QueryMeta +} + +// ACLAuthMethodUpsertRequest is used to upsert a set of auth methods +type ACLAuthMethodUpsertRequest struct { + AuthMethods []*ACLAuthMethod + WriteRequest +} + +// ACLAuthMethodUpsertResponse is a response of the upsert ACL auth methods +// operation +type ACLAuthMethodUpsertResponse struct { + WriteMeta +} + +// ACLAuthMethodDeleteRequest is used to delete a set of auth methods by their +// name +type ACLAuthMethodDeleteRequest struct { + Names []string + WriteRequest +} + +// ACLAuthMethodDeleteResponse is a response of the delete ACL auth methods +// operation +type ACLAuthMethodDeleteResponse struct { + WriteMeta +} diff --git a/nomad/structs/acl_test.go b/nomad/structs/acl_test.go index 05a497efd..ca51a4564 100644 --- a/nomad/structs/acl_test.go +++ b/nomad/structs/acl_test.go @@ -8,6 +8,7 @@ import ( "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/helper/pointer" "github.com/hashicorp/nomad/helper/uuid" + "github.com/shoenig/test/must" "github.com/stretchr/testify/require" ) @@ -831,3 +832,224 @@ func Test_ACLRoleByNameRequest(t *testing.T) { req := ACLRoleByNameRequest{} require.True(t, req.IsRead()) } + +func Test_ACLAuthMethodListRequest(t *testing.T) { + req := ACLAuthMethodListRequest{} + must.True(t, req.IsRead()) +} + +func Test_ACLAuthMethodGetRequest(t *testing.T) { + req := ACLAuthMethodGetRequest{} + must.True(t, req.IsRead()) +} + +func TestACLAuthMethodSetHash(t *testing.T) { + ci.Parallel(t) + + am := &ACLAuthMethod{ + Name: "foo", + Type: "bad type", + } + out1 := am.SetHash() + must.NotNil(t, out1) + must.NotNil(t, am.Hash) + must.Eq(t, out1, am.Hash) + + am.Type = "good type" + out2 := am.SetHash() + must.NotNil(t, out2) + must.NotNil(t, am.Hash) + must.Eq(t, out2, am.Hash) + must.NotEq(t, out1, out2) +} + +func TestACLAuthMethod_Stub(t *testing.T) { + ci.Parallel(t) + + maxTokenTTL, _ := time.ParseDuration("3600s") + am := ACLAuthMethod{ + Name: fmt.Sprintf("acl-auth-method-%s", uuid.Short()), + Type: "acl-auth-mock-type", + TokenLocality: "locality", + MaxTokenTTL: maxTokenTTL, + Default: true, + Config: &ACLAuthMethodConfig{ + OIDCDiscoveryURL: "http://example.com", + OIDCClientID: "mock", + OIDCClientSecret: "very secret secret", + BoundAudiences: []string{"audience1", "audience2"}, + AllowedRedirectURIs: []string{"foo", "bar"}, + DiscoveryCaPem: []string{"foo"}, + SigningAlgs: []string{"bar"}, + ClaimMappings: map[string]string{"foo": "bar"}, + ListClaimMappings: map[string]string{"foo": "bar"}, + }, + CreateTime: time.Now().UTC(), + CreateIndex: 10, + ModifyIndex: 10, + } + am.SetHash() + must.Eq(t, am.Stub(), &ACLAuthMethodStub{am.Name, am.Default}) + + nilAuthMethod := &ACLAuthMethod{} + must.Eq(t, nilAuthMethod.Stub(), &ACLAuthMethodStub{}) +} + +func TestACLAuthMethod_Equal(t *testing.T) { + ci.Parallel(t) + + maxTokenTTL, _ := time.ParseDuration("3600s") + am1 := &ACLAuthMethod{ + Name: fmt.Sprintf("acl-auth-method-%s", uuid.Short()), + Type: "acl-auth-mock-type", + TokenLocality: "locality", + MaxTokenTTL: maxTokenTTL, + Default: true, + Config: &ACLAuthMethodConfig{ + OIDCDiscoveryURL: "http://example.com", + OIDCClientID: "mock", + OIDCClientSecret: "very secret secret", + BoundAudiences: []string{"audience1", "audience2"}, + AllowedRedirectURIs: []string{"foo", "bar"}, + DiscoveryCaPem: []string{"foo"}, + SigningAlgs: []string{"bar"}, + ClaimMappings: map[string]string{"foo": "bar"}, + ListClaimMappings: map[string]string{"foo": "bar"}, + }, + CreateTime: time.Now().UTC(), + CreateIndex: 10, + ModifyIndex: 10, + } + am1.SetHash() + + // am2 differs from am1 by 1 nested conf field + am2 := am1.Copy() + am2.Config.OIDCClientID = "mock2" + am2.SetHash() + + tests := []struct { + name string + method1 *ACLAuthMethod + method2 *ACLAuthMethod + want bool + }{ + {"one nil", am1, &ACLAuthMethod{}, false}, + {"both nil", &ACLAuthMethod{}, &ACLAuthMethod{}, true}, + {"one is different than the other", am1, am2, false}, + {"equal", am1, am1.Copy(), true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.method1.Equal(tt.method2) + must.Eq(t, got, tt.want, must.Sprintf( + "ACLAuthMethod.Equal() got %v, want %v, test case: %s", got, tt.want, tt.name)) + }) + } +} + +func TestACLAuthMethod_Copy(t *testing.T) { + ci.Parallel(t) + + maxTokenTTL, _ := time.ParseDuration("3600s") + am1 := &ACLAuthMethod{ + Name: fmt.Sprintf("acl-auth-method-%s", uuid.Short()), + Type: "acl-auth-mock-type", + TokenLocality: "locality", + MaxTokenTTL: maxTokenTTL, + Default: true, + Config: &ACLAuthMethodConfig{ + OIDCDiscoveryURL: "http://example.com", + OIDCClientID: "mock", + OIDCClientSecret: "very secret secret", + BoundAudiences: []string{"audience1", "audience2"}, + AllowedRedirectURIs: []string{"foo", "bar"}, + DiscoveryCaPem: []string{"foo"}, + SigningAlgs: []string{"bar"}, + ClaimMappings: map[string]string{"foo": "bar"}, + ListClaimMappings: map[string]string{"foo": "bar"}, + }, + CreateTime: time.Now().UTC(), + CreateIndex: 10, + ModifyIndex: 10, + } + am1.SetHash() + + am2 := am1.Copy() + am2.SetHash() + must.Eq(t, am1, am2) + + am3 := am1.Copy() + am3.Config.AllowedRedirectURIs = []string{"new", "urls"} + am3.SetHash() + must.NotEq(t, am1, am3) +} + +func TestACLAuthMethod_Validate(t *testing.T) { + ci.Parallel(t) + + goodTTL, _ := time.ParseDuration("3600s") + badTTL, _ := time.ParseDuration("3600h") + + tests := []struct { + name string + method *ACLAuthMethod + wantErr bool + errContains string + }{ + { + "valid method", + &ACLAuthMethod{ + Name: "mock-auth-method", + Type: "OIDC", + TokenLocality: "local", + MaxTokenTTL: goodTTL, + }, + false, + "", + }, + {"invalid name", &ACLAuthMethod{Name: "is this name invalid?"}, true, "invalid name"}, + {"invalid token locality", &ACLAuthMethod{TokenLocality: "regional"}, true, "invalid token locality"}, + {"invalid type", &ACLAuthMethod{Type: "groovy"}, true, "invalid token type"}, + {"invalid max ttl", &ACLAuthMethod{MaxTokenTTL: badTTL}, true, "invalid token type"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + minTTL, _ := time.ParseDuration("10s") + maxTTL, _ := time.ParseDuration("10h") + got := tt.method.Validate(minTTL, maxTTL) + if tt.wantErr { + must.Error(t, got, must.Sprintf( + "ACLAuthMethod.Validate() got error, didn't expect it; test case: %s", tt.name)) + must.StrContains(t, got.Error(), tt.errContains, must.Sprintf( + "ACLAuthMethod.Validate() got %v error message, expected %v; test case: %s", + got, tt.errContains, tt.name)) + } else { + must.NoError(t, got, must.Sprintf( + "ACLAuthMethod.Validate() expected an error but didn't get one; test case: %s", tt.name)) + } + }) + } +} + +func TestACLAuthMethodConfig_Copy(t *testing.T) { + ci.Parallel(t) + + amc1 := &ACLAuthMethodConfig{ + OIDCDiscoveryURL: "http://example.com", + OIDCClientID: "mock", + OIDCClientSecret: "very secret secret", + BoundAudiences: []string{"audience1", "audience2"}, + AllowedRedirectURIs: []string{"foo", "bar"}, + DiscoveryCaPem: []string{"foo"}, + SigningAlgs: []string{"bar"}, + ClaimMappings: map[string]string{"foo": "bar"}, + ListClaimMappings: map[string]string{"foo": "bar"}, + } + + amc2 := amc1.Copy() + must.Eq(t, amc1, amc2) + + amc3 := amc1.Copy() + amc3.AllowedRedirectURIs = []string{"new", "urls"} + must.NotEq(t, amc1, amc3) +} diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index dc3d83b1e..c53b6bbae 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -12223,134 +12223,6 @@ type ACLTokenUpsertResponse struct { WriteMeta } -// ACLAuthMethod is used to capture the properties of an authentication method -// used for single sing-on -type ACLAuthMethod struct { - Name string - Type string - TokenLocality string // is the token valid locally or globally? - MaxTokenTTL string - Default bool - Config *ACLAuthMethodConfig - - Hash []byte - - CreateTime time.Time - ModifyTime time.Time - CreateIndex uint64 - ModifyIndex uint64 -} - -// SetHash is used to compute and set the hash of the ACL auth method. This -// should be called every and each time a user specified field on the method is -// changed before updating the Nomad state store. -func (a *ACLAuthMethod) SetHash() []byte { - - // Initialize a 256bit Blake2 hash (32 bytes). - hash, err := blake2b.New256(nil) - if err != nil { - panic(err) - } - - _, _ = hash.Write([]byte(a.Name)) - _, _ = hash.Write([]byte(a.Type)) - _, _ = hash.Write([]byte(a.TokenLocality)) - _, _ = hash.Write([]byte(a.MaxTokenTTL)) - _, _ = hash.Write([]byte(strconv.FormatBool(a.Default))) - - if a.Config != nil { - _, _ = hash.Write([]byte(a.Config.OIDCDiscoveryURL)) - _, _ = hash.Write([]byte(a.Config.OIDCClientID)) - _, _ = hash.Write([]byte(a.Config.OIDCClientSecret)) - for _, ba := range a.Config.BoundAudiences { - _, _ = hash.Write([]byte(ba)) - } - for _, uri := range a.Config.AllowedRedirectURIs { - _, _ = hash.Write([]byte(uri)) - } - for _, pem := range a.Config.DiscoveryCaPem { - _, _ = hash.Write([]byte(pem)) - } - for _, sa := range a.Config.SigningAlgs { - _, _ = hash.Write([]byte(sa)) - } - for k, v := range a.Config.ClaimMappings { - _, _ = hash.Write([]byte(k)) - _, _ = hash.Write([]byte(v)) - } - for k, v := range a.Config.ListClaimMappings { - _, _ = hash.Write([]byte(k)) - _, _ = hash.Write([]byte(v)) - } - } - - // Finalize the hash. - hashVal := hash.Sum(nil) - - // Set and return the hash. - a.Hash = hashVal - return hashVal -} - -func (a *ACLAuthMethod) Equal(other *ACLAuthMethod) bool { - if a == nil || other == nil { - return a == other - } - if len(a.Hash) == 0 { - a.SetHash() - } - if len(other.Hash) == 0 { - other.SetHash() - } - return bytes.Equal(a.Hash, other.Hash) - -} - -// Copy creates a deep copy of the ACL auth method. This copy can then be safely -// modified. It handles nil objects. -func (a *ACLAuthMethod) Copy() *ACLAuthMethod { - if a == nil { - return nil - } - - c := new(ACLAuthMethod) - *c = *a - - c.Hash = slices.Clone(a.Hash) - c.Config = a.Config.Copy() - - return c -} - -// ACLAuthMethodConfig is used to store configuration of an auth method -type ACLAuthMethodConfig struct { - OIDCDiscoveryURL string - OIDCClientID string - OIDCClientSecret string - BoundAudiences []string - AllowedRedirectURIs []string - DiscoveryCaPem []string - SigningAlgs []string - ClaimMappings map[string]string - ListClaimMappings map[string]string -} - -func (a *ACLAuthMethodConfig) Copy() *ACLAuthMethodConfig { - if a == nil { - return nil - } - - c := new(ACLAuthMethodConfig) - *c = *a - - c.BoundAudiences = slices.Clone(a.BoundAudiences) - c.AllowedRedirectURIs = slices.Clone(a.AllowedRedirectURIs) - c.DiscoveryCaPem = slices.Clone(a.DiscoveryCaPem) - c.SigningAlgs = slices.Clone(a.SigningAlgs) - - return c -} - // OneTimeToken is used to log into the web UI using a token provided by the // command line. type OneTimeToken struct { diff --git a/nomad/structs/structs_test.go b/nomad/structs/structs_test.go index 9399db789..d00cddca8 100644 --- a/nomad/structs/structs_test.go +++ b/nomad/structs/structs_test.go @@ -6117,26 +6117,6 @@ func TestACLPolicySetHash(t *testing.T) { assert.NotEqual(t, out1, out2) } -func TestACLAuthMethodSetHash(t *testing.T) { - ci.Parallel(t) - - am := &ACLAuthMethod{ - Name: "foo", - Type: "bad type", - } - out1 := am.SetHash() - assert.NotNil(t, out1) - assert.NotNil(t, am.Hash) - assert.Equal(t, out1, am.Hash) - - am.Type = "good type" - out2 := am.SetHash() - assert.NotNil(t, out2) - assert.NotNil(t, am.Hash) - assert.Equal(t, out2, am.Hash) - assert.NotEqual(t, out1, out2) -} - func TestTaskEventPopulate(t *testing.T) { ci.Parallel(t)