mirror of
https://github.com/kemko/nomad.git
synced 2026-01-01 16:05:42 +03:00
Verify TLS certificate on endpoints that are used between agents only (#11956)
This commit is contained in:
3
.changelog/11956.txt
Normal file
3
.changelog/11956.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
```release-note:security
|
||||||
|
server: validate mTLS certificate names on agent to agent endpoints
|
||||||
|
```
|
||||||
80
.semgrep/rpc_endpoint.yml
Normal file
80
.semgrep/rpc_endpoint.yml
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
rules:
|
||||||
|
# Check potentially unauthenticated RPC endpoints
|
||||||
|
- id: "rpc-potentially-unauthenticated"
|
||||||
|
patterns:
|
||||||
|
- pattern: |
|
||||||
|
if done, err := $A.$B.forward($METHOD, ...); done {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
- pattern-not-inside: |
|
||||||
|
if done, err := $A.$B.forward($METHOD, ...); done {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
...
|
||||||
|
... := $X.$Y.ResolveToken(...)
|
||||||
|
...
|
||||||
|
- pattern-not-inside: |
|
||||||
|
if done, err := $A.$B.forward($METHOD, ...); done {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
...
|
||||||
|
... := $U.requestACLToken(...)
|
||||||
|
...
|
||||||
|
- pattern-not-inside: |
|
||||||
|
if done, err := $A.$B.forward($METHOD, ...); done {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
...
|
||||||
|
... := $T.NamespaceValidator(...)
|
||||||
|
...
|
||||||
|
# Pattern used by endpoints called exclusively between agents
|
||||||
|
# (server -> server or client -> server)
|
||||||
|
- pattern-not-inside: |
|
||||||
|
if done, err := $A.$B.forward($METHOD, ...); done {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
...
|
||||||
|
... := validateLocalClientTLSCertificate(...)
|
||||||
|
...
|
||||||
|
- pattern-not-inside: |
|
||||||
|
if done, err := $A.$B.forward($METHOD, ...); done {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
...
|
||||||
|
... := validateLocalServerTLSCertificate(...)
|
||||||
|
...
|
||||||
|
- pattern-not-inside: |
|
||||||
|
if done, err := $A.$B.forward($METHOD, ...); done {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
...
|
||||||
|
... := validateTLSCertificate(...)
|
||||||
|
...
|
||||||
|
# Pattern used by some Node endpoints.
|
||||||
|
- pattern-not-inside: |
|
||||||
|
if done, err := $A.$B.forward($METHOD, ...); done {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
...
|
||||||
|
return $A.deregister(...)
|
||||||
|
...
|
||||||
|
- metavariable-pattern:
|
||||||
|
metavariable: $METHOD
|
||||||
|
patterns:
|
||||||
|
# Endpoints that are expected not to have authentication.
|
||||||
|
- pattern-not: '"ACL.Bootstrap"'
|
||||||
|
- pattern-not: '"ACL.ResolveToken"'
|
||||||
|
- pattern-not: '"ACL.UpsertOneTimeToken"'
|
||||||
|
- pattern-not: '"ACL.ExchangeOneTimeToken"'
|
||||||
|
- pattern-not: '"CSIPlugin.Get"'
|
||||||
|
- pattern-not: '"CSIPlugin.List"'
|
||||||
|
- pattern-not: '"Status.Leader"'
|
||||||
|
- pattern-not: '"Status.Peers"'
|
||||||
|
- pattern-not: '"Status.Version"'
|
||||||
|
message: "RPC method $METHOD appears to be unauthenticated"
|
||||||
|
languages:
|
||||||
|
- "go"
|
||||||
|
severity: "WARNING"
|
||||||
|
paths:
|
||||||
|
include:
|
||||||
|
- "*_endpoint.go"
|
||||||
@@ -20,6 +20,9 @@ import (
|
|||||||
type Alloc struct {
|
type Alloc struct {
|
||||||
srv *Server
|
srv *Server
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
|
|
||||||
|
// ctx provides context regarding the underlying connection
|
||||||
|
ctx *RPCContext
|
||||||
}
|
}
|
||||||
|
|
||||||
// List is used to list the allocations in the system
|
// List is used to list the allocations in the system
|
||||||
@@ -224,6 +227,11 @@ func (a *Alloc) GetAllocs(args *structs.AllocsGetRequest,
|
|||||||
}
|
}
|
||||||
defer metrics.MeasureSince([]string{"nomad", "alloc", "get_allocs"}, time.Now())
|
defer metrics.MeasureSince([]string{"nomad", "alloc", "get_allocs"}, time.Now())
|
||||||
|
|
||||||
|
// Ensure the connection was initiated by a client if TLS is used.
|
||||||
|
if err := validateLocalClientTLSCertificate(a.srv, a.ctx); err != nil {
|
||||||
|
return fmt.Errorf("invalid client connection in region %s: %v", a.srv.Region(), err)
|
||||||
|
}
|
||||||
|
|
||||||
allocs := make([]*structs.Allocation, len(args.AllocIDs))
|
allocs := make([]*structs.Allocation, len(args.AllocIDs))
|
||||||
|
|
||||||
// Setup the blocking query. We wait for at least one of the requested
|
// Setup the blocking query. We wait for at least one of the requested
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ import (
|
|||||||
type Deployment struct {
|
type Deployment struct {
|
||||||
srv *Server
|
srv *Server
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
|
|
||||||
|
// ctx provides context regarding the underlying connection
|
||||||
|
ctx *RPCContext
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDeployment is used to request information about a specific deployment
|
// GetDeployment is used to request information about a specific deployment
|
||||||
@@ -506,6 +509,11 @@ func (d *Deployment) Reap(args *structs.DeploymentDeleteRequest,
|
|||||||
}
|
}
|
||||||
defer metrics.MeasureSince([]string{"nomad", "deployment", "reap"}, time.Now())
|
defer metrics.MeasureSince([]string{"nomad", "deployment", "reap"}, time.Now())
|
||||||
|
|
||||||
|
// Ensure the connection was initiated by another server if TLS is used.
|
||||||
|
if err := validateLocalServerTLSCertificate(d.srv, d.ctx); err != nil {
|
||||||
|
return fmt.Errorf("invalid server connection in region %s: %v", d.srv.Region(), err)
|
||||||
|
}
|
||||||
|
|
||||||
// Update via Raft
|
// Update via Raft
|
||||||
_, index, err := d.srv.raftApply(structs.DeploymentDeleteRequestType, args)
|
_, index, err := d.srv.raftApply(structs.DeploymentDeleteRequestType, args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -24,6 +24,9 @@ const (
|
|||||||
type Eval struct {
|
type Eval struct {
|
||||||
srv *Server
|
srv *Server
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
|
|
||||||
|
// ctx provides context regarding the underlying connection
|
||||||
|
ctx *RPCContext
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetEval is used to request information about a specific evaluation
|
// GetEval is used to request information about a specific evaluation
|
||||||
@@ -87,6 +90,11 @@ func (e *Eval) Dequeue(args *structs.EvalDequeueRequest,
|
|||||||
}
|
}
|
||||||
defer metrics.MeasureSince([]string{"nomad", "eval", "dequeue"}, time.Now())
|
defer metrics.MeasureSince([]string{"nomad", "eval", "dequeue"}, time.Now())
|
||||||
|
|
||||||
|
// Ensure the connection was initiated by another server if TLS is used.
|
||||||
|
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
|
||||||
|
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure there is at least one scheduler
|
// Ensure there is at least one scheduler
|
||||||
if len(args.Schedulers) == 0 {
|
if len(args.Schedulers) == 0 {
|
||||||
return fmt.Errorf("dequeue requires at least one scheduler type")
|
return fmt.Errorf("dequeue requires at least one scheduler type")
|
||||||
@@ -172,6 +180,11 @@ func (e *Eval) Ack(args *structs.EvalAckRequest,
|
|||||||
}
|
}
|
||||||
defer metrics.MeasureSince([]string{"nomad", "eval", "ack"}, time.Now())
|
defer metrics.MeasureSince([]string{"nomad", "eval", "ack"}, time.Now())
|
||||||
|
|
||||||
|
// Ensure the connection was initiated by another server if TLS is used.
|
||||||
|
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
|
||||||
|
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
|
||||||
|
}
|
||||||
|
|
||||||
// Ack the EvalID
|
// Ack the EvalID
|
||||||
if err := e.srv.evalBroker.Ack(args.EvalID, args.Token); err != nil {
|
if err := e.srv.evalBroker.Ack(args.EvalID, args.Token); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -187,6 +200,11 @@ func (e *Eval) Nack(args *structs.EvalAckRequest,
|
|||||||
}
|
}
|
||||||
defer metrics.MeasureSince([]string{"nomad", "eval", "nack"}, time.Now())
|
defer metrics.MeasureSince([]string{"nomad", "eval", "nack"}, time.Now())
|
||||||
|
|
||||||
|
// Ensure the connection was initiated by another server if TLS is used.
|
||||||
|
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
|
||||||
|
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
|
||||||
|
}
|
||||||
|
|
||||||
// Nack the EvalID
|
// Nack the EvalID
|
||||||
if err := e.srv.evalBroker.Nack(args.EvalID, args.Token); err != nil {
|
if err := e.srv.evalBroker.Nack(args.EvalID, args.Token); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -202,6 +220,11 @@ func (e *Eval) Update(args *structs.EvalUpdateRequest,
|
|||||||
}
|
}
|
||||||
defer metrics.MeasureSince([]string{"nomad", "eval", "update"}, time.Now())
|
defer metrics.MeasureSince([]string{"nomad", "eval", "update"}, time.Now())
|
||||||
|
|
||||||
|
// Ensure the connection was initiated by another server if TLS is used.
|
||||||
|
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
|
||||||
|
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure there is only a single update with token
|
// Ensure there is only a single update with token
|
||||||
if len(args.Evals) != 1 {
|
if len(args.Evals) != 1 {
|
||||||
return fmt.Errorf("only a single eval can be updated")
|
return fmt.Errorf("only a single eval can be updated")
|
||||||
@@ -232,6 +255,11 @@ func (e *Eval) Create(args *structs.EvalUpdateRequest,
|
|||||||
}
|
}
|
||||||
defer metrics.MeasureSince([]string{"nomad", "eval", "create"}, time.Now())
|
defer metrics.MeasureSince([]string{"nomad", "eval", "create"}, time.Now())
|
||||||
|
|
||||||
|
// Ensure the connection was initiated by another server if TLS is used.
|
||||||
|
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
|
||||||
|
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure there is only a single update with token
|
// Ensure there is only a single update with token
|
||||||
if len(args.Evals) != 1 {
|
if len(args.Evals) != 1 {
|
||||||
return fmt.Errorf("only a single eval can be created")
|
return fmt.Errorf("only a single eval can be created")
|
||||||
@@ -277,6 +305,11 @@ func (e *Eval) Reblock(args *structs.EvalUpdateRequest, reply *structs.GenericRe
|
|||||||
}
|
}
|
||||||
defer metrics.MeasureSince([]string{"nomad", "eval", "reblock"}, time.Now())
|
defer metrics.MeasureSince([]string{"nomad", "eval", "reblock"}, time.Now())
|
||||||
|
|
||||||
|
// Ensure the connection was initiated by another server if TLS is used.
|
||||||
|
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
|
||||||
|
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure there is only a single update with token
|
// Ensure there is only a single update with token
|
||||||
if len(args.Evals) != 1 {
|
if len(args.Evals) != 1 {
|
||||||
return fmt.Errorf("only a single eval can be reblocked")
|
return fmt.Errorf("only a single eval can be reblocked")
|
||||||
@@ -319,6 +352,11 @@ func (e *Eval) Reap(args *structs.EvalDeleteRequest,
|
|||||||
}
|
}
|
||||||
defer metrics.MeasureSince([]string{"nomad", "eval", "reap"}, time.Now())
|
defer metrics.MeasureSince([]string{"nomad", "eval", "reap"}, time.Now())
|
||||||
|
|
||||||
|
// Ensure the connection was initiated by another server if TLS is used.
|
||||||
|
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
|
||||||
|
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
|
||||||
|
}
|
||||||
|
|
||||||
// Update via Raft
|
// Update via Raft
|
||||||
_, index, err := e.srv.raftApply(structs.EvalDeleteRequestType, args)
|
_, index, err := e.srv.raftApply(structs.EvalDeleteRequestType, args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -114,8 +114,8 @@ func (j *Job) Register(args *structs.JobRegisterRequest, reply *structs.JobRegis
|
|||||||
reply.Warnings = structs.MergeMultierrorWarnings(warnings...)
|
reply.Warnings = structs.MergeMultierrorWarnings(warnings...)
|
||||||
|
|
||||||
// Check job submission permissions
|
// Check job submission permissions
|
||||||
var aclObj *acl.ACL
|
aclObj, err := j.srv.ResolveToken(args.AuthToken)
|
||||||
if aclObj, err = j.srv.ResolveToken(args.AuthToken); err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if aclObj != nil {
|
} else if aclObj != nil {
|
||||||
if !aclObj.AllowNsOp(args.RequestNamespace(), acl.NamespaceCapabilitySubmitJob) {
|
if !aclObj.AllowNsOp(args.RequestNamespace(), acl.NamespaceCapabilitySubmitJob) {
|
||||||
@@ -1879,9 +1879,8 @@ func (j *Job) Dispatch(args *structs.JobDispatchRequest, reply *structs.JobDispa
|
|||||||
defer metrics.MeasureSince([]string{"nomad", "job", "dispatch"}, time.Now())
|
defer metrics.MeasureSince([]string{"nomad", "job", "dispatch"}, time.Now())
|
||||||
|
|
||||||
// Check for submit-job permissions
|
// Check for submit-job permissions
|
||||||
var aclObj *acl.ACL
|
aclObj, err := j.srv.ResolveToken(args.AuthToken)
|
||||||
var err error
|
if err != nil {
|
||||||
if aclObj, err = j.srv.ResolveToken(args.AuthToken); err != nil {
|
|
||||||
return err
|
return err
|
||||||
} else if aclObj != nil && !aclObj.AllowNsOp(args.RequestNamespace(), acl.NamespaceCapabilityDispatchJob) {
|
} else if aclObj != nil && !aclObj.AllowNsOp(args.RequestNamespace(), acl.NamespaceCapabilityDispatchJob) {
|
||||||
return structs.ErrPermissionDenied
|
return structs.ErrPermissionDenied
|
||||||
|
|||||||
@@ -1103,6 +1103,11 @@ func (n *Node) UpdateAlloc(args *structs.AllocUpdateRequest, reply *structs.Gene
|
|||||||
}
|
}
|
||||||
defer metrics.MeasureSince([]string{"nomad", "client", "update_alloc"}, time.Now())
|
defer metrics.MeasureSince([]string{"nomad", "client", "update_alloc"}, time.Now())
|
||||||
|
|
||||||
|
// Ensure the connection was initiated by a client if TLS is used.
|
||||||
|
if err := validateLocalClientTLSCertificate(n.srv, n.ctx); err != nil {
|
||||||
|
return fmt.Errorf("invalid client connection in region %s: %v", n.srv.Region(), err)
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure at least a single alloc
|
// Ensure at least a single alloc
|
||||||
if len(args.Alloc) == 0 {
|
if len(args.Alloc) == 0 {
|
||||||
return fmt.Errorf("must update at least one allocation")
|
return fmt.Errorf("must update at least one allocation")
|
||||||
@@ -1920,6 +1925,11 @@ func (n *Node) EmitEvents(args *structs.EmitNodeEventsRequest, reply *structs.Em
|
|||||||
}
|
}
|
||||||
defer metrics.MeasureSince([]string{"nomad", "client", "emit_events"}, time.Now())
|
defer metrics.MeasureSince([]string{"nomad", "client", "emit_events"}, time.Now())
|
||||||
|
|
||||||
|
// Ensure the connection was initiated by a client if TLS is used.
|
||||||
|
if err := validateLocalClientTLSCertificate(n.srv, n.ctx); err != nil {
|
||||||
|
return fmt.Errorf("invalid client connection in region %s: %v", n.srv.Region(), err)
|
||||||
|
}
|
||||||
|
|
||||||
if len(args.NodeEvents) == 0 {
|
if len(args.NodeEvents) == 0 {
|
||||||
return fmt.Errorf("no node events given")
|
return fmt.Errorf("no node events given")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,9 @@ import (
|
|||||||
type Plan struct {
|
type Plan struct {
|
||||||
srv *Server
|
srv *Server
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
|
|
||||||
|
// ctx provides context regarding the underlying connection
|
||||||
|
ctx *RPCContext
|
||||||
}
|
}
|
||||||
|
|
||||||
// Submit is used to submit a plan to the leader
|
// Submit is used to submit a plan to the leader
|
||||||
@@ -23,6 +26,11 @@ func (p *Plan) Submit(args *structs.PlanRequest, reply *structs.PlanResponse) er
|
|||||||
}
|
}
|
||||||
defer metrics.MeasureSince([]string{"nomad", "plan", "submit"}, time.Now())
|
defer metrics.MeasureSince([]string{"nomad", "plan", "submit"}, time.Now())
|
||||||
|
|
||||||
|
// Ensure the connection was initiated by another server if TLS is used.
|
||||||
|
if err := validateLocalServerTLSCertificate(p.srv, p.ctx); err != nil {
|
||||||
|
return fmt.Errorf("invalid server connection in region %s: %v", p.srv.Region(), err)
|
||||||
|
}
|
||||||
|
|
||||||
if args.Plan == nil {
|
if args.Plan == nil {
|
||||||
return fmt.Errorf("cannot submit nil plan")
|
return fmt.Errorf("cannot submit nil plan")
|
||||||
}
|
}
|
||||||
|
|||||||
62
nomad/rpc.go
62
nomad/rpc.go
@@ -107,6 +107,38 @@ type RPCContext struct {
|
|||||||
NodeID string
|
NodeID string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Certificate returns the first certificate available in the chain.
|
||||||
|
func (ctx *RPCContext) Certificate() *x509.Certificate {
|
||||||
|
if ctx == nil || len(ctx.VerifiedChains) == 0 || len(ctx.VerifiedChains[0]) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx.VerifiedChains[0][0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateCertificateForName returns true if the RPC context certificate is valid
|
||||||
|
// for the given domain name.
|
||||||
|
func (ctx *RPCContext) ValidateCertificateForName(name string) error {
|
||||||
|
if ctx == nil || !ctx.TLS {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cert := ctx.Certificate()
|
||||||
|
if cert == nil {
|
||||||
|
return errors.New("missing certificate information")
|
||||||
|
}
|
||||||
|
for _, dnsName := range cert.DNSNames {
|
||||||
|
if dnsName == name {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cert.Subject.CommonName == name {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("certificate not valid for %q", name)
|
||||||
|
}
|
||||||
|
|
||||||
// listen is used to listen for incoming RPC connections
|
// listen is used to listen for incoming RPC connections
|
||||||
func (r *rpcHandler) listen(ctx context.Context) {
|
func (r *rpcHandler) listen(ctx context.Context) {
|
||||||
defer close(r.listenerCh)
|
defer close(r.listenerCh)
|
||||||
@@ -838,30 +870,18 @@ func (r *rpcHandler) validateRaftTLS(rpcCtx *RPCContext) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// defensive conditions: these should have already been enforced by handleConn
|
|
||||||
if rpcCtx == nil || !rpcCtx.TLS {
|
|
||||||
return errors.New("non-TLS connection attempted")
|
|
||||||
}
|
|
||||||
if len(rpcCtx.VerifiedChains) == 0 || len(rpcCtx.VerifiedChains[0]) == 0 {
|
|
||||||
// this should never happen, as rpcNameAndRegionValidate should have enforced it
|
|
||||||
return errors.New("missing cert info")
|
|
||||||
}
|
|
||||||
|
|
||||||
// check that `server.<region>.nomad` is present in cert
|
// check that `server.<region>.nomad` is present in cert
|
||||||
expected := "server." + r.Region() + ".nomad"
|
expected := "server." + r.Region() + ".nomad"
|
||||||
|
err := rpcCtx.ValidateCertificateForName(expected)
|
||||||
cert := rpcCtx.VerifiedChains[0][0]
|
if err != nil {
|
||||||
for _, dnsName := range cert.DNSNames {
|
cert := rpcCtx.Certificate()
|
||||||
if dnsName == expected {
|
if cert != nil {
|
||||||
// Certificate is valid for the expected name
|
err = fmt.Errorf("request certificate is only valid for %s: %v", cert.DNSNames, err)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if cert.Subject.CommonName == expected {
|
return fmt.Errorf("unauthorized raft connection from %s: %v", rpcCtx.Conn.RemoteAddr(), err)
|
||||||
// Certificate is valid for the expected name
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r.logger.Warn("unauthorized raft connection", "remote_addr", rpcCtx.Conn.RemoteAddr(), "required_hostname", expected, "found", cert.DNSNames)
|
// Certificate is valid for the expected name
|
||||||
return fmt.Errorf("certificate is invalid for expected role or region: %q", expected)
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1018,7 +1018,7 @@ func TestRPC_Limits_Streaming(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRPC_TLS_Enforcement(t *testing.T) {
|
func TestRPC_TLS_Enforcement_Raft(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -1026,211 +1026,409 @@ func TestRPC_TLS_Enforcement(t *testing.T) {
|
|||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
dir := tmpDir(t)
|
tlsHelper := newTLSTestHelper(t)
|
||||||
defer os.RemoveAll(dir)
|
defer tlsHelper.cleanup()
|
||||||
|
|
||||||
caPEM, pk, err := tlsutil.GenerateCA(tlsutil.CAOpts{Days: 5, Domain: "nomad"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = ioutil.WriteFile(filepath.Join(dir, "ca.pem"), []byte(caPEM), 0600)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
nodeID := 1
|
|
||||||
newCert := func(t *testing.T, name string) string {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
node := fmt.Sprintf("node%d", nodeID)
|
|
||||||
nodeID++
|
|
||||||
signer, err := tlsutil.ParseSigner(pk)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
pem, key, err := tlsutil.GenerateCert(tlsutil.CertOpts{
|
|
||||||
Signer: signer,
|
|
||||||
CA: caPEM,
|
|
||||||
Name: name,
|
|
||||||
Days: 5,
|
|
||||||
DNSNames: []string{node + "." + name, name, "localhost"},
|
|
||||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".pem"), []byte(pem), 0600)
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".key"), []byte(key), 0600)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
return filepath.Join(dir, node+"-"+name)
|
|
||||||
}
|
|
||||||
|
|
||||||
connect := func(t *testing.T, s *Server, c *config.TLSConfig) net.Conn {
|
|
||||||
conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// configure TLS
|
|
||||||
_, err = conn.Write([]byte{byte(pool.RpcTLS)})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Client TLS verification isn't necessary for
|
|
||||||
// our assertions
|
|
||||||
tlsConf, err := tlsutil.NewTLSConfiguration(c, true, true)
|
|
||||||
require.NoError(t, err)
|
|
||||||
outTLSConf, err := tlsConf.OutgoingTLSConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
outTLSConf.InsecureSkipVerify = true
|
|
||||||
|
|
||||||
tlsConn := tls.Client(conn, outTLSConf)
|
|
||||||
require.NoError(t, tlsConn.Handshake())
|
|
||||||
|
|
||||||
return tlsConn
|
|
||||||
}
|
|
||||||
|
|
||||||
nomadRPC := func(t *testing.T, s *Server, c *config.TLSConfig) error {
|
|
||||||
conn := connect(t, s, c)
|
|
||||||
defer conn.Close()
|
|
||||||
_, err := conn.Write([]byte{byte(pool.RpcNomad)})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
codec := pool.NewClientCodec(conn)
|
|
||||||
|
|
||||||
arg := struct{}{}
|
|
||||||
var out struct{}
|
|
||||||
return msgpackrpc.CallWithCodec(codec, "Status.Ping", arg, &out)
|
|
||||||
}
|
|
||||||
|
|
||||||
raftRPC := func(t *testing.T, s *Server, c *config.TLSConfig) error {
|
|
||||||
conn := connect(t, s, c)
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
_, err := conn.Write([]byte{byte(pool.RpcRaft)})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = doRaftRPC(conn, s.config.NodeName)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// generate server cert
|
|
||||||
serverCert := newCert(t, "server.global.nomad")
|
|
||||||
|
|
||||||
mtlsS, cleanup := TestServer(t, func(c *Config) {
|
|
||||||
c.TLSConfig = &config.TLSConfig{
|
|
||||||
EnableRPC: true,
|
|
||||||
VerifyServerHostname: true,
|
|
||||||
CAFile: filepath.Join(dir, "ca.pem"),
|
|
||||||
CertFile: serverCert + ".pem",
|
|
||||||
KeyFile: serverCert + ".key",
|
|
||||||
}
|
|
||||||
})
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
nonVerifyS, cleanup := TestServer(t, func(c *Config) {
|
|
||||||
c.TLSConfig = &config.TLSConfig{
|
|
||||||
EnableRPC: true,
|
|
||||||
VerifyServerHostname: false,
|
|
||||||
CAFile: filepath.Join(dir, "ca.pem"),
|
|
||||||
CertFile: serverCert + ".pem",
|
|
||||||
KeyFile: serverCert + ".key",
|
|
||||||
}
|
|
||||||
})
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
// When VerifyServerHostname is enabled:
|
// When VerifyServerHostname is enabled:
|
||||||
// Only all servers and local clients can make RPC requests
|
|
||||||
// Only local servers can connect to the Raft layer
|
// Only local servers can connect to the Raft layer
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
name string
|
name string
|
||||||
cn string
|
cn string
|
||||||
canRPC bool
|
|
||||||
canRaft bool
|
canRaft bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "local server",
|
name: "local server",
|
||||||
cn: "server.global.nomad",
|
cn: "server.global.nomad",
|
||||||
canRPC: true,
|
|
||||||
canRaft: true,
|
canRaft: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "local client",
|
name: "local client",
|
||||||
cn: "client.global.nomad",
|
cn: "client.global.nomad",
|
||||||
canRPC: true,
|
|
||||||
canRaft: false,
|
canRaft: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "other region server",
|
name: "other region server",
|
||||||
cn: "server.other.nomad",
|
cn: "server.other.nomad",
|
||||||
canRPC: true,
|
|
||||||
canRaft: false,
|
canRaft: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "other client server",
|
name: "other region client",
|
||||||
cn: "client.other.nomad",
|
cn: "client.other.nomad",
|
||||||
canRPC: false,
|
|
||||||
canRaft: false,
|
canRaft: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "irrelevant cert",
|
name: "irrelevant cert",
|
||||||
cn: "nomad.example.com",
|
cn: "nomad.example.com",
|
||||||
canRPC: false,
|
|
||||||
canRaft: false,
|
canRaft: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "globs",
|
name: "globs",
|
||||||
cn: "*.global.nomad",
|
cn: "*.global.nomad",
|
||||||
canRPC: false,
|
|
||||||
canRaft: false,
|
canRaft: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
certPath := newCert(t, tc.cn)
|
certPath := tlsHelper.newCert(t, tc.cn)
|
||||||
|
|
||||||
cfg := &config.TLSConfig{
|
cfg := &config.TLSConfig{
|
||||||
EnableRPC: true,
|
EnableRPC: true,
|
||||||
VerifyServerHostname: true,
|
VerifyServerHostname: true,
|
||||||
CAFile: filepath.Join(dir, "ca.pem"),
|
CAFile: filepath.Join(tlsHelper.dir, "ca.pem"),
|
||||||
CertFile: certPath + ".pem",
|
CertFile: certPath + ".pem",
|
||||||
KeyFile: certPath + ".key",
|
KeyFile: certPath + ".key",
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("nomad RPC: verify_hostname=true", func(t *testing.T) {
|
|
||||||
err := nomadRPC(t, mtlsS, cfg)
|
|
||||||
|
|
||||||
if tc.canRPC {
|
|
||||||
require.NoError(t, err)
|
|
||||||
} else {
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Contains(t, err.Error(), "bad certificate")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
t.Run("nomad RPC: verify_hostname=false", func(t *testing.T) {
|
|
||||||
err := nomadRPC(t, nonVerifyS, cfg)
|
|
||||||
require.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Raft RPC: verify_hostname=true", func(t *testing.T) {
|
t.Run("Raft RPC: verify_hostname=true", func(t *testing.T) {
|
||||||
err := raftRPC(t, mtlsS, cfg)
|
err := tlsHelper.raftRPC(t, tlsHelper.mtlsServer, cfg)
|
||||||
|
|
||||||
// the expected error depends on location of failure.
|
// the expected error depends on location of failure.
|
||||||
// We expect "bad certificate" if connection fails during handshake,
|
// We expect "bad certificate" if connection fails during handshake,
|
||||||
// or EOF when connection is closed after RaftRPC byte.
|
// or EOF when connection is closed after RaftRPC byte.
|
||||||
if tc.canRaft {
|
if tc.canRaft {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
} else if !tc.canRPC {
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Contains(t, err.Error(), "bad certificate")
|
|
||||||
} else {
|
} else {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "EOF")
|
require.Regexp(t, "(bad certificate|EOF)", err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("Raft RPC: verify_hostname=false", func(t *testing.T) {
|
t.Run("Raft RPC: verify_hostname=false", func(t *testing.T) {
|
||||||
err := raftRPC(t, nonVerifyS, cfg)
|
err := tlsHelper.raftRPC(t, tlsHelper.nonVerifyServer, cfg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRPC_TLS_Enforcement_RPC(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
//TODO Avoid panics from logging during shutdown
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}()
|
||||||
|
|
||||||
|
tlsHelper := newTLSTestHelper(t)
|
||||||
|
defer tlsHelper.cleanup()
|
||||||
|
|
||||||
|
standardRPCs := map[string]interface{}{
|
||||||
|
"Status.Ping": struct{}{},
|
||||||
|
}
|
||||||
|
|
||||||
|
localServersOnlyRPCs := map[string]interface{}{
|
||||||
|
"Eval.Update": &structs.EvalUpdateRequest{
|
||||||
|
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||||
|
},
|
||||||
|
"Eval.Ack": &structs.EvalAckRequest{
|
||||||
|
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||||
|
},
|
||||||
|
"Eval.Nack": &structs.EvalAckRequest{
|
||||||
|
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||||
|
},
|
||||||
|
"Eval.Dequeue": &structs.EvalDequeueRequest{
|
||||||
|
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||||
|
},
|
||||||
|
"Eval.Create": &structs.EvalUpdateRequest{
|
||||||
|
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||||
|
},
|
||||||
|
"Eval.Reblock": &structs.EvalUpdateRequest{
|
||||||
|
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||||
|
},
|
||||||
|
"Eval.Reap": &structs.EvalDeleteRequest{
|
||||||
|
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||||
|
},
|
||||||
|
"Plan.Submit": &structs.PlanRequest{
|
||||||
|
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||||
|
},
|
||||||
|
"Deployment.Reap": &structs.DeploymentDeleteRequest{
|
||||||
|
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
localClientsOnlyRPCs := map[string]interface{}{
|
||||||
|
"Alloc.GetAllocs": &structs.AllocsGetRequest{
|
||||||
|
QueryOptions: structs.QueryOptions{Region: "global"},
|
||||||
|
},
|
||||||
|
"Node.EmitEvents": &structs.EmitNodeEventsRequest{
|
||||||
|
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||||
|
},
|
||||||
|
"Node.UpdateAlloc": &structs.AllocUpdateRequest{
|
||||||
|
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// When VerifyServerHostname is enabled:
|
||||||
|
// All servers can make RPC requests
|
||||||
|
// Only local clients can make RPC requests
|
||||||
|
// Some endpoints can only be called server -> server
|
||||||
|
// Some endpoints can only be called client -> server
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
cn string
|
||||||
|
rpcs map[string]interface{}
|
||||||
|
canRPC bool
|
||||||
|
}{
|
||||||
|
// Local server.
|
||||||
|
{
|
||||||
|
name: "local server/standard rpc",
|
||||||
|
cn: "server.global.nomad",
|
||||||
|
rpcs: standardRPCs,
|
||||||
|
canRPC: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "local server/servers only rpc",
|
||||||
|
cn: "server.global.nomad",
|
||||||
|
rpcs: localServersOnlyRPCs,
|
||||||
|
canRPC: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "local server/clients only rpc",
|
||||||
|
cn: "server.global.nomad",
|
||||||
|
rpcs: localClientsOnlyRPCs,
|
||||||
|
canRPC: false,
|
||||||
|
},
|
||||||
|
// Local client.
|
||||||
|
{
|
||||||
|
name: "local client/standard rpc",
|
||||||
|
cn: "client.global.nomad",
|
||||||
|
rpcs: standardRPCs,
|
||||||
|
canRPC: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "local client/servers only rpc",
|
||||||
|
cn: "client.global.nomad",
|
||||||
|
rpcs: localServersOnlyRPCs,
|
||||||
|
canRPC: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "local client/clients only rpc",
|
||||||
|
cn: "client.global.nomad",
|
||||||
|
rpcs: localClientsOnlyRPCs,
|
||||||
|
canRPC: true,
|
||||||
|
},
|
||||||
|
// Other region server.
|
||||||
|
{
|
||||||
|
name: "other region server/standard rpc",
|
||||||
|
cn: "server.other.nomad",
|
||||||
|
rpcs: standardRPCs,
|
||||||
|
canRPC: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "other region server/servers only rpc",
|
||||||
|
cn: "server.other.nomad",
|
||||||
|
rpcs: localServersOnlyRPCs,
|
||||||
|
canRPC: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "other region server/clients only rpc",
|
||||||
|
cn: "server.other.nomad",
|
||||||
|
rpcs: localClientsOnlyRPCs,
|
||||||
|
canRPC: false,
|
||||||
|
},
|
||||||
|
// Other region client.
|
||||||
|
{
|
||||||
|
name: "other region client/standard rpc",
|
||||||
|
cn: "client.other.nomad",
|
||||||
|
rpcs: standardRPCs,
|
||||||
|
canRPC: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "other region client/servers only rpc",
|
||||||
|
cn: "client.other.nomad",
|
||||||
|
rpcs: localServersOnlyRPCs,
|
||||||
|
canRPC: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "other region client/clients only rpc",
|
||||||
|
cn: "client.other.nomad",
|
||||||
|
rpcs: localClientsOnlyRPCs,
|
||||||
|
canRPC: false,
|
||||||
|
},
|
||||||
|
// Wrong certs.
|
||||||
|
{
|
||||||
|
name: "irrelevant cert",
|
||||||
|
cn: "nomad.example.com",
|
||||||
|
rpcs: standardRPCs,
|
||||||
|
canRPC: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "globs",
|
||||||
|
cn: "*.global.nomad",
|
||||||
|
rpcs: standardRPCs,
|
||||||
|
canRPC: false,
|
||||||
|
},
|
||||||
|
{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
certPath := tlsHelper.newCert(t, tc.cn)
|
||||||
|
|
||||||
|
cfg := &config.TLSConfig{
|
||||||
|
EnableRPC: true,
|
||||||
|
VerifyServerHostname: true,
|
||||||
|
CAFile: filepath.Join(tlsHelper.dir, "ca.pem"),
|
||||||
|
CertFile: certPath + ".pem",
|
||||||
|
KeyFile: certPath + ".key",
|
||||||
|
}
|
||||||
|
|
||||||
|
for method, arg := range tc.rpcs {
|
||||||
|
t.Run(fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=true", method), func(t *testing.T) {
|
||||||
|
err := tlsHelper.nomadRPC(t, tlsHelper.mtlsServer, cfg, method, arg)
|
||||||
|
|
||||||
|
if tc.canRPC {
|
||||||
|
if err != nil {
|
||||||
|
require.NotContains(t, err, "certificate")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "certificate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run(fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=false", method), func(t *testing.T) {
|
||||||
|
err := tlsHelper.nomadRPC(t, tlsHelper.nonVerifyServer, cfg, method, arg)
|
||||||
|
if err != nil {
|
||||||
|
require.NotContains(t, err, "certificate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type tlsTestHelper struct {
|
||||||
|
dir string
|
||||||
|
nodeID int
|
||||||
|
|
||||||
|
mtlsServer *Server
|
||||||
|
mtlsServerCleanup func()
|
||||||
|
nonVerifyServer *Server
|
||||||
|
nonVerifyServerCleanup func()
|
||||||
|
|
||||||
|
caPEM string
|
||||||
|
pk string
|
||||||
|
serverCert string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTLSTestHelper(t *testing.T) tlsTestHelper {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
h := tlsTestHelper{
|
||||||
|
dir: tmpDir(t),
|
||||||
|
nodeID: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate CA certificate and write it to disk.
|
||||||
|
h.caPEM, h.pk, err = tlsutil.GenerateCA(tlsutil.CAOpts{Days: 5, Domain: "nomad"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = ioutil.WriteFile(filepath.Join(h.dir, "ca.pem"), []byte(h.caPEM), 0600)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Generate servers and their certificate.
|
||||||
|
h.serverCert = h.newCert(t, "server.global.nomad")
|
||||||
|
|
||||||
|
h.mtlsServer, h.mtlsServerCleanup = TestServer(t, func(c *Config) {
|
||||||
|
c.TLSConfig = &config.TLSConfig{
|
||||||
|
EnableRPC: true,
|
||||||
|
VerifyServerHostname: true,
|
||||||
|
CAFile: filepath.Join(h.dir, "ca.pem"),
|
||||||
|
CertFile: h.serverCert + ".pem",
|
||||||
|
KeyFile: h.serverCert + ".key",
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
h.nonVerifyServer, h.nonVerifyServerCleanup = TestServer(t, func(c *Config) {
|
||||||
|
c.TLSConfig = &config.TLSConfig{
|
||||||
|
EnableRPC: true,
|
||||||
|
VerifyServerHostname: false,
|
||||||
|
CAFile: filepath.Join(h.dir, "ca.pem"),
|
||||||
|
CertFile: h.serverCert + ".pem",
|
||||||
|
KeyFile: h.serverCert + ".key",
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h tlsTestHelper) cleanup() {
|
||||||
|
h.mtlsServerCleanup()
|
||||||
|
h.nonVerifyServerCleanup()
|
||||||
|
os.RemoveAll(h.dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h tlsTestHelper) newCert(t *testing.T, name string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
node := fmt.Sprintf("node%d", h.nodeID)
|
||||||
|
h.nodeID++
|
||||||
|
signer, err := tlsutil.ParseSigner(h.pk)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pem, key, err := tlsutil.GenerateCert(tlsutil.CertOpts{
|
||||||
|
Signer: signer,
|
||||||
|
CA: h.caPEM,
|
||||||
|
Name: name,
|
||||||
|
Days: 5,
|
||||||
|
DNSNames: []string{node + "." + name, name, "localhost"},
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = ioutil.WriteFile(filepath.Join(h.dir, node+"-"+name+".pem"), []byte(pem), 0600)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = ioutil.WriteFile(filepath.Join(h.dir, node+"-"+name+".key"), []byte(key), 0600)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return filepath.Join(h.dir, node+"-"+name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h tlsTestHelper) connect(t *testing.T, s *Server, c *config.TLSConfig) net.Conn {
|
||||||
|
conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// configure TLS
|
||||||
|
_, err = conn.Write([]byte{byte(pool.RpcTLS)})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Client TLS verification isn't necessary for
|
||||||
|
// our assertions
|
||||||
|
tlsConf, err := tlsutil.NewTLSConfiguration(c, true, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
outTLSConf, err := tlsConf.OutgoingTLSConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
|
outTLSConf.InsecureSkipVerify = true
|
||||||
|
|
||||||
|
tlsConn := tls.Client(conn, outTLSConf)
|
||||||
|
require.NoError(t, tlsConn.Handshake())
|
||||||
|
|
||||||
|
return tlsConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h tlsTestHelper) nomadRPC(t *testing.T, s *Server, c *config.TLSConfig, method string, arg interface{}) error {
|
||||||
|
conn := h.connect(t, s, c)
|
||||||
|
defer conn.Close()
|
||||||
|
_, err := conn.Write([]byte{byte(pool.RpcNomad)})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
codec := pool.NewClientCodec(conn)
|
||||||
|
|
||||||
|
var out struct{}
|
||||||
|
return msgpackrpc.CallWithCodec(codec, method, arg, &out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h tlsTestHelper) raftRPC(t *testing.T, s *Server, c *config.TLSConfig) error {
|
||||||
|
conn := h.connect(t, s, c)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
_, err := conn.Write([]byte{byte(pool.RpcRaft)})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = doRaftRPC(conn, s.config.NodeName)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func doRaftRPC(conn net.Conn, leader string) (*raft.AppendEntriesResponse, error) {
|
func doRaftRPC(conn net.Conn, leader string) (*raft.AppendEntriesResponse, error) {
|
||||||
req := raft.AppendEntriesRequest{
|
req := raft.AppendEntriesRequest{
|
||||||
RPCHeader: raft.RPCHeader{ProtocolVersion: 3},
|
RPCHeader: raft.RPCHeader{ProtocolVersion: 3},
|
||||||
|
|||||||
@@ -265,9 +265,6 @@ type endpoints struct {
|
|||||||
Status *Status
|
Status *Status
|
||||||
Node *Node
|
Node *Node
|
||||||
Job *Job
|
Job *Job
|
||||||
Eval *Eval
|
|
||||||
Plan *Plan
|
|
||||||
Alloc *Alloc
|
|
||||||
CSIVolume *CSIVolume
|
CSIVolume *CSIVolume
|
||||||
CSIPlugin *CSIPlugin
|
CSIPlugin *CSIPlugin
|
||||||
Deployment *Deployment
|
Deployment *Deployment
|
||||||
@@ -1151,18 +1148,13 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) {
|
|||||||
if s.staticEndpoints.Status == nil {
|
if s.staticEndpoints.Status == nil {
|
||||||
// Initialize the list just once
|
// Initialize the list just once
|
||||||
s.staticEndpoints.ACL = &ACL{srv: s, logger: s.logger.Named("acl")}
|
s.staticEndpoints.ACL = &ACL{srv: s, logger: s.logger.Named("acl")}
|
||||||
s.staticEndpoints.Alloc = &Alloc{srv: s, logger: s.logger.Named("alloc")}
|
|
||||||
s.staticEndpoints.Eval = &Eval{srv: s, logger: s.logger.Named("eval")}
|
|
||||||
s.staticEndpoints.Job = NewJobEndpoints(s)
|
s.staticEndpoints.Job = NewJobEndpoints(s)
|
||||||
s.staticEndpoints.Node = &Node{srv: s, logger: s.logger.Named("client")} // Add but don't register
|
|
||||||
s.staticEndpoints.CSIVolume = &CSIVolume{srv: s, logger: s.logger.Named("csi_volume")}
|
s.staticEndpoints.CSIVolume = &CSIVolume{srv: s, logger: s.logger.Named("csi_volume")}
|
||||||
s.staticEndpoints.CSIPlugin = &CSIPlugin{srv: s, logger: s.logger.Named("csi_plugin")}
|
s.staticEndpoints.CSIPlugin = &CSIPlugin{srv: s, logger: s.logger.Named("csi_plugin")}
|
||||||
s.staticEndpoints.Deployment = &Deployment{srv: s, logger: s.logger.Named("deployment")}
|
|
||||||
s.staticEndpoints.Operator = &Operator{srv: s, logger: s.logger.Named("operator")}
|
s.staticEndpoints.Operator = &Operator{srv: s, logger: s.logger.Named("operator")}
|
||||||
s.staticEndpoints.Operator.register()
|
s.staticEndpoints.Operator.register()
|
||||||
|
|
||||||
s.staticEndpoints.Periodic = &Periodic{srv: s, logger: s.logger.Named("periodic")}
|
s.staticEndpoints.Periodic = &Periodic{srv: s, logger: s.logger.Named("periodic")}
|
||||||
s.staticEndpoints.Plan = &Plan{srv: s, logger: s.logger.Named("plan")}
|
|
||||||
s.staticEndpoints.Region = &Region{srv: s, logger: s.logger.Named("region")}
|
s.staticEndpoints.Region = &Region{srv: s, logger: s.logger.Named("region")}
|
||||||
s.staticEndpoints.Scaling = &Scaling{srv: s, logger: s.logger.Named("scaling")}
|
s.staticEndpoints.Scaling = &Scaling{srv: s, logger: s.logger.Named("scaling")}
|
||||||
s.staticEndpoints.Status = &Status{srv: s, logger: s.logger.Named("status")}
|
s.staticEndpoints.Status = &Status{srv: s, logger: s.logger.Named("status")}
|
||||||
@@ -1171,6 +1163,13 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) {
|
|||||||
s.staticEndpoints.Namespace = &Namespace{srv: s}
|
s.staticEndpoints.Namespace = &Namespace{srv: s}
|
||||||
s.staticEndpoints.Enterprise = NewEnterpriseEndpoints(s)
|
s.staticEndpoints.Enterprise = NewEnterpriseEndpoints(s)
|
||||||
|
|
||||||
|
// These endpoints are dynamic because they need access to the
|
||||||
|
// RPCContext, but they also need to be called directly in some cases,
|
||||||
|
// so store them into staticEndpoints for later access, but don't
|
||||||
|
// register them as static.
|
||||||
|
s.staticEndpoints.Deployment = &Deployment{srv: s, logger: s.logger.Named("deployment")}
|
||||||
|
s.staticEndpoints.Node = &Node{srv: s, logger: s.logger.Named("client")}
|
||||||
|
|
||||||
// Client endpoints
|
// Client endpoints
|
||||||
s.staticEndpoints.ClientStats = &ClientStats{srv: s, logger: s.logger.Named("client_stats")}
|
s.staticEndpoints.ClientStats = &ClientStats{srv: s, logger: s.logger.Named("client_stats")}
|
||||||
s.staticEndpoints.ClientAllocations = &ClientAllocations{srv: s, logger: s.logger.Named("client_allocs")}
|
s.staticEndpoints.ClientAllocations = &ClientAllocations{srv: s, logger: s.logger.Named("client_allocs")}
|
||||||
@@ -1191,15 +1190,11 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) {
|
|||||||
|
|
||||||
// Register the static handlers
|
// Register the static handlers
|
||||||
server.Register(s.staticEndpoints.ACL)
|
server.Register(s.staticEndpoints.ACL)
|
||||||
server.Register(s.staticEndpoints.Alloc)
|
|
||||||
server.Register(s.staticEndpoints.Eval)
|
|
||||||
server.Register(s.staticEndpoints.Job)
|
server.Register(s.staticEndpoints.Job)
|
||||||
server.Register(s.staticEndpoints.CSIVolume)
|
server.Register(s.staticEndpoints.CSIVolume)
|
||||||
server.Register(s.staticEndpoints.CSIPlugin)
|
server.Register(s.staticEndpoints.CSIPlugin)
|
||||||
server.Register(s.staticEndpoints.Deployment)
|
|
||||||
server.Register(s.staticEndpoints.Operator)
|
server.Register(s.staticEndpoints.Operator)
|
||||||
server.Register(s.staticEndpoints.Periodic)
|
server.Register(s.staticEndpoints.Periodic)
|
||||||
server.Register(s.staticEndpoints.Plan)
|
|
||||||
server.Register(s.staticEndpoints.Region)
|
server.Register(s.staticEndpoints.Region)
|
||||||
server.Register(s.staticEndpoints.Scaling)
|
server.Register(s.staticEndpoints.Scaling)
|
||||||
server.Register(s.staticEndpoints.Status)
|
server.Register(s.staticEndpoints.Status)
|
||||||
@@ -1214,10 +1209,18 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) {
|
|||||||
server.Register(s.staticEndpoints.Namespace)
|
server.Register(s.staticEndpoints.Namespace)
|
||||||
|
|
||||||
// Create new dynamic endpoints and add them to the RPC server.
|
// Create new dynamic endpoints and add them to the RPC server.
|
||||||
|
alloc := &Alloc{srv: s, ctx: ctx, logger: s.logger.Named("alloc")}
|
||||||
|
deployment := &Deployment{srv: s, ctx: ctx, logger: s.logger.Named("deployment")}
|
||||||
|
eval := &Eval{srv: s, ctx: ctx, logger: s.logger.Named("eval")}
|
||||||
node := &Node{srv: s, ctx: ctx, logger: s.logger.Named("client")}
|
node := &Node{srv: s, ctx: ctx, logger: s.logger.Named("client")}
|
||||||
|
plan := &Plan{srv: s, ctx: ctx, logger: s.logger.Named("plan")}
|
||||||
|
|
||||||
// Register the dynamic endpoints
|
// Register the dynamic endpoints
|
||||||
|
server.Register(alloc)
|
||||||
|
server.Register(deployment)
|
||||||
|
server.Register(eval)
|
||||||
server.Register(node)
|
server.Register(node)
|
||||||
|
server.Register(plan)
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupRaft is used to setup and initialize Raft
|
// setupRaft is used to setup and initialize Raft
|
||||||
|
|||||||
@@ -301,3 +301,27 @@ func getAlloc(state AllocGetter, allocID string) (*structs.Allocation, error) {
|
|||||||
|
|
||||||
return alloc, nil
|
return alloc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateLocalClientTLSCertificate checks if the provided RPC connection was
|
||||||
|
// initiated by a client in the same region as the target server.
|
||||||
|
func validateLocalClientTLSCertificate(srv *Server, ctx *RPCContext) error {
|
||||||
|
expected := fmt.Sprintf("client.%s.nomad", srv.Region())
|
||||||
|
return validateTLSCertificate(srv, ctx, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateLocalServerTLSCertificate checks if the provided RPC connection was
|
||||||
|
// initiated by a server in the same region as the target server.
|
||||||
|
func validateLocalServerTLSCertificate(srv *Server, ctx *RPCContext) error {
|
||||||
|
expected := fmt.Sprintf("server.%s.nomad", srv.Region())
|
||||||
|
return validateTLSCertificate(srv, ctx, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateTLSCertificate checks if the RPC connection mTLS certificates are
|
||||||
|
// valid for the given name.
|
||||||
|
func validateTLSCertificate(srv *Server, ctx *RPCContext, name string) error {
|
||||||
|
if srv.config.TLSConfig == nil || !srv.config.TLSConfig.VerifyServerHostname {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx.ValidateCertificateForName(name)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user