mirror of
https://github.com/kemko/nomad.git
synced 2026-01-01 16:05:42 +03:00
Merge pull request #1629 from hashicorp/f-derive-token
Server Deriving Tokens on behalf of Clients
This commit is contained in:
57
nomad/fsm.go
57
nomad/fsm.go
@@ -35,6 +35,7 @@ const (
|
||||
TimeTableSnapshot
|
||||
PeriodicLaunchSnapshot
|
||||
JobSummarySnapshot
|
||||
VaultAccessorSnapshot
|
||||
)
|
||||
|
||||
// nomadFSM implements a finite state machine that is used
|
||||
@@ -137,6 +138,8 @@ func (n *nomadFSM) Apply(log *raft.Log) interface{} {
|
||||
return n.applyAllocClientUpdate(buf[1:], log.Index)
|
||||
case structs.ReconcileJobSummariesRequestType:
|
||||
return n.applyReconcileSummaries(buf[1:], log.Index)
|
||||
case structs.VaultAccessorRegisterRequestType:
|
||||
return n.applyUpsertVaultAccessor(buf[1:], log.Index)
|
||||
default:
|
||||
if ignoreUnknown {
|
||||
n.logger.Printf("[WARN] nomad.fsm: ignoring unknown message type (%d), upgrade to newer version", msgType)
|
||||
@@ -454,6 +457,23 @@ func (n *nomadFSM) applyReconcileSummaries(buf []byte, index uint64) interface{}
|
||||
return n.reconcileQueuedAllocations(index)
|
||||
}
|
||||
|
||||
// applyUpsertVaultAccessor stores the Vault accessors for a given allocation
|
||||
// and task
|
||||
func (n *nomadFSM) applyUpsertVaultAccessor(buf []byte, index uint64) interface{} {
|
||||
defer metrics.MeasureSince([]string{"nomad", "fsm", "upsert_vault_accessor"}, time.Now())
|
||||
var req structs.VaultAccessorRegisterRequest
|
||||
if err := structs.Decode(buf, &req); err != nil {
|
||||
panic(fmt.Errorf("failed to decode request: %v", err))
|
||||
}
|
||||
|
||||
if err := n.state.UpsertVaultAccessor(index, req.Accessors); err != nil {
|
||||
n.logger.Printf("[ERR] nomad.fsm: UpsertVaultAccessor failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *nomadFSM) Snapshot() (raft.FSMSnapshot, error) {
|
||||
// Create a new snapshot
|
||||
snap, err := n.state.Snapshot()
|
||||
@@ -583,6 +603,15 @@ func (n *nomadFSM) Restore(old io.ReadCloser) error {
|
||||
return err
|
||||
}
|
||||
|
||||
case VaultAccessorSnapshot:
|
||||
accessor := new(structs.VaultAccessor)
|
||||
if err := dec.Decode(accessor); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := restore.VaultAccessorRestore(accessor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("Unrecognized snapshot type: %v", msgType)
|
||||
}
|
||||
@@ -756,6 +785,10 @@ func (s *nomadSnapshot) Persist(sink raft.SnapshotSink) error {
|
||||
sink.Cancel()
|
||||
return err
|
||||
}
|
||||
if err := s.persistVaultAccessors(sink, encoder); err != nil {
|
||||
sink.Cancel()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -945,6 +978,30 @@ func (s *nomadSnapshot) persistJobSummaries(sink raft.SnapshotSink,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *nomadSnapshot) persistVaultAccessors(sink raft.SnapshotSink,
|
||||
encoder *codec.Encoder) error {
|
||||
|
||||
accessors, err := s.snap.VaultAccessors()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
raw := accessors.Next()
|
||||
if raw == nil {
|
||||
break
|
||||
}
|
||||
|
||||
accessor := raw.(*structs.VaultAccessor)
|
||||
|
||||
sink.Write([]byte{byte(VaultAccessorSnapshot)})
|
||||
if err := encoder.Encode(accessor); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Release is a no-op, as we just need to GC the pointer
|
||||
// to the state store snapshot. There is nothing to explicitly
|
||||
// cleanup.
|
||||
|
||||
@@ -770,6 +770,54 @@ func TestFSM_UpdateAllocFromClient(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFSM_UpsertVaultAccessor(t *testing.T) {
|
||||
fsm := testFSM(t)
|
||||
fsm.blockedEvals.SetEnabled(true)
|
||||
|
||||
va := mock.VaultAccessor()
|
||||
va2 := mock.VaultAccessor()
|
||||
req := structs.VaultAccessorRegisterRequest{
|
||||
Accessors: []*structs.VaultAccessor{va, va2},
|
||||
}
|
||||
buf, err := structs.Encode(structs.VaultAccessorRegisterRequestType, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
resp := fsm.Apply(makeLog(buf))
|
||||
if resp != nil {
|
||||
t.Fatalf("resp: %v", resp)
|
||||
}
|
||||
|
||||
// Verify we are registered
|
||||
out1, err := fsm.State().VaultAccessor(va.Accessor)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if out1 == nil {
|
||||
t.Fatalf("not found!")
|
||||
}
|
||||
if out1.CreateIndex != 1 {
|
||||
t.Fatalf("bad index: %d", out1.CreateIndex)
|
||||
}
|
||||
out2, err := fsm.State().VaultAccessor(va2.Accessor)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if out2 == nil {
|
||||
t.Fatalf("not found!")
|
||||
}
|
||||
if out1.CreateIndex != 1 {
|
||||
t.Fatalf("bad index: %d", out2.CreateIndex)
|
||||
}
|
||||
|
||||
tt := fsm.TimeTable()
|
||||
index := tt.NearestIndex(time.Now().UTC())
|
||||
if index != 1 {
|
||||
t.Fatalf("bad: %d", index)
|
||||
}
|
||||
}
|
||||
|
||||
func testSnapshotRestore(t *testing.T, fsm *nomadFSM) *nomadFSM {
|
||||
// Snapshot
|
||||
snap, err := fsm.Snapshot()
|
||||
@@ -976,6 +1024,27 @@ func TestFSM_SnapshotRestore_JobSummary(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFSM_SnapshotRestore_VaultAccessors(t *testing.T) {
|
||||
// Add some state
|
||||
fsm := testFSM(t)
|
||||
state := fsm.State()
|
||||
a1 := mock.VaultAccessor()
|
||||
a2 := mock.VaultAccessor()
|
||||
state.UpsertVaultAccessor(1000, []*structs.VaultAccessor{a1, a2})
|
||||
|
||||
// Verify the contents
|
||||
fsm2 := testSnapshotRestore(t, fsm)
|
||||
state2 := fsm2.State()
|
||||
out1, _ := state2.VaultAccessor(a1.Accessor)
|
||||
out2, _ := state2.VaultAccessor(a2.Accessor)
|
||||
if !reflect.DeepEqual(a1, out1) {
|
||||
t.Fatalf("bad: \n%#v\n%#v", out1, a1)
|
||||
}
|
||||
if !reflect.DeepEqual(a2, out2) {
|
||||
t.Fatalf("bad: \n%#v\n%#v", out2, a2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFSM_SnapshotRestore_AddMissingSummary(t *testing.T) {
|
||||
// Add some state
|
||||
fsm := testFSM(t)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package nomad
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -83,7 +84,7 @@ func (j *Job) Register(args *structs.JobRegisterRequest, reply *structs.JobRegis
|
||||
}
|
||||
|
||||
vault := j.srv.vault
|
||||
s, err := vault.LookupToken(args.Job.VaultToken)
|
||||
s, err := vault.LookupToken(context.Background(), args.Job.VaultToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -290,6 +290,16 @@ func Alloc() *structs.Allocation {
|
||||
return alloc
|
||||
}
|
||||
|
||||
func VaultAccessor() *structs.VaultAccessor {
|
||||
return &structs.VaultAccessor{
|
||||
Accessor: structs.GenerateUUID(),
|
||||
NodeID: structs.GenerateUUID(),
|
||||
AllocID: structs.GenerateUUID(),
|
||||
CreationTTL: 86400,
|
||||
Task: "foo",
|
||||
}
|
||||
}
|
||||
|
||||
func Plan() *structs.Plan {
|
||||
return &structs.Plan{
|
||||
Priority: 50,
|
||||
|
||||
@@ -1,21 +1,29 @@
|
||||
package nomad
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
"github.com/hashicorp/nomad/nomad/state"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
"github.com/hashicorp/nomad/nomad/watch"
|
||||
vapi "github.com/hashicorp/vault/api"
|
||||
)
|
||||
|
||||
const (
|
||||
// batchUpdateInterval is how long we wait to batch updates
|
||||
batchUpdateInterval = 50 * time.Millisecond
|
||||
|
||||
// maxParallelRequestsPerDerive is the maximum number of parallel Vault
|
||||
// create token requests that may be outstanding per derive request
|
||||
maxParallelRequestsPerDerive = 16
|
||||
)
|
||||
|
||||
// Node endpoint is used for client interactions
|
||||
@@ -868,3 +876,176 @@ func (b *batchFuture) Respond(index uint64, err error) {
|
||||
b.err = err
|
||||
close(b.doneCh)
|
||||
}
|
||||
|
||||
// DeriveVaultToken is used by the clients to request wrapped Vault tokens for
|
||||
// tasks
|
||||
func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest,
|
||||
reply *structs.DeriveVaultTokenResponse) error {
|
||||
if done, err := n.srv.forward("Node.DeriveVaultToken", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
defer metrics.MeasureSince([]string{"nomad", "client", "derive_vault_token"}, time.Now())
|
||||
|
||||
// Verify the arguments
|
||||
if args.NodeID == "" {
|
||||
return fmt.Errorf("missing node ID")
|
||||
}
|
||||
if args.SecretID == "" {
|
||||
return fmt.Errorf("missing node SecretID")
|
||||
}
|
||||
if args.AllocID == "" {
|
||||
return fmt.Errorf("missing allocation ID")
|
||||
}
|
||||
if len(args.Tasks) == 0 {
|
||||
return fmt.Errorf("no tasks specified")
|
||||
}
|
||||
|
||||
// Verify the following:
|
||||
// * The Node exists and has the correct SecretID
|
||||
// * The Allocation exists on the specified node
|
||||
// * The allocation contains the given tasks and they each require Vault
|
||||
// tokens
|
||||
snap, err := n.srv.fsm.State().Snapshot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
node, err := snap.NodeByID(args.NodeID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if node == nil {
|
||||
return fmt.Errorf("Node %q does not exist", args.NodeID)
|
||||
}
|
||||
if node.SecretID != args.SecretID {
|
||||
return fmt.Errorf("SecretID mismatch")
|
||||
}
|
||||
|
||||
alloc, err := snap.AllocByID(args.AllocID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if alloc == nil {
|
||||
return fmt.Errorf("Allocation %q does not exist", args.AllocID)
|
||||
}
|
||||
if alloc.NodeID != args.NodeID {
|
||||
return fmt.Errorf("Allocation %q not running on Node %q", args.AllocID, args.NodeID)
|
||||
}
|
||||
if alloc.TerminalStatus() {
|
||||
return fmt.Errorf("Can't request Vault token for terminal allocation")
|
||||
}
|
||||
|
||||
// Check the policies
|
||||
policies := alloc.Job.VaultPolicies()
|
||||
if policies == nil {
|
||||
return fmt.Errorf("Job doesn't require Vault policies")
|
||||
}
|
||||
tg, ok := policies[alloc.TaskGroup]
|
||||
if !ok {
|
||||
return fmt.Errorf("Task group does not require Vault policies")
|
||||
}
|
||||
|
||||
var unneeded []string
|
||||
for _, task := range args.Tasks {
|
||||
taskVault := tg[task]
|
||||
if taskVault == nil || len(taskVault.Policies) == 0 {
|
||||
unneeded = append(unneeded, task)
|
||||
}
|
||||
}
|
||||
|
||||
if len(unneeded) != 0 {
|
||||
return fmt.Errorf("Requested Vault tokens for tasks without defined Vault policies: %s",
|
||||
strings.Join(unneeded, ", "))
|
||||
}
|
||||
|
||||
// At this point the request is valid and we should contact Vault for
|
||||
// tokens.
|
||||
|
||||
// Create an error group where we will spin up a fixed set of goroutines to
|
||||
// handle deriving tokens but where if any fails the whole group is
|
||||
// canceled.
|
||||
g, ctx := errgroup.WithContext(context.Background())
|
||||
|
||||
// Cap the handlers
|
||||
handlers := len(args.Tasks)
|
||||
if handlers > maxParallelRequestsPerDerive {
|
||||
handlers = maxParallelRequestsPerDerive
|
||||
}
|
||||
|
||||
// Create the Vault Tokens
|
||||
input := make(chan string, handlers)
|
||||
results := make(map[string]*vapi.Secret, len(args.Tasks))
|
||||
for i := 0; i < handlers; i++ {
|
||||
g.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case task, ok := <-input:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
secret, err := n.srv.vault.CreateToken(ctx, alloc, task)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token for task %q: %v", task, err)
|
||||
}
|
||||
|
||||
results[task] = secret
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Send the input
|
||||
go func() {
|
||||
defer close(input)
|
||||
for _, task := range args.Tasks {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case input <- task:
|
||||
}
|
||||
}
|
||||
|
||||
}()
|
||||
|
||||
// Wait for everything to complete or for an error
|
||||
err = g.Wait()
|
||||
if err != nil {
|
||||
// TODO Revoke any created token
|
||||
return err
|
||||
}
|
||||
|
||||
// Commit to Raft before returning any of the tokens
|
||||
accessors := make([]*structs.VaultAccessor, 0, len(results))
|
||||
tokens := make(map[string]string, len(results))
|
||||
for task, secret := range results {
|
||||
w := secret.WrapInfo
|
||||
if w == nil {
|
||||
return fmt.Errorf("Vault returned Secret without WrapInfo")
|
||||
}
|
||||
|
||||
tokens[task] = w.Token
|
||||
accessor := &structs.VaultAccessor{
|
||||
Accessor: w.WrappedAccessor,
|
||||
Task: task,
|
||||
NodeID: alloc.NodeID,
|
||||
AllocID: alloc.ID,
|
||||
CreationTTL: w.TTL,
|
||||
}
|
||||
|
||||
accessors = append(accessors, accessor)
|
||||
}
|
||||
|
||||
req := structs.VaultAccessorRegisterRequest{Accessors: accessors}
|
||||
_, index, err := n.srv.raftApply(structs.VaultAccessorRegisterRequestType, &req)
|
||||
if err != nil {
|
||||
n.srv.logger.Printf("[ERR] nomad.client: Register Vault accessors failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
reply.Index = index
|
||||
reply.Tasks = tokens
|
||||
n.srv.setQueryMeta(&reply.QueryMeta)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/hashicorp/nomad/nomad/mock"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
"github.com/hashicorp/nomad/testutil"
|
||||
vapi "github.com/hashicorp/vault/api"
|
||||
)
|
||||
|
||||
func TestClientEndpoint_Register(t *testing.T) {
|
||||
@@ -1597,3 +1598,160 @@ func TestBatchFuture(t *testing.T) {
|
||||
t.Fatalf("bad: %d", bf.Index())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientEndpoint_DeriveVaultToken_Bad(t *testing.T) {
|
||||
s1 := testServer(t, nil)
|
||||
defer s1.Shutdown()
|
||||
state := s1.fsm.State()
|
||||
codec := rpcClient(t, s1)
|
||||
testutil.WaitForLeader(t, s1.RPC)
|
||||
|
||||
// Create the node
|
||||
node := mock.Node()
|
||||
if err := state.UpsertNode(2, node); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Create an alloc
|
||||
alloc := mock.Alloc()
|
||||
task := alloc.Job.TaskGroups[0].Tasks[0]
|
||||
tasks := []string{task.Name}
|
||||
if err := state.UpsertAllocs(3, []*structs.Allocation{alloc}); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
req := &structs.DeriveVaultTokenRequest{
|
||||
NodeID: node.ID,
|
||||
SecretID: structs.GenerateUUID(),
|
||||
AllocID: alloc.ID,
|
||||
Tasks: tasks,
|
||||
QueryOptions: structs.QueryOptions{
|
||||
Region: "global",
|
||||
},
|
||||
}
|
||||
|
||||
var resp structs.DeriveVaultTokenResponse
|
||||
err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp)
|
||||
if err == nil || !strings.Contains(err.Error(), "SecretID mismatch") {
|
||||
t.Fatalf("Expected SecretID mismatch: %v", err)
|
||||
}
|
||||
|
||||
// Put the correct SecretID
|
||||
req.SecretID = node.SecretID
|
||||
|
||||
// Now we should get an error about the allocation not running on the node
|
||||
err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp)
|
||||
if err == nil || !strings.Contains(err.Error(), "not running on Node") {
|
||||
t.Fatalf("Expected not running on node error: %v", err)
|
||||
}
|
||||
|
||||
// Update to be running on the node
|
||||
alloc.NodeID = node.ID
|
||||
if err := state.UpsertAllocs(4, []*structs.Allocation{alloc}); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Now we should get an error about the job not needing any Vault secrets
|
||||
err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp)
|
||||
if err == nil || !strings.Contains(err.Error(), "without defined Vault") {
|
||||
t.Fatalf("Expected no policies error: %v", err)
|
||||
}
|
||||
|
||||
// Update to be terminal
|
||||
alloc.DesiredStatus = structs.AllocDesiredStatusStop
|
||||
if err := state.UpsertAllocs(5, []*structs.Allocation{alloc}); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Now we should get an error about the job not needing any Vault secrets
|
||||
err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp)
|
||||
if err == nil || !strings.Contains(err.Error(), "terminal") {
|
||||
t.Fatalf("Expected terminal allocation error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientEndpoint_DeriveVaultToken(t *testing.T) {
|
||||
s1 := testServer(t, nil)
|
||||
defer s1.Shutdown()
|
||||
state := s1.fsm.State()
|
||||
codec := rpcClient(t, s1)
|
||||
testutil.WaitForLeader(t, s1.RPC)
|
||||
|
||||
// Enable vault and allow authenticated
|
||||
s1.config.VaultConfig.Enabled = true
|
||||
s1.config.VaultConfig.AllowUnauthenticated = true
|
||||
|
||||
// Replace the Vault Client on the server
|
||||
tvc := &TestVaultClient{}
|
||||
s1.vault = tvc
|
||||
|
||||
// Create the node
|
||||
node := mock.Node()
|
||||
if err := state.UpsertNode(2, node); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Create an alloc an allocation that has vault policies required
|
||||
alloc := mock.Alloc()
|
||||
alloc.NodeID = node.ID
|
||||
task := alloc.Job.TaskGroups[0].Tasks[0]
|
||||
tasks := []string{task.Name}
|
||||
task.Vault = &structs.Vault{Policies: []string{"a", "b"}}
|
||||
if err := state.UpsertAllocs(3, []*structs.Allocation{alloc}); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Return a secret for the task
|
||||
token := structs.GenerateUUID()
|
||||
accessor := structs.GenerateUUID()
|
||||
ttl := 10
|
||||
secret := &vapi.Secret{
|
||||
WrapInfo: &vapi.SecretWrapInfo{
|
||||
Token: token,
|
||||
WrappedAccessor: accessor,
|
||||
TTL: ttl,
|
||||
},
|
||||
}
|
||||
tvc.SetCreateTokenSecret(alloc.ID, task.Name, secret)
|
||||
|
||||
req := &structs.DeriveVaultTokenRequest{
|
||||
NodeID: node.ID,
|
||||
SecretID: node.SecretID,
|
||||
AllocID: alloc.ID,
|
||||
Tasks: tasks,
|
||||
QueryOptions: structs.QueryOptions{
|
||||
Region: "global",
|
||||
},
|
||||
}
|
||||
|
||||
var resp structs.DeriveVaultTokenResponse
|
||||
if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil {
|
||||
t.Fatalf("bad: %v", err)
|
||||
}
|
||||
|
||||
// Check the state store and ensure that we created a VaultAccessor
|
||||
va, err := state.VaultAccessor(accessor)
|
||||
if err != nil {
|
||||
t.Fatalf("bad: %v", err)
|
||||
}
|
||||
if va == nil {
|
||||
t.Fatalf("bad: %v", va)
|
||||
}
|
||||
|
||||
if va.CreateIndex == 0 {
|
||||
t.Fatalf("bad: %v", va)
|
||||
}
|
||||
|
||||
va.CreateIndex = 0
|
||||
expected := &structs.VaultAccessor{
|
||||
AllocID: alloc.ID,
|
||||
Task: task.Name,
|
||||
NodeID: alloc.NodeID,
|
||||
Accessor: accessor,
|
||||
CreationTTL: ttl,
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, va) {
|
||||
t.Fatalf("Got %#v; want %#v", va, expected)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ func stateStoreSchema() *memdb.DBSchema {
|
||||
periodicLaunchTableSchema,
|
||||
evalTableSchema,
|
||||
allocTableSchema,
|
||||
vaultAccessorTableSchema,
|
||||
}
|
||||
|
||||
// Add each of the tables
|
||||
@@ -291,3 +292,41 @@ func allocTableSchema() *memdb.TableSchema {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// vaultAccessorTableSchema returns the MemDB schema for the Vault Accessor
|
||||
// Table. This table tracks Vault accessors for tokens created on behalf of
|
||||
// allocations required Vault tokens.
|
||||
func vaultAccessorTableSchema() *memdb.TableSchema {
|
||||
return &memdb.TableSchema{
|
||||
Name: "vault_accessors",
|
||||
Indexes: map[string]*memdb.IndexSchema{
|
||||
// The primary index is the accessor id
|
||||
"id": &memdb.IndexSchema{
|
||||
Name: "id",
|
||||
AllowMissing: false,
|
||||
Unique: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "Accessor",
|
||||
},
|
||||
},
|
||||
|
||||
"alloc_id": &memdb.IndexSchema{
|
||||
Name: "alloc_id",
|
||||
AllowMissing: false,
|
||||
Unique: false,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "AllocID",
|
||||
},
|
||||
},
|
||||
|
||||
"node_id": &memdb.IndexSchema{
|
||||
Name: "node_id",
|
||||
AllowMissing: false,
|
||||
Unique: false,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "NodeID",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1113,6 +1113,124 @@ func (s *StateStore) Allocs() (memdb.ResultIterator, error) {
|
||||
return iter, nil
|
||||
}
|
||||
|
||||
// UpsertVaultAccessors is used to register a set of Vault Accessors
|
||||
func (s *StateStore) UpsertVaultAccessor(index uint64, accessors []*structs.VaultAccessor) error {
|
||||
txn := s.db.Txn(true)
|
||||
defer txn.Abort()
|
||||
|
||||
for _, accessor := range accessors {
|
||||
// Set the create index
|
||||
accessor.CreateIndex = index
|
||||
|
||||
// Insert the accessor
|
||||
if err := txn.Insert("vault_accessors", accessor); err != nil {
|
||||
return fmt.Errorf("accessor insert failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := txn.Insert("index", &IndexEntry{"vault_accessors", index}); err != nil {
|
||||
return fmt.Errorf("index update failed: %v", err)
|
||||
}
|
||||
|
||||
txn.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteVaultAccessor is used to delete a Vault Accessor
|
||||
func (s *StateStore) DeleteVaultAccessor(index uint64, accessor string) error {
|
||||
txn := s.db.Txn(true)
|
||||
defer txn.Abort()
|
||||
|
||||
// Lookup the accessor
|
||||
existing, err := txn.First("vault_accessors", "id", accessor)
|
||||
if err != nil {
|
||||
return fmt.Errorf("accessor lookup failed: %v", err)
|
||||
}
|
||||
if existing == nil {
|
||||
return fmt.Errorf("vault_accessor not found")
|
||||
}
|
||||
|
||||
// Delete the accessor
|
||||
if err := txn.Delete("vault_accessors", existing); err != nil {
|
||||
return fmt.Errorf("accessor delete failed: %v", err)
|
||||
}
|
||||
if err := txn.Insert("index", &IndexEntry{"vault_accessors", index}); err != nil {
|
||||
return fmt.Errorf("index update failed: %v", err)
|
||||
}
|
||||
|
||||
txn.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// VaultAccessor returns the given Vault accessor
|
||||
func (s *StateStore) VaultAccessor(accessor string) (*structs.VaultAccessor, error) {
|
||||
txn := s.db.Txn(false)
|
||||
|
||||
existing, err := txn.First("vault_accessors", "id", accessor)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("accessor lookup failed: %v", err)
|
||||
}
|
||||
|
||||
if existing != nil {
|
||||
return existing.(*structs.VaultAccessor), nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// VaultAccessors returns an iterator of Vault accessors.
|
||||
func (s *StateStore) VaultAccessors() (memdb.ResultIterator, error) {
|
||||
txn := s.db.Txn(false)
|
||||
|
||||
iter, err := txn.Get("vault_accessors", "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return iter, nil
|
||||
}
|
||||
|
||||
// VaultAccessorsByAlloc returns all the Vault accessors by alloc id
|
||||
func (s *StateStore) VaultAccessorsByAlloc(allocID string) ([]*structs.VaultAccessor, error) {
|
||||
txn := s.db.Txn(false)
|
||||
|
||||
// Get an iterator over the accessors
|
||||
iter, err := txn.Get("vault_accessors", "alloc_id", allocID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var out []*structs.VaultAccessor
|
||||
for {
|
||||
raw := iter.Next()
|
||||
if raw == nil {
|
||||
break
|
||||
}
|
||||
out = append(out, raw.(*structs.VaultAccessor))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// VaultAccessorsByNode returns all the Vault accessors by node id
|
||||
func (s *StateStore) VaultAccessorsByNode(nodeID string) ([]*structs.VaultAccessor, error) {
|
||||
txn := s.db.Txn(false)
|
||||
|
||||
// Get an iterator over the accessors
|
||||
iter, err := txn.Get("vault_accessors", "node_id", nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var out []*structs.VaultAccessor
|
||||
for {
|
||||
raw := iter.Next()
|
||||
if raw == nil {
|
||||
break
|
||||
}
|
||||
out = append(out, raw.(*structs.VaultAccessor))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// LastIndex returns the greatest index value for all indexes
|
||||
func (s *StateStore) LatestIndex() (uint64, error) {
|
||||
indexes, err := s.Indexes()
|
||||
@@ -1627,6 +1745,14 @@ func (r *StateRestore) JobSummaryRestore(jobSummary *structs.JobSummary) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// VaultAccessorRestore is used to restore a vault accessor
|
||||
func (r *StateRestore) VaultAccessorRestore(accessor *structs.VaultAccessor) error {
|
||||
if err := r.txn.Insert("vault_accessors", accessor); err != nil {
|
||||
return fmt.Errorf("vault accessor insert failed: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// stateWatch holds shared state for watching updates. This is
|
||||
// outside of StateStore so it can be shared with snapshots.
|
||||
type stateWatch struct {
|
||||
|
||||
@@ -2833,6 +2833,206 @@ func TestJobSummary_UpdateClientStatus(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateStore_UpsertVaultAccessors(t *testing.T) {
|
||||
state := testStateStore(t)
|
||||
a := mock.VaultAccessor()
|
||||
a2 := mock.VaultAccessor()
|
||||
|
||||
err := state.UpsertVaultAccessor(1000, []*structs.VaultAccessor{a, a2})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
out, err := state.VaultAccessor(a.Accessor)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(a, out) {
|
||||
t.Fatalf("bad: %#v %#v", a, out)
|
||||
}
|
||||
|
||||
out, err = state.VaultAccessor(a2.Accessor)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(a2, out) {
|
||||
t.Fatalf("bad: %#v %#v", a2, out)
|
||||
}
|
||||
|
||||
iter, err := state.VaultAccessors()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
count := 0
|
||||
for {
|
||||
raw := iter.Next()
|
||||
if raw == nil {
|
||||
break
|
||||
}
|
||||
|
||||
count++
|
||||
accessor := raw.(*structs.VaultAccessor)
|
||||
|
||||
if !reflect.DeepEqual(accessor, a) && !reflect.DeepEqual(accessor, a2) {
|
||||
t.Fatalf("bad: %#v", accessor)
|
||||
}
|
||||
}
|
||||
|
||||
if count != 2 {
|
||||
t.Fatalf("bad: %d", count)
|
||||
}
|
||||
|
||||
index, err := state.Index("vault_accessors")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if index != 1000 {
|
||||
t.Fatalf("bad: %d", index)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateStore_DeleteVaultAccessor(t *testing.T) {
|
||||
state := testStateStore(t)
|
||||
accessor := mock.VaultAccessor()
|
||||
|
||||
err := state.UpsertVaultAccessor(1000, []*structs.VaultAccessor{accessor})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
err = state.DeleteVaultAccessor(1001, accessor.Accessor)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
out, err := state.VaultAccessor(accessor.Accessor)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if out != nil {
|
||||
t.Fatalf("bad: %#v %#v", accessor, out)
|
||||
}
|
||||
|
||||
index, err := state.Index("vault_accessors")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if index != 1001 {
|
||||
t.Fatalf("bad: %d", index)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateStore_VaultAccessorsByAlloc(t *testing.T) {
|
||||
state := testStateStore(t)
|
||||
alloc := mock.Alloc()
|
||||
var accessors []*structs.VaultAccessor
|
||||
var expected []*structs.VaultAccessor
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
accessor := mock.VaultAccessor()
|
||||
accessor.AllocID = alloc.ID
|
||||
expected = append(expected, accessor)
|
||||
accessors = append(accessors, accessor)
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
accessor := mock.VaultAccessor()
|
||||
accessors = append(accessors, accessor)
|
||||
}
|
||||
|
||||
err := state.UpsertVaultAccessor(1000, accessors)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
out, err := state.VaultAccessorsByAlloc(alloc.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if len(expected) != len(out) {
|
||||
t.Fatalf("bad: %#v %#v", len(expected), len(out))
|
||||
}
|
||||
|
||||
index, err := state.Index("vault_accessors")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if index != 1000 {
|
||||
t.Fatalf("bad: %d", index)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateStore_VaultAccessorsByNode(t *testing.T) {
|
||||
state := testStateStore(t)
|
||||
node := mock.Node()
|
||||
var accessors []*structs.VaultAccessor
|
||||
var expected []*structs.VaultAccessor
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
accessor := mock.VaultAccessor()
|
||||
accessor.NodeID = node.ID
|
||||
expected = append(expected, accessor)
|
||||
accessors = append(accessors, accessor)
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
accessor := mock.VaultAccessor()
|
||||
accessors = append(accessors, accessor)
|
||||
}
|
||||
|
||||
err := state.UpsertVaultAccessor(1000, accessors)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
out, err := state.VaultAccessorsByNode(node.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if len(expected) != len(out) {
|
||||
t.Fatalf("bad: %#v %#v", len(expected), len(out))
|
||||
}
|
||||
|
||||
index, err := state.Index("vault_accessors")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if index != 1000 {
|
||||
t.Fatalf("bad: %d", index)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateStore_RestoreVaultAccessor(t *testing.T) {
|
||||
state := testStateStore(t)
|
||||
a := mock.VaultAccessor()
|
||||
|
||||
restore, err := state.Restore()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
err = restore.VaultAccessorRestore(a)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
restore.Commit()
|
||||
|
||||
out, err := state.VaultAccessor(a.Accessor)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(out, a) {
|
||||
t.Fatalf("Bad: %#v %#v", out, a)
|
||||
}
|
||||
}
|
||||
|
||||
// setupNotifyTest takes a state store and a set of watch items, then creates
|
||||
// and subscribes a notification channel for each item.
|
||||
func setupNotifyTest(state *StateStore, items ...watch.Item) notifyTest {
|
||||
|
||||
@@ -253,12 +253,12 @@ func SliceStringIsSubset(larger, smaller []string) (bool, []string) {
|
||||
|
||||
// VaultPoliciesSet takes the structure returned by VaultPolicies and returns
|
||||
// the set of required policies
|
||||
func VaultPoliciesSet(policies map[string]map[string][]string) []string {
|
||||
func VaultPoliciesSet(policies map[string]map[string]*Vault) []string {
|
||||
set := make(map[string]struct{})
|
||||
|
||||
for _, tgp := range policies {
|
||||
for _, tp := range tgp {
|
||||
for _, p := range tp {
|
||||
for _, p := range tp.Policies {
|
||||
set[p] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ const (
|
||||
AllocUpdateRequestType
|
||||
AllocClientUpdateRequestType
|
||||
ReconcileJobSummariesRequestType
|
||||
VaultAccessorRegisterRequestType
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -354,6 +355,41 @@ type PeriodicForceRequest struct {
|
||||
WriteRequest
|
||||
}
|
||||
|
||||
// DeriveVaultTokenRequest is used to request wrapped Vault tokens for the
|
||||
// following tasks in the given allocation
|
||||
type DeriveVaultTokenRequest struct {
|
||||
NodeID string
|
||||
SecretID string
|
||||
AllocID string
|
||||
Tasks []string
|
||||
QueryOptions
|
||||
}
|
||||
|
||||
// VaultAccessorRegisterRequest is used to register a set of Vault accessors
|
||||
type VaultAccessorRegisterRequest struct {
|
||||
Accessors []*VaultAccessor
|
||||
}
|
||||
|
||||
// VaultAccessor is a reference to a created Vault token on behalf of
|
||||
// an allocation's task.
|
||||
type VaultAccessor struct {
|
||||
AllocID string
|
||||
Task string
|
||||
NodeID string
|
||||
Accessor string
|
||||
CreationTTL int
|
||||
|
||||
// Raft Indexes
|
||||
CreateIndex uint64
|
||||
}
|
||||
|
||||
// DeriveVaultTokenResponse returns the wrapped tokens for each requested task
|
||||
type DeriveVaultTokenResponse struct {
|
||||
// Tasks is a mapping between the task name and the wrapped token
|
||||
Tasks map[string]string
|
||||
QueryMeta
|
||||
}
|
||||
|
||||
// GenericRequest is used to request where no
|
||||
// specific information is needed.
|
||||
type GenericRequest struct {
|
||||
@@ -1239,11 +1275,11 @@ func (j *Job) IsPeriodic() bool {
|
||||
}
|
||||
|
||||
// VaultPolicies returns the set of Vault policies per task group, per task
|
||||
func (j *Job) VaultPolicies() map[string]map[string][]string {
|
||||
policies := make(map[string]map[string][]string, len(j.TaskGroups))
|
||||
func (j *Job) VaultPolicies() map[string]map[string]*Vault {
|
||||
policies := make(map[string]map[string]*Vault, len(j.TaskGroups))
|
||||
|
||||
for _, tg := range j.TaskGroups {
|
||||
tgPolicies := make(map[string][]string, len(tg.Tasks))
|
||||
tgPolicies := make(map[string]*Vault, len(tg.Tasks))
|
||||
policies[tg.Name] = tgPolicies
|
||||
|
||||
for _, task := range tg.Tasks {
|
||||
@@ -1251,7 +1287,7 @@ func (j *Job) VaultPolicies() map[string]map[string][]string {
|
||||
continue
|
||||
}
|
||||
|
||||
tgPolicies[task.Name] = task.Vault.Policies
|
||||
tgPolicies[task.Name] = task.Vault
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -224,8 +224,25 @@ func TestJob_SystemJob_Validate(t *testing.T) {
|
||||
|
||||
func TestJob_VaultPolicies(t *testing.T) {
|
||||
j0 := &Job{}
|
||||
e0 := make(map[string]map[string][]string, 0)
|
||||
e0 := make(map[string]map[string]*Vault, 0)
|
||||
|
||||
vj1 := &Vault{
|
||||
Policies: []string{
|
||||
"p1",
|
||||
"p2",
|
||||
},
|
||||
}
|
||||
vj2 := &Vault{
|
||||
Policies: []string{
|
||||
"p3",
|
||||
"p4",
|
||||
},
|
||||
}
|
||||
vj3 := &Vault{
|
||||
Policies: []string{
|
||||
"p5",
|
||||
},
|
||||
}
|
||||
j1 := &Job{
|
||||
TaskGroups: []*TaskGroup{
|
||||
&TaskGroup{
|
||||
@@ -235,13 +252,8 @@ func TestJob_VaultPolicies(t *testing.T) {
|
||||
Name: "t1",
|
||||
},
|
||||
&Task{
|
||||
Name: "t2",
|
||||
Vault: &Vault{
|
||||
Policies: []string{
|
||||
"p1",
|
||||
"p2",
|
||||
},
|
||||
},
|
||||
Name: "t2",
|
||||
Vault: vj1,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -249,40 +261,31 @@ func TestJob_VaultPolicies(t *testing.T) {
|
||||
Name: "bar",
|
||||
Tasks: []*Task{
|
||||
&Task{
|
||||
Name: "t3",
|
||||
Vault: &Vault{
|
||||
Policies: []string{
|
||||
"p3",
|
||||
"p4",
|
||||
},
|
||||
},
|
||||
Name: "t3",
|
||||
Vault: vj2,
|
||||
},
|
||||
&Task{
|
||||
Name: "t4",
|
||||
Vault: &Vault{
|
||||
Policies: []string{
|
||||
"p5",
|
||||
},
|
||||
},
|
||||
Name: "t4",
|
||||
Vault: vj3,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
e1 := map[string]map[string][]string{
|
||||
"foo": map[string][]string{
|
||||
"t2": []string{"p1", "p2"},
|
||||
e1 := map[string]map[string]*Vault{
|
||||
"foo": map[string]*Vault{
|
||||
"t2": vj1,
|
||||
},
|
||||
"bar": map[string][]string{
|
||||
"t3": []string{"p3", "p4"},
|
||||
"t4": []string{"p5"},
|
||||
"bar": map[string]*Vault{
|
||||
"t3": vj2,
|
||||
"t4": vj3,
|
||||
},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
Job *Job
|
||||
Expected map[string]map[string][]string
|
||||
Expected map[string]map[string]*Vault
|
||||
}{
|
||||
{
|
||||
Job: j0,
|
||||
|
||||
103
nomad/vault.go
103
nomad/vault.go
@@ -1,6 +1,7 @@
|
||||
package nomad
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -12,6 +13,8 @@ import (
|
||||
"github.com/hashicorp/nomad/nomad/structs/config"
|
||||
vapi "github.com/hashicorp/vault/api"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -21,16 +24,25 @@ const (
|
||||
|
||||
// minimumTokenTTL is the minimum Token TTL allowed for child tokens.
|
||||
minimumTokenTTL = 5 * time.Minute
|
||||
|
||||
// defaultTokenTTL is the default Token TTL used when the passed token is a
|
||||
// root token such that child tokens aren't being created against a role
|
||||
// that has defined a TTL
|
||||
defaultTokenTTL = "72h"
|
||||
|
||||
// requestRateLimit is the maximum number of requests per second Nomad will
|
||||
// make against Vault
|
||||
requestRateLimit rate.Limit = 500.0
|
||||
)
|
||||
|
||||
// VaultClient is the Servers interface for interfacing with Vault
|
||||
type VaultClient interface {
|
||||
// CreateToken takes an allocation and task and returns an appropriate Vault
|
||||
// Secret
|
||||
CreateToken(a *structs.Allocation, task string) (*vapi.Secret, error)
|
||||
CreateToken(ctx context.Context, a *structs.Allocation, task string) (*vapi.Secret, error)
|
||||
|
||||
// LookupToken takes a token string and returns its capabilities.
|
||||
LookupToken(token string) (*vapi.Secret, error)
|
||||
LookupToken(ctx context.Context, token string) (*vapi.Secret, error)
|
||||
|
||||
// Stop is used to stop token renewal.
|
||||
Stop()
|
||||
@@ -52,6 +64,9 @@ type tokenData struct {
|
||||
// the Server with the ability to create child tokens and lookup the permissions
|
||||
// of tokens.
|
||||
type vaultClient struct {
|
||||
// limiter is used to rate limit requests to Vault
|
||||
limiter *rate.Limiter
|
||||
|
||||
// client is the Vault API client
|
||||
client *vapi.Client
|
||||
|
||||
@@ -104,6 +119,7 @@ func NewVaultClient(c *config.VaultConfig, logger *log.Logger) (*vaultClient, er
|
||||
enabled: c.Enabled,
|
||||
config: c,
|
||||
logger: logger,
|
||||
limiter: rate.NewLimiter(requestRateLimit, int(requestRateLimit)),
|
||||
}
|
||||
|
||||
// If vault is not enabled do not configure an API client or start any token
|
||||
@@ -131,6 +147,9 @@ func NewVaultClient(c *config.VaultConfig, logger *log.Logger) (*vaultClient, er
|
||||
}
|
||||
|
||||
v.childTTL = c.TaskTokenTTL
|
||||
} else {
|
||||
// Default the TaskTokenTTL
|
||||
v.childTTL = defaultTokenTTL
|
||||
}
|
||||
|
||||
// Get the Vault API configuration
|
||||
@@ -157,6 +176,11 @@ func NewVaultClient(c *config.VaultConfig, logger *log.Logger) (*vaultClient, er
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// setLimit is used to update the rate limit
|
||||
func (v *vaultClient) setLimit(l rate.Limit) {
|
||||
v.limiter = rate.NewLimiter(l, int(l))
|
||||
}
|
||||
|
||||
// establishConnection is used to make first contact with Vault. This should be
|
||||
// called in a go-routine since the connection is retried til the Vault Client
|
||||
// is stopped or the connection is successfully made at which point the renew
|
||||
@@ -397,7 +421,7 @@ func (v *vaultClient) Stop() {
|
||||
|
||||
v.l.Lock()
|
||||
defer v.l.Unlock()
|
||||
if !v.renewalRunning || !v.establishingConn {
|
||||
if !v.renewalRunning && !v.establishingConn {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -414,12 +438,9 @@ func (v *vaultClient) ConnectionEstablished() bool {
|
||||
return v.connEstablished
|
||||
}
|
||||
|
||||
func (v *vaultClient) CreateToken(a *structs.Allocation, task string) (*vapi.Secret, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// LookupToken takes a Vault token and does a lookup against Vault
|
||||
func (v *vaultClient) LookupToken(token string) (*vapi.Secret, error) {
|
||||
// CreateToken takes the allocation and task and returns an appropriate Vault
|
||||
// token. The call is rate limited and may be canceled with the passed policy
|
||||
func (v *vaultClient) CreateToken(ctx context.Context, a *structs.Allocation, task string) (*vapi.Secret, error) {
|
||||
// Nothing to do
|
||||
if !v.enabled {
|
||||
return nil, fmt.Errorf("Vault integration disabled")
|
||||
@@ -430,6 +451,70 @@ func (v *vaultClient) LookupToken(token string) (*vapi.Secret, error) {
|
||||
return nil, fmt.Errorf("Connection to Vault has not been established. Retry")
|
||||
}
|
||||
|
||||
// Retrieve the Vault block for the task
|
||||
policies := a.Job.VaultPolicies()
|
||||
if policies == nil {
|
||||
return nil, fmt.Errorf("Job doesn't require Vault policies")
|
||||
}
|
||||
tg, ok := policies[a.TaskGroup]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Task group does not require Vault policies")
|
||||
}
|
||||
taskVault, ok := tg[task]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Task does not require Vault policies")
|
||||
}
|
||||
|
||||
// Build the creation request
|
||||
req := &vapi.TokenCreateRequest{
|
||||
Policies: taskVault.Policies,
|
||||
Metadata: map[string]string{
|
||||
"AllocationID": a.ID,
|
||||
"Task": task,
|
||||
"NodeID": a.NodeID,
|
||||
},
|
||||
TTL: v.childTTL,
|
||||
DisplayName: fmt.Sprintf("%s: %s", a.ID, task),
|
||||
}
|
||||
|
||||
// Ensure we are under our rate limit
|
||||
if err := v.limiter.Wait(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Make the request and switch depending on whether we are using a root
|
||||
// token or a role based token
|
||||
var secret *vapi.Secret
|
||||
var err error
|
||||
if v.token.Root {
|
||||
req.Period = v.childTTL
|
||||
secret, err = v.auth.Create(req)
|
||||
} else {
|
||||
// Make the token using the role
|
||||
secret, err = v.auth.CreateWithRole(req, v.token.Role)
|
||||
}
|
||||
|
||||
return secret, err
|
||||
}
|
||||
|
||||
// LookupToken takes a Vault token and does a lookup against Vault. The call is
|
||||
// rate limited and may be canceled with passed context.
|
||||
func (v *vaultClient) LookupToken(ctx context.Context, token string) (*vapi.Secret, error) {
|
||||
// Nothing to do
|
||||
if !v.enabled {
|
||||
return nil, fmt.Errorf("Vault integration disabled")
|
||||
}
|
||||
|
||||
// Check if we have established a connection with Vault
|
||||
if !v.ConnectionEstablished() {
|
||||
return nil, fmt.Errorf("Connection to Vault has not been established. Retry")
|
||||
}
|
||||
|
||||
// Ensure we are under our rate limit
|
||||
if err := v.limiter.Wait(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Lookup the token
|
||||
return v.auth.Lookup(token)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package nomad
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"os"
|
||||
@@ -9,12 +10,22 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/hashicorp/nomad/nomad/mock"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
"github.com/hashicorp/nomad/nomad/structs/config"
|
||||
"github.com/hashicorp/nomad/testutil"
|
||||
vapi "github.com/hashicorp/vault/api"
|
||||
)
|
||||
|
||||
const (
|
||||
// authPolicy is a policy that allows token creation operations
|
||||
authPolicy = `path "auth/token/create/*" {
|
||||
capabilities = ["create", "read", "update", "delete", "list"]
|
||||
}`
|
||||
)
|
||||
|
||||
func TestVaultClient_BadConfig(t *testing.T) {
|
||||
conf := &config.VaultConfig{}
|
||||
logger := log.New(os.Stderr, "", log.LstdFlags)
|
||||
@@ -24,6 +35,7 @@ func TestVaultClient_BadConfig(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
defer client.Stop()
|
||||
|
||||
if client.ConnectionEstablished() {
|
||||
t.Fatalf("bad")
|
||||
@@ -75,15 +87,20 @@ func TestVaultClient_EstablishConnection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestVaultClient_RenewalLoop(t *testing.T) {
|
||||
v := testutil.NewTestVault(t).Start()
|
||||
defer v.Stop()
|
||||
// testVaultRoleAndToken creates a test Vault role where children are created
|
||||
// with the passed period. A token created in that role is returned
|
||||
func testVaultRoleAndToken(v *testutil.TestVault, t *testing.T, rolePeriod int) string {
|
||||
// Build the auth policy
|
||||
sys := v.Client.Sys()
|
||||
if err := sys.PutPolicy("auth", authPolicy); err != nil {
|
||||
t.Fatalf("failed to create auth policy: %v", err)
|
||||
}
|
||||
|
||||
// Build a role
|
||||
l := v.Client.Logical()
|
||||
d := make(map[string]interface{}, 2)
|
||||
d["allowed_policies"] = "default"
|
||||
d["period"] = 5
|
||||
d["allowed_policies"] = "default,auth"
|
||||
d["period"] = rolePeriod
|
||||
l.Write("auth/token/roles/test", d)
|
||||
|
||||
// Create a new token with the role
|
||||
@@ -99,8 +116,15 @@ func TestVaultClient_RenewalLoop(t *testing.T) {
|
||||
t.Fatalf("bad secret response: %+v", s)
|
||||
}
|
||||
|
||||
// Set the configs token
|
||||
v.Config.Token = s.Auth.ClientToken
|
||||
return s.Auth.ClientToken
|
||||
}
|
||||
|
||||
func TestVaultClient_RenewalLoop(t *testing.T) {
|
||||
v := testutil.NewTestVault(t).Start()
|
||||
defer v.Stop()
|
||||
|
||||
// Set the configs token in a new test role
|
||||
v.Config.Token = testVaultRoleAndToken(v, t, 5)
|
||||
|
||||
// Start the client
|
||||
logger := log.New(os.Stderr, "", log.LstdFlags)
|
||||
@@ -114,6 +138,7 @@ func TestVaultClient_RenewalLoop(t *testing.T) {
|
||||
time.Sleep(8 * time.Second)
|
||||
|
||||
// Get the current TTL
|
||||
a := v.Client.Auth().Token()
|
||||
s2, err := a.Lookup(v.Config.Token)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to lookup token: %v", err)
|
||||
@@ -160,8 +185,9 @@ func TestVaultClient_LookupToken_Invalid(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build vault client: %v", err)
|
||||
}
|
||||
defer client.Stop()
|
||||
|
||||
_, err = client.LookupToken("foo")
|
||||
_, err = client.LookupToken(context.Background(), "foo")
|
||||
if err == nil || !strings.Contains(err.Error(), "disabled") {
|
||||
t.Fatalf("Expected error because Vault is disabled: %v", err)
|
||||
}
|
||||
@@ -175,7 +201,7 @@ func TestVaultClient_LookupToken_Invalid(t *testing.T) {
|
||||
t.Fatalf("failed to build vault client: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.LookupToken("foo")
|
||||
_, err = client.LookupToken(context.Background(), "foo")
|
||||
if err == nil || !strings.Contains(err.Error(), "established") {
|
||||
t.Fatalf("Expected error because connection to Vault hasn't been made: %v", err)
|
||||
}
|
||||
@@ -198,11 +224,12 @@ func TestVaultClient_LookupToken(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build vault client: %v", err)
|
||||
}
|
||||
defer client.Stop()
|
||||
|
||||
waitForConnection(client, t)
|
||||
|
||||
// Lookup ourselves
|
||||
s, err := client.LookupToken(v.Config.Token)
|
||||
s, err := client.LookupToken(context.Background(), v.Config.Token)
|
||||
if err != nil {
|
||||
t.Fatalf("self lookup failed: %v", err)
|
||||
}
|
||||
@@ -233,7 +260,7 @@ func TestVaultClient_LookupToken(t *testing.T) {
|
||||
}
|
||||
|
||||
// Lookup new child
|
||||
s, err = client.LookupToken(s.Auth.ClientToken)
|
||||
s, err = client.LookupToken(context.Background(), s.Auth.ClientToken)
|
||||
if err != nil {
|
||||
t.Fatalf("self lookup failed: %v", err)
|
||||
}
|
||||
@@ -247,3 +274,145 @@ func TestVaultClient_LookupToken(t *testing.T) {
|
||||
t.Fatalf("Unexpected policies; got %v; want %v", policies, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVaultClient_LookupToken_RateLimit(t *testing.T) {
|
||||
v := testutil.NewTestVault(t).Start()
|
||||
defer v.Stop()
|
||||
|
||||
logger := log.New(os.Stderr, "", log.LstdFlags)
|
||||
client, err := NewVaultClient(v.Config, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build vault client: %v", err)
|
||||
}
|
||||
defer client.Stop()
|
||||
client.setLimit(rate.Limit(1.0))
|
||||
|
||||
waitForConnection(client, t)
|
||||
|
||||
// Spin up many requests. These should block
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
cancels := 0
|
||||
numRequests := 10
|
||||
unblock := make(chan struct{})
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func() {
|
||||
// Ensure all the goroutines are made
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Lookup ourselves
|
||||
_, err := client.LookupToken(ctx, v.Config.Token)
|
||||
if err != nil {
|
||||
if err == context.Canceled {
|
||||
cancels += 1
|
||||
return
|
||||
}
|
||||
t.Fatalf("self lookup failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Cancel the context
|
||||
cancel()
|
||||
time.AfterFunc(1*time.Second, func() { close(unblock) })
|
||||
}()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("timeout")
|
||||
case <-unblock:
|
||||
}
|
||||
|
||||
desired := numRequests - 1
|
||||
if cancels != desired {
|
||||
t.Fatalf("Incorrect number of cancels; got %d; want %d", cancels, desired)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVaultClient_CreateToken_Root(t *testing.T) {
|
||||
v := testutil.NewTestVault(t).Start()
|
||||
defer v.Stop()
|
||||
|
||||
logger := log.New(os.Stderr, "", log.LstdFlags)
|
||||
client, err := NewVaultClient(v.Config, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build vault client: %v", err)
|
||||
}
|
||||
defer client.Stop()
|
||||
|
||||
waitForConnection(client, t)
|
||||
|
||||
// Create an allocation that requires a Vault policy
|
||||
a := mock.Alloc()
|
||||
task := a.Job.TaskGroups[0].Tasks[0]
|
||||
task.Vault = &structs.Vault{Policies: []string{"default"}}
|
||||
|
||||
s, err := client.CreateToken(context.Background(), a, task.Name)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateToken failed: %v", err)
|
||||
}
|
||||
|
||||
// Ensure that created secret is a wrapped token
|
||||
if s == nil || s.WrapInfo == nil {
|
||||
t.Fatalf("Bad secret: %#v", s)
|
||||
}
|
||||
|
||||
d, err := time.ParseDuration(vaultTokenCreateTTL)
|
||||
if err != nil {
|
||||
t.Fatalf("bad: %v", err)
|
||||
}
|
||||
|
||||
if s.WrapInfo.WrappedAccessor == "" {
|
||||
t.Fatalf("Bad accessor: %v", s.WrapInfo.WrappedAccessor)
|
||||
} else if s.WrapInfo.Token == "" {
|
||||
t.Fatalf("Bad token: %v", s.WrapInfo.WrappedAccessor)
|
||||
} else if s.WrapInfo.TTL != int(d.Seconds()) {
|
||||
t.Fatalf("Bad ttl: %v", s.WrapInfo.WrappedAccessor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVaultClient_CreateToken_Role(t *testing.T) {
|
||||
v := testutil.NewTestVault(t).Start()
|
||||
defer v.Stop()
|
||||
|
||||
// Set the configs token in a new test role
|
||||
v.Config.Token = testVaultRoleAndToken(v, t, 5)
|
||||
//testVaultRoleAndToken(v, t, 5)
|
||||
// Start the client
|
||||
logger := log.New(os.Stderr, "", log.LstdFlags)
|
||||
client, err := NewVaultClient(v.Config, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build vault client: %v", err)
|
||||
}
|
||||
defer client.Stop()
|
||||
|
||||
waitForConnection(client, t)
|
||||
|
||||
// Create an allocation that requires a Vault policy
|
||||
a := mock.Alloc()
|
||||
task := a.Job.TaskGroups[0].Tasks[0]
|
||||
task.Vault = &structs.Vault{Policies: []string{"default"}}
|
||||
|
||||
s, err := client.CreateToken(context.Background(), a, task.Name)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateToken failed: %v", err)
|
||||
}
|
||||
|
||||
// Ensure that created secret is a wrapped token
|
||||
if s == nil || s.WrapInfo == nil {
|
||||
t.Fatalf("Bad secret: %#v", s)
|
||||
}
|
||||
|
||||
d, err := time.ParseDuration(vaultTokenCreateTTL)
|
||||
if err != nil {
|
||||
t.Fatalf("bad: %v", err)
|
||||
}
|
||||
|
||||
if s.WrapInfo.WrappedAccessor == "" {
|
||||
t.Fatalf("Bad accessor: %v", s.WrapInfo.WrappedAccessor)
|
||||
} else if s.WrapInfo.Token == "" {
|
||||
t.Fatalf("Bad token: %v", s.WrapInfo.WrappedAccessor)
|
||||
} else if s.WrapInfo.TTL != int(d.Seconds()) {
|
||||
t.Fatalf("Bad ttl: %v", s.WrapInfo.WrappedAccessor)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package nomad
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
vapi "github.com/hashicorp/vault/api"
|
||||
)
|
||||
@@ -16,9 +18,17 @@ type TestVaultClient struct {
|
||||
// LookupTokenSecret maps a token to the Vault secret that will be returned
|
||||
// by the LookupToken call
|
||||
LookupTokenSecret map[string]*vapi.Secret
|
||||
|
||||
// CreateTokenErrors maps a token to an error that will be returned by the
|
||||
// CreateToken call
|
||||
CreateTokenErrors map[string]map[string]error
|
||||
|
||||
// CreateTokenSecret maps a token to the Vault secret that will be returned
|
||||
// by the CreateToken call
|
||||
CreateTokenSecret map[string]map[string]*vapi.Secret
|
||||
}
|
||||
|
||||
func (v *TestVaultClient) LookupToken(token string) (*vapi.Secret, error) {
|
||||
func (v *TestVaultClient) LookupToken(ctx context.Context, token string) (*vapi.Secret, error) {
|
||||
var secret *vapi.Secret
|
||||
var err error
|
||||
|
||||
@@ -64,8 +74,56 @@ func (v *TestVaultClient) SetLookupTokenAllowedPolicies(token string, policies [
|
||||
v.SetLookupTokenSecret(token, s)
|
||||
}
|
||||
|
||||
func (v *TestVaultClient) CreateToken(a *structs.Allocation, task string) (*vapi.Secret, error) {
|
||||
return nil, nil
|
||||
func (v *TestVaultClient) CreateToken(ctx context.Context, a *structs.Allocation, task string) (*vapi.Secret, error) {
|
||||
var secret *vapi.Secret
|
||||
var err error
|
||||
|
||||
if v.CreateTokenSecret != nil {
|
||||
tasks := v.CreateTokenSecret[a.ID]
|
||||
if tasks != nil {
|
||||
secret = tasks[task]
|
||||
}
|
||||
}
|
||||
if v.CreateTokenErrors != nil {
|
||||
tasks := v.CreateTokenErrors[a.ID]
|
||||
if tasks != nil {
|
||||
err = tasks[task]
|
||||
}
|
||||
}
|
||||
|
||||
return secret, err
|
||||
}
|
||||
|
||||
// SetCreateTokenError sets the error that will be returned by the token
|
||||
// creation
|
||||
func (v *TestVaultClient) SetCreateTokenError(allocID, task string, err error) {
|
||||
if v.CreateTokenErrors == nil {
|
||||
v.CreateTokenErrors = make(map[string]map[string]error)
|
||||
}
|
||||
|
||||
tasks := v.CreateTokenErrors[allocID]
|
||||
if tasks == nil {
|
||||
tasks = make(map[string]error)
|
||||
v.CreateTokenErrors[allocID] = tasks
|
||||
}
|
||||
|
||||
v.CreateTokenErrors[allocID][task] = err
|
||||
}
|
||||
|
||||
// SetCreateTokenSecret sets the secret that will be returned by the token
|
||||
// creation
|
||||
func (v *TestVaultClient) SetCreateTokenSecret(allocID, task string, secret *vapi.Secret) {
|
||||
if v.CreateTokenSecret == nil {
|
||||
v.CreateTokenSecret = make(map[string]map[string]*vapi.Secret)
|
||||
}
|
||||
|
||||
tasks := v.CreateTokenSecret[allocID]
|
||||
if tasks == nil {
|
||||
tasks = make(map[string]*vapi.Secret)
|
||||
v.CreateTokenSecret[allocID] = tasks
|
||||
}
|
||||
|
||||
v.CreateTokenSecret[allocID][task] = secret
|
||||
}
|
||||
|
||||
func (v *TestVaultClient) Stop() {}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
# Create a temp dir and clean it up on exit
|
||||
TEMPDIR=`mktemp -d -t nomad-test.XXX`
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
export PING_SLEEP=30
|
||||
bash -c "while true; do echo \$(date) - building ...; sleep $PING_SLEEP; done" &
|
||||
|
||||
@@ -119,6 +119,6 @@ func (tv *TestVault) waitForAPI() {
|
||||
// getPort returns the next available port to bind Vault against
|
||||
func getPort() uint64 {
|
||||
p := vaultStartPort + vaultPortOffset
|
||||
offset += 1
|
||||
vaultPortOffset += 1
|
||||
return p
|
||||
}
|
||||
|
||||
1
vendor/github.com/hashicorp/vault/api/auth_token.go
generated
vendored
1
vendor/github.com/hashicorp/vault/api/auth_token.go
generated
vendored
@@ -170,6 +170,7 @@ type TokenCreateRequest struct {
|
||||
Lease string `json:"lease,omitempty"`
|
||||
TTL string `json:"ttl,omitempty"`
|
||||
ExplicitMaxTTL string `json:"explicit_max_ttl,omitempty"`
|
||||
Period string `json:"period,omitempty"`
|
||||
NoParent bool `json:"no_parent,omitempty"`
|
||||
NoDefaultPolicy bool `json:"no_default_policy,omitempty"`
|
||||
DisplayName string `json:"display_name"`
|
||||
|
||||
39
vendor/github.com/hashicorp/vault/api/sys_audit.go
generated
vendored
39
vendor/github.com/hashicorp/vault/api/sys_audit.go
generated
vendored
@@ -22,21 +22,12 @@ func (c *Sys) AuditHash(path string, input string) (string, error) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
secret, err := ParseSecret(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if secret == nil || secret.Data == nil || len(secret.Data) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type d struct {
|
||||
Hash string
|
||||
Hash string `json:"hash"`
|
||||
}
|
||||
|
||||
var result d
|
||||
err = mapstructure.Decode(secret.Data, &result)
|
||||
err = resp.DecodeJSON(&result)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -52,26 +43,32 @@ func (c *Sys) ListAudit() (map[string]*Audit, error) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
secret, err := ParseSecret(resp.Body)
|
||||
var result map[string]interface{}
|
||||
err = resp.DecodeJSON(&result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if secret == nil || secret.Data == nil || len(secret.Data) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result := map[string]*Audit{}
|
||||
for k, v := range secret.Data {
|
||||
mounts := map[string]*Audit{}
|
||||
for k, v := range result {
|
||||
switch v.(type) {
|
||||
case map[string]interface{}:
|
||||
default:
|
||||
continue
|
||||
}
|
||||
var res Audit
|
||||
err = mapstructure.Decode(v, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[k] = &res
|
||||
// Not a mount, some other api.Secret data
|
||||
if res.Type == "" {
|
||||
continue
|
||||
}
|
||||
mounts[k] = &res
|
||||
}
|
||||
|
||||
return result, err
|
||||
return mounts, nil
|
||||
}
|
||||
|
||||
func (c *Sys) EnableAudit(
|
||||
@@ -106,7 +103,7 @@ func (c *Sys) DisableAudit(path string) error {
|
||||
}
|
||||
|
||||
// Structures for the requests/resposne are all down here. They aren't
|
||||
// individually documentd because the map almost directly to the raw HTTP API
|
||||
// individually documented because the map almost directly to the raw HTTP API
|
||||
// documentation. Please refer to that documentation for more details.
|
||||
|
||||
type Audit struct {
|
||||
|
||||
24
vendor/github.com/hashicorp/vault/api/sys_auth.go
generated
vendored
24
vendor/github.com/hashicorp/vault/api/sys_auth.go
generated
vendored
@@ -14,26 +14,32 @@ func (c *Sys) ListAuth() (map[string]*AuthMount, error) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
secret, err := ParseSecret(resp.Body)
|
||||
var result map[string]interface{}
|
||||
err = resp.DecodeJSON(&result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if secret == nil || secret.Data == nil || len(secret.Data) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result := map[string]*AuthMount{}
|
||||
for k, v := range secret.Data {
|
||||
mounts := map[string]*AuthMount{}
|
||||
for k, v := range result {
|
||||
switch v.(type) {
|
||||
case map[string]interface{}:
|
||||
default:
|
||||
continue
|
||||
}
|
||||
var res AuthMount
|
||||
err = mapstructure.Decode(v, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[k] = &res
|
||||
// Not a mount, some other api.Secret data
|
||||
if res.Type == "" {
|
||||
continue
|
||||
}
|
||||
mounts[k] = &res
|
||||
}
|
||||
|
||||
return result, err
|
||||
return mounts, nil
|
||||
}
|
||||
|
||||
func (c *Sys) EnableAuth(path, authType, desc string) error {
|
||||
|
||||
9
vendor/github.com/hashicorp/vault/api/sys_capabilities.go
generated
vendored
9
vendor/github.com/hashicorp/vault/api/sys_capabilities.go
generated
vendored
@@ -28,17 +28,14 @@ func (c *Sys) Capabilities(token, path string) ([]string, error) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
secret, err := ParseSecret(resp.Body)
|
||||
var result map[string]interface{}
|
||||
err = resp.DecodeJSON(&result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if secret == nil || secret.Data == nil || len(secret.Data) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var capabilities []string
|
||||
capabilitiesRaw := secret.Data["capabilities"].([]interface{})
|
||||
capabilitiesRaw := result["capabilities"].([]interface{})
|
||||
for _, capability := range capabilitiesRaw {
|
||||
capabilities = append(capabilities, capability.(string))
|
||||
}
|
||||
|
||||
8
vendor/github.com/hashicorp/vault/api/sys_init.go
generated
vendored
8
vendor/github.com/hashicorp/vault/api/sys_init.go
generated
vendored
@@ -45,7 +45,9 @@ type InitStatusResponse struct {
|
||||
}
|
||||
|
||||
type InitResponse struct {
|
||||
Keys []string `json:"keys"`
|
||||
RecoveryKeys []string `json:"recovery_keys"`
|
||||
RootToken string `json:"root_token"`
|
||||
Keys []string `json:"keys"`
|
||||
KeysB64 []string `json:"keys_base64"`
|
||||
RecoveryKeys []string `json:"recovery_keys"`
|
||||
RecoveryKeysB64 []string `json:"recovery_keys_base64"`
|
||||
RootToken string `json:"root_token"`
|
||||
}
|
||||
|
||||
35
vendor/github.com/hashicorp/vault/api/sys_mounts.go
generated
vendored
35
vendor/github.com/hashicorp/vault/api/sys_mounts.go
generated
vendored
@@ -15,26 +15,32 @@ func (c *Sys) ListMounts() (map[string]*MountOutput, error) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
secret, err := ParseSecret(resp.Body)
|
||||
var result map[string]interface{}
|
||||
err = resp.DecodeJSON(&result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if secret == nil || secret.Data == nil || len(secret.Data) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result := map[string]*MountOutput{}
|
||||
for k, v := range secret.Data {
|
||||
mounts := map[string]*MountOutput{}
|
||||
for k, v := range result {
|
||||
switch v.(type) {
|
||||
case map[string]interface{}:
|
||||
default:
|
||||
continue
|
||||
}
|
||||
var res MountOutput
|
||||
err = mapstructure.Decode(v, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[k] = &res
|
||||
// Not a mount, some other api.Secret data
|
||||
if res.Type == "" {
|
||||
continue
|
||||
}
|
||||
mounts[k] = &res
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return mounts, nil
|
||||
}
|
||||
|
||||
func (c *Sys) Mount(path string, mountInfo *MountInput) error {
|
||||
@@ -104,17 +110,8 @@ func (c *Sys) MountConfig(path string) (*MountConfigOutput, error) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
secret, err := ParseSecret(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if secret == nil || secret.Data == nil || len(secret.Data) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result MountConfigOutput
|
||||
err = mapstructure.Decode(secret.Data, &result)
|
||||
err = resp.DecodeJSON(&result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
41
vendor/github.com/hashicorp/vault/api/sys_policy.go
generated
vendored
41
vendor/github.com/hashicorp/vault/api/sys_policy.go
generated
vendored
@@ -1,10 +1,6 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
import "fmt"
|
||||
|
||||
func (c *Sys) ListPolicies() ([]string, error) {
|
||||
r := c.c.NewRequest("GET", "/v1/sys/policy")
|
||||
@@ -14,22 +10,25 @@ func (c *Sys) ListPolicies() ([]string, error) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
secret, err := ParseSecret(resp.Body)
|
||||
var result map[string]interface{}
|
||||
err = resp.DecodeJSON(&result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if secret == nil || secret.Data == nil || len(secret.Data) == 0 {
|
||||
return nil, nil
|
||||
var ok bool
|
||||
if _, ok = result["policies"]; !ok {
|
||||
return nil, fmt.Errorf("policies not found in response")
|
||||
}
|
||||
|
||||
var result listPoliciesResp
|
||||
err = mapstructure.Decode(secret.Data, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
listRaw := result["policies"].([]interface{})
|
||||
var policies []string
|
||||
|
||||
for _, val := range listRaw {
|
||||
policies = append(policies, val.(string))
|
||||
}
|
||||
|
||||
return result.Policies, err
|
||||
return policies, err
|
||||
}
|
||||
|
||||
func (c *Sys) GetPolicy(name string) (string, error) {
|
||||
@@ -45,22 +44,18 @@ func (c *Sys) GetPolicy(name string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
secret, err := ParseSecret(resp.Body)
|
||||
var result map[string]interface{}
|
||||
err = resp.DecodeJSON(&result)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if secret == nil || secret.Data == nil || len(secret.Data) == 0 {
|
||||
return "", nil
|
||||
var ok bool
|
||||
if _, ok = result["rules"]; !ok {
|
||||
return "", fmt.Errorf("rules not found in response")
|
||||
}
|
||||
|
||||
var result getPoliciesResp
|
||||
err = mapstructure.Decode(secret.Data, &result)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return result.Rules, err
|
||||
return result["rules"].(string), nil
|
||||
}
|
||||
|
||||
func (c *Sys) PutPolicy(name, rules string) error {
|
||||
|
||||
6
vendor/github.com/hashicorp/vault/api/sys_rekey.go
generated
vendored
6
vendor/github.com/hashicorp/vault/api/sys_rekey.go
generated
vendored
@@ -190,11 +190,13 @@ type RekeyUpdateResponse struct {
|
||||
Nonce string
|
||||
Complete bool
|
||||
Keys []string
|
||||
KeysB64 []string `json:"keys_base64"`
|
||||
PGPFingerprints []string `json:"pgp_fingerprints"`
|
||||
Backup bool
|
||||
}
|
||||
|
||||
type RekeyRetrieveResponse struct {
|
||||
Nonce string
|
||||
Keys map[string][]string
|
||||
Nonce string
|
||||
Keys map[string][]string
|
||||
KeysB64 map[string][]string `json:"keys_base64"`
|
||||
}
|
||||
|
||||
27
vendor/github.com/hashicorp/vault/api/sys_rotate.go
generated
vendored
27
vendor/github.com/hashicorp/vault/api/sys_rotate.go
generated
vendored
@@ -1,10 +1,6 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
import "time"
|
||||
|
||||
func (c *Sys) Rotate() error {
|
||||
r := c.c.NewRequest("POST", "/v1/sys/rotate")
|
||||
@@ -23,25 +19,12 @@ func (c *Sys) KeyStatus() (*KeyStatus, error) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
secret, err := ParseSecret(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if secret == nil || secret.Data == nil || len(secret.Data) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result KeyStatus
|
||||
err = mapstructure.Decode(secret.Data, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, err
|
||||
result := new(KeyStatus)
|
||||
err = resp.DecodeJSON(result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
type KeyStatus struct {
|
||||
Term int
|
||||
Term int `json:"term"`
|
||||
InstallTime time.Time `json:"install_time"`
|
||||
}
|
||||
|
||||
27
vendor/golang.org/x/sync/LICENSE
generated
vendored
Normal file
27
vendor/golang.org/x/sync/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
22
vendor/golang.org/x/sync/PATENTS
generated
vendored
Normal file
22
vendor/golang.org/x/sync/PATENTS
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
Additional IP Rights Grant (Patents)
|
||||
|
||||
"This implementation" means the copyrightable works distributed by
|
||||
Google as part of the Go project.
|
||||
|
||||
Google hereby grants to You a perpetual, worldwide, non-exclusive,
|
||||
no-charge, royalty-free, irrevocable (except as stated in this section)
|
||||
patent license to make, have made, use, offer to sell, sell, import,
|
||||
transfer and otherwise run, modify and propagate the contents of this
|
||||
implementation of Go, where such license applies only to those patent
|
||||
claims, both currently owned or controlled by Google and acquired in
|
||||
the future, licensable by Google that are necessarily infringed by this
|
||||
implementation of Go. This grant does not include claims that would be
|
||||
infringed only as a consequence of further modification of this
|
||||
implementation. If you or your agent or exclusive licensee institute or
|
||||
order or agree to the institution of patent litigation against any
|
||||
entity (including a cross-claim or counterclaim in a lawsuit) alleging
|
||||
that this implementation of Go or any code incorporated within this
|
||||
implementation of Go constitutes direct or contributory patent
|
||||
infringement, or inducement of patent infringement, then any patent
|
||||
rights granted to you under this License for this implementation of Go
|
||||
shall terminate as of the date such litigation is filed.
|
||||
67
vendor/golang.org/x/sync/errgroup/errgroup.go
generated
vendored
Normal file
67
vendor/golang.org/x/sync/errgroup/errgroup.go
generated
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package errgroup provides synchronization, error propagation, and Context
|
||||
// cancelation for groups of goroutines working on subtasks of a common task.
|
||||
package errgroup
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// A Group is a collection of goroutines working on subtasks that are part of
|
||||
// the same overall task.
|
||||
//
|
||||
// A zero Group is valid and does not cancel on error.
|
||||
type Group struct {
|
||||
cancel func()
|
||||
|
||||
wg sync.WaitGroup
|
||||
|
||||
errOnce sync.Once
|
||||
err error
|
||||
}
|
||||
|
||||
// WithContext returns a new Group and an associated Context derived from ctx.
|
||||
//
|
||||
// The derived Context is canceled the first time a function passed to Go
|
||||
// returns a non-nil error or the first time Wait returns, whichever occurs
|
||||
// first.
|
||||
func WithContext(ctx context.Context) (*Group, context.Context) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Group{cancel: cancel}, ctx
|
||||
}
|
||||
|
||||
// Wait blocks until all function calls from the Go method have returned, then
|
||||
// returns the first non-nil error (if any) from them.
|
||||
func (g *Group) Wait() error {
|
||||
g.wg.Wait()
|
||||
if g.cancel != nil {
|
||||
g.cancel()
|
||||
}
|
||||
return g.err
|
||||
}
|
||||
|
||||
// Go calls the given function in a new goroutine.
|
||||
//
|
||||
// The first call to return a non-nil error cancels the group; its error will be
|
||||
// returned by Wait.
|
||||
func (g *Group) Go(f func() error) {
|
||||
g.wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer g.wg.Done()
|
||||
|
||||
if err := f(); err != nil {
|
||||
g.errOnce.Do(func() {
|
||||
g.err = err
|
||||
if g.cancel != nil {
|
||||
g.cancel()
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
27
vendor/golang.org/x/time/LICENSE
generated
vendored
Normal file
27
vendor/golang.org/x/time/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
22
vendor/golang.org/x/time/PATENTS
generated
vendored
Normal file
22
vendor/golang.org/x/time/PATENTS
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
Additional IP Rights Grant (Patents)
|
||||
|
||||
"This implementation" means the copyrightable works distributed by
|
||||
Google as part of the Go project.
|
||||
|
||||
Google hereby grants to You a perpetual, worldwide, non-exclusive,
|
||||
no-charge, royalty-free, irrevocable (except as stated in this section)
|
||||
patent license to make, have made, use, offer to sell, sell, import,
|
||||
transfer and otherwise run, modify and propagate the contents of this
|
||||
implementation of Go, where such license applies only to those patent
|
||||
claims, both currently owned or controlled by Google and acquired in
|
||||
the future, licensable by Google that are necessarily infringed by this
|
||||
implementation of Go. This grant does not include claims that would be
|
||||
infringed only as a consequence of further modification of this
|
||||
implementation. If you or your agent or exclusive licensee institute or
|
||||
order or agree to the institution of patent litigation against any
|
||||
entity (including a cross-claim or counterclaim in a lawsuit) alleging
|
||||
that this implementation of Go or any code incorporated within this
|
||||
implementation of Go constitutes direct or contributory patent
|
||||
infringement, or inducement of patent infringement, then any patent
|
||||
rights granted to you under this License for this implementation of Go
|
||||
shall terminate as of the date such litigation is filed.
|
||||
368
vendor/golang.org/x/time/rate/rate.go
generated
vendored
Normal file
368
vendor/golang.org/x/time/rate/rate.go
generated
vendored
Normal file
@@ -0,0 +1,368 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package rate provides a rate limiter.
|
||||
package rate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// Limit defines the maximum frequency of some events.
|
||||
// Limit is represented as number of events per second.
|
||||
// A zero Limit allows no events.
|
||||
type Limit float64
|
||||
|
||||
// Inf is the infinite rate limit; it allows all events (even if burst is zero).
|
||||
const Inf = Limit(math.MaxFloat64)
|
||||
|
||||
// Every converts a minimum time interval between events to a Limit.
|
||||
func Every(interval time.Duration) Limit {
|
||||
if interval <= 0 {
|
||||
return Inf
|
||||
}
|
||||
return 1 / Limit(interval.Seconds())
|
||||
}
|
||||
|
||||
// A Limiter controls how frequently events are allowed to happen.
|
||||
// It implements a "token bucket" of size b, initially full and refilled
|
||||
// at rate r tokens per second.
|
||||
// Informally, in any large enough time interval, the Limiter limits the
|
||||
// rate to r tokens per second, with a maximum burst size of b events.
|
||||
// As a special case, if r == Inf (the infinite rate), b is ignored.
|
||||
// See https://en.wikipedia.org/wiki/Token_bucket for more about token buckets.
|
||||
//
|
||||
// The zero value is a valid Limiter, but it will reject all events.
|
||||
// Use NewLimiter to create non-zero Limiters.
|
||||
//
|
||||
// Limiter has three main methods, Allow, Reserve, and Wait.
|
||||
// Most callers should use Wait.
|
||||
//
|
||||
// Each of the three methods consumes a single token.
|
||||
// They differ in their behavior when no token is available.
|
||||
// If no token is available, Allow returns false.
|
||||
// If no token is available, Reserve returns a reservation for a future token
|
||||
// and the amount of time the caller must wait before using it.
|
||||
// If no token is available, Wait blocks until one can be obtained
|
||||
// or its associated context.Context is canceled.
|
||||
//
|
||||
// The methods AllowN, ReserveN, and WaitN consume n tokens.
|
||||
type Limiter struct {
|
||||
limit Limit
|
||||
burst int
|
||||
|
||||
mu sync.Mutex
|
||||
tokens float64
|
||||
// last is the last time the limiter's tokens field was updated
|
||||
last time.Time
|
||||
// lastEvent is the latest time of a rate-limited event (past or future)
|
||||
lastEvent time.Time
|
||||
}
|
||||
|
||||
// Limit returns the maximum overall event rate.
|
||||
func (lim *Limiter) Limit() Limit {
|
||||
lim.mu.Lock()
|
||||
defer lim.mu.Unlock()
|
||||
return lim.limit
|
||||
}
|
||||
|
||||
// Burst returns the maximum burst size. Burst is the maximum number of tokens
|
||||
// that can be consumed in a single call to Allow, Reserve, or Wait, so higher
|
||||
// Burst values allow more events to happen at once.
|
||||
// A zero Burst allows no events, unless limit == Inf.
|
||||
func (lim *Limiter) Burst() int {
|
||||
return lim.burst
|
||||
}
|
||||
|
||||
// NewLimiter returns a new Limiter that allows events up to rate r and permits
|
||||
// bursts of at most b tokens.
|
||||
func NewLimiter(r Limit, b int) *Limiter {
|
||||
return &Limiter{
|
||||
limit: r,
|
||||
burst: b,
|
||||
}
|
||||
}
|
||||
|
||||
// Allow is shorthand for AllowN(time.Now(), 1).
|
||||
func (lim *Limiter) Allow() bool {
|
||||
return lim.AllowN(time.Now(), 1)
|
||||
}
|
||||
|
||||
// AllowN reports whether n events may happen at time now.
|
||||
// Use this method if you intend to drop / skip events that exceed the rate limit.
|
||||
// Otherwise use Reserve or Wait.
|
||||
func (lim *Limiter) AllowN(now time.Time, n int) bool {
|
||||
return lim.reserveN(now, n, 0).ok
|
||||
}
|
||||
|
||||
// A Reservation holds information about events that are permitted by a Limiter to happen after a delay.
|
||||
// A Reservation may be canceled, which may enable the Limiter to permit additional events.
|
||||
type Reservation struct {
|
||||
ok bool
|
||||
lim *Limiter
|
||||
tokens int
|
||||
timeToAct time.Time
|
||||
// This is the Limit at reservation time, it can change later.
|
||||
limit Limit
|
||||
}
|
||||
|
||||
// OK returns whether the limiter can provide the requested number of tokens
|
||||
// within the maximum wait time. If OK is false, Delay returns InfDuration, and
|
||||
// Cancel does nothing.
|
||||
func (r *Reservation) OK() bool {
|
||||
return r.ok
|
||||
}
|
||||
|
||||
// Delay is shorthand for DelayFrom(time.Now()).
|
||||
func (r *Reservation) Delay() time.Duration {
|
||||
return r.DelayFrom(time.Now())
|
||||
}
|
||||
|
||||
// InfDuration is the duration returned by Delay when a Reservation is not OK.
|
||||
const InfDuration = time.Duration(1<<63 - 1)
|
||||
|
||||
// DelayFrom returns the duration for which the reservation holder must wait
|
||||
// before taking the reserved action. Zero duration means act immediately.
|
||||
// InfDuration means the limiter cannot grant the tokens requested in this
|
||||
// Reservation within the maximum wait time.
|
||||
func (r *Reservation) DelayFrom(now time.Time) time.Duration {
|
||||
if !r.ok {
|
||||
return InfDuration
|
||||
}
|
||||
delay := r.timeToAct.Sub(now)
|
||||
if delay < 0 {
|
||||
return 0
|
||||
}
|
||||
return delay
|
||||
}
|
||||
|
||||
// Cancel is shorthand for CancelAt(time.Now()).
|
||||
func (r *Reservation) Cancel() {
|
||||
r.CancelAt(time.Now())
|
||||
return
|
||||
}
|
||||
|
||||
// CancelAt indicates that the reservation holder will not perform the reserved action
|
||||
// and reverses the effects of this Reservation on the rate limit as much as possible,
|
||||
// considering that other reservations may have already been made.
|
||||
func (r *Reservation) CancelAt(now time.Time) {
|
||||
if !r.ok {
|
||||
return
|
||||
}
|
||||
|
||||
r.lim.mu.Lock()
|
||||
defer r.lim.mu.Unlock()
|
||||
|
||||
if r.lim.limit == Inf || r.tokens == 0 || r.timeToAct.Before(now) {
|
||||
return
|
||||
}
|
||||
|
||||
// calculate tokens to restore
|
||||
// The duration between lim.lastEvent and r.timeToAct tells us how many tokens were reserved
|
||||
// after r was obtained. These tokens should not be restored.
|
||||
restoreTokens := float64(r.tokens) - r.limit.tokensFromDuration(r.lim.lastEvent.Sub(r.timeToAct))
|
||||
if restoreTokens <= 0 {
|
||||
return
|
||||
}
|
||||
// advance time to now
|
||||
now, _, tokens := r.lim.advance(now)
|
||||
// calculate new number of tokens
|
||||
tokens += restoreTokens
|
||||
if burst := float64(r.lim.burst); tokens > burst {
|
||||
tokens = burst
|
||||
}
|
||||
// update state
|
||||
r.lim.last = now
|
||||
r.lim.tokens = tokens
|
||||
if r.timeToAct == r.lim.lastEvent {
|
||||
prevEvent := r.timeToAct.Add(r.limit.durationFromTokens(float64(-r.tokens)))
|
||||
if !prevEvent.Before(now) {
|
||||
r.lim.lastEvent = prevEvent
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Reserve is shorthand for ReserveN(time.Now(), 1).
|
||||
func (lim *Limiter) Reserve() *Reservation {
|
||||
return lim.ReserveN(time.Now(), 1)
|
||||
}
|
||||
|
||||
// ReserveN returns a Reservation that indicates how long the caller must wait before n events happen.
|
||||
// The Limiter takes this Reservation into account when allowing future events.
|
||||
// ReserveN returns false if n exceeds the Limiter's burst size.
|
||||
// Usage example:
|
||||
// r, ok := lim.ReserveN(time.Now(), 1)
|
||||
// if !ok {
|
||||
// // Not allowed to act! Did you remember to set lim.burst to be > 0 ?
|
||||
// }
|
||||
// time.Sleep(r.Delay())
|
||||
// Act()
|
||||
// Use this method if you wish to wait and slow down in accordance with the rate limit without dropping events.
|
||||
// If you need to respect a deadline or cancel the delay, use Wait instead.
|
||||
// To drop or skip events exceeding rate limit, use Allow instead.
|
||||
func (lim *Limiter) ReserveN(now time.Time, n int) *Reservation {
|
||||
r := lim.reserveN(now, n, InfDuration)
|
||||
return &r
|
||||
}
|
||||
|
||||
// Wait is shorthand for WaitN(ctx, 1).
|
||||
func (lim *Limiter) Wait(ctx context.Context) (err error) {
|
||||
return lim.WaitN(ctx, 1)
|
||||
}
|
||||
|
||||
// WaitN blocks until lim permits n events to happen.
|
||||
// It returns an error if n exceeds the Limiter's burst size, the Context is
|
||||
// canceled, or the expected wait time exceeds the Context's Deadline.
|
||||
func (lim *Limiter) WaitN(ctx context.Context, n int) (err error) {
|
||||
if n > lim.burst {
|
||||
return fmt.Errorf("rate: Wait(n=%d) exceeds limiter's burst %d", n, lim.burst)
|
||||
}
|
||||
// Check if ctx is already cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
// Determine wait limit
|
||||
now := time.Now()
|
||||
waitLimit := InfDuration
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
waitLimit = deadline.Sub(now)
|
||||
}
|
||||
// Reserve
|
||||
r := lim.reserveN(now, n, waitLimit)
|
||||
if !r.ok {
|
||||
return fmt.Errorf("rate: Wait(n=%d) would exceed context deadline", n)
|
||||
}
|
||||
// Wait
|
||||
t := time.NewTimer(r.DelayFrom(now))
|
||||
defer t.Stop()
|
||||
select {
|
||||
case <-t.C:
|
||||
// We can proceed.
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
// Context was canceled before we could proceed. Cancel the
|
||||
// reservation, which may permit other events to proceed sooner.
|
||||
r.Cancel()
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// SetLimit is shorthand for SetLimitAt(time.Now(), newLimit).
|
||||
func (lim *Limiter) SetLimit(newLimit Limit) {
|
||||
lim.SetLimitAt(time.Now(), newLimit)
|
||||
}
|
||||
|
||||
// SetLimitAt sets a new Limit for the limiter. The new Limit, and Burst, may be violated
|
||||
// or underutilized by those which reserved (using Reserve or Wait) but did not yet act
|
||||
// before SetLimitAt was called.
|
||||
func (lim *Limiter) SetLimitAt(now time.Time, newLimit Limit) {
|
||||
lim.mu.Lock()
|
||||
defer lim.mu.Unlock()
|
||||
|
||||
now, _, tokens := lim.advance(now)
|
||||
|
||||
lim.last = now
|
||||
lim.tokens = tokens
|
||||
lim.limit = newLimit
|
||||
}
|
||||
|
||||
// reserveN is a helper method for AllowN, ReserveN, and WaitN.
|
||||
// maxFutureReserve specifies the maximum reservation wait duration allowed.
|
||||
// reserveN returns Reservation, not *Reservation, to avoid allocation in AllowN and WaitN.
|
||||
func (lim *Limiter) reserveN(now time.Time, n int, maxFutureReserve time.Duration) Reservation {
|
||||
lim.mu.Lock()
|
||||
defer lim.mu.Unlock()
|
||||
|
||||
if lim.limit == Inf {
|
||||
return Reservation{
|
||||
ok: true,
|
||||
lim: lim,
|
||||
tokens: n,
|
||||
timeToAct: now,
|
||||
}
|
||||
}
|
||||
|
||||
now, last, tokens := lim.advance(now)
|
||||
|
||||
// Calculate the remaining number of tokens resulting from the request.
|
||||
tokens -= float64(n)
|
||||
|
||||
// Calculate the wait duration
|
||||
var waitDuration time.Duration
|
||||
if tokens < 0 {
|
||||
waitDuration = lim.limit.durationFromTokens(-tokens)
|
||||
}
|
||||
|
||||
// Decide result
|
||||
ok := n <= lim.burst && waitDuration <= maxFutureReserve
|
||||
|
||||
// Prepare reservation
|
||||
r := Reservation{
|
||||
ok: ok,
|
||||
lim: lim,
|
||||
limit: lim.limit,
|
||||
}
|
||||
if ok {
|
||||
r.tokens = n
|
||||
r.timeToAct = now.Add(waitDuration)
|
||||
}
|
||||
|
||||
// Update state
|
||||
if ok {
|
||||
lim.last = now
|
||||
lim.tokens = tokens
|
||||
lim.lastEvent = r.timeToAct
|
||||
} else {
|
||||
lim.last = last
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// advance calculates and returns an updated state for lim resulting from the passage of time.
|
||||
// lim is not changed.
|
||||
func (lim *Limiter) advance(now time.Time) (newNow time.Time, newLast time.Time, newTokens float64) {
|
||||
last := lim.last
|
||||
if now.Before(last) {
|
||||
last = now
|
||||
}
|
||||
|
||||
// Avoid making delta overflow below when last is very old.
|
||||
maxElapsed := lim.limit.durationFromTokens(float64(lim.burst) - lim.tokens)
|
||||
elapsed := now.Sub(last)
|
||||
if elapsed > maxElapsed {
|
||||
elapsed = maxElapsed
|
||||
}
|
||||
|
||||
// Calculate the new number of tokens, due to time that passed.
|
||||
delta := lim.limit.tokensFromDuration(elapsed)
|
||||
tokens := lim.tokens + delta
|
||||
if burst := float64(lim.burst); tokens > burst {
|
||||
tokens = burst
|
||||
}
|
||||
|
||||
return now, last, tokens
|
||||
}
|
||||
|
||||
// durationFromTokens is a unit conversion function from the number of tokens to the duration
|
||||
// of time it takes to accumulate them at a rate of limit tokens per second.
|
||||
func (limit Limit) durationFromTokens(tokens float64) time.Duration {
|
||||
seconds := tokens / float64(limit)
|
||||
return time.Nanosecond * time.Duration(1e9*seconds)
|
||||
}
|
||||
|
||||
// tokensFromDuration is a unit conversion function from a time duration to the number of tokens
|
||||
// which could be accumulated during that duration at a rate of limit tokens per second.
|
||||
func (limit Limit) tokensFromDuration(d time.Duration) float64 {
|
||||
return d.Seconds() * float64(limit)
|
||||
}
|
||||
24
vendor/vendor.json
vendored
24
vendor/vendor.json
vendored
@@ -623,10 +623,16 @@
|
||||
"revisionTime": "2016-06-09T00:18:40Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "0rkVtm9F1/pW9EGhHYJpCnY99O8=",
|
||||
"checksumSHA1": "RAJfRxZ8UmcL6+7VuXAZxBlnM/4=",
|
||||
"path": "github.com/hashicorp/vault",
|
||||
"revision": "fece3ca069fc5bafec5280bbcb0c0693ff69fdaf",
|
||||
"revisionTime": "2016-08-17T21:47:06Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "JH8wmQ8cWdn7mYu1T7gJ3IMIrec=",
|
||||
"path": "github.com/hashicorp/vault/api",
|
||||
"revision": "fbecd94926e289d3b81d8dae6136452a6c4c93f6",
|
||||
"revisionTime": "2016-08-13T15:54:01Z"
|
||||
"revision": "fece3ca069fc5bafec5280bbcb0c0693ff69fdaf",
|
||||
"revisionTime": "2016-08-17T21:47:06Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "5lR6EdY0ARRdKAq3hZcL38STD8Q=",
|
||||
@@ -827,6 +833,12 @@
|
||||
"revision": "30db96677b74e24b967e23f911eb3364fc61a011",
|
||||
"revisionTime": "2016-05-25T13:11:03Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "S0DP7Pn7sZUmXc55IzZnNvERu6s=",
|
||||
"path": "golang.org/x/sync/errgroup",
|
||||
"revision": "316e794f7b5e3df4e95175a45a5fb8b12f85cb4f",
|
||||
"revisionTime": "2016-07-15T18:54:39Z"
|
||||
},
|
||||
{
|
||||
"path": "golang.org/x/sys/unix",
|
||||
"revision": "50c6bc5e4292a1d4e65c6e9be5f53be28bcbe28e"
|
||||
@@ -837,6 +849,12 @@
|
||||
"revision": "b776ec39b3e54652e09028aaaaac9757f4f8211a",
|
||||
"revisionTime": "2016-04-21T02:29:30Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "h/06ikMECfJoTkWj2e1nJ30aDDg=",
|
||||
"path": "golang.org/x/time/rate",
|
||||
"revision": "a4bde12657593d5e90d0533a3e4fd95e635124cb",
|
||||
"revisionTime": "2016-02-02T18:34:45Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "93uHIq25lffEKY47PV8dBPD+XuQ=",
|
||||
"path": "gopkg.in/fsnotify.v1",
|
||||
|
||||
Reference in New Issue
Block a user