mirror of
https://github.com/kemko/nomad.git
synced 2026-01-08 03:15:42 +03:00
Merge pull request #5954 from hashicorp/b-fix-streaming-rpc-tls
rpc: use tls wrapped connection for streaming rpc
This commit is contained in:
22
nomad/rpc.go
22
nomad/rpc.go
@@ -540,18 +540,14 @@ func (r *rpcHandler) streamingRpc(server *serverParts, method string) (net.Conn,
|
||||
tcp.SetNoDelay(true)
|
||||
}
|
||||
|
||||
if err := r.streamingRpcImpl(conn, server.Region, method); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
return r.streamingRpcImpl(conn, server.Region, method)
|
||||
}
|
||||
|
||||
// streamingRpcImpl takes a pre-established connection to a server and conducts
|
||||
// the handshake to establish a streaming RPC for the given method. If an error
|
||||
// is returned, the underlying connection has been closed. Otherwise it is
|
||||
// assumed that the connection has been hijacked by the RPC method.
|
||||
func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) error {
|
||||
func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) (net.Conn, error) {
|
||||
// Check if TLS is enabled
|
||||
r.tlsWrapLock.RLock()
|
||||
tlsWrap := r.tlsWrap
|
||||
@@ -561,14 +557,14 @@ func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) erro
|
||||
// Switch the connection into TLS mode
|
||||
if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wrap the connection in a TLS client
|
||||
tlsConn, err := tlsWrap(region, conn)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
conn = tlsConn
|
||||
}
|
||||
@@ -576,7 +572,7 @@ func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) erro
|
||||
// Write the multiplex byte to set the mode
|
||||
if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send the header
|
||||
@@ -587,22 +583,22 @@ func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) erro
|
||||
}
|
||||
if err := encoder.Encode(header); err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wait for the acknowledgement
|
||||
var ack structs.StreamingRpcAck
|
||||
if err := decoder.Decode(&ack); err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ack.Error != "" {
|
||||
conn.Close()
|
||||
return errors.New(ack.Error)
|
||||
return nil, errors.New(ack.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// raftApplyFuture is used to encode a message, run it through raft, and return the Raft future.
|
||||
|
||||
@@ -10,8 +10,10 @@ import (
|
||||
"time"
|
||||
|
||||
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
cstructs "github.com/hashicorp/nomad/client/structs"
|
||||
"github.com/hashicorp/nomad/helper/pool"
|
||||
"github.com/hashicorp/nomad/helper/testlog"
|
||||
"github.com/hashicorp/nomad/helper/uuid"
|
||||
"github.com/hashicorp/nomad/nomad/mock"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
"github.com/hashicorp/nomad/nomad/structs/config"
|
||||
@@ -20,6 +22,7 @@ import (
|
||||
"github.com/hashicorp/yamux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ugorji/go/codec"
|
||||
)
|
||||
|
||||
// rpcClient is a test helper method to return a ClientCodec to use to make rpc
|
||||
@@ -267,6 +270,135 @@ func TestRPC_streamingRpcConn_badMethod_TLS(t *testing.T) {
|
||||
require.True(structs.IsErrUnknownMethod(err))
|
||||
}
|
||||
|
||||
func TestRPC_streamingRpcConn_goodMethod_Plaintext(t *testing.T) {
|
||||
t.Parallel()
|
||||
require := require.New(t)
|
||||
dir := tmpDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
s1 := TestServer(t, func(c *Config) {
|
||||
c.Region = "regionFoo"
|
||||
c.BootstrapExpect = 2
|
||||
c.DevMode = false
|
||||
c.DevDisableBootstrap = true
|
||||
c.DataDir = path.Join(dir, "node1")
|
||||
})
|
||||
defer s1.Shutdown()
|
||||
|
||||
s2 := TestServer(t, func(c *Config) {
|
||||
c.Region = "regionFoo"
|
||||
c.BootstrapExpect = 2
|
||||
c.DevMode = false
|
||||
c.DevDisableBootstrap = true
|
||||
c.DataDir = path.Join(dir, "node2")
|
||||
})
|
||||
defer s2.Shutdown()
|
||||
|
||||
TestJoin(t, s1, s2)
|
||||
testutil.WaitForLeader(t, s1.RPC)
|
||||
|
||||
s1.peerLock.RLock()
|
||||
ok, parts := isNomadServer(s2.LocalMember())
|
||||
require.True(ok)
|
||||
server := s1.localPeers[raft.ServerAddress(parts.Addr.String())]
|
||||
require.NotNil(server)
|
||||
s1.peerLock.RUnlock()
|
||||
|
||||
conn, err := s1.streamingRpc(server, "FileSystem.Logs")
|
||||
require.NotNil(conn)
|
||||
require.NoError(err)
|
||||
|
||||
decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
|
||||
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)
|
||||
|
||||
allocID := uuid.Generate()
|
||||
require.NoError(encoder.Encode(cstructs.FsStreamRequest{
|
||||
AllocID: allocID,
|
||||
QueryOptions: structs.QueryOptions{
|
||||
Region: "regionFoo",
|
||||
},
|
||||
}))
|
||||
|
||||
var result cstructs.StreamErrWrapper
|
||||
require.NoError(decoder.Decode(&result))
|
||||
require.Empty(result.Payload)
|
||||
require.True(structs.IsErrUnknownAllocation(result.Error))
|
||||
}
|
||||
|
||||
func TestRPC_streamingRpcConn_goodMethod_TLS(t *testing.T) {
|
||||
t.Parallel()
|
||||
require := require.New(t)
|
||||
const (
|
||||
cafile = "../helper/tlsutil/testdata/ca.pem"
|
||||
foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
|
||||
fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem"
|
||||
)
|
||||
dir := tmpDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
s1 := TestServer(t, func(c *Config) {
|
||||
c.Region = "regionFoo"
|
||||
c.BootstrapExpect = 2
|
||||
c.DevMode = false
|
||||
c.DevDisableBootstrap = true
|
||||
c.DataDir = path.Join(dir, "node1")
|
||||
c.TLSConfig = &config.TLSConfig{
|
||||
EnableHTTP: true,
|
||||
EnableRPC: true,
|
||||
VerifyServerHostname: true,
|
||||
CAFile: cafile,
|
||||
CertFile: foocert,
|
||||
KeyFile: fookey,
|
||||
}
|
||||
})
|
||||
defer s1.Shutdown()
|
||||
|
||||
s2 := TestServer(t, func(c *Config) {
|
||||
c.Region = "regionFoo"
|
||||
c.BootstrapExpect = 2
|
||||
c.DevMode = false
|
||||
c.DevDisableBootstrap = true
|
||||
c.DataDir = path.Join(dir, "node2")
|
||||
c.TLSConfig = &config.TLSConfig{
|
||||
EnableHTTP: true,
|
||||
EnableRPC: true,
|
||||
VerifyServerHostname: true,
|
||||
CAFile: cafile,
|
||||
CertFile: foocert,
|
||||
KeyFile: fookey,
|
||||
}
|
||||
})
|
||||
defer s2.Shutdown()
|
||||
|
||||
TestJoin(t, s1, s2)
|
||||
testutil.WaitForLeader(t, s1.RPC)
|
||||
|
||||
s1.peerLock.RLock()
|
||||
ok, parts := isNomadServer(s2.LocalMember())
|
||||
require.True(ok)
|
||||
server := s1.localPeers[raft.ServerAddress(parts.Addr.String())]
|
||||
require.NotNil(server)
|
||||
s1.peerLock.RUnlock()
|
||||
|
||||
conn, err := s1.streamingRpc(server, "FileSystem.Logs")
|
||||
require.NotNil(conn)
|
||||
require.NoError(err)
|
||||
|
||||
decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
|
||||
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)
|
||||
|
||||
allocID := uuid.Generate()
|
||||
require.NoError(encoder.Encode(cstructs.FsStreamRequest{
|
||||
AllocID: allocID,
|
||||
QueryOptions: structs.QueryOptions{
|
||||
Region: "regionFoo",
|
||||
},
|
||||
}))
|
||||
|
||||
var result cstructs.StreamErrWrapper
|
||||
require.NoError(decoder.Decode(&result))
|
||||
require.Empty(result.Payload)
|
||||
require.True(structs.IsErrUnknownAllocation(result.Error))
|
||||
}
|
||||
|
||||
// COMPAT: Remove in 0.10
|
||||
// This is a very low level test to assert that the V2 handling works. It is
|
||||
// making manual RPC calls since no helpers exist at this point since we are
|
||||
@@ -321,7 +453,7 @@ func TestRPC_handleMultiplexV2(t *testing.T) {
|
||||
require.NotEmpty(l)
|
||||
|
||||
// Make a streaming RPC
|
||||
err = s.streamingRpcImpl(s2, s.Region(), "Bogus")
|
||||
_, err = s.streamingRpcImpl(s2, s.Region(), "Bogus")
|
||||
require.NotNil(err)
|
||||
require.Contains(err.Error(), "Bogus")
|
||||
require.True(structs.IsErrUnknownMethod(err))
|
||||
|
||||
Reference in New Issue
Block a user