Merge pull request #1629 from hashicorp/f-derive-token

Server Deriving Tokens on behalf of Clients
This commit is contained in:
Alex Dadgar
2016-08-23 13:58:47 -07:00
committed by GitHub
34 changed files with 1892 additions and 167 deletions

View File

@@ -35,6 +35,7 @@ const (
TimeTableSnapshot
PeriodicLaunchSnapshot
JobSummarySnapshot
VaultAccessorSnapshot
)
// nomadFSM implements a finite state machine that is used
@@ -137,6 +138,8 @@ func (n *nomadFSM) Apply(log *raft.Log) interface{} {
return n.applyAllocClientUpdate(buf[1:], log.Index)
case structs.ReconcileJobSummariesRequestType:
return n.applyReconcileSummaries(buf[1:], log.Index)
case structs.VaultAccessorRegisterRequestType:
return n.applyUpsertVaultAccessor(buf[1:], log.Index)
default:
if ignoreUnknown {
n.logger.Printf("[WARN] nomad.fsm: ignoring unknown message type (%d), upgrade to newer version", msgType)
@@ -454,6 +457,23 @@ func (n *nomadFSM) applyReconcileSummaries(buf []byte, index uint64) interface{}
return n.reconcileQueuedAllocations(index)
}
// applyUpsertVaultAccessor stores the Vault accessors for a given allocation
// and task
func (n *nomadFSM) applyUpsertVaultAccessor(buf []byte, index uint64) interface{} {
defer metrics.MeasureSince([]string{"nomad", "fsm", "upsert_vault_accessor"}, time.Now())
var req structs.VaultAccessorRegisterRequest
if err := structs.Decode(buf, &req); err != nil {
panic(fmt.Errorf("failed to decode request: %v", err))
}
if err := n.state.UpsertVaultAccessor(index, req.Accessors); err != nil {
n.logger.Printf("[ERR] nomad.fsm: UpsertVaultAccessor failed: %v", err)
return err
}
return nil
}
func (n *nomadFSM) Snapshot() (raft.FSMSnapshot, error) {
// Create a new snapshot
snap, err := n.state.Snapshot()
@@ -583,6 +603,15 @@ func (n *nomadFSM) Restore(old io.ReadCloser) error {
return err
}
case VaultAccessorSnapshot:
accessor := new(structs.VaultAccessor)
if err := dec.Decode(accessor); err != nil {
return err
}
if err := restore.VaultAccessorRestore(accessor); err != nil {
return err
}
default:
return fmt.Errorf("Unrecognized snapshot type: %v", msgType)
}
@@ -756,6 +785,10 @@ func (s *nomadSnapshot) Persist(sink raft.SnapshotSink) error {
sink.Cancel()
return err
}
if err := s.persistVaultAccessors(sink, encoder); err != nil {
sink.Cancel()
return err
}
return nil
}
@@ -945,6 +978,30 @@ func (s *nomadSnapshot) persistJobSummaries(sink raft.SnapshotSink,
return nil
}
func (s *nomadSnapshot) persistVaultAccessors(sink raft.SnapshotSink,
encoder *codec.Encoder) error {
accessors, err := s.snap.VaultAccessors()
if err != nil {
return err
}
for {
raw := accessors.Next()
if raw == nil {
break
}
accessor := raw.(*structs.VaultAccessor)
sink.Write([]byte{byte(VaultAccessorSnapshot)})
if err := encoder.Encode(accessor); err != nil {
return err
}
}
return nil
}
// Release is a no-op, as we just need to GC the pointer
// to the state store snapshot. There is nothing to explicitly
// cleanup.

View File

