diff --git a/client/client.go b/client/client.go index 9da6940bd..e80c0459b 100644 --- a/client/client.go +++ b/client/client.go @@ -2108,14 +2108,15 @@ func (c *Client) retryRegisterNode() { } retryIntv := registerRetryIntv - if err == noServersErr || structs.IsErrNoRegionPath(err) { + if errors.Is(err, noServersErr) || structs.IsErrNoRegionPath(err) { c.logger.Debug("registration waiting on servers") c.triggerDiscovery() retryIntv = noServerRetryIntv - } else if structs.IsErrPermissionDenied(err) { - // any previous cluster state we have here is invalid (ex. client + } else if structs.IsErrPermissionDenied(err) && c.config.IntroToken == "" { + // Any previous cluster state we have here is invalid (ex. client // has been assigned to a new region), so clear the token and local - // state for next pass. + // state for next pass. This is unless the operator has provided an + // intro token, in which case we will retry with that. authToken = "" c.stateDB.PutNodeRegistration(&cstructs.NodeRegistration{HasRegistered: false}) c.logger.Error("error registering", "error", err) @@ -2131,10 +2132,14 @@ func (c *Client) retryRegisterNode() { } } -// getRegistrationToken gets the node secret to use for the Node.Register call. -// Registration is trust-on-first-use so we can't send the auth token with the -// initial request, but we want to add the auth token after that so that we can -// capture metrics. +// getRegistrationToken gets the appropriate authentication token to use for the +// Node.Register call. When a client first register, it may optionally use an +// intro token to bootstrap the registration. If this is not set, the existing +// behavior of no auth token is used. +// +// If the client has already registered, it will use either the nodes secret ID +// or its identity. This detail depends on whether the client is talking to +// upgraded servers that support the new identity system or not. func (c *Client) getRegistrationToken() string { select { @@ -2149,12 +2154,34 @@ func (c *Client) getRegistrationToken() string { if err != nil { c.logger.Error("could not determine previous node registration", "error", err) } - if registration != nil && registration.HasRegistered { - c.registeredOnce.Do(func() { close(c.registeredCh) }) - return c.nodeAuthToken() + + // If the state call indicates that we have not registered yet, + // fall-through to the end logic of this function to return any intro + // token. + if registration == nil || !registration.HasRegistered { + break } + + // Attempt to pull and use the node's identity from the state store. The + // state store restore happens asynchronously to this function, so we + // can't rely on it being populated in the client object at this time. + clientIdentity, err := c.stateDB.GetNodeIdentity() + if err != nil { + c.logger.Error("could not determine node identity", "error", err) + } + if clientIdentity != "" { + c.setNodeIdentityToken(clientIdentity) + } + + c.registeredOnce.Do(func() { close(c.registeredCh) }) + return c.nodeAuthToken() } - return "" + + // Reaching this point means we are registering for the first time. If the + // client configuration has a bootstrap token, we can use that to perform + // the initial registration. If this was not supplied, the parameter will be + // an empty string, which is fine and the backwards compatible behavior. + return c.GetConfig().IntroToken } // registerNode is used to register the node or update the registration diff --git a/client/client_test.go b/client/client_test.go index cfc3cd369..f3cba229d 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -27,6 +27,7 @@ import ( regMock "github.com/hashicorp/nomad/client/serviceregistration/mock" "github.com/hashicorp/nomad/client/state" cstate "github.com/hashicorp/nomad/client/state" + cstructs "github.com/hashicorp/nomad/client/structs" ctestutil "github.com/hashicorp/nomad/client/testutil" "github.com/hashicorp/nomad/command/agent/consul" "github.com/hashicorp/nomad/helper/pluginutils/catalog" @@ -1453,6 +1454,79 @@ func TestClient_ServerList(t *testing.T) { } } +func TestClient_getRegistrationToken(t *testing.T) { + ci.Parallel(t) + + t.Run("no intro initial register", func(t *testing.T) { + testClient, testClientCleanup := TestClient(t, func(c *config.Config) {}) + t.Cleanup(func() { _ = testClientCleanup() }) + must.Eq(t, "", testClient.getRegistrationToken()) + }) + + t.Run("intro initial register", func(t *testing.T) { + testClient, testClientCleanup := TestClient(t, func(c *config.Config) { + c.IntroToken = "my-intro-token" + }) + t.Cleanup(func() { _ = testClientCleanup() }) + must.Eq(t, "my-intro-token", testClient.getRegistrationToken()) + }) + + t.Run("secret id registered", func(t *testing.T) { + testClient, testClientCleanup := TestClient(t, func(c *config.Config) {}) + t.Cleanup(func() { _ = testClientCleanup() }) + + close(testClient.registeredCh) + + must.Eq(t, testClient.Node().SecretID, testClient.getRegistrationToken()) + }) + + t.Run("node identity registered", func(t *testing.T) { + testClient, testClientCleanup := TestClient(t, func(c *config.Config) {}) + t.Cleanup(func() { _ = testClientCleanup() }) + + testClient.identity.Store("mylovelylovelyidentity") + close(testClient.registeredCh) + + must.Eq(t, testClient.identity.Load().(string), testClient.getRegistrationToken()) + }) + + t.Run("secret id registered state", func(t *testing.T) { + testClient, testClientCleanup := TestClient(t, func(c *config.Config) { + c.StateDBFactory = func(logger hclog.Logger, stateDir string) (state.StateDB, error) { + return cstate.NewMemDB(logger), nil + } + }) + t.Cleanup(func() { _ = testClientCleanup() }) + + must.NoError(t, testClient.stateDB.PutNodeRegistration( + &cstructs.NodeRegistration{ + HasRegistered: true, + }, + )) + + must.Eq(t, testClient.Node().SecretID, testClient.getRegistrationToken()) + }) + + t.Run("node identity registered state", func(t *testing.T) { + testClient, testClientCleanup := TestClient(t, func(c *config.Config) { + c.StateDBFactory = func(logger hclog.Logger, stateDir string) (state.StateDB, error) { + return cstate.NewMemDB(logger), nil + } + }) + t.Cleanup(func() { _ = testClientCleanup() }) + + must.NoError(t, testClient.stateDB.PutNodeRegistration( + &cstructs.NodeRegistration{ + HasRegistered: true, + }, + )) + + must.NoError(t, testClient.stateDB.PutNodeIdentity("my-identity-token")) + + must.Eq(t, "my-identity-token", testClient.getRegistrationToken()) + }) +} + func TestClient_handleNodeUpdateResponse(t *testing.T) { ci.Parallel(t) diff --git a/client/config/config.go b/client/config/config.go index ebefd532a..50461e23e 100644 --- a/client/config/config.go +++ b/client/config/config.go @@ -108,6 +108,10 @@ type Config struct { // should be owned by root with file mode 0o755. AllocMountsDir string + // IntroToken is the signed JWT token that should be used to introduce this + // client to the servers on first registration. + IntroToken string + // Logger provides a logger to the client Logger log.InterceptLogger diff --git a/command/agent/agent.go b/command/agent/agent.go index f2694375b..69cd0fc4f 100644 --- a/command/agent/agent.go +++ b/command/agent/agent.go @@ -759,6 +759,47 @@ func (a *Agent) finalizeClientConfig(c *clientconfig.Config) error { to configure Nomad to work with Consul.`) } + // If the operator has not set an intro token via the CLI or an environment + // variable, attempt to read the intro token from the file system. This + // cannot be used as a CLI override. + if c.IntroToken == "" { + if err := a.readIntroTokenFile(c); err != nil { + return err + } + } + + return nil +} + +// readIntroTokenFile attempts to read the intro token from the file system. +func (a *Agent) readIntroTokenFile(cfg *clientconfig.Config) error { + + rootFile, err := os.OpenInRoot(cfg.StateDir, "intro_token.jwt") + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + fileStat, err := rootFile.Stat() + if err != nil { + return fmt.Errorf("failed to stat intro token file: %w", err) + } + + // If the file exists and is a file, attempt to read the contents and set + // the intro token. Any error is logged for the operator to investigate but + // does not block the agent from starting. + if fileStat.IsDir() { + return fmt.Errorf("intro token file is a directory") + } + + content, err := helper.ReadFileContent(rootFile) + if err != nil { + return fmt.Errorf("failed to read intro token file: %w", err) + } + + cfg.IntroToken = strings.TrimSpace(string(content)) return nil } @@ -775,6 +816,7 @@ func convertClientConfig(agentConfig *Config) (*clientconfig.Config, error) { conf.Servers = agentConfig.Client.Servers conf.DevMode = agentConfig.DevMode conf.EnableDebug = agentConfig.EnableDebug + conf.IntroToken = agentConfig.Client.IntroToken if agentConfig.Region != "" { conf.Region = agentConfig.Region diff --git a/command/agent/agent_test.go b/command/agent/agent_test.go index bc62c1019..cd233c85a 100644 --- a/command/agent/agent_test.go +++ b/command/agent/agent_test.go @@ -1401,6 +1401,52 @@ func TestServer_Reload_VaultConfig(t *testing.T) { must.NoError(t, agent.server.Reload(sconf)) } +func TestAgent_readIntroTokenFile(t *testing.T) { + ci.Parallel(t) + + t.Run("no file", func(t *testing.T) { + + tmpDir := t.TempDir() + testAgent := &Agent{logger: testlog.HCLogger(t), config: &Config{}} + + clientConfig := clientconfig.Config{StateDir: tmpDir} + + must.NoError(t, testAgent.readIntroTokenFile(&clientConfig)) + must.Eq(t, "", clientConfig.IntroToken) + }) + + t.Run("file", func(t *testing.T) { + + tmpDir := t.TempDir() + must.NoError( + t, + os.WriteFile( + filepath.Join(tmpDir, "intro_token.jwt"), + []byte("my-intro-token"), + 0600, + ), + ) + testAgent := &Agent{logger: testlog.HCLogger(t), config: &Config{}} + + clientConfig := clientconfig.Config{StateDir: tmpDir} + + must.NoError(t, testAgent.readIntroTokenFile(&clientConfig)) + must.Eq(t, "my-intro-token", clientConfig.IntroToken) + }) + + t.Run("directory", func(t *testing.T) { + + tmpDir := t.TempDir() + must.NoError(t, os.MkdirAll(filepath.Join(tmpDir, "intro_token.jwt"), os.ModeDir)) + + testAgent := &Agent{logger: testlog.HCLogger(t), config: &Config{}} + + clientConfig := clientconfig.Config{StateDir: tmpDir} + + must.Error(t, testAgent.readIntroTokenFile(&clientConfig)) + }) +} + func TestServer_ShouldReload_ReturnFalseForNoChanges(t *testing.T) { ci.Parallel(t) assert := assert.New(t) diff --git a/command/agent/command.go b/command/agent/command.go index b3c6ab528..c1f971750 100644 --- a/command/agent/command.go +++ b/command/agent/command.go @@ -118,6 +118,7 @@ func (c *Command) readConfig() *Config { flags.StringVar(&cmdConfig.Client.NetworkInterface, "network-interface", "", "") flags.StringVar((*string)(&cmdConfig.Client.PreferredAddressFamily), "preferred-address-family", "", "ipv4 or ipv6") flags.IntVar(&cmdConfig.Client.NetworkSpeed, "network-speed", 0, "") + flags.StringVar(&cmdConfig.Client.IntroToken, "client-intro-token", "", "") // General options flags.Var((*flaghelper.StringFlag)(&configPath), "config", "config") @@ -220,6 +221,12 @@ func (c *Command) readConfig() *Config { } } + // Perform an environment look for the client bootstrap token. If this is + // present, it will override the CLI flag. + if envToken, found := os.LookupEnv("NOMAD_CLIENT_INTRO_TOKEN"); found { + cmdConfig.Client.IntroToken = envToken + } + // Load the configuration var config *Config diff --git a/command/agent/command_test.go b/command/agent/command_test.go index 6b25c4020..56a9830d0 100644 --- a/command/agent/command_test.go +++ b/command/agent/command_test.go @@ -618,6 +618,34 @@ vault { } } +func TestCommand_readConfig_clientIntroToken(t *testing.T) { + + t.Run("env var", func(t *testing.T) { + t.Setenv("NOMAD_CLIENT_INTRO_TOKEN", "test-intro-token") + + cmd := &Command{Ui: cli.NewMockUi(), args: []string{"-dev"}} + outputConfig := cmd.readConfig() + must.Eq(t, "test-intro-token", outputConfig.Client.IntroToken) + }) + + t.Run("cli flag", func(t *testing.T) { + cmd := &Command{Ui: cli.NewMockUi(), args: []string{ + "-dev", + "-client-intro-token=test-intro-token", + }} + outputConfig := cmd.readConfig() + must.Eq(t, "test-intro-token", outputConfig.Client.IntroToken) + }) + + t.Run("none", func(t *testing.T) { + cmd := &Command{Ui: cli.NewMockUi(), args: []string{ + "-dev", + }} + outputConfig := cmd.readConfig() + must.Eq(t, "", outputConfig.Client.IntroToken) + }) +} + func Test_setupLoggers_logFile(t *testing.T) { // Generate a mock UI and temporary log file location to write to. diff --git a/command/agent/config.go b/command/agent/config.go index e35d6005a..e272ec0fa 100644 --- a/command/agent/config.go +++ b/command/agent/config.go @@ -240,6 +240,14 @@ type ClientConfig struct { // HostVolumePluginDir directory contains dynamic host volume plugins HostVolumePluginDir string `hcl:"host_volume_plugin_dir"` + // IntroToken is used to introduce the client to the servers. It is an + // optional parameter that cannot be passed within the configuration file + // object. + // + // It can be passed as a command line argument to the agent, set via an + // environment variable, or placed in a file at "${data_dir}/intro_token". + IntroToken string `hcl:"-"` + // Servers is a list of known server addresses. These are as "host:port" Servers []string `hcl:"servers"` @@ -2782,6 +2790,10 @@ func (a *ClientConfig) Merge(b *ClientConfig) *ClientConfig { if b.NodeMaxAllocs != 0 { result.NodeMaxAllocs = b.NodeMaxAllocs } + if b.IntroToken != "" { + result.IntroToken = b.IntroToken + } + return &result } diff --git a/helper/file.go b/helper/file.go new file mode 100644 index 000000000..eef19a3f0 --- /dev/null +++ b/helper/file.go @@ -0,0 +1,48 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package helper + +import ( + "io" + "os" +) + +// ReadFileContent is a helper that mimics the stdlib ReadFile implementation, +// but accepts an already opened file handle. This is useful when using os.Root +// functionality such as OpenInRoot which does not have convenient read methods. +func ReadFileContent(file *os.File) ([]byte, error) { + + var size int + if info, err := file.Stat(); err == nil { + size64 := info.Size() + if int64(int(size64)) == size64 { + size = int(size64) + } + } + size++ // one byte for final read at EOF + + // If a file claims a small size, read at least 512 bytes. In particular, + // files in Linux's /proc claim size 0 but then do not work right if read in + // small pieces, so an initial read of 1 byte would not work correctly. + if size < 512 { + size = 512 + } + + data := make([]byte, 0, size) + for { + n, err := file.Read(data[len(data):cap(data)]) + data = data[:len(data)+n] + if err != nil { + if err == io.EOF { + err = nil + } + return data, err + } + + if len(data) >= cap(data) { + d := append(data[:cap(data)], 0) //nolint:gocritic + data = d[:len(data)] + } + } +} diff --git a/helper/file_test.go b/helper/file_test.go new file mode 100644 index 000000000..ffa42d189 --- /dev/null +++ b/helper/file_test.go @@ -0,0 +1,36 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package helper + +import ( + "os" + "testing" + + "github.com/hashicorp/nomad/ci" + "github.com/shoenig/test/must" +) + +func Test_ReadFileContent(t *testing.T) { + ci.Parallel(t) + + tmpDir := t.TempDir() + + rootDir, err := os.OpenRoot(tmpDir) + must.NoError(t, err) + t.Cleanup(func() { must.NoError(t, rootDir.Close()) }) + + rootFile, err := rootDir.OpenFile("testfile.txt", os.O_CREATE|os.O_RDWR, 0777) + must.NoError(t, err) + + _, err = rootFile.WriteString("Hello, World!") + must.NoError(t, err) + must.NoError(t, rootFile.Close()) + + // Reopen the file using os.OpenInRoot to simulate reading from a root + // file. + rootFileRead, err := os.OpenInRoot(tmpDir, "testfile.txt") + data, err := ReadFileContent(rootFileRead) + must.NoError(t, err) + must.Eq(t, "Hello, World!", string(data)) +} diff --git a/nomad/acl_endpoint_test.go b/nomad/acl_endpoint_test.go index 941e4187c..7545a3ecd 100644 --- a/nomad/acl_endpoint_test.go +++ b/nomad/acl_endpoint_test.go @@ -4313,7 +4313,6 @@ func TestACL_ClientIntroductionToken(t *testing.T) { must.True(t, nodeWriteClaims.IsNodeIntroduction()) must.Eq(t, nodeWriteReq.NodeName, nodeWriteClaims.NodeIntroductionIdentityClaims.NodeName) must.Eq(t, nodeWriteReq.NodePool, nodeWriteClaims.NodeIntroductionIdentityClaims.NodePool) - must.Eq(t, nodeWriteReq.Region, nodeWriteClaims.NodeIntroductionIdentityClaims.NodeRegion) // The JWT creation happens asynchronously in the RPC handler, so we // need to verify the TTL is set using a bound check. @@ -4360,7 +4359,6 @@ func TestACL_ClientIntroductionToken(t *testing.T) { must.True(t, nodeWriteClaims.IsNodeIntroduction()) must.Eq(t, req.NodeName, nodeWriteClaims.NodeIntroductionIdentityClaims.NodeName) must.Eq(t, req.NodePool, nodeWriteClaims.NodeIntroductionIdentityClaims.NodePool) - must.Eq(t, req.Region, nodeWriteClaims.NodeIntroductionIdentityClaims.NodeRegion) // The JWT creation happens asynchronously in the RPC handler, so we // need to verify the TTL is set using a bound check. diff --git a/nomad/auth/auth.go b/nomad/auth/auth.go index afbd5ddfb..89a798d85 100644 --- a/nomad/auth/auth.go +++ b/nomad/auth/auth.go @@ -335,7 +335,7 @@ func (s *Authenticator) AuthenticateNodeIdentityGenerator(ctx RPCContext, args s if err != nil { return err } - if !claims.IsNode() { + if !claims.IsNode() && !claims.IsNodeIntroduction() { return structs.ErrPermissionDenied } identity.Claims = claims @@ -541,6 +541,13 @@ func (s *Authenticator) VerifyClaim(token string) (*structs.IdentityClaims, erro return claims, nil } + // Node introduction claims are a special case where we don't verify them + // against the state store, since they are used to introduce a node that + // does not yet exist. + if claims.IsNodeIntroduction() { + return claims, nil + } + return nil, errors.New("failed to determine claim type") } diff --git a/nomad/auth/auth_test.go b/nomad/auth/auth_test.go index b3059f6b7..ff958771b 100644 --- a/nomad/auth/auth_test.go +++ b/nomad/auth/auth_test.go @@ -746,6 +746,34 @@ func TestAuthenticator_AuthenticateClientRegistration(t *testing.T) { must.True(t, aclObj.AllowClientOp()) }, }, + { + name: "mTLS acl with node introduction", + testFn: func(t *testing.T, store *state.StateStore) { + + claims := structs.GenerateNodeIntroductionIdentityClaims( + "", + "default", + "global", + 1*time.Hour, + ) + + auth := testAuthenticator(t, store, true, true) + token, err := auth.encrypter.(*testEncrypter).signClaim(claims) + must.NoError(t, err) + + ctx := newTestContext(t, "client.global.nomad", "192.168.1.1") + + args := structs.GenericRequest{ + QueryOptions: structs.QueryOptions{ + AuthToken: token, + }, + } + + must.NoError(t, auth.AuthenticateNodeIdentityGenerator(ctx, &args)) + must.NotNil(t, args.GetIdentity().GetClaims()) + must.True(t, args.GetIdentity().GetClaims().IsNodeIntroduction()) + }, + }, } for _, tc := range testCases { diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index d752be090..653b34ffd 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -197,6 +197,10 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp return err } + // If the node has an entry in the state store, we perform a check to ensure + // the secret ID matches the one stored. If there is no entry, we perform a + // check to ensure the node is allowed to register given the request and the + // server introduction enforcement configuration. if originalNode != nil { // Check if the SecretID has been tampered with if args.Node.SecretID != originalNode.SecretID && originalNode.SecretID != "" { @@ -208,6 +212,10 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp if originalNode.Status != "" { args.Node.Status = originalNode.Status } + // The called function performs all the required logging and metric + // emitting, so we only need to check the return value. + } else if !n.newRegistrationAllowed(args, authErr) { + return structs.ErrPermissionDenied } // We have a valid node connection, so add the mapping to cache the @@ -311,6 +319,98 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp return nil } +// newRegistrationAllowed determines whether the node registration is allowed to +// proceed based on the node introduction enforcement level and the +// authenticated identity of the request. +// +// The function handles logging and emitting metrics based on the enforcement +// level, so the caller only needs to check the return value. +func (n *Node) newRegistrationAllowed( + args *structs.NodeRegisterRequest, + authErr error, +) bool { + + enforcementLvl := n.srv.config.NodeIntroductionConfig.Enforcement + + // If the enforcement level is set to "none", we allow the registration to + // proceed without any checks. This is the pre-1.11 workflow that won't emit + // any metrics or logs for node registrations. + if enforcementLvl == structs.NodeIntroductionEnforcementNone { + return true + } + + claims := args.GetIdentity().GetClaims() + + // If the request was made with a node introduction identity, check whether + // the claims match the node's claims. + var claimsMatch bool + + if claims.IsNodeIntroduction() { + claimsMatch = claims.NodeIntroductionIdentityClaims.NodePool == args.Node.NodePool && + (claims.NodeIntroductionIdentityClaims.NodeName == "" || + claims.NodeIntroductionIdentityClaims.NodeName == args.Node.Name) + } + + // If there was no authentication error and the identity claims match the + // node's claims, the registration is allowed to proceed. + if authErr == nil && claimsMatch { + return true + } + + // If we have reached this point, we know that the registration is dependent + // on the enforcement level and that this request does not have suitable + // claims. Emit our metric to indicate a node registration not using a valid + // introduction token. + metrics.IncrCounter([]string{"nomad", "client", "introduction_violation_num"}, 1) + + // Build a base set of logging pairs that will be used for the logging + // message. This provides operators with information about the node that is + // attempting to register without a valid introduction token. + loggingPairs := []any{ + "enforcement_level", enforcementLvl, + "node_id", args.Node.ID, + "node_pool", args.Node.NodePool, + "node_name", args.Node.Name, + } + + // If the node used a node introduction identity, add the claims for + // comparison to the logging pairs. + if claims.IsNodeIntroduction() { + loggingPairs = append(loggingPairs, claims.NodeIntroductionIdentityClaims.LoggingPairs()...) + } + + // Make some effort to log a message that indicates why the node is failing + // to introduce itself properly. + msg := "node registration introduction claims mismatch" + + if authErr != nil { + msg = "node registration introduction authentication failure" + loggingPairs = append(loggingPairs, "error", authErr) + } else if args.GetIdentity().ACLToken == structs.AnonymousACLToken { + msg = "node registration without introduction token" + } + + // If there was an authentication error or the claims do not match, the node + // registration is not allowed to proceed. This is considered an invalid + // request and does not take into account the enforcement level. + if authErr != nil || (!claimsMatch && claims.IsNodeIntroduction()) { + n.logger.Error(msg, loggingPairs...) + return false + } + + // Based on the enforcement level, log the message and return whether the + // handler should allow the registration to proceed. The default is a + // catchall that includes the strict enforcement level. + switch enforcementLvl { + case structs.NodeIntroductionEnforcementWarn: + n.logger.Warn(msg, loggingPairs...) + return true + default: + n.logger.Error(msg, loggingPairs...) + return false + } +} + // shouldCreateNodeEval returns true if the node update may result into // allocation updates, so the node should be re-evaluating. // diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index 92ad8b6a6..0945d72b0 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -5,6 +5,7 @@ package nomad import ( "context" + "errors" "fmt" "net" "net/rpc" @@ -4595,6 +4596,556 @@ func TestClientEndpoint_UpdateAlloc_Evals_ByTrigger(t *testing.T) { } +func TestNode_Register_Introduction(t *testing.T) { + ci.Parallel(t) + + testServer, _, testServerCleanup := TestACLServer(t, nil) + t.Cleanup(testServerCleanup) + rpcCodec := rpcClient(t, testServer) + + testutil.WaitForLeader(t, testServer.RPC) + testutil.WaitForKeyring(t, testServer.RPC, testServer.config.Region) + + t.Run("empty auth enforcement none", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementNone + + registerReq := structs.NodeRegisterRequest{ + Node: mock.Node(), + WriteRequest: structs.WriteRequest{ + Region: testServer.Region(), + }, + } + + var resp structs.NodeUpdateResponse + must.NoError( + t, + msgpackrpc.CallWithCodec(rpcCodec, "Node.Register", ®isterReq, &resp), + ) + + nodeResp, err := testServer.State().NodeByID(nil, registerReq.Node.ID) + must.NoError(t, err) + must.NotNil(t, nodeResp) + must.Eq(t, registerReq.Node.SecretID, nodeResp.SecretID) + }) + + t.Run("empty auth enforcement warn", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementWarn + + registerReq := structs.NodeRegisterRequest{ + Node: mock.Node(), + WriteRequest: structs.WriteRequest{ + Region: testServer.Region(), + }, + } + + var resp structs.NodeUpdateResponse + must.NoError( + t, + msgpackrpc.CallWithCodec(rpcCodec, "Node.Register", ®isterReq, &resp), + ) + + nodeResp, err := testServer.State().NodeByID(nil, registerReq.Node.ID) + must.NoError(t, err) + must.NotNil(t, nodeResp) + must.Eq(t, registerReq.Node.SecretID, nodeResp.SecretID) + }) + + t.Run("empty auth enforcement strict", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementStrict + + registerReq := structs.NodeRegisterRequest{ + Node: mock.Node(), + WriteRequest: structs.WriteRequest{ + Region: testServer.Region(), + }, + } + + var resp structs.NodeUpdateResponse + must.ErrorContains( + t, + msgpackrpc.CallWithCodec(rpcCodec, "Node.Register", ®isterReq, &resp), + "Permission denied", + ) + }) + + t.Run("valid jwt enforcement none", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementNone + + mockNode := mock.Node() + + introClaims := structs.GenerateNodeIntroductionIdentityClaims( + mockNode.Name, + mockNode.NodePool, + testServer.Region(), + testServer.config.NodeIntroductionConfig.DefaultIdentityTTL, + ) + + signedJWT, _, err := testServer.encrypter.SignClaims(introClaims) + must.NoError(t, err) + + registerReq := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{ + AuthToken: signedJWT, + Region: testServer.Region(), + }, + } + + var resp structs.NodeUpdateResponse + must.NoError( + t, + msgpackrpc.CallWithCodec(rpcCodec, "Node.Register", ®isterReq, &resp), + ) + + nodeResp, err := testServer.State().NodeByID(nil, registerReq.Node.ID) + must.NoError(t, err) + must.NotNil(t, nodeResp) + must.Eq(t, registerReq.Node.SecretID, nodeResp.SecretID) + }) + + t.Run("valid jwt enforcement warn", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementWarn + + mockNode := mock.Node() + + introClaims := structs.GenerateNodeIntroductionIdentityClaims( + mockNode.Name, + mockNode.NodePool, + testServer.Region(), + testServer.config.NodeIntroductionConfig.DefaultIdentityTTL, + ) + + signedJWT, _, err := testServer.encrypter.SignClaims(introClaims) + must.NoError(t, err) + + registerReq := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{ + AuthToken: signedJWT, + Region: testServer.Region(), + }, + } + + var resp structs.NodeUpdateResponse + must.NoError( + t, + msgpackrpc.CallWithCodec(rpcCodec, "Node.Register", ®isterReq, &resp), + ) + + nodeResp, err := testServer.State().NodeByID(nil, registerReq.Node.ID) + must.NoError(t, err) + must.NotNil(t, nodeResp) + must.Eq(t, registerReq.Node.SecretID, nodeResp.SecretID) + }) + + t.Run("valid jwt enforcement strict", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementStrict + + mockNode := mock.Node() + + introClaims := structs.GenerateNodeIntroductionIdentityClaims( + mockNode.Name, + mockNode.NodePool, + testServer.Region(), + testServer.config.NodeIntroductionConfig.DefaultIdentityTTL, + ) + + signedJWT, _, err := testServer.encrypter.SignClaims(introClaims) + must.NoError(t, err) + + registerReq := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{ + AuthToken: signedJWT, + Region: testServer.Region(), + }, + } + + var resp structs.NodeUpdateResponse + must.NoError( + t, + msgpackrpc.CallWithCodec(rpcCodec, "Node.Register", ®isterReq, &resp), + ) + + nodeResp, err := testServer.State().NodeByID(nil, registerReq.Node.ID) + must.NoError(t, err) + must.NotNil(t, nodeResp) + must.Eq(t, registerReq.Node.SecretID, nodeResp.SecretID) + }) + + t.Run("invalid jwt enforcement none", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementNone + + mockNode := mock.Node() + + introClaims := structs.GenerateNodeIntroductionIdentityClaims( + mockNode.Name, + mockNode.NodePool, + testServer.Region(), + testServer.config.NodeIntroductionConfig.DefaultIdentityTTL, + ) + + signedJWT, _, err := testServer.encrypter.SignClaims(introClaims) + must.NoError(t, err) + + mockNode.Name = "changed-name" + + registerReq := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{ + AuthToken: signedJWT, + Region: testServer.Region(), + }, + } + + var resp structs.NodeUpdateResponse + must.NoError( + t, + msgpackrpc.CallWithCodec(rpcCodec, "Node.Register", ®isterReq, &resp), + ) + + nodeResp, err := testServer.State().NodeByID(nil, registerReq.Node.ID) + must.NoError(t, err) + must.NotNil(t, nodeResp) + must.Eq(t, registerReq.Node.SecretID, nodeResp.SecretID) + }) + + t.Run("invalid jwt enforcement warn", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementWarn + + mockNode := mock.Node() + + introClaims := structs.GenerateNodeIntroductionIdentityClaims( + mockNode.Name, + mockNode.NodePool, + testServer.Region(), + testServer.config.NodeIntroductionConfig.DefaultIdentityTTL, + ) + + signedJWT, _, err := testServer.encrypter.SignClaims(introClaims) + must.NoError(t, err) + + mockNode.Name = "changed-name" + + registerReq := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{ + AuthToken: signedJWT, + Region: testServer.Region(), + }, + } + + var resp structs.NodeUpdateResponse + must.ErrorContains( + t, + msgpackrpc.CallWithCodec(rpcCodec, "Node.Register", ®isterReq, &resp), + "Permission denied", + ) + }) + + t.Run("invalid jwt enforcement strict", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementStrict + + mockNode := mock.Node() + + introClaims := structs.GenerateNodeIntroductionIdentityClaims( + mockNode.Name, + mockNode.NodePool, + testServer.Region(), + testServer.config.NodeIntroductionConfig.DefaultIdentityTTL, + ) + + signedJWT, _, err := testServer.encrypter.SignClaims(introClaims) + must.NoError(t, err) + + mockNode.Name = "changed-name" + + registerReq := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{ + AuthToken: signedJWT, + Region: testServer.Region(), + }, + } + + var resp structs.NodeUpdateResponse + must.ErrorContains( + t, + msgpackrpc.CallWithCodec(rpcCodec, "Node.Register", ®isterReq, &resp), + "Permission denied", + ) + }) +} + +func TestNode_newRegistrationAllowed(t *testing.T) { + ci.Parallel(t) + + // Generate a stable mock node for testing that includes a populated node + // pool field. + mockNode := structs.MockNode() + mockNode.NodePool = "monitoring" + + // Create a test server, so we can sign JWTs. + testServer, _, testServerCleanup := TestACLServer(t, nil) + t.Cleanup(testServerCleanup) + testutil.WaitForLeader(t, testServer.RPC) + testutil.WaitForKeyring(t, testServer.RPC, testServer.config.Region) + + nodeEndpoint := &Node{ + ctx: &RPCContext{}, + logger: testServer.logger, + srv: testServer, + } + + t.Run("enforcement none anonymous", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementNone + + nodeRegisterRequest := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{}, + } + + require.True(t, nodeEndpoint.newRegistrationAllowed( + &nodeRegisterRequest, + testServer.auth.AuthenticateNodeIdentityGenerator(nodeEndpoint.ctx, &nodeRegisterRequest), + )) + }) + + t.Run("enforcement warn anonymous", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementWarn + + nodeRegisterRequest := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{ + AuthToken: structs.AnonymousACLToken.SecretID, + }, + } + + require.True(t, nodeEndpoint.newRegistrationAllowed( + &nodeRegisterRequest, + testServer.auth.AuthenticateNodeIdentityGenerator(nodeEndpoint.ctx, &nodeRegisterRequest), + )) + }) + + t.Run("enforcement strict anonymous", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementStrict + + nodeRegisterRequest := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{}, + } + + require.False(t, nodeEndpoint.newRegistrationAllowed( + &nodeRegisterRequest, + testServer.auth.AuthenticateNodeIdentityGenerator(nodeEndpoint.ctx, &nodeRegisterRequest), + )) + }) + + t.Run("enforcement warn auth error", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementWarn + + nodeRegisterRequest := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{}, + } + + require.False(t, nodeEndpoint.newRegistrationAllowed( + &nodeRegisterRequest, + errors.New("jwt: token is expired"), + )) + }) + + t.Run("enforcement strict auth error", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementStrict + + nodeRegisterRequest := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{}, + } + + require.False(t, nodeEndpoint.newRegistrationAllowed( + &nodeRegisterRequest, + errors.New("jwt: token is expired"), + )) + }) + + t.Run("enforcement warn claims pool mismatch", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementWarn + + claims := structs.GenerateNodeIntroductionIdentityClaims( + mockNode.Name, + "wrong-node-pool", + testServer.Region(), + 10*time.Minute, + ) + + signedJWT, _, err := testServer.encrypter.SignClaims(claims) + must.NoError(t, err) + + nodeRegisterRequest := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{ + AuthToken: signedJWT, + }, + } + + require.False(t, nodeEndpoint.newRegistrationAllowed( + &nodeRegisterRequest, + testServer.auth.AuthenticateNodeIdentityGenerator(nodeEndpoint.ctx, &nodeRegisterRequest), + )) + }) + + t.Run("enforcement strict claims pool mismatch", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementStrict + + claims := structs.GenerateNodeIntroductionIdentityClaims( + mockNode.Name, + "wrong-node-pool", + testServer.Region(), + 10*time.Minute, + ) + + signedJWT, _, err := testServer.encrypter.SignClaims(claims) + must.NoError(t, err) + + nodeRegisterRequest := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{ + AuthToken: signedJWT, + }, + } + + require.False(t, nodeEndpoint.newRegistrationAllowed( + &nodeRegisterRequest, + testServer.auth.AuthenticateNodeIdentityGenerator(nodeEndpoint.ctx, &nodeRegisterRequest), + )) + }) + + t.Run("enforcement warn claims name mismatch", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementWarn + + claims := structs.GenerateNodeIntroductionIdentityClaims( + "wrong-node-name", + mockNode.NodePool, + testServer.Region(), + 10*time.Minute, + ) + + signedJWT, _, err := testServer.encrypter.SignClaims(claims) + must.NoError(t, err) + + nodeRegisterRequest := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{ + AuthToken: signedJWT, + }, + } + + require.False(t, nodeEndpoint.newRegistrationAllowed( + &nodeRegisterRequest, + testServer.auth.AuthenticateNodeIdentityGenerator(nodeEndpoint.ctx, &nodeRegisterRequest), + )) + }) + + t.Run("enforcement strict claims name mismatch", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementStrict + + claims := structs.GenerateNodeIntroductionIdentityClaims( + "wrong-node-name", + mockNode.NodePool, + testServer.Region(), + 10*time.Minute, + ) + + signedJWT, _, err := testServer.encrypter.SignClaims(claims) + must.NoError(t, err) + + nodeRegisterRequest := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{ + AuthToken: signedJWT, + }, + } + + require.False(t, nodeEndpoint.newRegistrationAllowed( + &nodeRegisterRequest, + testServer.auth.AuthenticateNodeIdentityGenerator(nodeEndpoint.ctx, &nodeRegisterRequest), + )) + }) + + t.Run("enforcement warn claims match", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementWarn + + claims := structs.GenerateNodeIntroductionIdentityClaims( + mockNode.Name, + mockNode.NodePool, + testServer.Region(), + 10*time.Minute, + ) + + signedJWT, _, err := testServer.encrypter.SignClaims(claims) + must.NoError(t, err) + + nodeRegisterRequest := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{ + AuthToken: signedJWT, + }, + } + + require.True(t, nodeEndpoint.newRegistrationAllowed( + &nodeRegisterRequest, + testServer.auth.AuthenticateNodeIdentityGenerator(nodeEndpoint.ctx, &nodeRegisterRequest), + )) + }) + + t.Run("enforcement strict claims match", func(t *testing.T) { + + testServer.config.NodeIntroductionConfig.Enforcement = structs.NodeIntroductionEnforcementStrict + + claims := structs.GenerateNodeIntroductionIdentityClaims( + mockNode.Name, + mockNode.NodePool, + testServer.Region(), + 10*time.Minute, + ) + + signedJWT, _, err := testServer.encrypter.SignClaims(claims) + must.NoError(t, err) + + nodeRegisterRequest := structs.NodeRegisterRequest{ + Node: mockNode, + WriteRequest: structs.WriteRequest{ + AuthToken: signedJWT, + }, + } + + require.True(t, nodeEndpoint.newRegistrationAllowed( + &nodeRegisterRequest, + testServer.auth.AuthenticateNodeIdentityGenerator(nodeEndpoint.ctx, &nodeRegisterRequest), + )) + }) +} + // TestNode_List_PaginationFiltering asserts that API pagination and filtering // works against the Node.List RPC. func TestNode_List_PaginationFiltering(t *testing.T) { diff --git a/nomad/structs/node.go b/nomad/structs/node.go index a4f0156e0..7c8b26f20 100644 --- a/nomad/structs/node.go +++ b/nomad/structs/node.go @@ -580,6 +580,12 @@ func (n *NodeRegisterRequest) ShouldGenerateNodeIdentity( // node identity. claims := n.GetIdentity().GetClaims() + // If the request was made with a node introduction identity, this is an + // initial registration and we should generate a new identity. + if claims.IsNodeIntroduction() { + return true + } + // It is possible that the node has been restarted and had its configuration // updated. In this case, we should generate a new identity for the node, so // it reflects its new claims. @@ -835,9 +841,8 @@ func (n *NodeIntroductionConfig) Validate() error { // NodeIntroductionIdentityClaims contains the claims for node introduction. type NodeIntroductionIdentityClaims struct { - NodeRegion string `json:"nomad_region"` - NodePool string `json:"nomad_node_pool"` - NodeName string `json:"nomad_node_name"` + NodePool string `json:"nomad_node_pool"` + NodeName string `json:"nomad_node_name"` } // GenerateNodeIntroductionIdentityClaims generates a new identity JWT for node @@ -851,9 +856,8 @@ func GenerateNodeIntroductionIdentityClaims(name, pool, region string, ttl time. claims := &IdentityClaims{ NodeIntroductionIdentityClaims: &NodeIntroductionIdentityClaims{ - NodeRegion: region, - NodePool: pool, - NodeName: name, + NodePool: pool, + NodeName: name, }, Claims: jwt.Claims{ ID: uuid.Generate(), @@ -868,3 +872,19 @@ func GenerateNodeIntroductionIdentityClaims(name, pool, region string, ttl time. return claims } + +// LoggingPairs returns a set of key-value pairs that can be used for logging +// purposes. +func (n *NodeIntroductionIdentityClaims) LoggingPairs() []any { + + // The node pool is a required field on the node introduction identity, so + // we can always include it in the logging pairs. + pairs := []any{"claim_node_pool", n.NodePool} + + // The node name is optional, so we only include it if it is set. + if n.NodeName != "" { + pairs = append(pairs, "claim_node_name", n.NodeName) + } + + return pairs +} diff --git a/nomad/structs/node_test.go b/nomad/structs/node_test.go index 010005d74..6917af981 100644 --- a/nomad/structs/node_test.go +++ b/nomad/structs/node_test.go @@ -759,7 +759,6 @@ func TestGenerateNodeIntroductionIdentityClaims(t *testing.T) { must.Eq(t, "node-name-1", claims.NodeIntroductionIdentityClaims.NodeName) must.Eq(t, "custom-pool", claims.NodeIntroductionIdentityClaims.NodePool) - must.Eq(t, "euw", claims.NodeIntroductionIdentityClaims.NodeRegion) must.StrEqFold(t, "node-introduction:euw:custom-pool:node-name-1:default", claims.Subject) must.Eq(t, []string{IdentityDefaultAud}, claims.Audience) must.NotNil(t, claims.ID) @@ -767,3 +766,17 @@ func TestGenerateNodeIntroductionIdentityClaims(t *testing.T) { must.NotNil(t, claims.NotBefore) must.NotNil(t, claims.Expiry) } + +func TestNodeIntroductionIdentityClaims_LoggingPairs(t *testing.T) { + ci.Parallel(t) + + claims := &NodeIntroductionIdentityClaims{ + NodeName: "node-name-1", + NodePool: "custom-pool", + } + + must.SliceContainsAll(t, []any{ + "claim_node_name", "node-name-1", + "claim_node_pool", "custom-pool", + }, claims.LoggingPairs()) +}