diff --git a/nomad/fsm.go b/nomad/fsm.go index 4b757b89a..318ad8bf3 100644 --- a/nomad/fsm.go +++ b/nomad/fsm.go @@ -35,6 +35,7 @@ const ( TimeTableSnapshot PeriodicLaunchSnapshot JobSummarySnapshot + VaultAccessorSnapshot ) // nomadFSM implements a finite state machine that is used @@ -137,6 +138,8 @@ func (n *nomadFSM) Apply(log *raft.Log) interface{} { return n.applyAllocClientUpdate(buf[1:], log.Index) case structs.ReconcileJobSummariesRequestType: return n.applyReconcileSummaries(buf[1:], log.Index) + case structs.VaultAccessorRegisterRequestType: + return n.applyUpsertVaultAccessor(buf[1:], log.Index) default: if ignoreUnknown { n.logger.Printf("[WARN] nomad.fsm: ignoring unknown message type (%d), upgrade to newer version", msgType) @@ -454,6 +457,23 @@ func (n *nomadFSM) applyReconcileSummaries(buf []byte, index uint64) interface{} return n.reconcileQueuedAllocations(index) } +// applyUpsertVaultAccessor stores the Vault accessors for a given allocation +// and task +func (n *nomadFSM) applyUpsertVaultAccessor(buf []byte, index uint64) interface{} { + defer metrics.MeasureSince([]string{"nomad", "fsm", "upsert_vault_accessor"}, time.Now()) + var req structs.VaultAccessorRegisterRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + + if err := n.state.UpsertVaultAccessor(index, req.Accessors); err != nil { + n.logger.Printf("[ERR] nomad.fsm: UpsertVaultAccessor failed: %v", err) + return err + } + + return nil +} + func (n *nomadFSM) Snapshot() (raft.FSMSnapshot, error) { // Create a new snapshot snap, err := n.state.Snapshot() @@ -583,6 +603,15 @@ func (n *nomadFSM) Restore(old io.ReadCloser) error { return err } + case VaultAccessorSnapshot: + accessor := new(structs.VaultAccessor) + if err := dec.Decode(accessor); err != nil { + return err + } + if err := restore.VaultAccessorRestore(accessor); err != nil { + return err + } + default: return fmt.Errorf("Unrecognized snapshot type: %v", msgType) } @@ -756,6 +785,10 @@ func (s *nomadSnapshot) Persist(sink raft.SnapshotSink) error { sink.Cancel() return err } + if err := s.persistVaultAccessors(sink, encoder); err != nil { + sink.Cancel() + return err + } return nil } @@ -945,6 +978,30 @@ func (s *nomadSnapshot) persistJobSummaries(sink raft.SnapshotSink, return nil } +func (s *nomadSnapshot) persistVaultAccessors(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + + accessors, err := s.snap.VaultAccessors() + if err != nil { + return err + } + + for { + raw := accessors.Next() + if raw == nil { + break + } + + accessor := raw.(*structs.VaultAccessor) + + sink.Write([]byte{byte(VaultAccessorSnapshot)}) + if err := encoder.Encode(accessor); err != nil { + return err + } + } + return nil +} + // Release is a no-op, as we just need to GC the pointer // to the state store snapshot. There is nothing to explicitly // cleanup. diff --git a/nomad/fsm_test.go b/nomad/fsm_test.go index 2274ea27d..805e365c4 100644 --- a/nomad/fsm_test.go +++ b/nomad/fsm_test.go @@ -770,6 +770,54 @@ func TestFSM_UpdateAllocFromClient(t *testing.T) { } } +func TestFSM_UpsertVaultAccessor(t *testing.T) { + fsm := testFSM(t) + fsm.blockedEvals.SetEnabled(true) + + va := mock.VaultAccessor() + va2 := mock.VaultAccessor() + req := structs.VaultAccessorRegisterRequest{ + Accessors: []*structs.VaultAccessor{va, va2}, + } + buf, err := structs.Encode(structs.VaultAccessorRegisterRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Verify we are registered + out1, err := fsm.State().VaultAccessor(va.Accessor) + if err != nil { + t.Fatalf("err: %v", err) + } + if out1 == nil { + t.Fatalf("not found!") + } + if out1.CreateIndex != 1 { + t.Fatalf("bad index: %d", out1.CreateIndex) + } + out2, err := fsm.State().VaultAccessor(va2.Accessor) + if err != nil { + t.Fatalf("err: %v", err) + } + if out2 == nil { + t.Fatalf("not found!") + } + if out1.CreateIndex != 1 { + t.Fatalf("bad index: %d", out2.CreateIndex) + } + + tt := fsm.TimeTable() + index := tt.NearestIndex(time.Now().UTC()) + if index != 1 { + t.Fatalf("bad: %d", index) + } +} + func testSnapshotRestore(t *testing.T, fsm *nomadFSM) *nomadFSM { // Snapshot snap, err := fsm.Snapshot() @@ -976,6 +1024,27 @@ func TestFSM_SnapshotRestore_JobSummary(t *testing.T) { } } +func TestFSM_SnapshotRestore_VaultAccessors(t *testing.T) { + // Add some state + fsm := testFSM(t) + state := fsm.State() + a1 := mock.VaultAccessor() + a2 := mock.VaultAccessor() + state.UpsertVaultAccessor(1000, []*structs.VaultAccessor{a1, a2}) + + // Verify the contents + fsm2 := testSnapshotRestore(t, fsm) + state2 := fsm2.State() + out1, _ := state2.VaultAccessor(a1.Accessor) + out2, _ := state2.VaultAccessor(a2.Accessor) + if !reflect.DeepEqual(a1, out1) { + t.Fatalf("bad: \n%#v\n%#v", out1, a1) + } + if !reflect.DeepEqual(a2, out2) { + t.Fatalf("bad: \n%#v\n%#v", out2, a2) + } +} + func TestFSM_SnapshotRestore_AddMissingSummary(t *testing.T) { // Add some state fsm := testFSM(t) diff --git a/nomad/job_endpoint.go b/nomad/job_endpoint.go index 9900bd7cd..b3cde6a89 100644 --- a/nomad/job_endpoint.go +++ b/nomad/job_endpoint.go @@ -1,6 +1,7 @@ package nomad import ( + "context" "fmt" "strings" "time" @@ -83,7 +84,7 @@ func (j *Job) Register(args *structs.JobRegisterRequest, reply *structs.JobRegis } vault := j.srv.vault - s, err := vault.LookupToken(args.Job.VaultToken) + s, err := vault.LookupToken(context.Background(), args.Job.VaultToken) if err != nil { return err } diff --git a/nomad/mock/mock.go b/nomad/mock/mock.go index 1cdec8d18..937d09783 100644 --- a/nomad/mock/mock.go +++ b/nomad/mock/mock.go @@ -290,6 +290,16 @@ func Alloc() *structs.Allocation { return alloc } +func VaultAccessor() *structs.VaultAccessor { + return &structs.VaultAccessor{ + Accessor: structs.GenerateUUID(), + NodeID: structs.GenerateUUID(), + AllocID: structs.GenerateUUID(), + CreationTTL: 86400, + Task: "foo", + } +} + func Plan() *structs.Plan { return &structs.Plan{ Priority: 50, diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index f760a1796..abed2632a 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -1,21 +1,29 @@ package nomad import ( + "context" "fmt" "strings" "sync" "time" + "golang.org/x/sync/errgroup" + "github.com/armon/go-metrics" "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/watch" + vapi "github.com/hashicorp/vault/api" ) const ( // batchUpdateInterval is how long we wait to batch updates batchUpdateInterval = 50 * time.Millisecond + + // maxParallelRequestsPerDerive is the maximum number of parallel Vault + // create token requests that may be outstanding per derive request + maxParallelRequestsPerDerive = 16 ) // Node endpoint is used for client interactions @@ -868,3 +876,176 @@ func (b *batchFuture) Respond(index uint64, err error) { b.err = err close(b.doneCh) } + +// DeriveVaultToken is used by the clients to request wrapped Vault tokens for +// tasks +func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest, + reply *structs.DeriveVaultTokenResponse) error { + if done, err := n.srv.forward("Node.DeriveVaultToken", args, args, reply); done { + return err + } + defer metrics.MeasureSince([]string{"nomad", "client", "derive_vault_token"}, time.Now()) + + // Verify the arguments + if args.NodeID == "" { + return fmt.Errorf("missing node ID") + } + if args.SecretID == "" { + return fmt.Errorf("missing node SecretID") + } + if args.AllocID == "" { + return fmt.Errorf("missing allocation ID") + } + if len(args.Tasks) == 0 { + return fmt.Errorf("no tasks specified") + } + + // Verify the following: + // * The Node exists and has the correct SecretID + // * The Allocation exists on the specified node + // * The allocation contains the given tasks and they each require Vault + // tokens + snap, err := n.srv.fsm.State().Snapshot() + if err != nil { + return err + } + node, err := snap.NodeByID(args.NodeID) + if err != nil { + return err + } + if node == nil { + return fmt.Errorf("Node %q does not exist", args.NodeID) + } + if node.SecretID != args.SecretID { + return fmt.Errorf("SecretID mismatch") + } + + alloc, err := snap.AllocByID(args.AllocID) + if err != nil { + return err + } + if alloc == nil { + return fmt.Errorf("Allocation %q does not exist", args.AllocID) + } + if alloc.NodeID != args.NodeID { + return fmt.Errorf("Allocation %q not running on Node %q", args.AllocID, args.NodeID) + } + if alloc.TerminalStatus() { + return fmt.Errorf("Can't request Vault token for terminal allocation") + } + + // Check the policies + policies := alloc.Job.VaultPolicies() + if policies == nil { + return fmt.Errorf("Job doesn't require Vault policies") + } + tg, ok := policies[alloc.TaskGroup] + if !ok { + return fmt.Errorf("Task group does not require Vault policies") + } + + var unneeded []string + for _, task := range args.Tasks { + taskVault := tg[task] + if taskVault == nil || len(taskVault.Policies) == 0 { + unneeded = append(unneeded, task) + } + } + + if len(unneeded) != 0 { + return fmt.Errorf("Requested Vault tokens for tasks without defined Vault policies: %s", + strings.Join(unneeded, ", ")) + } + + // At this point the request is valid and we should contact Vault for + // tokens. + + // Create an error group where we will spin up a fixed set of goroutines to + // handle deriving tokens but where if any fails the whole group is + // canceled. + g, ctx := errgroup.WithContext(context.Background()) + + // Cap the handlers + handlers := len(args.Tasks) + if handlers > maxParallelRequestsPerDerive { + handlers = maxParallelRequestsPerDerive + } + + // Create the Vault Tokens + input := make(chan string, handlers) + results := make(map[string]*vapi.Secret, len(args.Tasks)) + for i := 0; i < handlers; i++ { + g.Go(func() error { + for { + select { + case task, ok := <-input: + if !ok { + return nil + } + + secret, err := n.srv.vault.CreateToken(ctx, alloc, task) + if err != nil { + return fmt.Errorf("failed to create token for task %q: %v", task, err) + } + + results[task] = secret + case <-ctx.Done(): + return nil + } + } + }) + } + + // Send the input + go func() { + defer close(input) + for _, task := range args.Tasks { + select { + case <-ctx.Done(): + return + case input <- task: + } + } + + }() + + // Wait for everything to complete or for an error + err = g.Wait() + if err != nil { + // TODO Revoke any created token + return err + } + + // Commit to Raft before returning any of the tokens + accessors := make([]*structs.VaultAccessor, 0, len(results)) + tokens := make(map[string]string, len(results)) + for task, secret := range results { + w := secret.WrapInfo + if w == nil { + return fmt.Errorf("Vault returned Secret without WrapInfo") + } + + tokens[task] = w.Token + accessor := &structs.VaultAccessor{ + Accessor: w.WrappedAccessor, + Task: task, + NodeID: alloc.NodeID, + AllocID: alloc.ID, + CreationTTL: w.TTL, + } + + accessors = append(accessors, accessor) + } + + req := structs.VaultAccessorRegisterRequest{Accessors: accessors} + _, index, err := n.srv.raftApply(structs.VaultAccessorRegisterRequestType, &req) + if err != nil { + n.srv.logger.Printf("[ERR] nomad.client: Register Vault accessors failed: %v", err) + return err + } + + reply.Index = index + reply.Tasks = tokens + n.srv.setQueryMeta(&reply.QueryMeta) + return nil +} diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index 23dabec2f..93531f341 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" + vapi "github.com/hashicorp/vault/api" ) func TestClientEndpoint_Register(t *testing.T) { @@ -1597,3 +1598,160 @@ func TestBatchFuture(t *testing.T) { t.Fatalf("bad: %d", bf.Index()) } } + +func TestClientEndpoint_DeriveVaultToken_Bad(t *testing.T) { + s1 := testServer(t, nil) + defer s1.Shutdown() + state := s1.fsm.State() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Create the node + node := mock.Node() + if err := state.UpsertNode(2, node); err != nil { + t.Fatalf("err: %v", err) + } + + // Create an alloc + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + tasks := []string{task.Name} + if err := state.UpsertAllocs(3, []*structs.Allocation{alloc}); err != nil { + t.Fatalf("err: %v", err) + } + + req := &structs.DeriveVaultTokenRequest{ + NodeID: node.ID, + SecretID: structs.GenerateUUID(), + AllocID: alloc.ID, + Tasks: tasks, + QueryOptions: structs.QueryOptions{ + Region: "global", + }, + } + + var resp structs.DeriveVaultTokenResponse + err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp) + if err == nil || !strings.Contains(err.Error(), "SecretID mismatch") { + t.Fatalf("Expected SecretID mismatch: %v", err) + } + + // Put the correct SecretID + req.SecretID = node.SecretID + + // Now we should get an error about the allocation not running on the node + err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp) + if err == nil || !strings.Contains(err.Error(), "not running on Node") { + t.Fatalf("Expected not running on node error: %v", err) + } + + // Update to be running on the node + alloc.NodeID = node.ID + if err := state.UpsertAllocs(4, []*structs.Allocation{alloc}); err != nil { + t.Fatalf("err: %v", err) + } + + // Now we should get an error about the job not needing any Vault secrets + err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp) + if err == nil || !strings.Contains(err.Error(), "without defined Vault") { + t.Fatalf("Expected no policies error: %v", err) + } + + // Update to be terminal + alloc.DesiredStatus = structs.AllocDesiredStatusStop + if err := state.UpsertAllocs(5, []*structs.Allocation{alloc}); err != nil { + t.Fatalf("err: %v", err) + } + + // Now we should get an error about the job not needing any Vault secrets + err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp) + if err == nil || !strings.Contains(err.Error(), "terminal") { + t.Fatalf("Expected terminal allocation error: %v", err) + } +} + +func TestClientEndpoint_DeriveVaultToken(t *testing.T) { + s1 := testServer(t, nil) + defer s1.Shutdown() + state := s1.fsm.State() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Enable vault and allow authenticated + s1.config.VaultConfig.Enabled = true + s1.config.VaultConfig.AllowUnauthenticated = true + + // Replace the Vault Client on the server + tvc := &TestVaultClient{} + s1.vault = tvc + + // Create the node + node := mock.Node() + if err := state.UpsertNode(2, node); err != nil { + t.Fatalf("err: %v", err) + } + + // Create an alloc an allocation that has vault policies required + alloc := mock.Alloc() + alloc.NodeID = node.ID + task := alloc.Job.TaskGroups[0].Tasks[0] + tasks := []string{task.Name} + task.Vault = &structs.Vault{Policies: []string{"a", "b"}} + if err := state.UpsertAllocs(3, []*structs.Allocation{alloc}); err != nil { + t.Fatalf("err: %v", err) + } + + // Return a secret for the task + token := structs.GenerateUUID() + accessor := structs.GenerateUUID() + ttl := 10 + secret := &vapi.Secret{ + WrapInfo: &vapi.SecretWrapInfo{ + Token: token, + WrappedAccessor: accessor, + TTL: ttl, + }, + } + tvc.SetCreateTokenSecret(alloc.ID, task.Name, secret) + + req := &structs.DeriveVaultTokenRequest{ + NodeID: node.ID, + SecretID: node.SecretID, + AllocID: alloc.ID, + Tasks: tasks, + QueryOptions: structs.QueryOptions{ + Region: "global", + }, + } + + var resp structs.DeriveVaultTokenResponse + if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil { + t.Fatalf("bad: %v", err) + } + + // Check the state store and ensure that we created a VaultAccessor + va, err := state.VaultAccessor(accessor) + if err != nil { + t.Fatalf("bad: %v", err) + } + if va == nil { + t.Fatalf("bad: %v", va) + } + + if va.CreateIndex == 0 { + t.Fatalf("bad: %v", va) + } + + va.CreateIndex = 0 + expected := &structs.VaultAccessor{ + AllocID: alloc.ID, + Task: task.Name, + NodeID: alloc.NodeID, + Accessor: accessor, + CreationTTL: ttl, + } + + if !reflect.DeepEqual(expected, va) { + t.Fatalf("Got %#v; want %#v", va, expected) + } +} diff --git a/nomad/state/schema.go b/nomad/state/schema.go index b88cf43cf..05ffd78cf 100644 --- a/nomad/state/schema.go +++ b/nomad/state/schema.go @@ -23,6 +23,7 @@ func stateStoreSchema() *memdb.DBSchema { periodicLaunchTableSchema, evalTableSchema, allocTableSchema, + vaultAccessorTableSchema, } // Add each of the tables @@ -291,3 +292,41 @@ func allocTableSchema() *memdb.TableSchema { }, } } + +// vaultAccessorTableSchema returns the MemDB schema for the Vault Accessor +// Table. This table tracks Vault accessors for tokens created on behalf of +// allocations required Vault tokens. +func vaultAccessorTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "vault_accessors", + Indexes: map[string]*memdb.IndexSchema{ + // The primary index is the accessor id + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Accessor", + }, + }, + + "alloc_id": &memdb.IndexSchema{ + Name: "alloc_id", + AllowMissing: false, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "AllocID", + }, + }, + + "node_id": &memdb.IndexSchema{ + Name: "node_id", + AllowMissing: false, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "NodeID", + }, + }, + }, + } +} diff --git a/nomad/state/state_store.go b/nomad/state/state_store.go index dfe5945ff..cee5935f3 100644 --- a/nomad/state/state_store.go +++ b/nomad/state/state_store.go @@ -1113,6 +1113,124 @@ func (s *StateStore) Allocs() (memdb.ResultIterator, error) { return iter, nil } +// UpsertVaultAccessors is used to register a set of Vault Accessors +func (s *StateStore) UpsertVaultAccessor(index uint64, accessors []*structs.VaultAccessor) error { + txn := s.db.Txn(true) + defer txn.Abort() + + for _, accessor := range accessors { + // Set the create index + accessor.CreateIndex = index + + // Insert the accessor + if err := txn.Insert("vault_accessors", accessor); err != nil { + return fmt.Errorf("accessor insert failed: %v", err) + } + } + + if err := txn.Insert("index", &IndexEntry{"vault_accessors", index}); err != nil { + return fmt.Errorf("index update failed: %v", err) + } + + txn.Commit() + return nil +} + +// DeleteVaultAccessor is used to delete a Vault Accessor +func (s *StateStore) DeleteVaultAccessor(index uint64, accessor string) error { + txn := s.db.Txn(true) + defer txn.Abort() + + // Lookup the accessor + existing, err := txn.First("vault_accessors", "id", accessor) + if err != nil { + return fmt.Errorf("accessor lookup failed: %v", err) + } + if existing == nil { + return fmt.Errorf("vault_accessor not found") + } + + // Delete the accessor + if err := txn.Delete("vault_accessors", existing); err != nil { + return fmt.Errorf("accessor delete failed: %v", err) + } + if err := txn.Insert("index", &IndexEntry{"vault_accessors", index}); err != nil { + return fmt.Errorf("index update failed: %v", err) + } + + txn.Commit() + return nil +} + +// VaultAccessor returns the given Vault accessor +func (s *StateStore) VaultAccessor(accessor string) (*structs.VaultAccessor, error) { + txn := s.db.Txn(false) + + existing, err := txn.First("vault_accessors", "id", accessor) + if err != nil { + return nil, fmt.Errorf("accessor lookup failed: %v", err) + } + + if existing != nil { + return existing.(*structs.VaultAccessor), nil + } + + return nil, nil +} + +// VaultAccessors returns an iterator of Vault accessors. +func (s *StateStore) VaultAccessors() (memdb.ResultIterator, error) { + txn := s.db.Txn(false) + + iter, err := txn.Get("vault_accessors", "id") + if err != nil { + return nil, err + } + return iter, nil +} + +// VaultAccessorsByAlloc returns all the Vault accessors by alloc id +func (s *StateStore) VaultAccessorsByAlloc(allocID string) ([]*structs.VaultAccessor, error) { + txn := s.db.Txn(false) + + // Get an iterator over the accessors + iter, err := txn.Get("vault_accessors", "alloc_id", allocID) + if err != nil { + return nil, err + } + + var out []*structs.VaultAccessor + for { + raw := iter.Next() + if raw == nil { + break + } + out = append(out, raw.(*structs.VaultAccessor)) + } + return out, nil +} + +// VaultAccessorsByNode returns all the Vault accessors by node id +func (s *StateStore) VaultAccessorsByNode(nodeID string) ([]*structs.VaultAccessor, error) { + txn := s.db.Txn(false) + + // Get an iterator over the accessors + iter, err := txn.Get("vault_accessors", "node_id", nodeID) + if err != nil { + return nil, err + } + + var out []*structs.VaultAccessor + for { + raw := iter.Next() + if raw == nil { + break + } + out = append(out, raw.(*structs.VaultAccessor)) + } + return out, nil +} + // LastIndex returns the greatest index value for all indexes func (s *StateStore) LatestIndex() (uint64, error) { indexes, err := s.Indexes() @@ -1627,6 +1745,14 @@ func (r *StateRestore) JobSummaryRestore(jobSummary *structs.JobSummary) error { return nil } +// VaultAccessorRestore is used to restore a vault accessor +func (r *StateRestore) VaultAccessorRestore(accessor *structs.VaultAccessor) error { + if err := r.txn.Insert("vault_accessors", accessor); err != nil { + return fmt.Errorf("vault accessor insert failed: %v", err) + } + return nil +} + // stateWatch holds shared state for watching updates. This is // outside of StateStore so it can be shared with snapshots. type stateWatch struct { diff --git a/nomad/state/state_store_test.go b/nomad/state/state_store_test.go index 4ee119e68..f272b9fb1 100644 --- a/nomad/state/state_store_test.go +++ b/nomad/state/state_store_test.go @@ -2833,6 +2833,206 @@ func TestJobSummary_UpdateClientStatus(t *testing.T) { } } +func TestStateStore_UpsertVaultAccessors(t *testing.T) { + state := testStateStore(t) + a := mock.VaultAccessor() + a2 := mock.VaultAccessor() + + err := state.UpsertVaultAccessor(1000, []*structs.VaultAccessor{a, a2}) + if err != nil { + t.Fatalf("err: %v", err) + } + + out, err := state.VaultAccessor(a.Accessor) + if err != nil { + t.Fatalf("err: %v", err) + } + + if !reflect.DeepEqual(a, out) { + t.Fatalf("bad: %#v %#v", a, out) + } + + out, err = state.VaultAccessor(a2.Accessor) + if err != nil { + t.Fatalf("err: %v", err) + } + + if !reflect.DeepEqual(a2, out) { + t.Fatalf("bad: %#v %#v", a2, out) + } + + iter, err := state.VaultAccessors() + if err != nil { + t.Fatalf("err: %v", err) + } + + count := 0 + for { + raw := iter.Next() + if raw == nil { + break + } + + count++ + accessor := raw.(*structs.VaultAccessor) + + if !reflect.DeepEqual(accessor, a) && !reflect.DeepEqual(accessor, a2) { + t.Fatalf("bad: %#v", accessor) + } + } + + if count != 2 { + t.Fatalf("bad: %d", count) + } + + index, err := state.Index("vault_accessors") + if err != nil { + t.Fatalf("err: %v", err) + } + if index != 1000 { + t.Fatalf("bad: %d", index) + } +} + +func TestStateStore_DeleteVaultAccessor(t *testing.T) { + state := testStateStore(t) + accessor := mock.VaultAccessor() + + err := state.UpsertVaultAccessor(1000, []*structs.VaultAccessor{accessor}) + if err != nil { + t.Fatalf("err: %v", err) + } + + err = state.DeleteVaultAccessor(1001, accessor.Accessor) + if err != nil { + t.Fatalf("err: %v", err) + } + + out, err := state.VaultAccessor(accessor.Accessor) + if err != nil { + t.Fatalf("err: %v", err) + } + + if out != nil { + t.Fatalf("bad: %#v %#v", accessor, out) + } + + index, err := state.Index("vault_accessors") + if err != nil { + t.Fatalf("err: %v", err) + } + if index != 1001 { + t.Fatalf("bad: %d", index) + } +} + +func TestStateStore_VaultAccessorsByAlloc(t *testing.T) { + state := testStateStore(t) + alloc := mock.Alloc() + var accessors []*structs.VaultAccessor + var expected []*structs.VaultAccessor + + for i := 0; i < 5; i++ { + accessor := mock.VaultAccessor() + accessor.AllocID = alloc.ID + expected = append(expected, accessor) + accessors = append(accessors, accessor) + } + + for i := 0; i < 10; i++ { + accessor := mock.VaultAccessor() + accessors = append(accessors, accessor) + } + + err := state.UpsertVaultAccessor(1000, accessors) + if err != nil { + t.Fatalf("err: %v", err) + } + + out, err := state.VaultAccessorsByAlloc(alloc.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(expected) != len(out) { + t.Fatalf("bad: %#v %#v", len(expected), len(out)) + } + + index, err := state.Index("vault_accessors") + if err != nil { + t.Fatalf("err: %v", err) + } + if index != 1000 { + t.Fatalf("bad: %d", index) + } +} + +func TestStateStore_VaultAccessorsByNode(t *testing.T) { + state := testStateStore(t) + node := mock.Node() + var accessors []*structs.VaultAccessor + var expected []*structs.VaultAccessor + + for i := 0; i < 5; i++ { + accessor := mock.VaultAccessor() + accessor.NodeID = node.ID + expected = append(expected, accessor) + accessors = append(accessors, accessor) + } + + for i := 0; i < 10; i++ { + accessor := mock.VaultAccessor() + accessors = append(accessors, accessor) + } + + err := state.UpsertVaultAccessor(1000, accessors) + if err != nil { + t.Fatalf("err: %v", err) + } + + out, err := state.VaultAccessorsByNode(node.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(expected) != len(out) { + t.Fatalf("bad: %#v %#v", len(expected), len(out)) + } + + index, err := state.Index("vault_accessors") + if err != nil { + t.Fatalf("err: %v", err) + } + if index != 1000 { + t.Fatalf("bad: %d", index) + } +} + +func TestStateStore_RestoreVaultAccessor(t *testing.T) { + state := testStateStore(t) + a := mock.VaultAccessor() + + restore, err := state.Restore() + if err != nil { + t.Fatalf("err: %v", err) + } + + err = restore.VaultAccessorRestore(a) + if err != nil { + t.Fatalf("err: %v", err) + } + restore.Commit() + + out, err := state.VaultAccessor(a.Accessor) + if err != nil { + t.Fatalf("err: %v", err) + } + + if !reflect.DeepEqual(out, a) { + t.Fatalf("Bad: %#v %#v", out, a) + } +} + // setupNotifyTest takes a state store and a set of watch items, then creates // and subscribes a notification channel for each item. func setupNotifyTest(state *StateStore, items ...watch.Item) notifyTest { diff --git a/nomad/structs/funcs.go b/nomad/structs/funcs.go index 68f0af18c..37df72c33 100644 --- a/nomad/structs/funcs.go +++ b/nomad/structs/funcs.go @@ -253,12 +253,12 @@ func SliceStringIsSubset(larger, smaller []string) (bool, []string) { // VaultPoliciesSet takes the structure returned by VaultPolicies and returns // the set of required policies -func VaultPoliciesSet(policies map[string]map[string][]string) []string { +func VaultPoliciesSet(policies map[string]map[string]*Vault) []string { set := make(map[string]struct{}) for _, tgp := range policies { for _, tp := range tgp { - for _, p := range tp { + for _, p := range tp.Policies { set[p] = struct{}{} } } diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index 47159e315..058ae16d3 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -47,6 +47,7 @@ const ( AllocUpdateRequestType AllocClientUpdateRequestType ReconcileJobSummariesRequestType + VaultAccessorRegisterRequestType ) const ( @@ -354,6 +355,41 @@ type PeriodicForceRequest struct { WriteRequest } +// DeriveVaultTokenRequest is used to request wrapped Vault tokens for the +// following tasks in the given allocation +type DeriveVaultTokenRequest struct { + NodeID string + SecretID string + AllocID string + Tasks []string + QueryOptions +} + +// VaultAccessorRegisterRequest is used to register a set of Vault accessors +type VaultAccessorRegisterRequest struct { + Accessors []*VaultAccessor +} + +// VaultAccessor is a reference to a created Vault token on behalf of +// an allocation's task. +type VaultAccessor struct { + AllocID string + Task string + NodeID string + Accessor string + CreationTTL int + + // Raft Indexes + CreateIndex uint64 +} + +// DeriveVaultTokenResponse returns the wrapped tokens for each requested task +type DeriveVaultTokenResponse struct { + // Tasks is a mapping between the task name and the wrapped token + Tasks map[string]string + QueryMeta +} + // GenericRequest is used to request where no // specific information is needed. type GenericRequest struct { @@ -1239,11 +1275,11 @@ func (j *Job) IsPeriodic() bool { } // VaultPolicies returns the set of Vault policies per task group, per task -func (j *Job) VaultPolicies() map[string]map[string][]string { - policies := make(map[string]map[string][]string, len(j.TaskGroups)) +func (j *Job) VaultPolicies() map[string]map[string]*Vault { + policies := make(map[string]map[string]*Vault, len(j.TaskGroups)) for _, tg := range j.TaskGroups { - tgPolicies := make(map[string][]string, len(tg.Tasks)) + tgPolicies := make(map[string]*Vault, len(tg.Tasks)) policies[tg.Name] = tgPolicies for _, task := range tg.Tasks { @@ -1251,7 +1287,7 @@ func (j *Job) VaultPolicies() map[string]map[string][]string { continue } - tgPolicies[task.Name] = task.Vault.Policies + tgPolicies[task.Name] = task.Vault } } diff --git a/nomad/structs/structs_test.go b/nomad/structs/structs_test.go index 413603d70..db9857571 100644 --- a/nomad/structs/structs_test.go +++ b/nomad/structs/structs_test.go @@ -224,8 +224,25 @@ func TestJob_SystemJob_Validate(t *testing.T) { func TestJob_VaultPolicies(t *testing.T) { j0 := &Job{} - e0 := make(map[string]map[string][]string, 0) + e0 := make(map[string]map[string]*Vault, 0) + vj1 := &Vault{ + Policies: []string{ + "p1", + "p2", + }, + } + vj2 := &Vault{ + Policies: []string{ + "p3", + "p4", + }, + } + vj3 := &Vault{ + Policies: []string{ + "p5", + }, + } j1 := &Job{ TaskGroups: []*TaskGroup{ &TaskGroup{ @@ -235,13 +252,8 @@ func TestJob_VaultPolicies(t *testing.T) { Name: "t1", }, &Task{ - Name: "t2", - Vault: &Vault{ - Policies: []string{ - "p1", - "p2", - }, - }, + Name: "t2", + Vault: vj1, }, }, }, @@ -249,40 +261,31 @@ func TestJob_VaultPolicies(t *testing.T) { Name: "bar", Tasks: []*Task{ &Task{ - Name: "t3", - Vault: &Vault{ - Policies: []string{ - "p3", - "p4", - }, - }, + Name: "t3", + Vault: vj2, }, &Task{ - Name: "t4", - Vault: &Vault{ - Policies: []string{ - "p5", - }, - }, + Name: "t4", + Vault: vj3, }, }, }, }, } - e1 := map[string]map[string][]string{ - "foo": map[string][]string{ - "t2": []string{"p1", "p2"}, + e1 := map[string]map[string]*Vault{ + "foo": map[string]*Vault{ + "t2": vj1, }, - "bar": map[string][]string{ - "t3": []string{"p3", "p4"}, - "t4": []string{"p5"}, + "bar": map[string]*Vault{ + "t3": vj2, + "t4": vj3, }, } cases := []struct { Job *Job - Expected map[string]map[string][]string + Expected map[string]map[string]*Vault }{ { Job: j0, diff --git a/nomad/vault.go b/nomad/vault.go index a7c734562..1866505b4 100644 --- a/nomad/vault.go +++ b/nomad/vault.go @@ -1,6 +1,7 @@ package nomad import ( + "context" "errors" "fmt" "log" @@ -12,6 +13,8 @@ import ( "github.com/hashicorp/nomad/nomad/structs/config" vapi "github.com/hashicorp/vault/api" "github.com/mitchellh/mapstructure" + + "golang.org/x/time/rate" ) const ( @@ -21,16 +24,25 @@ const ( // minimumTokenTTL is the minimum Token TTL allowed for child tokens. minimumTokenTTL = 5 * time.Minute + + // defaultTokenTTL is the default Token TTL used when the passed token is a + // root token such that child tokens aren't being created against a role + // that has defined a TTL + defaultTokenTTL = "72h" + + // requestRateLimit is the maximum number of requests per second Nomad will + // make against Vault + requestRateLimit rate.Limit = 500.0 ) // VaultClient is the Servers interface for interfacing with Vault type VaultClient interface { // CreateToken takes an allocation and task and returns an appropriate Vault // Secret - CreateToken(a *structs.Allocation, task string) (*vapi.Secret, error) + CreateToken(ctx context.Context, a *structs.Allocation, task string) (*vapi.Secret, error) // LookupToken takes a token string and returns its capabilities. - LookupToken(token string) (*vapi.Secret, error) + LookupToken(ctx context.Context, token string) (*vapi.Secret, error) // Stop is used to stop token renewal. Stop() @@ -52,6 +64,9 @@ type tokenData struct { // the Server with the ability to create child tokens and lookup the permissions // of tokens. type vaultClient struct { + // limiter is used to rate limit requests to Vault + limiter *rate.Limiter + // client is the Vault API client client *vapi.Client @@ -104,6 +119,7 @@ func NewVaultClient(c *config.VaultConfig, logger *log.Logger) (*vaultClient, er enabled: c.Enabled, config: c, logger: logger, + limiter: rate.NewLimiter(requestRateLimit, int(requestRateLimit)), } // If vault is not enabled do not configure an API client or start any token @@ -131,6 +147,9 @@ func NewVaultClient(c *config.VaultConfig, logger *log.Logger) (*vaultClient, er } v.childTTL = c.TaskTokenTTL + } else { + // Default the TaskTokenTTL + v.childTTL = defaultTokenTTL } // Get the Vault API configuration @@ -157,6 +176,11 @@ func NewVaultClient(c *config.VaultConfig, logger *log.Logger) (*vaultClient, er return v, nil } +// setLimit is used to update the rate limit +func (v *vaultClient) setLimit(l rate.Limit) { + v.limiter = rate.NewLimiter(l, int(l)) +} + // establishConnection is used to make first contact with Vault. This should be // called in a go-routine since the connection is retried til the Vault Client // is stopped or the connection is successfully made at which point the renew @@ -397,7 +421,7 @@ func (v *vaultClient) Stop() { v.l.Lock() defer v.l.Unlock() - if !v.renewalRunning || !v.establishingConn { + if !v.renewalRunning && !v.establishingConn { return } @@ -414,12 +438,9 @@ func (v *vaultClient) ConnectionEstablished() bool { return v.connEstablished } -func (v *vaultClient) CreateToken(a *structs.Allocation, task string) (*vapi.Secret, error) { - return nil, nil -} - -// LookupToken takes a Vault token and does a lookup against Vault -func (v *vaultClient) LookupToken(token string) (*vapi.Secret, error) { +// CreateToken takes the allocation and task and returns an appropriate Vault +// token. The call is rate limited and may be canceled with the passed policy +func (v *vaultClient) CreateToken(ctx context.Context, a *structs.Allocation, task string) (*vapi.Secret, error) { // Nothing to do if !v.enabled { return nil, fmt.Errorf("Vault integration disabled") @@ -430,6 +451,70 @@ func (v *vaultClient) LookupToken(token string) (*vapi.Secret, error) { return nil, fmt.Errorf("Connection to Vault has not been established. Retry") } + // Retrieve the Vault block for the task + policies := a.Job.VaultPolicies() + if policies == nil { + return nil, fmt.Errorf("Job doesn't require Vault policies") + } + tg, ok := policies[a.TaskGroup] + if !ok { + return nil, fmt.Errorf("Task group does not require Vault policies") + } + taskVault, ok := tg[task] + if !ok { + return nil, fmt.Errorf("Task does not require Vault policies") + } + + // Build the creation request + req := &vapi.TokenCreateRequest{ + Policies: taskVault.Policies, + Metadata: map[string]string{ + "AllocationID": a.ID, + "Task": task, + "NodeID": a.NodeID, + }, + TTL: v.childTTL, + DisplayName: fmt.Sprintf("%s: %s", a.ID, task), + } + + // Ensure we are under our rate limit + if err := v.limiter.Wait(ctx); err != nil { + return nil, err + } + + // Make the request and switch depending on whether we are using a root + // token or a role based token + var secret *vapi.Secret + var err error + if v.token.Root { + req.Period = v.childTTL + secret, err = v.auth.Create(req) + } else { + // Make the token using the role + secret, err = v.auth.CreateWithRole(req, v.token.Role) + } + + return secret, err +} + +// LookupToken takes a Vault token and does a lookup against Vault. The call is +// rate limited and may be canceled with passed context. +func (v *vaultClient) LookupToken(ctx context.Context, token string) (*vapi.Secret, error) { + // Nothing to do + if !v.enabled { + return nil, fmt.Errorf("Vault integration disabled") + } + + // Check if we have established a connection with Vault + if !v.ConnectionEstablished() { + return nil, fmt.Errorf("Connection to Vault has not been established. Retry") + } + + // Ensure we are under our rate limit + if err := v.limiter.Wait(ctx); err != nil { + return nil, err + } + // Lookup the token return v.auth.Lookup(token) } diff --git a/nomad/vault_test.go b/nomad/vault_test.go index b007ca22e..4d12fd6a7 100644 --- a/nomad/vault_test.go +++ b/nomad/vault_test.go @@ -1,6 +1,7 @@ package nomad import ( + "context" "encoding/json" "log" "os" @@ -9,12 +10,22 @@ import ( "testing" "time" + "golang.org/x/time/rate" + + "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs/config" "github.com/hashicorp/nomad/testutil" vapi "github.com/hashicorp/vault/api" ) +const ( + // authPolicy is a policy that allows token creation operations + authPolicy = `path "auth/token/create/*" { + capabilities = ["create", "read", "update", "delete", "list"] +}` +) + func TestVaultClient_BadConfig(t *testing.T) { conf := &config.VaultConfig{} logger := log.New(os.Stderr, "", log.LstdFlags) @@ -24,6 +35,7 @@ func TestVaultClient_BadConfig(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } + defer client.Stop() if client.ConnectionEstablished() { t.Fatalf("bad") @@ -75,15 +87,20 @@ func TestVaultClient_EstablishConnection(t *testing.T) { } } -func TestVaultClient_RenewalLoop(t *testing.T) { - v := testutil.NewTestVault(t).Start() - defer v.Stop() +// testVaultRoleAndToken creates a test Vault role where children are created +// with the passed period. A token created in that role is returned +func testVaultRoleAndToken(v *testutil.TestVault, t *testing.T, rolePeriod int) string { + // Build the auth policy + sys := v.Client.Sys() + if err := sys.PutPolicy("auth", authPolicy); err != nil { + t.Fatalf("failed to create auth policy: %v", err) + } // Build a role l := v.Client.Logical() d := make(map[string]interface{}, 2) - d["allowed_policies"] = "default" - d["period"] = 5 + d["allowed_policies"] = "default,auth" + d["period"] = rolePeriod l.Write("auth/token/roles/test", d) // Create a new token with the role @@ -99,8 +116,15 @@ func TestVaultClient_RenewalLoop(t *testing.T) { t.Fatalf("bad secret response: %+v", s) } - // Set the configs token - v.Config.Token = s.Auth.ClientToken + return s.Auth.ClientToken +} + +func TestVaultClient_RenewalLoop(t *testing.T) { + v := testutil.NewTestVault(t).Start() + defer v.Stop() + + // Set the configs token in a new test role + v.Config.Token = testVaultRoleAndToken(v, t, 5) // Start the client logger := log.New(os.Stderr, "", log.LstdFlags) @@ -114,6 +138,7 @@ func TestVaultClient_RenewalLoop(t *testing.T) { time.Sleep(8 * time.Second) // Get the current TTL + a := v.Client.Auth().Token() s2, err := a.Lookup(v.Config.Token) if err != nil { t.Fatalf("failed to lookup token: %v", err) @@ -160,8 +185,9 @@ func TestVaultClient_LookupToken_Invalid(t *testing.T) { if err != nil { t.Fatalf("failed to build vault client: %v", err) } + defer client.Stop() - _, err = client.LookupToken("foo") + _, err = client.LookupToken(context.Background(), "foo") if err == nil || !strings.Contains(err.Error(), "disabled") { t.Fatalf("Expected error because Vault is disabled: %v", err) } @@ -175,7 +201,7 @@ func TestVaultClient_LookupToken_Invalid(t *testing.T) { t.Fatalf("failed to build vault client: %v", err) } - _, err = client.LookupToken("foo") + _, err = client.LookupToken(context.Background(), "foo") if err == nil || !strings.Contains(err.Error(), "established") { t.Fatalf("Expected error because connection to Vault hasn't been made: %v", err) } @@ -198,11 +224,12 @@ func TestVaultClient_LookupToken(t *testing.T) { if err != nil { t.Fatalf("failed to build vault client: %v", err) } + defer client.Stop() waitForConnection(client, t) // Lookup ourselves - s, err := client.LookupToken(v.Config.Token) + s, err := client.LookupToken(context.Background(), v.Config.Token) if err != nil { t.Fatalf("self lookup failed: %v", err) } @@ -233,7 +260,7 @@ func TestVaultClient_LookupToken(t *testing.T) { } // Lookup new child - s, err = client.LookupToken(s.Auth.ClientToken) + s, err = client.LookupToken(context.Background(), s.Auth.ClientToken) if err != nil { t.Fatalf("self lookup failed: %v", err) } @@ -247,3 +274,145 @@ func TestVaultClient_LookupToken(t *testing.T) { t.Fatalf("Unexpected policies; got %v; want %v", policies, expected) } } + +func TestVaultClient_LookupToken_RateLimit(t *testing.T) { + v := testutil.NewTestVault(t).Start() + defer v.Stop() + + logger := log.New(os.Stderr, "", log.LstdFlags) + client, err := NewVaultClient(v.Config, logger) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + defer client.Stop() + client.setLimit(rate.Limit(1.0)) + + waitForConnection(client, t) + + // Spin up many requests. These should block + ctx, cancel := context.WithCancel(context.Background()) + + cancels := 0 + numRequests := 10 + unblock := make(chan struct{}) + for i := 0; i < numRequests; i++ { + go func() { + // Ensure all the goroutines are made + time.Sleep(10 * time.Millisecond) + + // Lookup ourselves + _, err := client.LookupToken(ctx, v.Config.Token) + if err != nil { + if err == context.Canceled { + cancels += 1 + return + } + t.Fatalf("self lookup failed: %v", err) + return + } + + // Cancel the context + cancel() + time.AfterFunc(1*time.Second, func() { close(unblock) }) + }() + } + + select { + case <-time.After(5 * time.Second): + t.Fatalf("timeout") + case <-unblock: + } + + desired := numRequests - 1 + if cancels != desired { + t.Fatalf("Incorrect number of cancels; got %d; want %d", cancels, desired) + } +} + +func TestVaultClient_CreateToken_Root(t *testing.T) { + v := testutil.NewTestVault(t).Start() + defer v.Stop() + + logger := log.New(os.Stderr, "", log.LstdFlags) + client, err := NewVaultClient(v.Config, logger) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + defer client.Stop() + + waitForConnection(client, t) + + // Create an allocation that requires a Vault policy + a := mock.Alloc() + task := a.Job.TaskGroups[0].Tasks[0] + task.Vault = &structs.Vault{Policies: []string{"default"}} + + s, err := client.CreateToken(context.Background(), a, task.Name) + if err != nil { + t.Fatalf("CreateToken failed: %v", err) + } + + // Ensure that created secret is a wrapped token + if s == nil || s.WrapInfo == nil { + t.Fatalf("Bad secret: %#v", s) + } + + d, err := time.ParseDuration(vaultTokenCreateTTL) + if err != nil { + t.Fatalf("bad: %v", err) + } + + if s.WrapInfo.WrappedAccessor == "" { + t.Fatalf("Bad accessor: %v", s.WrapInfo.WrappedAccessor) + } else if s.WrapInfo.Token == "" { + t.Fatalf("Bad token: %v", s.WrapInfo.WrappedAccessor) + } else if s.WrapInfo.TTL != int(d.Seconds()) { + t.Fatalf("Bad ttl: %v", s.WrapInfo.WrappedAccessor) + } +} + +func TestVaultClient_CreateToken_Role(t *testing.T) { + v := testutil.NewTestVault(t).Start() + defer v.Stop() + + // Set the configs token in a new test role + v.Config.Token = testVaultRoleAndToken(v, t, 5) + //testVaultRoleAndToken(v, t, 5) + // Start the client + logger := log.New(os.Stderr, "", log.LstdFlags) + client, err := NewVaultClient(v.Config, logger) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + defer client.Stop() + + waitForConnection(client, t) + + // Create an allocation that requires a Vault policy + a := mock.Alloc() + task := a.Job.TaskGroups[0].Tasks[0] + task.Vault = &structs.Vault{Policies: []string{"default"}} + + s, err := client.CreateToken(context.Background(), a, task.Name) + if err != nil { + t.Fatalf("CreateToken failed: %v", err) + } + + // Ensure that created secret is a wrapped token + if s == nil || s.WrapInfo == nil { + t.Fatalf("Bad secret: %#v", s) + } + + d, err := time.ParseDuration(vaultTokenCreateTTL) + if err != nil { + t.Fatalf("bad: %v", err) + } + + if s.WrapInfo.WrappedAccessor == "" { + t.Fatalf("Bad accessor: %v", s.WrapInfo.WrappedAccessor) + } else if s.WrapInfo.Token == "" { + t.Fatalf("Bad token: %v", s.WrapInfo.WrappedAccessor) + } else if s.WrapInfo.TTL != int(d.Seconds()) { + t.Fatalf("Bad ttl: %v", s.WrapInfo.WrappedAccessor) + } +} diff --git a/nomad/vault_testing.go b/nomad/vault_testing.go index f8558ea0c..73e5efc34 100644 --- a/nomad/vault_testing.go +++ b/nomad/vault_testing.go @@ -1,6 +1,8 @@ package nomad import ( + "context" + "github.com/hashicorp/nomad/nomad/structs" vapi "github.com/hashicorp/vault/api" ) @@ -16,9 +18,17 @@ type TestVaultClient struct { // LookupTokenSecret maps a token to the Vault secret that will be returned // by the LookupToken call LookupTokenSecret map[string]*vapi.Secret + + // CreateTokenErrors maps a token to an error that will be returned by the + // CreateToken call + CreateTokenErrors map[string]map[string]error + + // CreateTokenSecret maps a token to the Vault secret that will be returned + // by the CreateToken call + CreateTokenSecret map[string]map[string]*vapi.Secret } -func (v *TestVaultClient) LookupToken(token string) (*vapi.Secret, error) { +func (v *TestVaultClient) LookupToken(ctx context.Context, token string) (*vapi.Secret, error) { var secret *vapi.Secret var err error @@ -64,8 +74,56 @@ func (v *TestVaultClient) SetLookupTokenAllowedPolicies(token string, policies [ v.SetLookupTokenSecret(token, s) } -func (v *TestVaultClient) CreateToken(a *structs.Allocation, task string) (*vapi.Secret, error) { - return nil, nil +func (v *TestVaultClient) CreateToken(ctx context.Context, a *structs.Allocation, task string) (*vapi.Secret, error) { + var secret *vapi.Secret + var err error + + if v.CreateTokenSecret != nil { + tasks := v.CreateTokenSecret[a.ID] + if tasks != nil { + secret = tasks[task] + } + } + if v.CreateTokenErrors != nil { + tasks := v.CreateTokenErrors[a.ID] + if tasks != nil { + err = tasks[task] + } + } + + return secret, err +} + +// SetCreateTokenError sets the error that will be returned by the token +// creation +func (v *TestVaultClient) SetCreateTokenError(allocID, task string, err error) { + if v.CreateTokenErrors == nil { + v.CreateTokenErrors = make(map[string]map[string]error) + } + + tasks := v.CreateTokenErrors[allocID] + if tasks == nil { + tasks = make(map[string]error) + v.CreateTokenErrors[allocID] = tasks + } + + v.CreateTokenErrors[allocID][task] = err +} + +// SetCreateTokenSecret sets the secret that will be returned by the token +// creation +func (v *TestVaultClient) SetCreateTokenSecret(allocID, task string, secret *vapi.Secret) { + if v.CreateTokenSecret == nil { + v.CreateTokenSecret = make(map[string]map[string]*vapi.Secret) + } + + tasks := v.CreateTokenSecret[allocID] + if tasks == nil { + tasks = make(map[string]*vapi.Secret) + v.CreateTokenSecret[allocID] = tasks + } + + v.CreateTokenSecret[allocID][task] = secret } func (v *TestVaultClient) Stop() {} diff --git a/scripts/test.sh b/scripts/test.sh index 033876ddf..decf962a6 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -1,4 +1,5 @@ #!/usr/bin/env bash +set -e # Create a temp dir and clean it up on exit TEMPDIR=`mktemp -d -t nomad-test.XXX` diff --git a/scripts/travis.sh b/scripts/travis.sh index 05d368001..4bb85a1f2 100755 --- a/scripts/travis.sh +++ b/scripts/travis.sh @@ -1,4 +1,5 @@ #!/usr/bin/env bash +set -e export PING_SLEEP=30 bash -c "while true; do echo \$(date) - building ...; sleep $PING_SLEEP; done" & diff --git a/testutil/vault.go b/testutil/vault.go index bb560cecd..1f449c73c 100644 --- a/testutil/vault.go +++ b/testutil/vault.go @@ -119,6 +119,6 @@ func (tv *TestVault) waitForAPI() { // getPort returns the next available port to bind Vault against func getPort() uint64 { p := vaultStartPort + vaultPortOffset - offset += 1 + vaultPortOffset += 1 return p } diff --git a/vendor/github.com/hashicorp/vault/api/auth_token.go b/vendor/github.com/hashicorp/vault/api/auth_token.go index 2dae4df62..1901ea110 100644 --- a/vendor/github.com/hashicorp/vault/api/auth_token.go +++ b/vendor/github.com/hashicorp/vault/api/auth_token.go @@ -170,6 +170,7 @@ type TokenCreateRequest struct { Lease string `json:"lease,omitempty"` TTL string `json:"ttl,omitempty"` ExplicitMaxTTL string `json:"explicit_max_ttl,omitempty"` + Period string `json:"period,omitempty"` NoParent bool `json:"no_parent,omitempty"` NoDefaultPolicy bool `json:"no_default_policy,omitempty"` DisplayName string `json:"display_name"` diff --git a/vendor/github.com/hashicorp/vault/api/sys_audit.go b/vendor/github.com/hashicorp/vault/api/sys_audit.go index b6fed6af9..1ffdef880 100644 --- a/vendor/github.com/hashicorp/vault/api/sys_audit.go +++ b/vendor/github.com/hashicorp/vault/api/sys_audit.go @@ -22,21 +22,12 @@ func (c *Sys) AuditHash(path string, input string) (string, error) { } defer resp.Body.Close() - secret, err := ParseSecret(resp.Body) - if err != nil { - return "", err - } - - if secret == nil || secret.Data == nil || len(secret.Data) == 0 { - return "", nil - } - type d struct { - Hash string + Hash string `json:"hash"` } var result d - err = mapstructure.Decode(secret.Data, &result) + err = resp.DecodeJSON(&result) if err != nil { return "", err } @@ -52,26 +43,32 @@ func (c *Sys) ListAudit() (map[string]*Audit, error) { } defer resp.Body.Close() - secret, err := ParseSecret(resp.Body) + var result map[string]interface{} + err = resp.DecodeJSON(&result) if err != nil { return nil, err } - if secret == nil || secret.Data == nil || len(secret.Data) == 0 { - return nil, nil - } - - result := map[string]*Audit{} - for k, v := range secret.Data { + mounts := map[string]*Audit{} + for k, v := range result { + switch v.(type) { + case map[string]interface{}: + default: + continue + } var res Audit err = mapstructure.Decode(v, &res) if err != nil { return nil, err } - result[k] = &res + // Not a mount, some other api.Secret data + if res.Type == "" { + continue + } + mounts[k] = &res } - return result, err + return mounts, nil } func (c *Sys) EnableAudit( @@ -106,7 +103,7 @@ func (c *Sys) DisableAudit(path string) error { } // Structures for the requests/resposne are all down here. They aren't -// individually documentd because the map almost directly to the raw HTTP API +// individually documented because the map almost directly to the raw HTTP API // documentation. Please refer to that documentation for more details. type Audit struct { diff --git a/vendor/github.com/hashicorp/vault/api/sys_auth.go b/vendor/github.com/hashicorp/vault/api/sys_auth.go index 743b8e6d8..1940e8417 100644 --- a/vendor/github.com/hashicorp/vault/api/sys_auth.go +++ b/vendor/github.com/hashicorp/vault/api/sys_auth.go @@ -14,26 +14,32 @@ func (c *Sys) ListAuth() (map[string]*AuthMount, error) { } defer resp.Body.Close() - secret, err := ParseSecret(resp.Body) + var result map[string]interface{} + err = resp.DecodeJSON(&result) if err != nil { return nil, err } - if secret == nil || secret.Data == nil || len(secret.Data) == 0 { - return nil, nil - } - - result := map[string]*AuthMount{} - for k, v := range secret.Data { + mounts := map[string]*AuthMount{} + for k, v := range result { + switch v.(type) { + case map[string]interface{}: + default: + continue + } var res AuthMount err = mapstructure.Decode(v, &res) if err != nil { return nil, err } - result[k] = &res + // Not a mount, some other api.Secret data + if res.Type == "" { + continue + } + mounts[k] = &res } - return result, err + return mounts, nil } func (c *Sys) EnableAuth(path, authType, desc string) error { diff --git a/vendor/github.com/hashicorp/vault/api/sys_capabilities.go b/vendor/github.com/hashicorp/vault/api/sys_capabilities.go index 6d501a495..80f621884 100644 --- a/vendor/github.com/hashicorp/vault/api/sys_capabilities.go +++ b/vendor/github.com/hashicorp/vault/api/sys_capabilities.go @@ -28,17 +28,14 @@ func (c *Sys) Capabilities(token, path string) ([]string, error) { } defer resp.Body.Close() - secret, err := ParseSecret(resp.Body) + var result map[string]interface{} + err = resp.DecodeJSON(&result) if err != nil { return nil, err } - if secret == nil || secret.Data == nil || len(secret.Data) == 0 { - return nil, nil - } - var capabilities []string - capabilitiesRaw := secret.Data["capabilities"].([]interface{}) + capabilitiesRaw := result["capabilities"].([]interface{}) for _, capability := range capabilitiesRaw { capabilities = append(capabilities, capability.(string)) } diff --git a/vendor/github.com/hashicorp/vault/api/sys_init.go b/vendor/github.com/hashicorp/vault/api/sys_init.go index 37c2bcc8c..d307f732b 100644 --- a/vendor/github.com/hashicorp/vault/api/sys_init.go +++ b/vendor/github.com/hashicorp/vault/api/sys_init.go @@ -45,7 +45,9 @@ type InitStatusResponse struct { } type InitResponse struct { - Keys []string `json:"keys"` - RecoveryKeys []string `json:"recovery_keys"` - RootToken string `json:"root_token"` + Keys []string `json:"keys"` + KeysB64 []string `json:"keys_base64"` + RecoveryKeys []string `json:"recovery_keys"` + RecoveryKeysB64 []string `json:"recovery_keys_base64"` + RootToken string `json:"root_token"` } diff --git a/vendor/github.com/hashicorp/vault/api/sys_mounts.go b/vendor/github.com/hashicorp/vault/api/sys_mounts.go index 504e5711b..ca5e42707 100644 --- a/vendor/github.com/hashicorp/vault/api/sys_mounts.go +++ b/vendor/github.com/hashicorp/vault/api/sys_mounts.go @@ -15,26 +15,32 @@ func (c *Sys) ListMounts() (map[string]*MountOutput, error) { } defer resp.Body.Close() - secret, err := ParseSecret(resp.Body) + var result map[string]interface{} + err = resp.DecodeJSON(&result) if err != nil { return nil, err } - if secret == nil || secret.Data == nil || len(secret.Data) == 0 { - return nil, nil - } - - result := map[string]*MountOutput{} - for k, v := range secret.Data { + mounts := map[string]*MountOutput{} + for k, v := range result { + switch v.(type) { + case map[string]interface{}: + default: + continue + } var res MountOutput err = mapstructure.Decode(v, &res) if err != nil { return nil, err } - result[k] = &res + // Not a mount, some other api.Secret data + if res.Type == "" { + continue + } + mounts[k] = &res } - return result, nil + return mounts, nil } func (c *Sys) Mount(path string, mountInfo *MountInput) error { @@ -104,17 +110,8 @@ func (c *Sys) MountConfig(path string) (*MountConfigOutput, error) { } defer resp.Body.Close() - secret, err := ParseSecret(resp.Body) - if err != nil { - return nil, err - } - - if secret == nil || secret.Data == nil || len(secret.Data) == 0 { - return nil, nil - } - var result MountConfigOutput - err = mapstructure.Decode(secret.Data, &result) + err = resp.DecodeJSON(&result) if err != nil { return nil, err } diff --git a/vendor/github.com/hashicorp/vault/api/sys_policy.go b/vendor/github.com/hashicorp/vault/api/sys_policy.go index 35e18b388..ba0e17fab 100644 --- a/vendor/github.com/hashicorp/vault/api/sys_policy.go +++ b/vendor/github.com/hashicorp/vault/api/sys_policy.go @@ -1,10 +1,6 @@ package api -import ( - "fmt" - - "github.com/mitchellh/mapstructure" -) +import "fmt" func (c *Sys) ListPolicies() ([]string, error) { r := c.c.NewRequest("GET", "/v1/sys/policy") @@ -14,22 +10,25 @@ func (c *Sys) ListPolicies() ([]string, error) { } defer resp.Body.Close() - secret, err := ParseSecret(resp.Body) + var result map[string]interface{} + err = resp.DecodeJSON(&result) if err != nil { return nil, err } - if secret == nil || secret.Data == nil || len(secret.Data) == 0 { - return nil, nil + var ok bool + if _, ok = result["policies"]; !ok { + return nil, fmt.Errorf("policies not found in response") } - var result listPoliciesResp - err = mapstructure.Decode(secret.Data, &result) - if err != nil { - return nil, err + listRaw := result["policies"].([]interface{}) + var policies []string + + for _, val := range listRaw { + policies = append(policies, val.(string)) } - return result.Policies, err + return policies, err } func (c *Sys) GetPolicy(name string) (string, error) { @@ -45,22 +44,18 @@ func (c *Sys) GetPolicy(name string) (string, error) { return "", err } - secret, err := ParseSecret(resp.Body) + var result map[string]interface{} + err = resp.DecodeJSON(&result) if err != nil { return "", err } - if secret == nil || secret.Data == nil || len(secret.Data) == 0 { - return "", nil + var ok bool + if _, ok = result["rules"]; !ok { + return "", fmt.Errorf("rules not found in response") } - var result getPoliciesResp - err = mapstructure.Decode(secret.Data, &result) - if err != nil { - return "", err - } - - return result.Rules, err + return result["rules"].(string), nil } func (c *Sys) PutPolicy(name, rules string) error { diff --git a/vendor/github.com/hashicorp/vault/api/sys_rekey.go b/vendor/github.com/hashicorp/vault/api/sys_rekey.go index 4fbfbb9fc..e6d039e27 100644 --- a/vendor/github.com/hashicorp/vault/api/sys_rekey.go +++ b/vendor/github.com/hashicorp/vault/api/sys_rekey.go @@ -190,11 +190,13 @@ type RekeyUpdateResponse struct { Nonce string Complete bool Keys []string + KeysB64 []string `json:"keys_base64"` PGPFingerprints []string `json:"pgp_fingerprints"` Backup bool } type RekeyRetrieveResponse struct { - Nonce string - Keys map[string][]string + Nonce string + Keys map[string][]string + KeysB64 map[string][]string `json:"keys_base64"` } diff --git a/vendor/github.com/hashicorp/vault/api/sys_rotate.go b/vendor/github.com/hashicorp/vault/api/sys_rotate.go index 2a78b4691..8108dced8 100644 --- a/vendor/github.com/hashicorp/vault/api/sys_rotate.go +++ b/vendor/github.com/hashicorp/vault/api/sys_rotate.go @@ -1,10 +1,6 @@ package api -import ( - "time" - - "github.com/mitchellh/mapstructure" -) +import "time" func (c *Sys) Rotate() error { r := c.c.NewRequest("POST", "/v1/sys/rotate") @@ -23,25 +19,12 @@ func (c *Sys) KeyStatus() (*KeyStatus, error) { } defer resp.Body.Close() - secret, err := ParseSecret(resp.Body) - if err != nil { - return nil, err - } - - if secret == nil || secret.Data == nil || len(secret.Data) == 0 { - return nil, nil - } - - var result KeyStatus - err = mapstructure.Decode(secret.Data, &result) - if err != nil { - return nil, err - } - - return &result, err + result := new(KeyStatus) + err = resp.DecodeJSON(result) + return result, err } type KeyStatus struct { - Term int + Term int `json:"term"` InstallTime time.Time `json:"install_time"` } diff --git a/vendor/golang.org/x/sync/LICENSE b/vendor/golang.org/x/sync/LICENSE new file mode 100644 index 000000000..6a66aea5e --- /dev/null +++ b/vendor/golang.org/x/sync/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/golang.org/x/sync/PATENTS b/vendor/golang.org/x/sync/PATENTS new file mode 100644 index 000000000..733099041 --- /dev/null +++ b/vendor/golang.org/x/sync/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/vendor/golang.org/x/sync/errgroup/errgroup.go b/vendor/golang.org/x/sync/errgroup/errgroup.go new file mode 100644 index 000000000..533438d91 --- /dev/null +++ b/vendor/golang.org/x/sync/errgroup/errgroup.go @@ -0,0 +1,67 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package errgroup provides synchronization, error propagation, and Context +// cancelation for groups of goroutines working on subtasks of a common task. +package errgroup + +import ( + "sync" + + "golang.org/x/net/context" +) + +// A Group is a collection of goroutines working on subtasks that are part of +// the same overall task. +// +// A zero Group is valid and does not cancel on error. +type Group struct { + cancel func() + + wg sync.WaitGroup + + errOnce sync.Once + err error +} + +// WithContext returns a new Group and an associated Context derived from ctx. +// +// The derived Context is canceled the first time a function passed to Go +// returns a non-nil error or the first time Wait returns, whichever occurs +// first. +func WithContext(ctx context.Context) (*Group, context.Context) { + ctx, cancel := context.WithCancel(ctx) + return &Group{cancel: cancel}, ctx +} + +// Wait blocks until all function calls from the Go method have returned, then +// returns the first non-nil error (if any) from them. +func (g *Group) Wait() error { + g.wg.Wait() + if g.cancel != nil { + g.cancel() + } + return g.err +} + +// Go calls the given function in a new goroutine. +// +// The first call to return a non-nil error cancels the group; its error will be +// returned by Wait. +func (g *Group) Go(f func() error) { + g.wg.Add(1) + + go func() { + defer g.wg.Done() + + if err := f(); err != nil { + g.errOnce.Do(func() { + g.err = err + if g.cancel != nil { + g.cancel() + } + }) + } + }() +} diff --git a/vendor/golang.org/x/time/LICENSE b/vendor/golang.org/x/time/LICENSE new file mode 100644 index 000000000..6a66aea5e --- /dev/null +++ b/vendor/golang.org/x/time/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/golang.org/x/time/PATENTS b/vendor/golang.org/x/time/PATENTS new file mode 100644 index 000000000..733099041 --- /dev/null +++ b/vendor/golang.org/x/time/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/vendor/golang.org/x/time/rate/rate.go b/vendor/golang.org/x/time/rate/rate.go new file mode 100644 index 000000000..2131b9217 --- /dev/null +++ b/vendor/golang.org/x/time/rate/rate.go @@ -0,0 +1,368 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package rate provides a rate limiter. +package rate + +import ( + "fmt" + "math" + "sync" + "time" + + "golang.org/x/net/context" +) + +// Limit defines the maximum frequency of some events. +// Limit is represented as number of events per second. +// A zero Limit allows no events. +type Limit float64 + +// Inf is the infinite rate limit; it allows all events (even if burst is zero). +const Inf = Limit(math.MaxFloat64) + +// Every converts a minimum time interval between events to a Limit. +func Every(interval time.Duration) Limit { + if interval <= 0 { + return Inf + } + return 1 / Limit(interval.Seconds()) +} + +// A Limiter controls how frequently events are allowed to happen. +// It implements a "token bucket" of size b, initially full and refilled +// at rate r tokens per second. +// Informally, in any large enough time interval, the Limiter limits the +// rate to r tokens per second, with a maximum burst size of b events. +// As a special case, if r == Inf (the infinite rate), b is ignored. +// See https://en.wikipedia.org/wiki/Token_bucket for more about token buckets. +// +// The zero value is a valid Limiter, but it will reject all events. +// Use NewLimiter to create non-zero Limiters. +// +// Limiter has three main methods, Allow, Reserve, and Wait. +// Most callers should use Wait. +// +// Each of the three methods consumes a single token. +// They differ in their behavior when no token is available. +// If no token is available, Allow returns false. +// If no token is available, Reserve returns a reservation for a future token +// and the amount of time the caller must wait before using it. +// If no token is available, Wait blocks until one can be obtained +// or its associated context.Context is canceled. +// +// The methods AllowN, ReserveN, and WaitN consume n tokens. +type Limiter struct { + limit Limit + burst int + + mu sync.Mutex + tokens float64 + // last is the last time the limiter's tokens field was updated + last time.Time + // lastEvent is the latest time of a rate-limited event (past or future) + lastEvent time.Time +} + +// Limit returns the maximum overall event rate. +func (lim *Limiter) Limit() Limit { + lim.mu.Lock() + defer lim.mu.Unlock() + return lim.limit +} + +// Burst returns the maximum burst size. Burst is the maximum number of tokens +// that can be consumed in a single call to Allow, Reserve, or Wait, so higher +// Burst values allow more events to happen at once. +// A zero Burst allows no events, unless limit == Inf. +func (lim *Limiter) Burst() int { + return lim.burst +} + +// NewLimiter returns a new Limiter that allows events up to rate r and permits +// bursts of at most b tokens. +func NewLimiter(r Limit, b int) *Limiter { + return &Limiter{ + limit: r, + burst: b, + } +} + +// Allow is shorthand for AllowN(time.Now(), 1). +func (lim *Limiter) Allow() bool { + return lim.AllowN(time.Now(), 1) +} + +// AllowN reports whether n events may happen at time now. +// Use this method if you intend to drop / skip events that exceed the rate limit. +// Otherwise use Reserve or Wait. +func (lim *Limiter) AllowN(now time.Time, n int) bool { + return lim.reserveN(now, n, 0).ok +} + +// A Reservation holds information about events that are permitted by a Limiter to happen after a delay. +// A Reservation may be canceled, which may enable the Limiter to permit additional events. +type Reservation struct { + ok bool + lim *Limiter + tokens int + timeToAct time.Time + // This is the Limit at reservation time, it can change later. + limit Limit +} + +// OK returns whether the limiter can provide the requested number of tokens +// within the maximum wait time. If OK is false, Delay returns InfDuration, and +// Cancel does nothing. +func (r *Reservation) OK() bool { + return r.ok +} + +// Delay is shorthand for DelayFrom(time.Now()). +func (r *Reservation) Delay() time.Duration { + return r.DelayFrom(time.Now()) +} + +// InfDuration is the duration returned by Delay when a Reservation is not OK. +const InfDuration = time.Duration(1<<63 - 1) + +// DelayFrom returns the duration for which the reservation holder must wait +// before taking the reserved action. Zero duration means act immediately. +// InfDuration means the limiter cannot grant the tokens requested in this +// Reservation within the maximum wait time. +func (r *Reservation) DelayFrom(now time.Time) time.Duration { + if !r.ok { + return InfDuration + } + delay := r.timeToAct.Sub(now) + if delay < 0 { + return 0 + } + return delay +} + +// Cancel is shorthand for CancelAt(time.Now()). +func (r *Reservation) Cancel() { + r.CancelAt(time.Now()) + return +} + +// CancelAt indicates that the reservation holder will not perform the reserved action +// and reverses the effects of this Reservation on the rate limit as much as possible, +// considering that other reservations may have already been made. +func (r *Reservation) CancelAt(now time.Time) { + if !r.ok { + return + } + + r.lim.mu.Lock() + defer r.lim.mu.Unlock() + + if r.lim.limit == Inf || r.tokens == 0 || r.timeToAct.Before(now) { + return + } + + // calculate tokens to restore + // The duration between lim.lastEvent and r.timeToAct tells us how many tokens were reserved + // after r was obtained. These tokens should not be restored. + restoreTokens := float64(r.tokens) - r.limit.tokensFromDuration(r.lim.lastEvent.Sub(r.timeToAct)) + if restoreTokens <= 0 { + return + } + // advance time to now + now, _, tokens := r.lim.advance(now) + // calculate new number of tokens + tokens += restoreTokens + if burst := float64(r.lim.burst); tokens > burst { + tokens = burst + } + // update state + r.lim.last = now + r.lim.tokens = tokens + if r.timeToAct == r.lim.lastEvent { + prevEvent := r.timeToAct.Add(r.limit.durationFromTokens(float64(-r.tokens))) + if !prevEvent.Before(now) { + r.lim.lastEvent = prevEvent + } + } + + return +} + +// Reserve is shorthand for ReserveN(time.Now(), 1). +func (lim *Limiter) Reserve() *Reservation { + return lim.ReserveN(time.Now(), 1) +} + +// ReserveN returns a Reservation that indicates how long the caller must wait before n events happen. +// The Limiter takes this Reservation into account when allowing future events. +// ReserveN returns false if n exceeds the Limiter's burst size. +// Usage example: +// r, ok := lim.ReserveN(time.Now(), 1) +// if !ok { +// // Not allowed to act! Did you remember to set lim.burst to be > 0 ? +// } +// time.Sleep(r.Delay()) +// Act() +// Use this method if you wish to wait and slow down in accordance with the rate limit without dropping events. +// If you need to respect a deadline or cancel the delay, use Wait instead. +// To drop or skip events exceeding rate limit, use Allow instead. +func (lim *Limiter) ReserveN(now time.Time, n int) *Reservation { + r := lim.reserveN(now, n, InfDuration) + return &r +} + +// Wait is shorthand for WaitN(ctx, 1). +func (lim *Limiter) Wait(ctx context.Context) (err error) { + return lim.WaitN(ctx, 1) +} + +// WaitN blocks until lim permits n events to happen. +// It returns an error if n exceeds the Limiter's burst size, the Context is +// canceled, or the expected wait time exceeds the Context's Deadline. +func (lim *Limiter) WaitN(ctx context.Context, n int) (err error) { + if n > lim.burst { + return fmt.Errorf("rate: Wait(n=%d) exceeds limiter's burst %d", n, lim.burst) + } + // Check if ctx is already cancelled + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + // Determine wait limit + now := time.Now() + waitLimit := InfDuration + if deadline, ok := ctx.Deadline(); ok { + waitLimit = deadline.Sub(now) + } + // Reserve + r := lim.reserveN(now, n, waitLimit) + if !r.ok { + return fmt.Errorf("rate: Wait(n=%d) would exceed context deadline", n) + } + // Wait + t := time.NewTimer(r.DelayFrom(now)) + defer t.Stop() + select { + case <-t.C: + // We can proceed. + return nil + case <-ctx.Done(): + // Context was canceled before we could proceed. Cancel the + // reservation, which may permit other events to proceed sooner. + r.Cancel() + return ctx.Err() + } +} + +// SetLimit is shorthand for SetLimitAt(time.Now(), newLimit). +func (lim *Limiter) SetLimit(newLimit Limit) { + lim.SetLimitAt(time.Now(), newLimit) +} + +// SetLimitAt sets a new Limit for the limiter. The new Limit, and Burst, may be violated +// or underutilized by those which reserved (using Reserve or Wait) but did not yet act +// before SetLimitAt was called. +func (lim *Limiter) SetLimitAt(now time.Time, newLimit Limit) { + lim.mu.Lock() + defer lim.mu.Unlock() + + now, _, tokens := lim.advance(now) + + lim.last = now + lim.tokens = tokens + lim.limit = newLimit +} + +// reserveN is a helper method for AllowN, ReserveN, and WaitN. +// maxFutureReserve specifies the maximum reservation wait duration allowed. +// reserveN returns Reservation, not *Reservation, to avoid allocation in AllowN and WaitN. +func (lim *Limiter) reserveN(now time.Time, n int, maxFutureReserve time.Duration) Reservation { + lim.mu.Lock() + defer lim.mu.Unlock() + + if lim.limit == Inf { + return Reservation{ + ok: true, + lim: lim, + tokens: n, + timeToAct: now, + } + } + + now, last, tokens := lim.advance(now) + + // Calculate the remaining number of tokens resulting from the request. + tokens -= float64(n) + + // Calculate the wait duration + var waitDuration time.Duration + if tokens < 0 { + waitDuration = lim.limit.durationFromTokens(-tokens) + } + + // Decide result + ok := n <= lim.burst && waitDuration <= maxFutureReserve + + // Prepare reservation + r := Reservation{ + ok: ok, + lim: lim, + limit: lim.limit, + } + if ok { + r.tokens = n + r.timeToAct = now.Add(waitDuration) + } + + // Update state + if ok { + lim.last = now + lim.tokens = tokens + lim.lastEvent = r.timeToAct + } else { + lim.last = last + } + + return r +} + +// advance calculates and returns an updated state for lim resulting from the passage of time. +// lim is not changed. +func (lim *Limiter) advance(now time.Time) (newNow time.Time, newLast time.Time, newTokens float64) { + last := lim.last + if now.Before(last) { + last = now + } + + // Avoid making delta overflow below when last is very old. + maxElapsed := lim.limit.durationFromTokens(float64(lim.burst) - lim.tokens) + elapsed := now.Sub(last) + if elapsed > maxElapsed { + elapsed = maxElapsed + } + + // Calculate the new number of tokens, due to time that passed. + delta := lim.limit.tokensFromDuration(elapsed) + tokens := lim.tokens + delta + if burst := float64(lim.burst); tokens > burst { + tokens = burst + } + + return now, last, tokens +} + +// durationFromTokens is a unit conversion function from the number of tokens to the duration +// of time it takes to accumulate them at a rate of limit tokens per second. +func (limit Limit) durationFromTokens(tokens float64) time.Duration { + seconds := tokens / float64(limit) + return time.Nanosecond * time.Duration(1e9*seconds) +} + +// tokensFromDuration is a unit conversion function from a time duration to the number of tokens +// which could be accumulated during that duration at a rate of limit tokens per second. +func (limit Limit) tokensFromDuration(d time.Duration) float64 { + return d.Seconds() * float64(limit) +} diff --git a/vendor/vendor.json b/vendor/vendor.json index cceb23094..6e0dcba03 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -623,10 +623,16 @@ "revisionTime": "2016-06-09T00:18:40Z" }, { - "checksumSHA1": "0rkVtm9F1/pW9EGhHYJpCnY99O8=", + "checksumSHA1": "RAJfRxZ8UmcL6+7VuXAZxBlnM/4=", + "path": "github.com/hashicorp/vault", + "revision": "fece3ca069fc5bafec5280bbcb0c0693ff69fdaf", + "revisionTime": "2016-08-17T21:47:06Z" + }, + { + "checksumSHA1": "JH8wmQ8cWdn7mYu1T7gJ3IMIrec=", "path": "github.com/hashicorp/vault/api", - "revision": "fbecd94926e289d3b81d8dae6136452a6c4c93f6", - "revisionTime": "2016-08-13T15:54:01Z" + "revision": "fece3ca069fc5bafec5280bbcb0c0693ff69fdaf", + "revisionTime": "2016-08-17T21:47:06Z" }, { "checksumSHA1": "5lR6EdY0ARRdKAq3hZcL38STD8Q=", @@ -827,6 +833,12 @@ "revision": "30db96677b74e24b967e23f911eb3364fc61a011", "revisionTime": "2016-05-25T13:11:03Z" }, + { + "checksumSHA1": "S0DP7Pn7sZUmXc55IzZnNvERu6s=", + "path": "golang.org/x/sync/errgroup", + "revision": "316e794f7b5e3df4e95175a45a5fb8b12f85cb4f", + "revisionTime": "2016-07-15T18:54:39Z" + }, { "path": "golang.org/x/sys/unix", "revision": "50c6bc5e4292a1d4e65c6e9be5f53be28bcbe28e" @@ -837,6 +849,12 @@ "revision": "b776ec39b3e54652e09028aaaaac9757f4f8211a", "revisionTime": "2016-04-21T02:29:30Z" }, + { + "checksumSHA1": "h/06ikMECfJoTkWj2e1nJ30aDDg=", + "path": "golang.org/x/time/rate", + "revision": "a4bde12657593d5e90d0533a3e4fd95e635124cb", + "revisionTime": "2016-02-02T18:34:45Z" + }, { "checksumSHA1": "93uHIq25lffEKY47PV8dBPD+XuQ=", "path": "gopkg.in/fsnotify.v1",