@@ -770,6 +770,54 @@ func TestFSM_UpdateAllocFromClient(t *testing.T) {
}
}
func TestFSM_UpsertVaultAccessor(t *testing.T) {
fsm := testFSM(t)
fsm.blockedEvals.SetEnabled(true)
va := mock.VaultAccessor()
va2 := mock.VaultAccessor()
req := structs.VaultAccessorRegisterRequest{
Accessors: []*structs.VaultAccessor{va, va2},
}
buf, err := structs.Encode(structs.VaultAccessorRegisterRequestType, req)
if err != nil {
t.Fatalf("err: %v", err)
}
resp := fsm.Apply(makeLog(buf))
if resp != nil {
t.Fatalf("resp: %v", resp)
}
// Verify we are registered
out1, err := fsm.State().VaultAccessor(va.Accessor)
if err != nil {
t.Fatalf("err: %v", err)
}
if out1 == nil {
t.Fatalf("not found!")
}
if out1.CreateIndex != 1 {
t.Fatalf("bad index: %d", out1.CreateIndex)
}
out2, err := fsm.State().VaultAccessor(va2.Accessor)
if err != nil {
t.Fatalf("err: %v", err)
}
if out2 == nil {
t.Fatalf("not found!")
}
if out1.CreateIndex != 1 {
t.Fatalf("bad index: %d", out2.CreateIndex)
}
tt := fsm.TimeTable()
index := tt.NearestIndex(time.Now().UTC())
if index != 1 {
t.Fatalf("bad: %d", index)
}
}
func testSnapshotRestore(t *testing.T, fsm *nomadFSM) *nomadFSM {
// Snapshot
snap, err := fsm.Snapshot()
@@ -976,6 +1024,27 @@ func TestFSM_SnapshotRestore_JobSummary(t *testing.T) {
}
}
func TestFSM_SnapshotRestore_VaultAccessors(t *testing.T) {
// Add some state
fsm := testFSM(t)
state := fsm.State()
a1 := mock.VaultAccessor()
a2 := mock.VaultAccessor()
state.UpsertVaultAccessor(1000, []*structs.VaultAccessor{a1, a2})
// Verify the contents
fsm2 := testSnapshotRestore(t, fsm)
state2 := fsm2.State()
out1, _ := state2.VaultAccessor(a1.Accessor)
out2, _ := state2.VaultAccessor(a2.Accessor)
if !reflect.DeepEqual(a1, out1) {
t.Fatalf("bad: \n%#v\n%#v", out1, a1)
}
if !reflect.DeepEqual(a2, out2) {
t.Fatalf("bad: \n%#v\n%#v", out2, a2)
}
}
func TestFSM_SnapshotRestore_AddMissingSummary(t *testing.T) {
// Add some state
fsm := testFSM(t)

View File

@@ -1,6 +1,7 @@
package nomad
import (
"context"
"fmt"
"strings"
"time"
@@ -83,7 +84,7 @@ func (j *Job) Register(args *structs.JobRegisterRequest, reply *structs.JobRegis
}
vault := j.srv.vault
s, err := vault.LookupToken(args.Job.VaultToken)
s, err := vault.LookupToken(context.Background(), args.Job.VaultToken)
if err != nil {
return err
}

View File

@@ -290,6 +290,16 @@ func Alloc() *structs.Allocation {
return alloc
}
func VaultAccessor() *structs.VaultAccessor {
return &structs.VaultAccessor{
Accessor: structs.GenerateUUID(),
NodeID: structs.GenerateUUID(),
AllocID: structs.GenerateUUID(),
CreationTTL: 86400,
Task: "foo",
}
}
func Plan() *structs.Plan {
return &structs.Plan{
Priority: 50,

View File

@@ -1,21 +1,29 @@
package nomad
import (
"context"
"fmt"
"strings"
"sync"
"time"
"golang.org/x/sync/errgroup"
"github.com/armon/go-metrics"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/nomad/nomad/state"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/nomad/watch"
vapi "github.com/hashicorp/vault/api"
)
const (
// batchUpdateInterval is how long we wait to batch updates
batchUpdateInterval = 50 * time.Millisecond
// maxParallelRequestsPerDerive is the maximum number of parallel Vault
// create token requests that may be outstanding per derive request
maxParallelRequestsPerDerive = 16
)
// Node endpoint is used for client interactions
@@ -868,3 +876,176 @@ func (b *batchFuture) Respond(index uint64, err error) {
b.err = err
close(b.doneCh)
}
// DeriveVaultToken is used by the clients to request wrapped Vault tokens for
// tasks
func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest,
reply *structs.DeriveVaultTokenResponse) error {
if done, err := n.srv.forward("Node.DeriveVaultToken", args, args, reply); done {
return err
}
defer metrics.MeasureSince([]string{"nomad", "client", "derive_vault_token"}, time.Now())
// Verify the arguments
if args.NodeID == "" {
return fmt.Errorf("missing node ID")
}
if args.SecretID == "" {
return fmt.Errorf("missing node SecretID")
}
if args.AllocID == "" {
return fmt.Errorf("missing allocation ID")
}
if len(args.Tasks) == 0 {
return fmt.Errorf("no tasks specified")
}
// Verify the following:
// * The Node exists and has the correct SecretID
// * The Allocation exists on the specified node
// * The allocation contains the given tasks and they each require Vault
// tokens
snap, err := n.srv.fsm.State().Snapshot()
if err != nil {
return err
}
node, err := snap.NodeByID(args.NodeID)
if err != nil {
return err
}
if node == nil {
return fmt.Errorf("Node %q does not exist", args.NodeID)
}
if node.SecretID != args.SecretID {
return fmt.Errorf("SecretID mismatch")
}
alloc, err := snap.AllocByID(args.AllocID)
if err != nil {
return err
}
if alloc == nil {
return fmt.Errorf("Allocation %q does not exist", args.AllocID)
}
if alloc.NodeID != args.NodeID {
return fmt.Errorf("Allocation %q not running on Node %q", args.AllocID, args.NodeID)
}
if alloc.TerminalStatus() {
return fmt.Errorf("Can't request Vault token for terminal allocation")
}
// Check the policies
policies := alloc.Job.VaultPolicies()
if policies == nil {
return fmt.Errorf("Job doesn't require Vault policies")
}
tg, ok := policies[alloc.TaskGroup]
if !ok {
return fmt.Errorf("Task group does not require Vault policies")
}
var unneeded []string
for _, task := range args.Tasks {
taskVault := tg[task]
if taskVault == nil || len(taskVault.Policies) == 0 {
unneeded = append(unneeded, task)
}
}
if len(unneeded) != 0 {
return fmt.Errorf("Requested Vault tokens for tasks without defined Vault policies: %s",
strings.Join(unneeded, ", "))
}
// At this point the request is valid and we should contact Vault for
// tokens.
// Create an error group where we will spin up a fixed set of goroutines to
// handle deriving tokens but where if any fails the whole group is
// canceled.
g, ctx := errgroup.WithContext(context.Background())
// Cap the handlers
handlers := len(args.Tasks)
if handlers > maxParallelRequestsPerDerive {
handlers = maxParallelRequestsPerDerive
}
// Create the Vault Tokens
input := make(chan string, handlers)
results := make(map[string]*vapi.Secret, len(args.Tasks))
for i := 0; i < handlers; i++ {
g.Go(func() error {
for {
select {
case task, ok := <-input:
if !ok {
return nil
}
secret, err := n.srv.vault.CreateToken(ctx, alloc, task)
if err != nil {
return fmt.Errorf("failed to create token for task %q: %v", task, err)
}
results[task] = secret
case <-ctx.Done():
return nil
}
}
})
}
// Send the input
go func() {
defer close(input)
for _, task := range args.Tasks {
select {
case <-ctx.Done():
return
case input <- task:
}
}
}()
// Wait for everything to complete or for an error
err = g.Wait()
if err != nil {
// TODO Revoke any created token
return err
}
// Commit to Raft before returning any of the tokens
accessors := make([]*structs.VaultAccessor, 0, len(results))
tokens := make(map[string]string, len(results))
for task, secret := range results {
w := secret.WrapInfo
if w == nil {
return fmt.Errorf("Vault returned Secret without WrapInfo")
}
tokens[task] = w.Token
accessor := &structs.VaultAccessor{
Accessor: w.WrappedAccessor,
Task: task,
NodeID: alloc.NodeID,
AllocID: alloc.ID,
CreationTTL: w.TTL,
}
accessors = append(accessors, accessor)
}
req := structs.VaultAccessorRegisterRequest{Accessors: accessors}
_, index, err := n.srv.raftApply(structs.VaultAccessorRegisterRequestType, &req)
if err != nil {
n.srv.logger.Printf("[ERR] nomad.client: Register Vault accessors failed: %v", err)
return err
}
reply.Index = index
reply.Tasks = tokens
n.srv.setQueryMeta(&reply.QueryMeta)
return nil
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/testutil"
vapi "github.com/hashicorp/vault/api"
)
func TestClientEndpoint_Register(t *testing.T) {
@@ -1597,3 +1598,160 @@ func TestBatchFuture(t *testing.T) {
t.Fatalf("bad: %d", bf.Index())
}
}
func TestClientEndpoint_DeriveVaultToken_Bad(t *testing.T) {
s1 := testServer(t, nil)
defer s1.Shutdown()
state := s1.fsm.State()
codec := rpcClient(t, s1)
testutil.WaitForLeader(t, s1.RPC)
// Create the node
node := mock.Node()
if err := state.UpsertNode(2, node); err != nil {
t.Fatalf("err: %v", err)
}
// Create an alloc
alloc := mock.Alloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
tasks := []string{task.Name}
if err := state.UpsertAllocs(3, []*structs.Allocation{alloc}); err != nil {
t.Fatalf("err: %v", err)
}
req := &structs.DeriveVaultTokenRequest{
NodeID: node.ID,
SecretID: structs.GenerateUUID(),
AllocID: alloc.ID,
Tasks: tasks,
QueryOptions: structs.QueryOptions{
Region: "global",
},
}
var resp structs.DeriveVaultTokenResponse
err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp)
if err == nil || !strings.Contains(err.Error(), "SecretID mismatch") {
t.Fatalf("Expected SecretID mismatch: %v", err)
}
// Put the correct SecretID
req.SecretID = node.SecretID
// Now we should get an error about the allocation not running on the node
err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp)
if err == nil || !strings.Contains(err.Error(), "not running on Node") {
t.Fatalf("Expected not running on node error: %v", err)
}
// Update to be running on the node
alloc.NodeID = node.ID
if err := state.UpsertAllocs(4, []*structs.Allocation{alloc}); err != nil {
t.Fatalf("err: %v", err)
}
// Now we should get an error about the job not needing any Vault secrets
err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp)
if err == nil || !strings.Contains(err.Error(), "without defined Vault") {
t.Fatalf("Expected no policies error: %v", err)
}
// Update to be terminal
alloc.DesiredStatus = structs.AllocDesiredStatusStop
if err := state.UpsertAllocs(5, []*structs.Allocation{alloc}); err != nil {
t.Fatalf("err: %v", err)
}
// Now we should get an error about the job not needing any Vault secrets
err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp)
if err == nil || !strings.Contains(err.Error(), "terminal") {
t.Fatalf("Expected terminal allocation error: %v", err)
}
}
func TestClientEndpoint_DeriveVaultToken(t *testing.T) {
s1 := testServer(t, nil)
defer s1.Shutdown()
state := s1.fsm.State()
codec := rpcClient(t, s1)
testutil.WaitForLeader(t, s1.RPC)
// Enable vault and allow authenticated
s1.config.VaultConfig.Enabled = true
s1.config.VaultConfig.AllowUnauthenticated = true
// Replace the Vault Client on the server
tvc := &TestVaultClient{}
s1.vault = tvc
// Create the node
node := mock.Node()
if err := state.UpsertNode(2, node); err != nil {
t.Fatalf("err: %v", err)
}
// Create an alloc an allocation that has vault policies required
alloc := mock.Alloc()
alloc.NodeID = node.ID
task := alloc.Job.TaskGroups[0].Tasks[0]
tasks := []string{task.Name}
task.Vault = &structs.Vault{Policies: []string{"a", "b"}}
if err := state.UpsertAllocs(3, []*structs.Allocation{alloc}); err != nil {
t.Fatalf("err: %v", err)
}
// Return a secret for the task
token := structs.GenerateUUID()
accessor := structs.GenerateUUID()
ttl := 10
secret := &vapi.Secret{
WrapInfo: &vapi.SecretWrapInfo{
Token: token,
WrappedAccessor: accessor,
TTL: ttl,
},
}
tvc.SetCreateTokenSecret(alloc.ID, task.Name, secret)
req := &structs.DeriveVaultTokenRequest{
NodeID: node.ID,
SecretID: node.SecretID,
AllocID: alloc.ID,
Tasks: tasks,
QueryOptions: structs.QueryOptions{
Region: "global",
},
}
var resp structs.DeriveVaultTokenResponse
if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil {
t.Fatalf("bad: %v", err)
}
// Check the state store and ensure that we created a VaultAccessor
va, err := state.VaultAccessor(accessor)
if err != nil {
t.Fatalf("bad: %v", err)
}
if va == nil {
t.Fatalf("bad: %v", va)
}
if va.CreateIndex == 0 {
t.Fatalf("bad: %v", va)
}
va.CreateIndex = 0
expected := &structs.VaultAccessor{
AllocID: alloc.ID,
Task: task.Name,
NodeID: alloc.NodeID,
Accessor: accessor,
CreationTTL: ttl,
}
if !reflect.DeepEqual(expected, va) {
t.Fatalf("Got %#v; want %#v", va, expected)
}
}

View File

@@ -23,6 +23,7 @@ func stateStoreSchema() *memdb.DBSchema {
periodicLaunchTableSchema,
evalTableSchema,
allocTableSchema,
vaultAccessorTableSchema,
}
// Add each of the tables
@@ -291,3 +292,41 @@ func allocTableSchema() *memdb.TableSchema {
},
}
}
// vaultAccessorTableSchema returns the MemDB schema for the Vault Accessor
// Table. This table tracks Vault accessors for tokens created on behalf of
// allocations required Vault tokens.
func vaultAccessorTableSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: "vault_accessors",
Indexes: map[string]*memdb.IndexSchema{
// The primary index is the accessor id
"id": &memdb.IndexSchema{
Name: "id",
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Accessor",
},
},
"alloc_id": &memdb.IndexSchema{
Name: "alloc_id",
AllowMissing: false,
Unique: false,
Indexer: &memdb.StringFieldIndex{
Field: "AllocID",
},
},
"node_id": &memdb.IndexSchema{
Name: "node_id",
AllowMissing: false,
Unique: false,
Indexer: &memdb.StringFieldIndex{
Field: "NodeID",
},
},
},
}
}

View File

