mirror of
https://github.com/kemko/nomad.git
synced 2026-01-01 16:05:42 +03:00
Merge pull request #11089 from hashicorp/b-cve-2021-37218
Apply authZ for nomad Raft RPC layer
This commit is contained in:
3
.changelog/11084.txt
Normal file
3
.changelog/11084.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
```release-note:security
|
||||
Restricted access to the Raft RPC layer, so only servers within the region can issue Raft RPC requests. Previously, local clients and federated servers can issue Raft RPC requests directly. [CVE-2021-37218](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37218)
|
||||
```
|
||||
298
helper/tlsutil/generate.go
Normal file
298
helper/tlsutil/generate.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package tlsutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GenerateSerialNumber returns random bigint generated with crypto/rand
|
||||
func GenerateSerialNumber() (*big.Int, error) {
|
||||
l := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
s, err := rand.Int(rand.Reader, l)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// GeneratePrivateKey generates a new ecdsa private key
|
||||
func GeneratePrivateKey() (crypto.Signer, string, error) {
|
||||
curve := elliptic.P256()
|
||||
|
||||
pk, err := ecdsa.GenerateKey(curve, rand.Reader)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error generating ECDSA private key: %s", err)
|
||||
}
|
||||
|
||||
bs, err := x509.MarshalECPrivateKey(pk)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error marshaling ECDSA private key: %s", err)
|
||||
}
|
||||
|
||||
pemBlock, err := pemEncodeKey(bs, "EC PRIVATE KEY")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return pk, pemBlock, nil
|
||||
}
|
||||
|
||||
func pemEncodeKey(key []byte, blockType string) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := pem.Encode(&buf, &pem.Block{Type: blockType, Bytes: key}); err != nil {
|
||||
return "", fmt.Errorf("error encoding private key: %s", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
type CAOpts struct {
|
||||
Signer crypto.Signer
|
||||
Serial *big.Int
|
||||
Days int
|
||||
PermittedDNSDomains []string
|
||||
Domain string
|
||||
Name string
|
||||
}
|
||||
|
||||
type CertOpts struct {
|
||||
Signer crypto.Signer
|
||||
CA string
|
||||
Serial *big.Int
|
||||
Name string
|
||||
Days int
|
||||
DNSNames []string
|
||||
IPAddresses []net.IP
|
||||
ExtKeyUsage []x509.ExtKeyUsage
|
||||
}
|
||||
|
||||
// GenerateCA generates a new CA for agent TLS (not to be confused with Connect TLS)
|
||||
func GenerateCA(opts CAOpts) (string, string, error) {
|
||||
signer := opts.Signer
|
||||
var pk string
|
||||
if signer == nil {
|
||||
var err error
|
||||
signer, pk, err = GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
id, err := keyID(signer.Public())
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
sn := opts.Serial
|
||||
if sn == nil {
|
||||
var err error
|
||||
sn, err = GenerateSerialNumber()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
name := opts.Name
|
||||
if name == "" {
|
||||
name = fmt.Sprintf("Consul Agent CA %d", sn)
|
||||
}
|
||||
|
||||
days := opts.Days
|
||||
if opts.Days == 0 {
|
||||
days = 365
|
||||
}
|
||||
|
||||
// Create the CA cert
|
||||
template := x509.Certificate{
|
||||
SerialNumber: sn,
|
||||
Subject: pkix.Name{
|
||||
Country: []string{"US"},
|
||||
PostalCode: []string{"94105"},
|
||||
Province: []string{"CA"},
|
||||
Locality: []string{"San Francisco"},
|
||||
StreetAddress: []string{"101 Second Street"},
|
||||
Organization: []string{"HashiCorp Inc."},
|
||||
CommonName: name,
|
||||
},
|
||||
BasicConstraintsValid: true,
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign | x509.KeyUsageDigitalSignature,
|
||||
IsCA: true,
|
||||
NotAfter: time.Now().AddDate(0, 0, days),
|
||||
NotBefore: time.Now(),
|
||||
AuthorityKeyId: id,
|
||||
SubjectKeyId: id,
|
||||
}
|
||||
|
||||
if len(opts.PermittedDNSDomains) > 0 {
|
||||
template.PermittedDNSDomainsCritical = true
|
||||
template.PermittedDNSDomains = opts.PermittedDNSDomains
|
||||
}
|
||||
bs, err := x509.CreateCertificate(
|
||||
rand.Reader, &template, &template, signer.Public(), signer)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("error generating CA certificate: %s", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("error encoding private key: %s", err)
|
||||
}
|
||||
|
||||
return buf.String(), pk, nil
|
||||
}
|
||||
|
||||
// GenerateCert generates a new certificate for agent TLS (not to be confused with Connect TLS)
|
||||
func GenerateCert(opts CertOpts) (string, string, error) {
|
||||
parent, err := parseCert(opts.CA)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
signee, pk, err := GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
id, err := keyID(signee.Public())
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
sn := opts.Serial
|
||||
if sn == nil {
|
||||
var err error
|
||||
sn, err = GenerateSerialNumber()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: sn,
|
||||
Subject: pkix.Name{CommonName: opts.Name},
|
||||
BasicConstraintsValid: true,
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: opts.ExtKeyUsage,
|
||||
IsCA: false,
|
||||
NotAfter: time.Now().AddDate(0, 0, opts.Days),
|
||||
NotBefore: time.Now(),
|
||||
SubjectKeyId: id,
|
||||
DNSNames: opts.DNSNames,
|
||||
IPAddresses: opts.IPAddresses,
|
||||
}
|
||||
|
||||
bs, err := x509.CreateCertificate(rand.Reader, &template, parent, signee.Public(), opts.Signer)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("error encoding private key: %s", err)
|
||||
}
|
||||
|
||||
return buf.String(), pk, nil
|
||||
}
|
||||
|
||||
// KeyId returns a x509 KeyId from the given signing key.
|
||||
func keyID(raw interface{}) ([]byte, error) {
|
||||
switch raw.(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
case *rsa.PublicKey:
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid key type: %T", raw)
|
||||
}
|
||||
|
||||
// This is not standard; RFC allows any unique identifier as long as they
|
||||
// match in subject/authority chains but suggests specific hashing of DER
|
||||
// bytes of public key including DER tags.
|
||||
bs, err := x509.MarshalPKIXPublicKey(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// String formatted
|
||||
kID := sha256.Sum256(bs)
|
||||
return kID[:], nil
|
||||
}
|
||||
|
||||
func parseCert(pemValue string) (*x509.Certificate, error) {
|
||||
// The _ result below is not an error but the remaining PEM bytes.
|
||||
block, _ := pem.Decode([]byte(pemValue))
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("no PEM-encoded data found")
|
||||
}
|
||||
|
||||
if block.Type != "CERTIFICATE" {
|
||||
return nil, fmt.Errorf("first PEM-block should be CERTIFICATE type")
|
||||
}
|
||||
|
||||
return x509.ParseCertificate(block.Bytes)
|
||||
}
|
||||
|
||||
// ParseSigner parses a crypto.Signer from a PEM-encoded key. The private key
|
||||
// is expected to be the first block in the PEM value.
|
||||
func ParseSigner(pemValue string) (crypto.Signer, error) {
|
||||
// The _ result below is not an error but the remaining PEM bytes.
|
||||
block, _ := pem.Decode([]byte(pemValue))
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("no PEM-encoded data found")
|
||||
}
|
||||
|
||||
switch block.Type {
|
||||
case "EC PRIVATE KEY":
|
||||
return x509.ParseECPrivateKey(block.Bytes)
|
||||
|
||||
case "RSA PRIVATE KEY":
|
||||
return x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
|
||||
case "PRIVATE KEY":
|
||||
signer, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pk, ok := signer.(crypto.Signer)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("private key is not a valid format")
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown PEM block type for signing key: %s", block.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func Verify(caString, certString, dns string) error {
|
||||
roots := x509.NewCertPool()
|
||||
ok := roots.AppendCertsFromPEM([]byte(caString))
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to parse root certificate")
|
||||
}
|
||||
|
||||
cert, err := parseCert(certString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse certificate")
|
||||
}
|
||||
|
||||
opts := x509.VerifyOptions{
|
||||
DNSName: fmt.Sprint(dns),
|
||||
Roots: roots,
|
||||
}
|
||||
|
||||
_, err = cert.Verify(opts)
|
||||
return err
|
||||
}
|
||||
159
helper/tlsutil/generate_test.go
Normal file
159
helper/tlsutil/generate_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package tlsutil
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"strings"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSerialNumber(t *testing.T) {
|
||||
n1, err := GenerateSerialNumber()
|
||||
require.Nil(t, err)
|
||||
|
||||
n2, err := GenerateSerialNumber()
|
||||
require.Nil(t, err)
|
||||
require.NotEqual(t, n1, n2)
|
||||
|
||||
n3, err := GenerateSerialNumber()
|
||||
require.Nil(t, err)
|
||||
require.NotEqual(t, n1, n3)
|
||||
require.NotEqual(t, n2, n3)
|
||||
|
||||
}
|
||||
|
||||
func TestGeneratePrivateKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, p, err := GeneratePrivateKey()
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, p)
|
||||
require.Contains(t, p, "BEGIN EC PRIVATE KEY")
|
||||
require.Contains(t, p, "END EC PRIVATE KEY")
|
||||
|
||||
block, _ := pem.Decode([]byte(p))
|
||||
pk, err := x509.ParseECPrivateKey(block.Bytes)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, pk)
|
||||
require.Equal(t, 256, pk.Params().BitSize)
|
||||
}
|
||||
|
||||
type TestSigner struct {
|
||||
public interface{}
|
||||
}
|
||||
|
||||
func (s *TestSigner) Public() crypto.PublicKey {
|
||||
return s.public
|
||||
}
|
||||
|
||||
func (s *TestSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
func TestGenerateCA(t *testing.T) {
|
||||
t.Run("no signer", func(t *testing.T) {
|
||||
ca, pk, err := GenerateCA(CAOpts{Signer: &TestSigner{}})
|
||||
require.Error(t, err)
|
||||
require.Empty(t, ca)
|
||||
require.Empty(t, pk)
|
||||
})
|
||||
|
||||
t.Run("wrong key", func(t *testing.T) {
|
||||
ca, pk, err := GenerateCA(CAOpts{Signer: &TestSigner{public: &rsa.PublicKey{}}})
|
||||
require.Error(t, err)
|
||||
require.Empty(t, ca)
|
||||
require.Empty(t, pk)
|
||||
})
|
||||
|
||||
t.Run("valid key", func(t *testing.T) {
|
||||
ca, pk, err := GenerateCA(CAOpts{})
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, ca)
|
||||
require.NotEmpty(t, pk)
|
||||
|
||||
cert, err := parseCert(ca)
|
||||
require.Nil(t, err)
|
||||
require.True(t, strings.HasPrefix(cert.Subject.CommonName, "Consul Agent CA"))
|
||||
require.Equal(t, true, cert.IsCA)
|
||||
require.Equal(t, true, cert.BasicConstraintsValid)
|
||||
|
||||
require.WithinDuration(t, cert.NotBefore, time.Now(), time.Minute)
|
||||
require.WithinDuration(t, cert.NotAfter, time.Now().AddDate(0, 0, 365), time.Minute)
|
||||
|
||||
require.Equal(t, x509.KeyUsageCertSign|x509.KeyUsageCRLSign|x509.KeyUsageDigitalSignature, cert.KeyUsage)
|
||||
})
|
||||
|
||||
t.Run("RSA key", func(t *testing.T) {
|
||||
ca, pk, err := GenerateCA(CAOpts{})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, ca)
|
||||
require.NotEmpty(t, pk)
|
||||
|
||||
cert, err := parseCert(ca)
|
||||
require.NoError(t, err)
|
||||
require.True(t, strings.HasPrefix(cert.Subject.CommonName, "Consul Agent CA"))
|
||||
require.Equal(t, true, cert.IsCA)
|
||||
require.Equal(t, true, cert.BasicConstraintsValid)
|
||||
|
||||
require.WithinDuration(t, cert.NotBefore, time.Now(), time.Minute)
|
||||
require.WithinDuration(t, cert.NotAfter, time.Now().AddDate(0, 0, 365), time.Minute)
|
||||
|
||||
require.Equal(t, x509.KeyUsageCertSign|x509.KeyUsageCRLSign|x509.KeyUsageDigitalSignature, cert.KeyUsage)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateCert(t *testing.T) {
|
||||
t.Parallel()
|
||||
signer, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.Nil(t, err)
|
||||
ca, _, err := GenerateCA(CAOpts{Signer: signer})
|
||||
require.Nil(t, err)
|
||||
|
||||
DNSNames := []string{"server.dc1.consul"}
|
||||
IPAddresses := []net.IP{net.ParseIP("123.234.243.213")}
|
||||
extKeyUsage := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
|
||||
name := "Cert Name"
|
||||
certificate, pk, err := GenerateCert(CertOpts{
|
||||
Signer: signer, CA: ca, Name: name, Days: 365,
|
||||
DNSNames: DNSNames, IPAddresses: IPAddresses, ExtKeyUsage: extKeyUsage,
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, certificate)
|
||||
require.NotEmpty(t, pk)
|
||||
|
||||
cert, err := parseCert(certificate)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, name, cert.Subject.CommonName)
|
||||
require.Equal(t, true, cert.BasicConstraintsValid)
|
||||
signee, err := ParseSigner(pk)
|
||||
require.Nil(t, err)
|
||||
certID, err := keyID(signee.Public())
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, certID, cert.SubjectKeyId)
|
||||
caID, err := keyID(signer.Public())
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, caID, cert.AuthorityKeyId)
|
||||
require.Contains(t, cert.Issuer.CommonName, "Consul Agent CA")
|
||||
require.Equal(t, false, cert.IsCA)
|
||||
|
||||
require.WithinDuration(t, cert.NotBefore, time.Now(), time.Minute)
|
||||
require.WithinDuration(t, cert.NotAfter, time.Now().AddDate(0, 0, 365), time.Minute)
|
||||
|
||||
require.Equal(t, x509.KeyUsageDigitalSignature|x509.KeyUsageKeyEncipherment, cert.KeyUsage)
|
||||
require.Equal(t, extKeyUsage, cert.ExtKeyUsage)
|
||||
|
||||
// https://github.com/golang/go/blob/10538a8f9e2e718a47633ac5a6e90415a2c3f5f1/src/crypto/x509/verify.go#L414
|
||||
require.Equal(t, DNSNames, cert.DNSNames)
|
||||
require.True(t, IPAddresses[0].Equal(cert.IPAddresses[0]))
|
||||
}
|
||||
40
nomad/rpc.go
40
nomad/rpc.go
@@ -238,6 +238,11 @@ func (r *rpcHandler) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCC
|
||||
|
||||
case pool.RpcRaft:
|
||||
metrics.IncrCounter([]string{"nomad", "rpc", "raft_handoff"}, 1)
|
||||
// Ensure that when TLS is configured, only certificates from `server.<region>.nomad` are accepted for Raft connections.
|
||||
if err := r.validateRaftTLS(rpcCtx); err != nil {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
r.raftLayer.Handoff(ctx, conn)
|
||||
|
||||
case pool.RpcMultiplex:
|
||||
@@ -825,3 +830,38 @@ RUN_QUERY:
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *rpcHandler) validateRaftTLS(rpcCtx *RPCContext) error {
|
||||
// TLS is not configured or not to be enforced
|
||||
tlsConf := r.config.TLSConfig
|
||||
if !tlsConf.EnableRPC || !tlsConf.VerifyServerHostname || tlsConf.RPCUpgradeMode {
|
||||
return nil
|
||||
}
|
||||
|
||||
// defensive conditions: these should have already been enforced by handleConn
|
||||
if rpcCtx == nil || !rpcCtx.TLS {
|
||||
return errors.New("non-TLS connection attempted")
|
||||
}
|
||||
if len(rpcCtx.VerifiedChains) == 0 || len(rpcCtx.VerifiedChains[0]) == 0 {
|
||||
// this should never happen, as rpcNameAndRegionValidate should have enforced it
|
||||
return errors.New("missing cert info")
|
||||
}
|
||||
|
||||
// check that `server.<region>.nomad` is present in cert
|
||||
expected := "server." + r.Region() + ".nomad"
|
||||
|
||||
cert := rpcCtx.VerifiedChains[0][0]
|
||||
for _, dnsName := range cert.DNSNames {
|
||||
if dnsName == expected {
|
||||
// Certificate is valid for the expected name
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if cert.Subject.CommonName == expected {
|
||||
// Certificate is valid for the expected name
|
||||
return nil
|
||||
}
|
||||
|
||||
r.logger.Warn("unauthorized raft connection", "remote_addr", rpcCtx.Conn.RemoteAddr(), "required_hostname", expected, "found", cert.DNSNames)
|
||||
return fmt.Errorf("certificate is invalid for expected role or region: %q", expected)
|
||||
}
|
||||
|
||||
@@ -3,13 +3,16 @@ package nomad
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/rpc"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -1014,3 +1017,253 @@ func TestRPC_Limits_Streaming(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRPC_TLS_Enforcement(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
defer func() {
|
||||
//TODO Avoid panics from logging during shutdown
|
||||
time.Sleep(1 * time.Second)
|
||||
}()
|
||||
|
||||
dir := tmpDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
caPEM, pk, err := tlsutil.GenerateCA(tlsutil.CAOpts{Days: 5, Domain: "nomad"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ioutil.WriteFile(filepath.Join(dir, "ca.pem"), []byte(caPEM), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
nodeID := 1
|
||||
newCert := func(t *testing.T, name string) string {
|
||||
t.Helper()
|
||||
|
||||
node := fmt.Sprintf("node%d", nodeID)
|
||||
nodeID++
|
||||
signer, err := tlsutil.ParseSigner(pk)
|
||||
require.NoError(t, err)
|
||||
|
||||
pem, key, err := tlsutil.GenerateCert(tlsutil.CertOpts{
|
||||
Signer: signer,
|
||||
CA: caPEM,
|
||||
Name: name,
|
||||
Days: 5,
|
||||
DNSNames: []string{node + "." + name, name, "localhost"},
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".pem"), []byte(pem), 0600)
|
||||
require.NoError(t, err)
|
||||
err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".key"), []byte(key), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
return filepath.Join(dir, node+"-"+name)
|
||||
}
|
||||
|
||||
connect := func(t *testing.T, s *Server, c *config.TLSConfig) net.Conn {
|
||||
conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
// configure TLS
|
||||
_, err = conn.Write([]byte{byte(pool.RpcTLS)})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Client TLS verification isn't necessary for
|
||||
// our assertions
|
||||
tlsConf, err := tlsutil.NewTLSConfiguration(c, true, true)
|
||||
require.NoError(t, err)
|
||||
outTLSConf, err := tlsConf.OutgoingTLSConfig()
|
||||
require.NoError(t, err)
|
||||
outTLSConf.InsecureSkipVerify = true
|
||||
|
||||
tlsConn := tls.Client(conn, outTLSConf)
|
||||
require.NoError(t, tlsConn.Handshake())
|
||||
|
||||
return tlsConn
|
||||
}
|
||||
|
||||
nomadRPC := func(t *testing.T, s *Server, c *config.TLSConfig) error {
|
||||
conn := connect(t, s, c)
|
||||
defer conn.Close()
|
||||
_, err := conn.Write([]byte{byte(pool.RpcNomad)})
|
||||
require.NoError(t, err)
|
||||
|
||||
codec := pool.NewClientCodec(conn)
|
||||
|
||||
arg := struct{}{}
|
||||
var out struct{}
|
||||
return msgpackrpc.CallWithCodec(codec, "Status.Ping", arg, &out)
|
||||
}
|
||||
|
||||
raftRPC := func(t *testing.T, s *Server, c *config.TLSConfig) error {
|
||||
conn := connect(t, s, c)
|
||||
defer conn.Close()
|
||||
|
||||
_, err := conn.Write([]byte{byte(pool.RpcRaft)})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = doRaftRPC(conn, s.config.NodeName)
|
||||
return err
|
||||
}
|
||||
|
||||
// generate server cert
|
||||
serverCert := newCert(t, "server.global.nomad")
|
||||
|
||||
mtlsS, cleanup := TestServer(t, func(c *Config) {
|
||||
c.TLSConfig = &config.TLSConfig{
|
||||
EnableRPC: true,
|
||||
VerifyServerHostname: true,
|
||||
CAFile: filepath.Join(dir, "ca.pem"),
|
||||
CertFile: serverCert + ".pem",
|
||||
KeyFile: serverCert + ".key",
|
||||
}
|
||||
})
|
||||
defer cleanup()
|
||||
|
||||
nonVerifyS, cleanup := TestServer(t, func(c *Config) {
|
||||
c.TLSConfig = &config.TLSConfig{
|
||||
EnableRPC: true,
|
||||
VerifyServerHostname: false,
|
||||
CAFile: filepath.Join(dir, "ca.pem"),
|
||||
CertFile: serverCert + ".pem",
|
||||
KeyFile: serverCert + ".key",
|
||||
}
|
||||
})
|
||||
defer cleanup()
|
||||
|
||||
// When VerifyServerHostname is enabled:
|
||||
// Only all servers and local clients can make RPC requests
|
||||
// Only local servers can connect to the Raft layer
|
||||
cases := []struct {
|
||||
name string
|
||||
cn string
|
||||
canRPC bool
|
||||
canRaft bool
|
||||
}{
|
||||
{
|
||||
name: "local server",
|
||||
cn: "server.global.nomad",
|
||||
canRPC: true,
|
||||
canRaft: true,
|
||||
},
|
||||
{
|
||||
name: "local client",
|
||||
cn: "client.global.nomad",
|
||||
canRPC: true,
|
||||
canRaft: false,
|
||||
},
|
||||
{
|
||||
name: "other region server",
|
||||
cn: "server.other.nomad",
|
||||
canRPC: true,
|
||||
canRaft: false,
|
||||
},
|
||||
{
|
||||
name: "other client server",
|
||||
cn: "client.other.nomad",
|
||||
canRPC: false,
|
||||
canRaft: false,
|
||||
},
|
||||
{
|
||||
name: "irrelevant cert",
|
||||
cn: "nomad.example.com",
|
||||
canRPC: false,
|
||||
canRaft: false,
|
||||
},
|
||||
{
|
||||
name: "globs",
|
||||
cn: "*.global.nomad",
|
||||
canRPC: false,
|
||||
canRaft: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
certPath := newCert(t, tc.cn)
|
||||
|
||||
cfg := &config.TLSConfig{
|
||||
EnableRPC: true,
|
||||
VerifyServerHostname: true,
|
||||
CAFile: filepath.Join(dir, "ca.pem"),
|
||||
CertFile: certPath + ".pem",
|
||||
KeyFile: certPath + ".key",
|
||||
}
|
||||
|
||||
t.Run("nomad RPC: verify_hostname=true", func(t *testing.T) {
|
||||
err := nomadRPC(t, mtlsS, cfg)
|
||||
|
||||
if tc.canRPC {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "bad certificate")
|
||||
}
|
||||
})
|
||||
t.Run("nomad RPC: verify_hostname=false", func(t *testing.T) {
|
||||
err := nomadRPC(t, nonVerifyS, cfg)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Raft RPC: verify_hostname=true", func(t *testing.T) {
|
||||
err := raftRPC(t, mtlsS, cfg)
|
||||
|
||||
// the expected error depends on location of failure.
|
||||
// We expect "bad certificate" if connection fails during handshake,
|
||||
// or EOF when connection is closed after RaftRPC byte.
|
||||
if tc.canRaft {
|
||||
require.NoError(t, err)
|
||||
} else if !tc.canRPC {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "bad certificate")
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "EOF")
|
||||
}
|
||||
})
|
||||
t.Run("Raft RPC: verify_hostname=false", func(t *testing.T) {
|
||||
err := raftRPC(t, nonVerifyS, cfg)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func doRaftRPC(conn net.Conn, leader string) (*raft.AppendEntriesResponse, error) {
|
||||
req := raft.AppendEntriesRequest{
|
||||
RPCHeader: raft.RPCHeader{ProtocolVersion: 3},
|
||||
Term: 0,
|
||||
Leader: []byte(leader),
|
||||
PrevLogEntry: 0,
|
||||
PrevLogTerm: 0xc,
|
||||
LeaderCommitIndex: 50,
|
||||
}
|
||||
|
||||
enc := codec.NewEncoder(conn, &codec.MsgpackHandle{})
|
||||
dec := codec.NewDecoder(conn, &codec.MsgpackHandle{})
|
||||
|
||||
const rpcAppendEntries = 0
|
||||
if _, err := conn.Write([]byte{rpcAppendEntries}); err != nil {
|
||||
return nil, fmt.Errorf("failed to write raft-RPC byte: %w", err)
|
||||
}
|
||||
|
||||
if err := enc.Encode(req); err != nil {
|
||||
return nil, fmt.Errorf("failed to send append entries RPC: %w", err)
|
||||
}
|
||||
|
||||
var rpcError string
|
||||
var resp raft.AppendEntriesResponse
|
||||
if err := dec.Decode(&rpcError); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response error: %w", err)
|
||||
}
|
||||
if rpcError != "" {
|
||||
return nil, fmt.Errorf("rpc error: %v", rpcError)
|
||||
}
|
||||
if err := dec.Decode(&resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user