diff --git a/.semgrep/rpc_endpoint.yml b/.semgrep/rpc_endpoint.yml index 2277a6b19..9f22f67a2 100644 --- a/.semgrep/rpc_endpoint.yml +++ b/.semgrep/rpc_endpoint.yml @@ -30,26 +30,11 @@ rules: # Pattern used by endpoints called exclusively between agents # (server -> server or client -> server) - pattern-not-inside: | + ... := validateTLSCertificateLevel(...) + ... if done, err := $A.$B.forward($METHOD, ...); done { return err } - ... - ... := validateLocalClientTLSCertificate(...) - ... - - pattern-not-inside: | - if done, err := $A.$B.forward($METHOD, ...); done { - return err - } - ... - ... := validateLocalServerTLSCertificate(...) - ... - - pattern-not-inside: | - if done, err := $A.$B.forward($METHOD, ...); done { - return err - } - ... - ... := validateTLSCertificate(...) - ... # Pattern used by some Node endpoints. - pattern-not-inside: | if done, err := $A.$B.forward($METHOD, ...); done { diff --git a/nomad/alloc_endpoint.go b/nomad/alloc_endpoint.go index 6c8231c6b..92abee62f 100644 --- a/nomad/alloc_endpoint.go +++ b/nomad/alloc_endpoint.go @@ -222,16 +222,18 @@ func (a *Alloc) GetAlloc(args *structs.AllocSpecificRequest, // GetAllocs is used to lookup a set of allocations func (a *Alloc) GetAllocs(args *structs.AllocsGetRequest, reply *structs.AllocsGetResponse) error { + + // Ensure the connection was initiated by a client if TLS is used. + err := validateTLSCertificateLevel(a.srv, a.ctx, tlsCertificateLevelClient) + if err != nil { + return err + } + if done, err := a.srv.forward("Alloc.GetAllocs", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "alloc", "get_allocs"}, time.Now()) - // Ensure the connection was initiated by a client if TLS is used. - if err := validateLocalClientTLSCertificate(a.srv, a.ctx); err != nil { - return fmt.Errorf("invalid client connection in region %s: %v", a.srv.Region(), err) - } - allocs := make([]*structs.Allocation, len(args.AllocIDs)) // Setup the blocking query. We wait for at least one of the requested diff --git a/nomad/deployment_endpoint.go b/nomad/deployment_endpoint.go index 0bc073768..2c18de98d 100644 --- a/nomad/deployment_endpoint.go +++ b/nomad/deployment_endpoint.go @@ -504,16 +504,18 @@ func (d *Deployment) Allocations(args *structs.DeploymentSpecificRequest, reply // Reap is used to cleanup terminal deployments func (d *Deployment) Reap(args *structs.DeploymentDeleteRequest, reply *structs.GenericResponse) error { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(d.srv, d.ctx, tlsCertificateLevelServer) + if err != nil { + return err + } + if done, err := d.srv.forward("Deployment.Reap", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "deployment", "reap"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(d.srv, d.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", d.srv.Region(), err) - } - // Update via Raft _, index, err := d.srv.raftApply(structs.DeploymentDeleteRequestType, args) if err != nil { diff --git a/nomad/eval_endpoint.go b/nomad/eval_endpoint.go index 18b83c45d..8a48e27c1 100644 --- a/nomad/eval_endpoint.go +++ b/nomad/eval_endpoint.go @@ -85,16 +85,18 @@ func (e *Eval) GetEval(args *structs.EvalSpecificRequest, // Dequeue is used to dequeue a pending evaluation func (e *Eval) Dequeue(args *structs.EvalDequeueRequest, reply *structs.EvalDequeueResponse) error { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { + return err + } + if done, err := e.srv.forward("Eval.Dequeue", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "eval", "dequeue"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) - } - // Ensure there is at least one scheduler if len(args.Schedulers) == 0 { return fmt.Errorf("dequeue requires at least one scheduler type") @@ -175,16 +177,18 @@ func (e *Eval) getWaitIndex(namespace, job string, evalModifyIndex uint64) (uint // Ack is used to acknowledge completion of a dequeued evaluation func (e *Eval) Ack(args *structs.EvalAckRequest, reply *structs.GenericResponse) error { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { + return err + } + if done, err := e.srv.forward("Eval.Ack", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "eval", "ack"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) - } - // Ack the EvalID if err := e.srv.evalBroker.Ack(args.EvalID, args.Token); err != nil { return err @@ -195,16 +199,18 @@ func (e *Eval) Ack(args *structs.EvalAckRequest, // Nack is used to negative acknowledge completion of a dequeued evaluation. func (e *Eval) Nack(args *structs.EvalAckRequest, reply *structs.GenericResponse) error { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { + return err + } + if done, err := e.srv.forward("Eval.Nack", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "eval", "nack"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) - } - // Nack the EvalID if err := e.srv.evalBroker.Nack(args.EvalID, args.Token); err != nil { return err @@ -215,16 +221,18 @@ func (e *Eval) Nack(args *structs.EvalAckRequest, // Update is used to perform an update of an Eval if it is outstanding. func (e *Eval) Update(args *structs.EvalUpdateRequest, reply *structs.GenericResponse) error { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { + return err + } + if done, err := e.srv.forward("Eval.Update", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "eval", "update"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) - } - // Ensure there is only a single update with token if len(args.Evals) != 1 { return fmt.Errorf("only a single eval can be updated") @@ -250,16 +258,18 @@ func (e *Eval) Update(args *structs.EvalUpdateRequest, // Create is used to make a new evaluation func (e *Eval) Create(args *structs.EvalUpdateRequest, reply *structs.GenericResponse) error { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { + return err + } + if done, err := e.srv.forward("Eval.Create", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "eval", "create"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) - } - // Ensure there is only a single update with token if len(args.Evals) != 1 { return fmt.Errorf("only a single eval can be created") @@ -300,16 +310,17 @@ func (e *Eval) Create(args *structs.EvalUpdateRequest, // Reblock is used to reinsert an existing blocked evaluation into the blocked // evaluation tracker. func (e *Eval) Reblock(args *structs.EvalUpdateRequest, reply *structs.GenericResponse) error { + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { + return err + } + if done, err := e.srv.forward("Eval.Reblock", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "eval", "reblock"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) - } - // Ensure there is only a single update with token if len(args.Evals) != 1 { return fmt.Errorf("only a single eval can be reblocked") @@ -347,16 +358,18 @@ func (e *Eval) Reblock(args *structs.EvalUpdateRequest, reply *structs.GenericRe // Reap is used to cleanup dead evaluations and allocations func (e *Eval) Reap(args *structs.EvalDeleteRequest, reply *structs.GenericResponse) error { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { + return err + } + if done, err := e.srv.forward("Eval.Reap", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "eval", "reap"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) - } - // Update via Raft _, index, err := e.srv.raftApply(structs.EvalDeleteRequestType, args) if err != nil { diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index 8ed43b2dc..d5a1725b4 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -1098,16 +1098,17 @@ func (n *Node) GetClientAllocs(args *structs.NodeSpecificRequest, // UpdateAlloc is used to update the client status of an allocation func (n *Node) UpdateAlloc(args *structs.AllocUpdateRequest, reply *structs.GenericResponse) error { + // Ensure the connection was initiated by another client if TLS is used. + err := validateTLSCertificateLevel(n.srv, n.ctx, tlsCertificateLevelClient) + if err != nil { + return err + } + if done, err := n.srv.forward("Node.UpdateAlloc", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "client", "update_alloc"}, time.Now()) - // Ensure the connection was initiated by a client if TLS is used. - if err := validateLocalClientTLSCertificate(n.srv, n.ctx); err != nil { - return fmt.Errorf("invalid client connection in region %s: %v", n.srv.Region(), err) - } - // Ensure at least a single alloc if len(args.Alloc) == 0 { return fmt.Errorf("must update at least one allocation") @@ -1920,16 +1921,17 @@ func taskUsesConnect(task *structs.Task) bool { } func (n *Node) EmitEvents(args *structs.EmitNodeEventsRequest, reply *structs.EmitNodeEventsResponse) error { + // Ensure the connection was initiated by another client if TLS is used. + err := validateTLSCertificateLevel(n.srv, n.ctx, tlsCertificateLevelClient) + if err != nil { + return err + } + if done, err := n.srv.forward("Node.EmitEvents", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "client", "emit_events"}, time.Now()) - // Ensure the connection was initiated by a client if TLS is used. - if err := validateLocalClientTLSCertificate(n.srv, n.ctx); err != nil { - return fmt.Errorf("invalid client connection in region %s: %v", n.srv.Region(), err) - } - if len(args.NodeEvents) == 0 { return fmt.Errorf("no node events given") } diff --git a/nomad/plan_endpoint.go b/nomad/plan_endpoint.go index a6cd8dbef..4979270e4 100644 --- a/nomad/plan_endpoint.go +++ b/nomad/plan_endpoint.go @@ -21,16 +21,17 @@ type Plan struct { // Submit is used to submit a plan to the leader func (p *Plan) Submit(args *structs.PlanRequest, reply *structs.PlanResponse) error { + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(p.srv, p.ctx, tlsCertificateLevelServer) + if err != nil { + return err + } + if done, err := p.srv.forward("Plan.Submit", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "plan", "submit"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(p.srv, p.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", p.srv.Region(), err) - } - if args.Plan == nil { return fmt.Errorf("cannot submit nil plan") } diff --git a/nomad/rpc.go b/nomad/rpc.go index 37446d53e..96db49b64 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -127,16 +127,16 @@ func (ctx *RPCContext) ValidateCertificateForName(name string) error { if cert == nil { return errors.New("missing certificate information") } - for _, dnsName := range cert.DNSNames { - if dnsName == name { + + validNames := []string{cert.Subject.CommonName} + validNames = append(validNames, cert.DNSNames...) + for _, valid := range validNames { + if name == valid { return nil } } - if cert.Subject.CommonName == name { - return nil - } - return fmt.Errorf("certificate not valid for %q", name) + return fmt.Errorf("invalid certificate, %s not in %s", name, strings.Join(validNames, ",")) } // listen is used to listen for incoming RPC connections diff --git a/nomad/rpc_test.go b/nomad/rpc_test.go index 07f2d9492..bd738f279 100644 --- a/nomad/rpc_test.go +++ b/nomad/rpc_test.go @@ -1081,7 +1081,7 @@ func TestRPC_TLS_Enforcement_Raft(t *testing.T) { } t.Run("Raft RPC: verify_hostname=true", func(t *testing.T) { - err := tlsHelper.raftRPC(t, tlsHelper.mtlsServer, cfg) + err := tlsHelper.raftRPC(t, tlsHelper.mtlsServer1, cfg) // the expected error depends on location of failure. // We expect "bad certificate" if connection fails during handshake, @@ -1186,7 +1186,7 @@ func TestRPC_TLS_Enforcement_RPC(t *testing.T) { name: "local server/clients only rpc", cn: "server.global.nomad", rpcs: localClientsOnlyRPCs, - canRPC: false, + canRPC: true, }, // Local client. { @@ -1274,18 +1274,22 @@ func TestRPC_TLS_Enforcement_RPC(t *testing.T) { } for method, arg := range tc.rpcs { - t.Run(fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=true", method), func(t *testing.T) { - err := tlsHelper.nomadRPC(t, tlsHelper.mtlsServer, cfg, method, arg) + for _, srv := range []*Server{tlsHelper.mtlsServer1, tlsHelper.mtlsServer2} { + name := fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=true leader=%v", method, srv.IsLeader()) + t.Run(name, func(t *testing.T) { + err := tlsHelper.nomadRPC(t, srv, cfg, method, arg) - if tc.canRPC { - if err != nil { - require.NotContains(t, err, "certificate") + if tc.canRPC { + if err != nil { + require.NotContains(t, err, "certificate") + } + } else { + require.Error(t, err) + require.Contains(t, err.Error(), "certificate") } - } else { - require.Error(t, err) - require.Contains(t, err.Error(), "certificate") - } - }) + }) + } + t.Run(fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=false", method), func(t *testing.T) { err := tlsHelper.nomadRPC(t, tlsHelper.nonVerifyServer, cfg, method, arg) if err != nil { @@ -1301,8 +1305,10 @@ type tlsTestHelper struct { dir string nodeID int - mtlsServer *Server - mtlsServerCleanup func() + mtlsServer1 *Server + mtlsServer1Cleanup func() + mtlsServer2 *Server + mtlsServer2Cleanup func() nonVerifyServer *Server nonVerifyServerCleanup func() @@ -1329,7 +1335,8 @@ func newTLSTestHelper(t *testing.T) tlsTestHelper { // Generate servers and their certificate. h.serverCert = h.newCert(t, "server.global.nomad") - h.mtlsServer, h.mtlsServerCleanup = TestServer(t, func(c *Config) { + h.mtlsServer1, h.mtlsServer1Cleanup = TestServer(t, func(c *Config) { + c.BootstrapExpect = 2 c.TLSConfig = &config.TLSConfig{ EnableRPC: true, VerifyServerHostname: true, @@ -1338,6 +1345,19 @@ func newTLSTestHelper(t *testing.T) tlsTestHelper { KeyFile: h.serverCert + ".key", } }) + h.mtlsServer2, h.mtlsServer2Cleanup = TestServer(t, func(c *Config) { + c.BootstrapExpect = 2 + c.TLSConfig = &config.TLSConfig{ + EnableRPC: true, + VerifyServerHostname: true, + CAFile: filepath.Join(h.dir, "ca.pem"), + CertFile: h.serverCert + ".pem", + KeyFile: h.serverCert + ".key", + } + }) + TestJoin(t, h.mtlsServer1, h.mtlsServer2) + testutil.WaitForLeader(t, h.mtlsServer1.RPC) + testutil.WaitForLeader(t, h.mtlsServer2.RPC) h.nonVerifyServer, h.nonVerifyServerCleanup = TestServer(t, func(c *Config) { c.TLSConfig = &config.TLSConfig{ @@ -1353,7 +1373,8 @@ func newTLSTestHelper(t *testing.T) tlsTestHelper { } func (h tlsTestHelper) cleanup() { - h.mtlsServerCleanup() + h.mtlsServer1Cleanup() + h.mtlsServer2Cleanup() h.nonVerifyServerCleanup() os.RemoveAll(h.dir) } diff --git a/nomad/util.go b/nomad/util.go index daa6999f8..210a202d9 100644 --- a/nomad/util.go +++ b/nomad/util.go @@ -302,18 +302,56 @@ func getAlloc(state AllocGetter, allocID string) (*structs.Allocation, error) { return alloc, nil } +// tlsCertificateLevel represents a role level for mTLS certificates. +type tlsCertificateLevel int8 + +const ( + tlsCertificateLevelServer tlsCertificateLevel = iota + tlsCertificateLevelClient +) + +// validateTLSCertificateLevel checks if the provided RPC connection was +// initiated with a certificate that matches the given TLS role level. +// +// - tlsCertificateLevelServer requires a server certificate. +// - tlsCertificateLevelServer requires a client or server certificate. +func validateTLSCertificateLevel(srv *Server, ctx *RPCContext, lvl tlsCertificateLevel) error { + switch lvl { + case tlsCertificateLevelClient: + err := validateLocalClientTLSCertificate(srv, ctx) + if err != nil { + return validateLocalServerTLSCertificate(srv, ctx) + } + return nil + case tlsCertificateLevelServer: + return validateLocalServerTLSCertificate(srv, ctx) + } + + return fmt.Errorf("invalid TLS certificate level %v", lvl) +} + // validateLocalClientTLSCertificate checks if the provided RPC connection was // initiated by a client in the same region as the target server. func validateLocalClientTLSCertificate(srv *Server, ctx *RPCContext) error { expected := fmt.Sprintf("client.%s.nomad", srv.Region()) - return validateTLSCertificate(srv, ctx, expected) + + err := validateTLSCertificate(srv, ctx, expected) + if err != nil { + return fmt.Errorf("invalid client connection in region %s: %v", srv.Region(), err) + } + return nil } // validateLocalServerTLSCertificate checks if the provided RPC connection was // initiated by a server in the same region as the target server. func validateLocalServerTLSCertificate(srv *Server, ctx *RPCContext) error { expected := fmt.Sprintf("server.%s.nomad", srv.Region()) - return validateTLSCertificate(srv, ctx, expected) + + err := validateTLSCertificate(srv, ctx, expected) + if err != nil { + return fmt.Errorf("invalid server connection in region %s: %v", srv.Region(), err) + } + return nil } // validateTLSCertificate checks if the RPC connection mTLS certificates are