@@ -1113,6 +1113,124 @@ func (s *StateStore) Allocs() (memdb.ResultIterator, error) {
return iter, nil
}
// UpsertVaultAccessors is used to register a set of Vault Accessors
func (s *StateStore) UpsertVaultAccessor(index uint64, accessors []*structs.VaultAccessor) error {
txn := s.db.Txn(true)
defer txn.Abort()
for _, accessor := range accessors {
// Set the create index
accessor.CreateIndex = index
// Insert the accessor
if err := txn.Insert("vault_accessors", accessor); err != nil {
return fmt.Errorf("accessor insert failed: %v", err)
}
}
if err := txn.Insert("index", &IndexEntry{"vault_accessors", index}); err != nil {
return fmt.Errorf("index update failed: %v", err)
}
txn.Commit()
return nil
}
// DeleteVaultAccessor is used to delete a Vault Accessor
func (s *StateStore) DeleteVaultAccessor(index uint64, accessor string) error {
txn := s.db.Txn(true)
defer txn.Abort()
// Lookup the accessor
existing, err := txn.First("vault_accessors", "id", accessor)
if err != nil {
return fmt.Errorf("accessor lookup failed: %v", err)
}
if existing == nil {
return fmt.Errorf("vault_accessor not found")
}
// Delete the accessor
if err := txn.Delete("vault_accessors", existing); err != nil {
return fmt.Errorf("accessor delete failed: %v", err)
}
if err := txn.Insert("index", &IndexEntry{"vault_accessors", index}); err != nil {
return fmt.Errorf("index update failed: %v", err)
}
txn.Commit()
return nil
}
// VaultAccessor returns the given Vault accessor
func (s *StateStore) VaultAccessor(accessor string) (*structs.VaultAccessor, error) {
txn := s.db.Txn(false)
existing, err := txn.First("vault_accessors", "id", accessor)
if err != nil {
return nil, fmt.Errorf("accessor lookup failed: %v", err)
}
if existing != nil {
return existing.(*structs.VaultAccessor), nil
}
return nil, nil
}
// VaultAccessors returns an iterator of Vault accessors.
func (s *StateStore) VaultAccessors() (memdb.ResultIterator, error) {
txn := s.db.Txn(false)
iter, err := txn.Get("vault_accessors", "id")
if err != nil {
return nil, err
}
return iter, nil
}
// VaultAccessorsByAlloc returns all the Vault accessors by alloc id
func (s *StateStore) VaultAccessorsByAlloc(allocID string) ([]*structs.VaultAccessor, error) {
txn := s.db.Txn(false)
// Get an iterator over the accessors
iter, err := txn.Get("vault_accessors", "alloc_id", allocID)
if err != nil {
return nil, err
}
var out []*structs.VaultAccessor
for {
raw := iter.Next()
if raw == nil {
break
}
out = append(out, raw.(*structs.VaultAccessor))
}
return out, nil
}
// VaultAccessorsByNode returns all the Vault accessors by node id
func (s *StateStore) VaultAccessorsByNode(nodeID string) ([]*structs.VaultAccessor, error) {
txn := s.db.Txn(false)
// Get an iterator over the accessors
iter, err := txn.Get("vault_accessors", "node_id", nodeID)
if err != nil {
return nil, err
}
var out []*structs.VaultAccessor
for {
raw := iter.Next()
if raw == nil {
break
}
out = append(out, raw.(*structs.VaultAccessor))
}
return out, nil
}
// LastIndex returns the greatest index value for all indexes
func (s *StateStore) LatestIndex() (uint64, error) {
indexes, err := s.Indexes()
@@ -1627,6 +1745,14 @@ func (r *StateRestore) JobSummaryRestore(jobSummary *structs.JobSummary) error {
return nil
}
// VaultAccessorRestore is used to restore a vault accessor
func (r *StateRestore) VaultAccessorRestore(accessor *structs.VaultAccessor) error {
if err := r.txn.Insert("vault_accessors", accessor); err != nil {
return fmt.Errorf("vault accessor insert failed: %v", err)
}
return nil
}
// stateWatch holds shared state for watching updates. This is
// outside of StateStore so it can be shared with snapshots.
type stateWatch struct {

View File

@@ -2833,6 +2833,206 @@ func TestJobSummary_UpdateClientStatus(t *testing.T) {
}
}
func TestStateStore_UpsertVaultAccessors(t *testing.T) {
state := testStateStore(t)
a := mock.VaultAccessor()
a2 := mock.VaultAccessor()
err := state.UpsertVaultAccessor(1000, []*structs.VaultAccessor{a, a2})
if err != nil {
t.Fatalf("err: %v", err)
}
out, err := state.VaultAccessor(a.Accessor)
if err != nil {
t.Fatalf("err: %v", err)
}
if !reflect.DeepEqual(a, out) {
t.Fatalf("bad: %#v %#v", a, out)
}
out, err = state.VaultAccessor(a2.Accessor)
if err != nil {
t.Fatalf("err: %v", err)
}
if !reflect.DeepEqual(a2, out) {
t.Fatalf("bad: %#v %#v", a2, out)
}
iter, err := state.VaultAccessors()
if err != nil {
t.Fatalf("err: %v", err)
}
count := 0
for {
raw := iter.Next()
if raw == nil {
break
}
count++
accessor := raw.(*structs.VaultAccessor)
if !reflect.DeepEqual(accessor, a) && !reflect.DeepEqual(accessor, a2) {
t.Fatalf("bad: %#v", accessor)
}
}
if count != 2 {
t.Fatalf("bad: %d", count)
}
index, err := state.Index("vault_accessors")
if err != nil {
t.Fatalf("err: %v", err)
}
if index != 1000 {
t.Fatalf("bad: %d", index)
}
}
func TestStateStore_DeleteVaultAccessor(t *testing.T) {
state := testStateStore(t)
accessor := mock.VaultAccessor()
err := state.UpsertVaultAccessor(1000, []*structs.VaultAccessor{accessor})
if err != nil {
t.Fatalf("err: %v", err)
}
err = state.DeleteVaultAccessor(1001, accessor.Accessor)
if err != nil {
t.Fatalf("err: %v", err)
}
out, err := state.VaultAccessor(accessor.Accessor)
if err != nil {
t.Fatalf("err: %v", err)
}
if out != nil {
t.Fatalf("bad: %#v %#v", accessor, out)
}
index, err := state.Index("vault_accessors")
if err != nil {
t.Fatalf("err: %v", err)
}
if index != 1001 {
t.Fatalf("bad: %d", index)
}
}
func TestStateStore_VaultAccessorsByAlloc(t *testing.T) {
state := testStateStore(t)
alloc := mock.Alloc()
var accessors []*structs.VaultAccessor
var expected []*structs.VaultAccessor
for i := 0; i < 5; i++ {
accessor := mock.VaultAccessor()
accessor.AllocID = alloc.ID
expected = append(expected, accessor)
accessors = append(accessors, accessor)
}
for i := 0; i < 10; i++ {
accessor := mock.VaultAccessor()
accessors = append(accessors, accessor)
}
err := state.UpsertVaultAccessor(1000, accessors)
if err != nil {
t.Fatalf("err: %v", err)
}
out, err := state.VaultAccessorsByAlloc(alloc.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
if len(expected) != len(out) {
t.Fatalf("bad: %#v %#v", len(expected), len(out))
}
index, err := state.Index("vault_accessors")
if err != nil {
t.Fatalf("err: %v", err)
}
if index != 1000 {
t.Fatalf("bad: %d", index)
}
}
func TestStateStore_VaultAccessorsByNode(t *testing.T) {
state := testStateStore(t)
node := mock.Node()
var accessors []*structs.VaultAccessor
var expected []*structs.VaultAccessor
for i := 0; i < 5; i++ {
accessor := mock.VaultAccessor()
accessor.NodeID = node.ID
expected = append(expected, accessor)
accessors = append(accessors, accessor)
}
for i := 0; i < 10; i++ {
accessor := mock.VaultAccessor()
accessors = append(accessors, accessor)
}
err := state.UpsertVaultAccessor(1000, accessors)
if err != nil {
t.Fatalf("err: %v", err)
}
out, err := state.VaultAccessorsByNode(node.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
if len(expected) != len(out) {
t.Fatalf("bad: %#v %#v", len(expected), len(out))
}
index, err := state.Index("vault_accessors")
if err != nil {
t.Fatalf("err: %v", err)
}
if index != 1000 {
t.Fatalf("bad: %d", index)
}
}
func TestStateStore_RestoreVaultAccessor(t *testing.T) {
state := testStateStore(t)
a := mock.VaultAccessor()
restore, err := state.Restore()
if err != nil {
t.Fatalf("err: %v", err)
}
err = restore.VaultAccessorRestore(a)
if err != nil {
t.Fatalf("err: %v", err)
}
restore.Commit()
out, err := state.VaultAccessor(a.Accessor)
if err != nil {
t.Fatalf("err: %v", err)
}
if !reflect.DeepEqual(out, a) {
t.Fatalf("Bad: %#v %#v", out, a)
}
}
// setupNotifyTest takes a state store and a set of watch items, then creates
// and subscribes a notification channel for each item.
func setupNotifyTest(state *StateStore, items ...watch.Item) notifyTest {

View File

@@ -253,12 +253,12 @@ func SliceStringIsSubset(larger, smaller []string) (bool, []string) {
// VaultPoliciesSet takes the structure returned by VaultPolicies and returns
// the set of required policies
func VaultPoliciesSet(policies map[string]map[string][]string) []string {
func VaultPoliciesSet(policies map[string]map[string]*Vault) []string {
set := make(map[string]struct{})
for _, tgp := range policies {
for _, tp := range tgp {
for _, p := range tp {
for _, p := range tp.Policies {
set[p] = struct{}{}
}
}

View File

@@ -47,6 +47,7 @@ const (
AllocUpdateRequestType
AllocClientUpdateRequestType
ReconcileJobSummariesRequestType
VaultAccessorRegisterRequestType
)
const (
@@ -354,6 +355,41 @@ type PeriodicForceRequest struct {
WriteRequest
}
// DeriveVaultTokenRequest is used to request wrapped Vault tokens for the
// following tasks in the given allocation
type DeriveVaultTokenRequest struct {
NodeID string
SecretID string
AllocID string
Tasks []string
QueryOptions
}
// VaultAccessorRegisterRequest is used to register a set of Vault accessors
type VaultAccessorRegisterRequest struct {
Accessors []*VaultAccessor
}
// VaultAccessor is a reference to a created Vault token on behalf of
// an allocation's task.
type VaultAccessor struct {
AllocID string
Task string
NodeID string
Accessor string
CreationTTL int
// Raft Indexes
CreateIndex uint64
}
// DeriveVaultTokenResponse returns the wrapped tokens for each requested task
type DeriveVaultTokenResponse struct {
// Tasks is a mapping between the task name and the wrapped token
Tasks map[string]string
QueryMeta
}
// GenericRequest is used to request where no
// specific information is needed.
type GenericRequest struct {
@@ -1239,11 +1275,11 @@ func (j *Job) IsPeriodic() bool {
}
// VaultPolicies returns the set of Vault policies per task group, per task
func (j *Job) VaultPolicies() map[string]map[string][]string {
policies := make(map[string]map[string][]string, len(j.TaskGroups))
func (j *Job) VaultPolicies() map[string]map[string]*Vault {
policies := make(map[string]map[string]*Vault, len(j.TaskGroups))
for _, tg := range j.TaskGroups {
tgPolicies := make(map[string][]string, len(tg.Tasks))
tgPolicies := make(map[string]*Vault, len(tg.Tasks))
policies[tg.Name] = tgPolicies
for _, task := range tg.Tasks {
@@ -1251,7 +1287,7 @@ func (j *Job) VaultPolicies() map[string]map[string][]string {
continue
}
tgPolicies[task.Name] = task.Vault.Policies
tgPolicies[task.Name] = task.Vault
}
}

View File

@@ -224,8 +224,25 @@ func TestJob_SystemJob_Validate(t *testing.T) {
func TestJob_VaultPolicies(t *testing.T) {
j0 := &Job{}
e0 := make(map[string]map[string][]string, 0)
e0 := make(map[string]map[string]*Vault, 0)
vj1 := &Vault{
Policies: []string{
"p1",
"p2",
},
}
vj2 := &Vault{
Policies: []string{
"p3",
"p4",
},
}
vj3 := &Vault{
Policies: []string{
"p5",
},
}
j1 := &Job{
TaskGroups: []*TaskGroup{
&TaskGroup{
@@ -235,13 +252,8 @@ func TestJob_VaultPolicies(t *testing.T) {
Name: "t1",
},
&Task{
Name: "t2",
Vault: &Vault{
Policies: []string{
"p1",
"p2",
},
},
Name: "t2",
Vault: vj1,
},
},
},
@@ -249,40 +261,31 @@ func TestJob_VaultPolicies(t *testing.T) {
Name: "bar",
Tasks: []*Task{
&Task{
Name: "t3",
Vault: &Vault{
Policies: []string{
"p3",
"p4",
},
},
Name: "t3",
Vault: vj2,
},
&Task{
Name: "t4",
Vault: &Vault{
Policies: []string{
"p5",
},
},
Name: "t4",
Vault: vj3,
},
},
},
},
}
e1 := map[string]map[string][]string{
"foo": map[string][]string{
"t2": []string{"p1", "p2"},
e1 := map[string]map[string]*Vault{
"foo": map[string]*Vault{
"t2": vj1,
},
"bar": map[string][]string{
"t3": []string{"p3", "p4"},
"t4": []string{"p5"},
"bar": map[string]*Vault{
"t3": vj2,
"t4": vj3,
},
}
cases := []struct {
Job *Job
Expected map[string]map[string][]string
Expected map[string]map[string]*Vault
}{
{
Job: j0,

View File

@@ -1,6 +1,7 @@
package nomad
import (
"context"
"errors"
"fmt"
"log"
@@ -12,6 +13,8 @@ import (
"github.com/hashicorp/nomad/nomad/structs/config"
vapi "github.com/hashicorp/vault/api"
"github.com/mitchellh/mapstructure"
"golang.org/x/time/rate"
)
const (
@@ -21,16 +24,25 @@ const (
// minimumTokenTTL is the minimum Token TTL allowed for child tokens.
minimumTokenTTL = 5 * time.Minute
// defaultTokenTTL is the default Token TTL used when the passed token is a
// root token such that child tokens aren't being created against a role
// that has defined a TTL
defaultTokenTTL = "72h"
// requestRateLimit is the maximum number of requests per second Nomad will
// make against Vault
requestRateLimit rate.Limit = 500.0
)
// VaultClient is the Servers interface for interfacing with Vault
type VaultClient interface {
// CreateToken takes an allocation and task and returns an appropriate Vault
// Secret
CreateToken(a *structs.Allocation, task string) (*vapi.Secret, error)
CreateToken(ctx context.Context, a *structs.Allocation, task string) (*vapi.Secret, error)
// LookupToken takes a token string and returns its capabilities.
LookupToken(token string) (*vapi.Secret, error)
LookupToken(ctx context.Context, token string) (*vapi.Secret, error)
// Stop is used to stop token renewal.
Stop()
@@ -52,6 +64,9 @@ type tokenData struct {
// the Server with the ability to create child tokens and lookup the permissions
// of tokens.
type vaultClient struct {
// limiter is used to rate limit requests to Vault
limiter *rate.Limiter
// client is the Vault API client
client *vapi.Client
@@ -104,6 +119,7 @@ func NewVaultClient(c *config.VaultConfig, logger *log.Logger) (*vaultClient, er
enabled: c.Enabled,
config: c,
logger: logger,
limiter: rate.NewLimiter(requestRateLimit, int(requestRateLimit)),
}
// If vault is not enabled do not configure an API client or start any token
@@ -131,6 +147,9 @@ func NewVaultClient(c *config.VaultConfig, logger *log.Logger) (*vaultClient, er
}
v.childTTL = c.TaskTokenTTL
} else {
// Default the TaskTokenTTL
v.childTTL = defaultTokenTTL
}
// Get the Vault API configuration
@@ -157,6 +176,11 @@ func NewVaultClient(c *config.VaultConfig, logger *log.Logger) (*vaultClient, er
return v, nil
}
// setLimit is used to update the rate limit
func (v *vaultClient) setLimit(l rate.Limit) {
v.limiter = rate.NewLimiter(l, int(l))
}
// establishConnection is used to make first contact with Vault. This should be
// called in a go-routine since the connection is retried til the Vault Client
// is stopped or the connection is successfully made at which point the renew
@@ -397,7 +421,7 @@ func (v *vaultClient) Stop() {
v.l.Lock()
defer v.l.Unlock()
if !v.renewalRunning || !v.establishingConn {
if !v.renewalRunning && !v.establishingConn {
return
}
@@ -414,12 +438,9 @@ func (v *vaultClient) ConnectionEstablished() bool {
return v.connEstablished
}
func (v *vaultClient) CreateToken(a *structs.Allocation, task string) (*vapi.Secret, error) {
return nil, nil
}
// LookupToken takes a Vault token and does a lookup against Vault
func (v *vaultClient) LookupToken(token string) (*vapi.Secret, error) {
// CreateToken takes the allocation and task and returns an appropriate Vault
// token. The call is rate limited and may be canceled with the passed policy
func (v *vaultClient) CreateToken(ctx context.Context, a *structs.Allocation, task string) (*vapi.Secret, error) {
// Nothing to do
if !v.enabled {
return nil, fmt.Errorf("Vault integration disabled")
@@ -430,6 +451,70 @@ func (v *vaultClient) LookupToken(token string) (*vapi.Secret, error) {
return nil, fmt.Errorf("Connection to Vault has not been established. Retry")
}
// Retrieve the Vault block for the task
policies := a.Job.VaultPolicies()
if policies == nil {
return nil, fmt.Errorf("Job doesn't require Vault policies")
}
tg, ok := policies[a.TaskGroup]
if !ok {
return nil, fmt.Errorf("Task group does not require Vault policies")
}
taskVault, ok := tg[task]
if !ok {
return nil, fmt.Errorf("Task does not require Vault policies")
}
// Build the creation request
req := &vapi.TokenCreateRequest{
Policies: taskVault.Policies,
Metadata: map[string]string{
"AllocationID": a.ID,
"Task": task,
"NodeID": a.NodeID,
},
TTL: v.childTTL,
DisplayName: fmt.Sprintf("%s: %s", a.ID, task),
}
// Ensure we are under our rate limit
if err := v.limiter.Wait(ctx); err != nil {
return nil, err
}
// Make the request and switch depending on whether we are using a root
// token or a role based token
var secret *vapi.Secret
var err error
if v.token.Root {
req.Period = v.childTTL
secret, err = v.auth.Create(req)
} else {
// Make the token using the role
secret, err = v.auth.CreateWithRole(req, v.token.Role)
}
return secret, err
}
// LookupToken takes a Vault token and does a lookup against Vault. The call is
// rate limited and may be canceled with passed context.
func (v *vaultClient) LookupToken(ctx context.Context, token string) (*vapi.Secret, error) {
// Nothing to do
if !v.enabled {
return nil, fmt.Errorf("Vault integration disabled")
}
// Check if we have established a connection with Vault
if !v.ConnectionEstablished() {
return nil, fmt.Errorf("Connection to Vault has not been established. Retry")
}
// Ensure we are under our rate limit
if err := v.limiter.Wait(ctx); err != nil {
return nil, err
}
// Lookup the token
return v.auth.Lookup(token)
}

View File

@@ -1,6 +1,7 @@
package nomad
import (
"context"
"encoding/json"
"log"
"os"
@@ -9,12 +10,22 @@ import (
"testing"
"time"
"golang.org/x/time/rate"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/nomad/structs/config"
"github.com/hashicorp/nomad/testutil"
vapi "github.com/hashicorp/vault/api"
)
const (
// authPolicy is a policy that allows token creation operations
authPolicy = `path "auth/token/create/*" {
capabilities = ["create", "read", "update", "delete", "list"]
}`
)
func TestVaultClient_BadConfig(t *testing.T) {
conf := &config.VaultConfig{}
logger := log.New(os.Stderr, "", log.LstdFlags)
@@ -24,6 +35,7 @@ func TestVaultClient_BadConfig(t *testing.T) {
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer client.Stop()
if client.ConnectionEstablished() {
t.Fatalf("bad")
@@ -75,15 +87,20 @@ func TestVaultClient_EstablishConnection(t *testing.T) {
}
}
func TestVaultClient_RenewalLoop(t *testing.T) {
v := testutil.NewTestVault(t).Start()
defer v.Stop()
// testVaultRoleAndToken creates a test Vault role where children are created
// with the passed period. A token created in that role is returned
func testVaultRoleAndToken(v *testutil.TestVault, t *testing.T, rolePeriod int) string {
// Build the auth policy
sys := v.Client.Sys()
if err := sys.PutPolicy("auth", authPolicy); err != nil {
t.Fatalf("failed to create auth policy: %v", err)
}
// Build a role
l := v.Client.Logical()
d := make(map[string]interface{}, 2)
d["allowed_policies"] = "default"
d["period"] = 5
d["allowed_policies"] = "default,auth"
d["period"] = rolePeriod
l.Write("auth/token/roles/test", d)
// Create a new token with the role
@@ -99,8 +116,15 @@ func TestVaultClient_RenewalLoop(t *testing.T) {
t.Fatalf("bad secret response: %+v", s)
}
// Set the configs token
v.Config.Token = s.Auth.ClientToken
return s.Auth.ClientToken
}
func TestVaultClient_RenewalLoop(t *testing.T) {
v := testutil.NewTestVault(t).Start()
defer v.Stop()
// Set the configs token in a new test role
v.Config.Token = testVaultRoleAndToken(v, t, 5)
// Start the client
logger := log.New(os.Stderr, "", log.LstdFlags)
@@ -114,6 +138,7 @@ func TestVaultClient_RenewalLoop(t *testing.T) {
time.Sleep(8 * time.Second)
// Get the current TTL
a := v.Client.Auth().Token()
s2, err := a.Lookup(v.Config.Token)
if err != nil {
t.Fatalf("failed to lookup token: %v", err)
@@ -160,8 +185,9 @@ func TestVaultClient_LookupToken_Invalid(t *testing.T) {
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
defer client.Stop()
_, err = client.LookupToken("foo")
_, err = client.LookupToken(context.Background(), "foo")
if err == nil || !strings.Contains(err.Error(), "disabled") {
t.Fatalf("Expected error because Vault is disabled: %v", err)
}
@@ -175,7 +201,7 @@ func TestVaultClient_LookupToken_Invalid(t *testing.T) {
t.Fatalf("failed to build vault client: %v", err)
}
_, err = client.LookupToken("foo")
_, err = client.LookupToken(context.Background(), "foo")
if err == nil || !strings.Contains(err.Error(), "established") {
t.Fatalf("Expected error because connection to Vault hasn't been made: %v", err)
}
@@ -198,11 +224,12 @@ func TestVaultClient_LookupToken(t *testing.T) {
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
defer client.Stop()
waitForConnection(client, t)
// Lookup ourselves
s, err := client.LookupToken(v.Config.Token)
s, err := client.LookupToken(context.Background(), v.Config.Token)
if err != nil {
t.Fatalf("self lookup failed: %v", err)
}
@@ -233,7 +260,7 @@ func TestVaultClient_LookupToken(t *testing.T) {
}
// Lookup new child
s, err = client.LookupToken(s.Auth.ClientToken)
s, err = client.LookupToken(context.Background(), s.Auth.ClientToken)
if err != nil {
t.Fatalf("self lookup failed: %v", err)
}
@@ -247,3 +274,145 @@ func TestVaultClient_LookupToken(t *testing.T) {
t.Fatalf("Unexpected policies; got %v; want %v", policies, expected)
}
}
func TestVaultClient_LookupToken_RateLimit(t *testing.T) {
v := testutil.NewTestVault(t).Start()
defer v.Stop()
logger := log.New(os.Stderr, "", log.LstdFlags)
client, err := NewVaultClient(v.Config, logger)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
defer client.Stop()
client.setLimit(rate.Limit(1.0))
waitForConnection(client, t)
// Spin up many requests. These should block
ctx, cancel := context.WithCancel(context.Background())
cancels := 0
numRequests := 10
unblock := make(chan struct{})
for i := 0; i < numRequests; i++ {
go func() {
// Ensure all the goroutines are made
time.Sleep(10 * time.Millisecond)
// Lookup ourselves
_, err := client.LookupToken(ctx, v.Config.Token)
if err != nil {
if err == context.Canceled {
cancels += 1
return
}
t.Fatalf("self lookup failed: %v", err)
return
}
// Cancel the context
cancel()
time.AfterFunc(1*time.Second, func() { close(unblock) })
}()
}
select {
case <-time.After(5 * time.Second):
t.Fatalf("timeout")
case <-unblock:
}
desired := numRequests - 1
if cancels != desired {
t.Fatalf("Incorrect number of cancels; got %d; want %d", cancels, desired)
}
}
func TestVaultClient_CreateToken_Root(t *testing.T) {
v := testutil.NewTestVault(t).Start()
defer v.Stop()
logger := log.New(os.Stderr, "", log.LstdFlags)
client, err := NewVaultClient(v.Config, logger)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
defer client.Stop()
waitForConnection(client, t)
// Create an allocation that requires a Vault policy
a := mock.Alloc()
task := a.Job.TaskGroups[0].Tasks[0]
task.Vault = &structs.Vault{Policies: []string{"default"}}
s, err := client.CreateToken(context.Background(), a, task.Name)
if err != nil {
t.Fatalf("CreateToken failed: %v", err)
}
// Ensure that created secret is a wrapped token
if s == nil || s.WrapInfo == nil {
t.Fatalf("Bad secret: %#v", s)
}
d, err := time.ParseDuration(vaultTokenCreateTTL)
if err != nil {
t.Fatalf("bad: %v", err)
}
if s.WrapInfo.WrappedAccessor == "" {
t.Fatalf("Bad accessor: %v", s.WrapInfo.WrappedAccessor)
} else if s.WrapInfo.Token == "" {
t.Fatalf("Bad token: %v", s.WrapInfo.WrappedAccessor)
} else if s.WrapInfo.TTL != int(d.Seconds()) {
t.Fatalf("Bad ttl: %v", s.WrapInfo.WrappedAccessor)
}
}
func TestVaultClient_CreateToken_Role(t *testing.T) {
v := testutil.NewTestVault(t).Start()
defer v.Stop()
// Set the configs token in a new test role
v.Config.Token = testVaultRoleAndToken(v, t, 5)
//testVaultRoleAndToken(v, t, 5)
// Start the client
logger := log.New(os.Stderr, "", log.LstdFlags)
client, err := NewVaultClient(v.Config, logger)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
defer client.Stop()
waitForConnection(client, t)
// Create an allocation that requires a Vault policy
a := mock.Alloc()
task := a.Job.TaskGroups[0].Tasks[0]
task.Vault = &structs.Vault{Policies: []string{"default"}}
s, err := client.CreateToken(context.Background(), a, task.Name)
if err != nil {
t.Fatalf("CreateToken failed: %v", err)
}
// Ensure that created secret is a wrapped token
if s == nil || s.WrapInfo == nil {
t.Fatalf("Bad secret: %#v", s)
}
d, err := time.ParseDuration(vaultTokenCreateTTL)
if err != nil {
t.Fatalf("bad: %v", err)
}
if s.WrapInfo.WrappedAccessor == "" {
t.Fatalf("Bad accessor: %v", s.WrapInfo.WrappedAccessor)
} else if s.WrapInfo.Token == "" {
t.Fatalf("Bad token: %v", s.WrapInfo.WrappedAccessor)
} else if s.WrapInfo.TTL != int(d.Seconds()) {
t.Fatalf("Bad ttl: %v", s.WrapInfo.WrappedAccessor)
}
}

View File

@@ -1,6 +1,8 @@
package nomad
import (
"context"
"github.com/hashicorp/nomad/nomad/structs"
vapi "github.com/hashicorp/vault/api"
)
@@ -16,9 +18,17 @@ type TestVaultClient struct {
// LookupTokenSecret maps a token to the Vault secret that will be returned
// by the LookupToken call
LookupTokenSecret map[string]*vapi.Secret
// CreateTokenErrors maps a token to an error that will be returned by the
// CreateToken call
CreateTokenErrors map[string]map[string]error
// CreateTokenSecret maps a token to the Vault secret that will be returned
// by the CreateToken call
CreateTokenSecret map[string]map[string]*vapi.Secret
}
func (v *TestVaultClient) LookupToken(token string) (*vapi.Secret, error) {
func (v *TestVaultClient) LookupToken(ctx context.Context, token string) (*vapi.Secret, error) {
var secret *vapi.Secret
var err error
@@ -64,8 +74,56 @@ func (v *TestVaultClient) SetLookupTokenAllowedPolicies(token string, policies [
v.SetLookupTokenSecret(token, s)
}
func (v *TestVaultClient) CreateToken(a *structs.Allocation, task string) (*vapi.Secret, error) {
return nil, nil
func (v *TestVaultClient) CreateToken(ctx context.Context, a *structs.Allocation, task string) (*vapi.Secret, error) {
var secret *vapi.Secret
var err error
if v.CreateTokenSecret != nil {
tasks := v.CreateTokenSecret[a.ID]
if tasks != nil {
secret = tasks[task]
}
}
if v.CreateTokenErrors != nil {
tasks := v.CreateTokenErrors[a.ID]
if tasks != nil {
err = tasks[task]
}
}
return secret, err
}
// SetCreateTokenError sets the error that will be returned by the token
// creation
func (v *TestVaultClient) SetCreateTokenError(allocID, task string, err error) {
if v.CreateTokenErrors == nil {
v.CreateTokenErrors = make(map[string]map[string]error)
}
tasks := v.CreateTokenErrors[allocID]
if tasks == nil {
tasks = make(map[string]error)
v.CreateTokenErrors[allocID] = tasks
}
v.CreateTokenErrors[allocID][task] = err
}
// SetCreateTokenSecret sets the secret that will be returned by the token
// creation
func (v *TestVaultClient) SetCreateTokenSecret(allocID, task string, secret *vapi.Secret) {
if v.CreateTokenSecret == nil {
v.CreateTokenSecret = make(map[string]map[string]*vapi.Secret)
}
tasks := v.CreateTokenSecret[allocID]
if tasks == nil {
tasks = make(map[string]*vapi.Secret)
v.CreateTokenSecret[allocID] = tasks
}
v.CreateTokenSecret[allocID][task] = secret
}
func (v *TestVaultClient) Stop() {}

View File

@@ -1,4 +1,5 @@
#!/usr/bin/env bash
set -e
# Create a temp dir and clean it up on exit
TEMPDIR=`mktemp -d -t nomad-test.XXX`

View File

@@ -1,4 +1,5 @@
#!/usr/bin/env bash
set -e
export PING_SLEEP=30
bash -c "while true; do echo \$(date) - building ...; sleep $PING_SLEEP; done" &

View File

@@ -119,6 +119,6 @@ func (tv *TestVault) waitForAPI() {
// getPort returns the next available port to bind Vault against
func getPort() uint64 {
p := vaultStartPort + vaultPortOffset
offset += 1
vaultPortOffset += 1
return p
}

View File

@@ -170,6 +170,7 @@ type TokenCreateRequest struct {
Lease string `json:"lease,omitempty"`
TTL string `json:"ttl,omitempty"`
ExplicitMaxTTL string `json:"explicit_max_ttl,omitempty"`
Period string `json:"period,omitempty"`
NoParent bool `json:"no_parent,omitempty"`
NoDefaultPolicy bool `json:"no_default_policy,omitempty"`
DisplayName string `json:"display_name"`

View File

@@ -22,21 +22,12 @@ func (c *Sys) AuditHash(path string, input string) (string, error) {
}
defer resp.Body.Close()
secret, err := ParseSecret(resp.Body)
if err != nil {
return "", err
}
if secret == nil || secret.Data == nil || len(secret.Data) == 0 {
return "", nil
}
type d struct {
Hash string
Hash string `json:"hash"`
}
var result d
err = mapstructure.Decode(secret.Data, &result)
err = resp.DecodeJSON(&result)
if err != nil {
return "", err
}
@@ -52,26 +43,32 @@ func (c *Sys) ListAudit() (map[string]*Audit, error) {
}
defer resp.Body.Close()
secret, err := ParseSecret(resp.Body)
var result map[string]interface{}
err = resp.DecodeJSON(&result)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil || len(secret.Data) == 0 {
return nil, nil
}
result := map[string]*Audit{}
for k, v := range secret.Data {
mounts := map[string]*Audit{}
for k, v := range result {
switch v.(type) {
case map[string]interface{}:
default:
continue
}
var res Audit
err = mapstructure.Decode(v, &res)
if err != nil {
return nil, err
}
result[k] = &res
// Not a mount, some other api.Secret data
if res.Type == "" {
continue
}
mounts[k] = &res
}
return result, err
return mounts, nil
}
func (c *Sys) EnableAudit(
@@ -106,7 +103,7 @@ func (c *Sys) DisableAudit(path string) error {
}
// Structures for the requests/resposne are all down here. They aren't
// individually documentd because the map almost directly to the raw HTTP API
// individually documented because the map almost directly to the raw HTTP API
// documentation. Please refer to that documentation for more details.
type Audit struct {

View File

@@ -14,26 +14,32 @@ func (c *Sys) ListAuth() (map[string]*AuthMount, error) {
}
defer resp.Body.Close()
secret, err := ParseSecret(resp.Body)
var result map[string]interface{}
err = resp.DecodeJSON(&result)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil || len(secret.Data) == 0 {
return nil, nil
}
result := map[string]*AuthMount{}
for k, v := range secret.Data {
mounts := map[string]*AuthMount{}
for k, v := range result {
switch v.(type) {
case map[string]interface{}:
default:
continue
}
var res AuthMount
err = mapstructure.Decode(v, &res)
if err != nil {
return nil, err
}
result[k] = &res
// Not a mount, some other api.Secret data
if res.Type == "" {
continue
}
mounts[k] = &res
}
return result, err
return mounts, nil
}
func (c *Sys) EnableAuth(path, authType, desc string) error {

View File

@@ -28,17 +28,14 @@ func (c *Sys) Capabilities(token, path string) ([]string, error) {
}
defer resp.Body.Close()
secret, err := ParseSecret(resp.Body)
var result map[string]interface{}
err = resp.DecodeJSON(&result)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil || len(secret.Data) == 0 {
return nil, nil
}
var capabilities []string
capabilitiesRaw := secret.Data["capabilities"].([]interface{})
capabilitiesRaw := result["capabilities"].([]interface{})
for _, capability := range capabilitiesRaw {
capabilities = append(capabilities, capability.(string))
}

View File

@@ -45,7 +45,9 @@ type InitStatusResponse struct {
}
type InitResponse struct {
Keys []string `json:"keys"`
RecoveryKeys []string `json:"recovery_keys"`
RootToken string `json:"root_token"`
Keys []string `json:"keys"`
KeysB64 []string `json:"keys_base64"`
RecoveryKeys []string `json:"recovery_keys"`
RecoveryKeysB64 []string `json:"recovery_keys_base64"`
RootToken string `json:"root_token"`
}

View File

@@ -15,26 +15,32 @@ func (c *Sys) ListMounts() (map[string]*MountOutput, error) {
}
defer resp.Body.Close()
secret, err := ParseSecret(resp.Body)
var result map[string]interface{}
err = resp.DecodeJSON(&result)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil || len(secret.Data) == 0 {
return nil, nil
}
result := map[string]*MountOutput{}
for k, v := range secret.Data {
mounts := map[string]*MountOutput{}
for k, v := range result {
switch v.(type) {
case map[string]interface{}:
default:
continue
}
var res MountOutput
err = mapstructure.Decode(v, &res)
if err != nil {
return nil, err
}
result[k] = &res
// Not a mount, some other api.Secret data
if res.Type == "" {
continue
}
mounts[k] = &res
}
return result, nil
return mounts, nil
}
func (c *Sys) Mount(path string, mountInfo *MountInput) error {
@@ -104,17 +110,8 @@ func (c *Sys) MountConfig(path string) (*MountConfigOutput, error) {
}
defer resp.Body.Close()
secret, err := ParseSecret(resp.Body)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil || len(secret.Data) == 0 {
return nil, nil
}
var result MountConfigOutput
err = mapstructure.Decode(secret.Data, &result)
err = resp.DecodeJSON(&result)
if err != nil {
return nil, err
}

View File

@@ -1,10 +1,6 @@
package api
import (
"fmt"
"github.com/mitchellh/mapstructure"
)
import "fmt"
func (c *Sys) ListPolicies() ([]string, error) {
r := c.c.NewRequest("GET", "/v1/sys/policy")
@@ -14,22 +10,25 @@ func (c *Sys) ListPolicies() ([]string, error) {
}
defer resp.Body.Close()
secret, err := ParseSecret(resp.Body)
var result map[string]interface{}
err = resp.DecodeJSON(&result)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil || len(secret.Data) == 0 {
return nil, nil
var ok bool
if _, ok = result["policies"]; !ok {
return nil, fmt.Errorf("policies not found in response")
}
var result listPoliciesResp
err = mapstructure.Decode(secret.Data, &result)
if err != nil {
return nil, err
listRaw := result["policies"].([]interface{})
var policies []string
for _, val := range listRaw {
policies = append(policies, val.(string))
}
return result.Policies, err
return policies, err
}
func (c *Sys) GetPolicy(name string) (string, error) {
@@ -45,22 +44,18 @@ func (c *Sys) GetPolicy(name string) (string, error) {
return "", err
}
secret, err := ParseSecret(resp.Body)
var result map[string]interface{}
err = resp.DecodeJSON(&result)
if err != nil {
return "", err
}
if secret == nil || secret.Data == nil || len(secret.Data) == 0 {
return "", nil
var ok bool
if _, ok = result["rules"]; !ok {
return "", fmt.Errorf("rules not found in response")
}
var result getPoliciesResp
err = mapstructure.Decode(secret.Data, &result)
if err != nil {
return "", err
}
return result.Rules, err
return result["rules"].(string), nil
}
func (c *Sys) PutPolicy(name, rules string) error {

View File

@@ -190,11 +190,13 @@ type RekeyUpdateResponse struct {
Nonce string
Complete bool
Keys []string
KeysB64 []string `json:"keys_base64"`
PGPFingerprints []string `json:"pgp_fingerprints"`
Backup bool
}
type RekeyRetrieveResponse struct {
Nonce string
Keys map[string][]string
Nonce string
Keys map[string][]string
KeysB64 map[string][]string `json:"keys_base64"`
}

View File

@@ -1,10 +1,6 @@
package api
import (
"time"
"github.com/mitchellh/mapstructure"
)
import "time"
func (c *Sys) Rotate() error {
r := c.c.NewRequest("POST", "/v1/sys/rotate")
@@ -23,25 +19,12 @@ func (c *Sys) KeyStatus() (*KeyStatus, error) {
}
defer resp.Body.Close()
secret, err := ParseSecret(resp.Body)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil || len(secret.Data) == 0 {
return nil, nil
}
var result KeyStatus
err = mapstructure.Decode(secret.Data, &result)
if err != nil {
return nil, err
}
return &result, err
result := new(KeyStatus)
err = resp.DecodeJSON(result)
return result, err
}
type KeyStatus struct {
Term int
Term int `json:"term"`
InstallTime time.Time `json:"install_time"`
}

27
vendor/golang.org/x/sync/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

22
vendor/golang.org/x/sync/PATENTS generated vendored Normal file
View File

@@ -0,0 +1,22 @@
Additional IP Rights Grant (Patents)
"This implementation" means the copyrightable works distributed by
Google as part of the Go project.
Google hereby grants to You a perpetual, worldwide, non-exclusive,
no-charge, royalty-free, irrevocable (except as stated in this section)
patent license to make, have made, use, offer to sell, sell, import,
transfer and otherwise run, modify and propagate the contents of this
implementation of Go, where such license applies only to those patent
claims, both currently owned or controlled by Google and acquired in
the future, licensable by Google that are necessarily infringed by this
implementation of Go. This grant does not include claims that would be
infringed only as a consequence of further modification of this
implementation. If you or your agent or exclusive licensee institute or
order or agree to the institution of patent litigation against any
entity (including a cross-claim or counterclaim in a lawsuit) alleging
that this implementation of Go or any code incorporated within this
implementation of Go constitutes direct or contributory patent
infringement, or inducement of patent infringement, then any patent
rights granted to you under this License for this implementation of Go
shall terminate as of the date such litigation is filed.

67
vendor/golang.org/x/sync/errgroup/errgroup.go generated vendored Normal file
View File

@@ -0,0 +1,67 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package errgroup provides synchronization, error propagation, and Context
// cancelation for groups of goroutines working on subtasks of a common task.
package errgroup
import (
"sync"
"golang.org/x/net/context"
)
// A Group is a collection of goroutines working on subtasks that are part of
// the same overall task.
//
// A zero Group is valid and does not cancel on error.
type Group struct {
cancel func()
wg sync.WaitGroup
errOnce sync.Once
err error
}
// WithContext returns a new Group and an associated Context derived from ctx.
//
// The derived Context is canceled the first time a function passed to Go
// returns a non-nil error or the first time Wait returns, whichever occurs
// first.
func WithContext(ctx context.Context) (*Group, context.Context) {
ctx, cancel := context.WithCancel(ctx)
return &Group{cancel: cancel}, ctx
}
// Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them.
func (g *Group) Wait() error {
g.wg.Wait()
if g.cancel != nil {
g.cancel()
}
return g.err
}
// Go calls the given function in a new goroutine.
//
// The first call to return a non-nil error cancels the group; its error will be
// returned by Wait.
func (g *Group) Go(f func() error) {
g.wg.Add(1)
go func() {
defer g.wg.Done()
if err := f(); err != nil {
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
g.cancel()
}
})
}
}()
}

27
vendor/golang.org/x/time/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

22
vendor/golang.org/x/time/PATENTS generated vendored Normal file
View File

@@ -0,0 +1,22 @@
Additional IP Rights Grant (Patents)
"This implementation" means the copyrightable works distributed by
Google as part of the Go project.
Google hereby grants to You a perpetual, worldwide, non-exclusive,
no-charge, royalty-free, irrevocable (except as stated in this section)
patent license to make, have made, use, offer to sell, sell, import,
transfer and otherwise run, modify and propagate the contents of this
implementation of Go, where such license applies only to those patent
claims, both currently owned or controlled by Google and acquired in
the future, licensable by Google that are necessarily infringed by this
implementation of Go. This grant does not include claims that would be
infringed only as a consequence of further modification of this
implementation. If you or your agent or exclusive licensee institute or
order or agree to the institution of patent litigation against any
entity (including a cross-claim or counterclaim in a lawsuit) alleging
that this implementation of Go or any code incorporated within this
implementation of Go constitutes direct or contributory patent
infringement, or inducement of patent infringement, then any patent
rights granted to you under this License for this implementation of Go
shall terminate as of the date such litigation is filed.

368
vendor/golang.org/x/time/rate/rate.go generated vendored Normal file
View File

@@ -0,0 +1,368 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package rate provides a rate limiter.
package rate
import (
"fmt"
"math"
"sync"
"time"
"golang.org/x/net/context"
)
// Limit defines the maximum frequency of some events.
// Limit is represented as number of events per second.
// A zero Limit allows no events.
type Limit float64
// Inf is the infinite rate limit; it allows all events (even if burst is zero).
const Inf = Limit(math.MaxFloat64)
// Every converts a minimum time interval between events to a Limit.
func Every(interval time.Duration) Limit {
if interval <= 0 {
return Inf
}
return 1 / Limit(interval.Seconds())
}
// A Limiter controls how frequently events are allowed to happen.
// It implements a "token bucket" of size b, initially full and refilled
// at rate r tokens per second.
// Informally, in any large enough time interval, the Limiter limits the
// rate to r tokens per second, with a maximum burst size of b events.
// As a special case, if r == Inf (the infinite rate), b is ignored.
// See https://en.wikipedia.org/wiki/Token_bucket for more about token buckets.
//
// The zero value is a valid Limiter, but it will reject all events.
// Use NewLimiter to create non-zero Limiters.
//
// Limiter has three main methods, Allow, Reserve, and Wait.
// Most callers should use Wait.
//
// Each of the three methods consumes a single token.
// They differ in their behavior when no token is available.
// If no token is available, Allow returns false.
// If no token is available, Reserve returns a reservation for a future token
// and the amount of time the caller must wait before using it.
// If no token is available, Wait blocks until one can be obtained
// or its associated context.Context is canceled.
//
// The methods AllowN, ReserveN, and WaitN consume n tokens.
type Limiter struct {
limit Limit
burst int
mu sync.Mutex
tokens float64
// last is the last time the limiter's tokens field was updated
last time.Time
// lastEvent is the latest time of a rate-limited event (past or future)
lastEvent time.Time
}
// Limit returns the maximum overall event rate.
func (lim *Limiter) Limit() Limit {
lim.mu.Lock()
defer lim.mu.Unlock()
return lim.limit
}
// Burst returns the maximum burst size. Burst is the maximum number of tokens
// that can be consumed in a single call to Allow, Reserve, or Wait, so higher
// Burst values allow more events to happen at once.
// A zero Burst allows no events, unless limit == Inf.
func (lim *Limiter) Burst() int {
return lim.burst
}
// NewLimiter returns a new Limiter that allows events up to rate r and permits
// bursts of at most b tokens.
func NewLimiter(r Limit, b int) *Limiter {
return &Limiter{
limit: r,
burst: b,
}
}
// Allow is shorthand for AllowN(time.Now(), 1).
func (lim *Limiter) Allow() bool {
return lim.AllowN(time.Now(), 1)
}
// AllowN reports whether n events may happen at time now.
// Use this method if you intend to drop / skip events that exceed the rate limit.
// Otherwise use Reserve or Wait.
func (lim *Limiter) AllowN(now time.Time, n int) bool {
return lim.reserveN(now, n, 0).ok
}
// A Reservation holds information about events that are permitted by a Limiter to happen after a delay.
// A Reservation may be canceled, which may enable the Limiter to permit additional events.
type Reservation struct {
ok bool
lim *Limiter
tokens int
timeToAct time.Time
// This is the Limit at reservation time, it can change later.
limit Limit
}
// OK returns whether the limiter can provide the requested number of tokens
// within the maximum wait time. If OK is false, Delay returns InfDuration, and
// Cancel does nothing.
func (r *Reservation) OK() bool {
return r.ok
}
// Delay is shorthand for DelayFrom(time.Now()).
func (r *Reservation) Delay() time.Duration {
return r.DelayFrom(time.Now())
}
// InfDuration is the duration returned by Delay when a Reservation is not OK.
const InfDuration = time.Duration(1<<63 - 1)
// DelayFrom returns the duration for which the reservation holder must wait
// before taking the reserved action. Zero duration means act immediately.
// InfDuration means the limiter cannot grant the tokens requested in this
// Reservation within the maximum wait time.
func (r *Reservation) DelayFrom(now time.Time) time.Duration {
if !r.ok {
return InfDuration
}
delay := r.timeToAct.Sub(now)
if delay < 0 {
return 0
}
return delay
}
// Cancel is shorthand for CancelAt(time.Now()).
func (r *Reservation) Cancel() {
r.CancelAt(time.Now())
return
}
// CancelAt indicates that the reservation holder will not perform the reserved action
// and reverses the effects of this Reservation on the rate limit as much as possible,
// considering that other reservations may have already been made.
func (r *Reservation) CancelAt(now time.Time) {
if !r.ok {
return
}
r.lim.mu.Lock()
defer r.lim.mu.Unlock()
if r.lim.limit == Inf || r.tokens == 0 || r.timeToAct.Before(now) {
return
}
// calculate tokens to restore
// The duration between lim.lastEvent and r.timeToAct tells us how many tokens were reserved
// after r was obtained. These tokens should not be restored.
restoreTokens := float64(r.tokens) - r.limit.tokensFromDuration(r.lim.lastEvent.Sub(r.timeToAct))
if restoreTokens <= 0 {
return
}
// advance time to now
now, _, tokens := r.lim.advance(now)
// calculate new number of tokens
tokens += restoreTokens
if burst := float64(r.lim.burst); tokens > burst {
tokens = burst
}
// update state
r.lim.last = now
r.lim.tokens = tokens
if r.timeToAct == r.lim.lastEvent {
prevEvent := r.timeToAct.Add(r.limit.durationFromTokens(float64(-r.tokens)))
if !prevEvent.Before(now) {
r.lim.lastEvent = prevEvent
}
}
return
}
// Reserve is shorthand for ReserveN(time.Now(), 1).
func (lim *Limiter) Reserve() *Reservation {
return lim.ReserveN(time.Now(), 1)
}
// ReserveN returns a Reservation that indicates how long the caller must wait before n events happen.
// The Limiter takes this Reservation into account when allowing future events.
// ReserveN returns false if n exceeds the Limiter's burst size.
// Usage example:
// r, ok := lim.ReserveN(time.Now(), 1)
// if !ok {
// // Not allowed to act! Did you remember to set lim.burst to be > 0 ?
// }
// time.Sleep(r.Delay())
// Act()
// Use this method if you wish to wait and slow down in accordance with the rate limit without dropping events.
// If you need to respect a deadline or cancel the delay, use Wait instead.
// To drop or skip events exceeding rate limit, use Allow instead.
func (lim *Limiter) ReserveN(now time.Time, n int) *Reservation {
r := lim.reserveN(now, n, InfDuration)
return &r
}
// Wait is shorthand for WaitN(ctx, 1).
func (lim *Limiter) Wait(ctx context.Context) (err error) {
return lim.WaitN(ctx, 1)
}
// WaitN blocks until lim permits n events to happen.
// It returns an error if n exceeds the Limiter's burst size, the Context is
// canceled, or the expected wait time exceeds the Context's Deadline.
func (lim *Limiter) WaitN(ctx context.Context, n int) (err error) {
if n > lim.burst {
return fmt.Errorf("rate: Wait(n=%d) exceeds limiter's burst %d", n, lim.burst)
}
// Check if ctx is already cancelled
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// Determine wait limit
now := time.Now()
waitLimit := InfDuration
if deadline, ok := ctx.Deadline(); ok {
waitLimit = deadline.Sub(now)
}
// Reserve
r := lim.reserveN(now, n, waitLimit)
if !r.ok {
return fmt.Errorf("rate: Wait(n=%d) would exceed context deadline", n)
}
// Wait
t := time.NewTimer(r.DelayFrom(now))
defer t.Stop()
select {
case <-t.C:
// We can proceed.
return nil
case <-ctx.Done():
// Context was canceled before we could proceed. Cancel the
// reservation, which may permit other events to proceed sooner.
r.Cancel()
return ctx.Err()
}
}
// SetLimit is shorthand for SetLimitAt(time.Now(), newLimit).
func (lim *Limiter) SetLimit(newLimit Limit) {
lim.SetLimitAt(time.Now(), newLimit)
}
// SetLimitAt sets a new Limit for the limiter. The new Limit, and Burst, may be violated
// or underutilized by those which reserved (using Reserve or Wait) but did not yet act
// before SetLimitAt was called.
func (lim *Limiter) SetLimitAt(now time.Time, newLimit Limit) {
lim.mu.Lock()
defer lim.mu.Unlock()
now, _, tokens := lim.advance(now)
lim.last = now
lim.tokens = tokens
lim.limit = newLimit
}
// reserveN is a helper method for AllowN, ReserveN, and WaitN.
// maxFutureReserve specifies the maximum reservation wait duration allowed.
// reserveN returns Reservation, not *Reservation, to avoid allocation in AllowN and WaitN.
func (lim *Limiter) reserveN(now time.Time, n int, maxFutureReserve time.Duration) Reservation {
lim.mu.Lock()
defer lim.mu.Unlock()
if lim.limit == Inf {
return Reservation{
ok: true,
lim: lim,
tokens: n,
timeToAct: now,
}
}
now, last, tokens := lim.advance(now)
// Calculate the remaining number of tokens resulting from the request.
tokens -= float64(n)
// Calculate the wait duration
var waitDuration time.Duration
if tokens < 0 {
waitDuration = lim.limit.durationFromTokens(-tokens)
}
// Decide result
ok := n <= lim.burst && waitDuration <= maxFutureReserve
// Prepare reservation
r := Reservation{
ok: ok,
lim: lim,
limit: lim.limit,
}
if ok {
r.tokens = n
r.timeToAct = now.Add(waitDuration)
}
// Update state
if ok {
lim.last = now
lim.tokens = tokens
lim.lastEvent = r.timeToAct
} else {
lim.last = last
}
return r
}
// advance calculates and returns an updated state for lim resulting from the passage of time.
// lim is not changed.
func (lim *Limiter) advance(now time.Time) (newNow time.Time, newLast time.Time, newTokens float64) {
last := lim.last
if now.Before(last) {
last = now
}
// Avoid making delta overflow below when last is very old.
maxElapsed := lim.limit.durationFromTokens(float64(lim.burst) - lim.tokens)
elapsed := now.Sub(last)
if elapsed > maxElapsed {
elapsed = maxElapsed
}
// Calculate the new number of tokens, due to time that passed.
delta := lim.limit.tokensFromDuration(elapsed)
tokens := lim.tokens + delta
if burst := float64(lim.burst); tokens > burst {
tokens = burst
}
return now, last, tokens
}
// durationFromTokens is a unit conversion function from the number of tokens to the duration
// of time it takes to accumulate them at a rate of limit tokens per second.
func (limit Limit) durationFromTokens(tokens float64) time.Duration {
seconds := tokens / float64(limit)
return time.Nanosecond * time.Duration(1e9*seconds)
}
// tokensFromDuration is a unit conversion function from a time duration to the number of tokens
// which could be accumulated during that duration at a rate of limit tokens per second.
func (limit Limit) tokensFromDuration(d time.Duration) float64 {
return d.Seconds() * float64(limit)
}

24
vendor/vendor.json vendored
View File

@@ -623,10 +623,16 @@
"revisionTime": "2016-06-09T00:18:40Z"
},
{
"checksumSHA1": "0rkVtm9F1/pW9EGhHYJpCnY99O8=",
"checksumSHA1": "RAJfRxZ8UmcL6+7VuXAZxBlnM/4=",
"path": "github.com/hashicorp/vault",
"revision": "fece3ca069fc5bafec5280bbcb0c0693ff69fdaf",
"revisionTime": "2016-08-17T21:47:06Z"
},
{
"checksumSHA1": "JH8wmQ8cWdn7mYu1T7gJ3IMIrec=",
"path": "github.com/hashicorp/vault/api",
"revision": "fbecd94926e289d3b81d8dae6136452a6c4c93f6",
"revisionTime": "2016-08-13T15:54:01Z"
"revision": "fece3ca069fc5bafec5280bbcb0c0693ff69fdaf",
"revisionTime": "2016-08-17T21:47:06Z"
},
{
"checksumSHA1": "5lR6EdY0ARRdKAq3hZcL38STD8Q=",
@@ -827,6 +833,12 @@
"revision": "30db96677b74e24b967e23f911eb3364fc61a011",
"revisionTime": "2016-05-25T13:11:03Z"
},
{
"checksumSHA1": "S0DP7Pn7sZUmXc55IzZnNvERu6s=",
"path": "golang.org/x/sync/errgroup",
"revision": "316e794f7b5e3df4e95175a45a5fb8b12f85cb4f",
"revisionTime": "2016-07-15T18:54:39Z"
},
{
"path": "golang.org/x/sys/unix",
"revision": "50c6bc5e4292a1d4e65c6e9be5f53be28bcbe28e"
@@ -837,6 +849,12 @@
"revision": "b776ec39b3e54652e09028aaaaac9757f4f8211a",
"revisionTime": "2016-04-21T02:29:30Z"
},
{
"checksumSHA1": "h/06ikMECfJoTkWj2e1nJ30aDDg=",
"path": "golang.org/x/time/rate",
"revision": "a4bde12657593d5e90d0533a3e4fd95e635124cb",
"revisionTime": "2016-02-02T18:34:45Z"
},
{
"checksumSHA1": "93uHIq25lffEKY47PV8dBPD+XuQ=",
"path": "gopkg.in/fsnotify.v1",