diff --git a/client/allocrunnerv2/taskrunner/task_runner.go b/client/allocrunnerv2/taskrunner/task_runner.go index 45ecb158e..9844bb28e 100644 --- a/client/allocrunnerv2/taskrunner/task_runner.go +++ b/client/allocrunnerv2/taskrunner/task_runner.go @@ -1,10 +1,8 @@ package taskrunner import ( - "bytes" "context" "fmt" - "io" "sync" "time" @@ -21,8 +19,6 @@ import ( cstate "github.com/hashicorp/nomad/client/state" "github.com/hashicorp/nomad/client/vaultclient" "github.com/hashicorp/nomad/nomad/structs" - "github.com/ugorji/go/codec" - "golang.org/x/crypto/blake2b" ) const ( @@ -453,44 +449,11 @@ func (tr *TaskRunner) handleDestroy(handle driver.DriverHandle) (destroyed bool, } // persistLocalState persists local state to disk synchronously. -// -//XXX Not safe for concurrent calls. Should it be? func (tr *TaskRunner) persistLocalState() error { tr.localStateLock.Lock() defer tr.localStateLock.Unlock() - // buffer for writing serialized state to - var buf bytes.Buffer - - // Hash for skipping unnecessary writes - h, err := blake2b.New256(nil) - if err != nil { - // Programming error that should never happen! - return err - } - - // Multiplex writes to both - w := io.MultiWriter(h, &buf) - - // Encode as msgpack value - if err := codec.NewEncoder(w, structs.MsgpackHandle).Encode(&tr.localState); err != nil { - return fmt.Errorf("failed to serialize snapshot: %v", err) - } - - // If the hashes are equal, skip the write - hashVal := h.Sum(nil) - if bytes.Equal(hashVal, tr.persistedHash) { - return nil - } - - if err := tr.stateDB.PutTaskRunnerLocalState(tr.allocID, tr.taskName, buf.Bytes()); err != nil { - return err - } - - // State was persisted, set the hash - tr.persistedHash = hashVal - - return nil + return tr.stateDB.PutTaskRunnerLocalState(tr.allocID, tr.taskName, tr.localState) } // XXX If the objects don't exists since the client shutdown before the task diff --git a/client/client.go b/client/client.go index e52d5b917..13b3ec415 100644 --- a/client/client.go +++ b/client/client.go @@ -381,7 +381,7 @@ func (c *Client) init() error { c.logger.Printf("[INFO] client: using state directory %v", c.config.StateDir) // Open the state database - db, err := state.NewStateDB(c.config.StateDir, c.config.DevMode) + db, err := state.GetStateDBFactory(c.config.DevMode)(c.config.StateDir) if err != nil { return fmt.Errorf("failed to open state database: %v", err) } diff --git a/client/state/interface.go b/client/state/interface.go index e4d60f211..59d6bd08b 100644 --- a/client/state/interface.go +++ b/client/state/interface.go @@ -10,7 +10,7 @@ type StateDB interface { GetAllAllocations() ([]*structs.Allocation, map[string]error, error) PutAllocation(*structs.Allocation) error GetTaskRunnerState(allocID, taskName string) (*state.LocalState, *structs.TaskState, error) - PutTaskRunnerLocalState(allocID, taskName string, buf []byte) error + PutTaskRunnerLocalState(allocID, taskName string, val interface{}) error PutTaskState(allocID, taskName string, state *structs.TaskState) error Close() error } diff --git a/client/state/kvcodec.go b/client/state/kvcodec.go new file mode 100644 index 000000000..581ea528c --- /dev/null +++ b/client/state/kvcodec.go @@ -0,0 +1,104 @@ +package state + +import ( + "bytes" + "fmt" + "io" + "sync" + + "github.com/hashicorp/nomad/nomad/structs" + "github.com/ugorji/go/codec" + "golang.org/x/crypto/blake2b" +) + +type kvStore interface { + Get(key []byte) (val []byte) + Put(key, val []byte) error + Writable() bool +} + +// keyValueCodec handles encoding and decoding values from a key/value store +// such as boltdb. +type keyValueCodec struct { + // hashes maps keys to the hash of the last content written + hashes map[string][]byte + hashesLock sync.Mutex +} + +func newKeyValueCodec() *keyValueCodec { + return &keyValueCodec{ + hashes: make(map[string][]byte), + } +} + +// hashKey returns a unique key for each hashed boltdb value +func (c *keyValueCodec) hashKey(path string, key []byte) string { + return path + "-" + string(key) +} + +// Put into kv store iff it has changed since the last write. A globally +// unique key is constructed for each value by concatinating the path and key +// passed in. +func (c *keyValueCodec) Put(bkt kvStore, path string, key []byte, val interface{}) error { + if !bkt.Writable() { + return fmt.Errorf("bucket must be writable") + } + + // buffer for writing serialized state to + var buf bytes.Buffer + + // Hash for skipping unnecessary writes + h, err := blake2b.New256(nil) + if err != nil { + // Programming error that should never happen! + return err + } + + // Multiplex writes to both hasher and buffer + w := io.MultiWriter(h, &buf) + + // Serialize the object + if err := codec.NewEncoder(w, structs.MsgpackHandle).Encode(val); err != nil { + return fmt.Errorf("failed to encode passed object: %v", err) + } + + // If the hashes are equal, skip the write + hashVal := h.Sum(nil) + hashKey := c.hashKey(path, key) + + c.hashesLock.Lock() + persistedHash := c.hashes[hashKey] + c.hashesLock.Unlock() + + if bytes.Equal(hashVal, persistedHash) { + return nil + } + + if err := bkt.Put(key, buf.Bytes()); err != nil { + return fmt.Errorf("failed to write data at key %s: %v", key, err) + } + + // New value written, store hash + c.hashesLock.Lock() + c.hashes[hashKey] = hashVal + c.hashesLock.Unlock() + + return nil + +} + +// Get value by key from boltdb. +func (c *keyValueCodec) Get(bkt kvStore, key []byte, obj interface{}) error { + // Get the data + data := bkt.Get(key) + if data == nil { + return fmt.Errorf("no data at key %v", string(key)) + } + + // Deserialize the object + if err := codec.NewDecoderBytes(data, structs.MsgpackHandle).Decode(obj); err != nil { + return fmt.Errorf("failed to decode data into passed object: %v", err) + } + + return nil +} diff --git a/client/state/kvcodec_test.go b/client/state/kvcodec_test.go new file mode 100644 index 000000000..49c110b91 --- /dev/null +++ b/client/state/kvcodec_test.go @@ -0,0 +1,70 @@ +package state + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// mockKVStore tracks puts and is useful for testing KVCodec's write-on-change +// code. +type mockKVStore struct { + puts int +} + +func (mockKVStore) Get(key []byte) (val []byte) { + return nil +} + +func (m *mockKVStore) Put(key, val []byte) error { + m.puts++ + return nil +} + +func (mockKVStore) Writable() bool { + return true +} + +// TestKVCodec_PutHash asserts that Puts on the underlying kvstore only occur +// when the data actually changes. +func TestKVCodec_PutHash(t *testing.T) { + require := require.New(t) + codec := newKeyValueCodec() + + // Create arguments for Put + kv := new(mockKVStore) + path := "path-path" + key := []byte("key1") + val := &struct { + Val int + }{ + Val: 1, + } + + // Initial Put should be written + require.NoError(codec.Put(kv, path, key, val)) + require.Equal(1, kv.puts) + + // Writing the same values again should be a noop + require.NoError(codec.Put(kv, path, key, val)) + require.Equal(1, kv.puts) + + // Changing the value should write again + val.Val++ + require.NoError(codec.Put(kv, path, key, val)) + require.Equal(2, kv.puts) + + // Changing the key should write again + key = []byte("key2") + require.NoError(codec.Put(kv, path, key, val)) + require.Equal(3, kv.puts) + + // Changing the path should write again + path = "new-path" + require.NoError(codec.Put(kv, path, key, val)) + require.Equal(4, kv.puts) + + // Writing the same values again should be a noop + require.NoError(codec.Put(kv, path, key, val)) + require.Equal(4, kv.puts) +} diff --git a/client/state/noopdb.go b/client/state/noopdb.go index 67f5f1760..7a31c461f 100644 --- a/client/state/noopdb.go +++ b/client/state/noopdb.go @@ -19,7 +19,7 @@ func (n noopDB) GetTaskRunnerState(allocID string, taskName string) (*state.Loca return nil, nil, nil } -func (n noopDB) PutTaskRunnerLocalState(allocID string, taskName string, buf []byte) error { +func (n noopDB) PutTaskRunnerLocalState(allocID string, taskName string, val interface{}) error { return nil } diff --git a/client/state/state_database.go b/client/state/state_database.go index 1828f62a2..8793d651e 100644 --- a/client/state/state_database.go +++ b/client/state/state_database.go @@ -1,14 +1,13 @@ package state import ( - "bytes" "fmt" "path/filepath" + "strings" "github.com/boltdb/bolt" trstate "github.com/hashicorp/nomad/client/allocrunnerv2/taskrunner/state" "github.com/hashicorp/nomad/nomad/structs" - "github.com/ugorji/go/codec" ) /* @@ -40,6 +39,7 @@ var ( taskStateKey = []byte("task_state") ) +//TODO delete from kvcodec // DeleteAllocationBucket is used to delete an allocation bucket if it exists. func DeleteAllocationBucket(tx *bolt.Tx, allocID string) error { if !tx.Writable() { @@ -61,6 +61,7 @@ func DeleteAllocationBucket(tx *bolt.Tx, allocID string) error { return allocations.DeleteBucket(key) } +//TODO delete from kvcodec // DeleteTaskBucket is used to delete a task bucket if it exists. func DeleteTaskBucket(tx *bolt.Tx, allocID, taskName string) error { if !tx.Writable() { @@ -88,13 +89,70 @@ func DeleteTaskBucket(tx *bolt.Tx, allocID, taskName string) error { return alloc.DeleteBucket(key) } -var () +// NewStateDBFunc creates a StateDB given a state directory. +type NewStateDBFunc func(stateDir string) (StateDB, error) +// GetStateDBFactory returns a func for creating a StateDB +func GetStateDBFactory(devMode bool) NewStateDBFunc { + // Return a noop state db implementation when in debug mode + if devMode { + return func(string) (StateDB, error) { + return noopDB{}, nil + } + } + + return NewBoltStateDB +} + +// BoltStateDB persists and restores Nomad client state in a boltdb. All +// methods are safe for concurrent access. Create via NewStateDB by setting +// devMode=false. +type BoltStateDB struct { + db *bolt.DB + codec *keyValueCodec +} + +func NewBoltStateDB(stateDir string) (StateDB, error) { + // Create or open the boltdb state database + db, err := bolt.Open(filepath.Join(stateDir, "state.db"), 0600, nil) + if err != nil { + return nil, fmt.Errorf("failed to create state database: %v", err) + } + + sdb := &BoltStateDB{ + db: db, + codec: newKeyValueCodec(), + } + return sdb, nil +} + +// GetAllAllocations gets all allocations persisted by this client and returns +// a map of alloc ids to errors for any allocations that could not be restored. +// +// If a fatal error was encountered it will be returned and the other two +// values will be nil. +func (s *BoltStateDB) GetAllAllocations() ([]*structs.Allocation, map[string]error, error) { + var allocs []*structs.Allocation + var errs map[string]error + err := s.db.View(func(tx *bolt.Tx) error { + allocs, errs = s.getAllAllocations(tx) + return nil + }) + + // db.View itself may return an error, so still check + if err != nil { + return nil, nil, err + } + + return allocs, errs, nil +} + +// allocEntry wraps values in the Allocations buckets type allocEntry struct { Alloc *structs.Allocation } -func getAllAllocations(tx *bolt.Tx) ([]*structs.Allocation, map[string]error) { +func (s *BoltStateDB) getAllAllocations(tx *bolt.Tx) ([]*structs.Allocation, map[string]error) { allocationsBkt := tx.Bucket(allocationsBucket) if allocationsBkt == nil { // No allocs @@ -117,7 +175,7 @@ func getAllAllocations(tx *bolt.Tx) ([]*structs.Allocation, map[string]error) { } var allocState allocEntry - if err := getObject(allocBkt, allocKey, &allocState); err != nil { + if err := s.codec.Get(allocBkt, allocKey, &allocState); err != nil { errs[allocID] = fmt.Errorf("failed to decode alloc %v", err) continue } @@ -128,52 +186,6 @@ func getAllAllocations(tx *bolt.Tx) ([]*structs.Allocation, map[string]error) { return allocs, errs } -// BoltStateDB persists and restores Nomad client state in a boltdb. All -// methods are safe for concurrent access. Create via NewStateDB by setting -// devMode=false. -type BoltStateDB struct { - db *bolt.DB -} - -func NewStateDB(stateDir string, devMode bool) (StateDB, error) { - // Return a noop state db implementation when in debug mode - if devMode { - return noopDB{}, nil - } - - // Create or open the boltdb state database - db, err := bolt.Open(filepath.Join(stateDir, "state.db"), 0600, nil) - if err != nil { - return nil, fmt.Errorf("failed to create state database: %v", err) - } - - sdb := &BoltStateDB{ - db: db, - } - return sdb, nil -} - -// GetAllAllocations gets all allocations persisted by this client and returns -// a map of alloc ids to errors for any allocations that could not be restored. -// -// If a fatal error was encountered it will be returned and the other two -// values will be nil. -func (s *BoltStateDB) GetAllAllocations() ([]*structs.Allocation, map[string]error, error) { - var allocs []*structs.Allocation - var errs map[string]error - err := s.db.View(func(tx *bolt.Tx) error { - allocs, errs = getAllAllocations(tx) - return nil - }) - - // db.View itself may return an error, so still check - if err != nil { - return nil, nil, err - } - - return allocs, errs, nil -} - // PutAllocation stores an allocation or returns an error. func (s *BoltStateDB) PutAllocation(alloc *structs.Allocation) error { return s.db.Update(func(tx *bolt.Tx) error { @@ -193,7 +205,7 @@ func (s *BoltStateDB) PutAllocation(alloc *structs.Allocation) error { allocState := allocEntry{ Alloc: alloc, } - return putObject(allocBkt, allocKey, &allocState) + return s.codec.Put(allocBkt, alloc.ID, allocKey, &allocState) }) } @@ -211,12 +223,12 @@ func (s *BoltStateDB) GetTaskRunnerState(allocID, taskName string) (*trstate.Loc // Restore Local State //XXX set persisted hash to avoid immediate write on first use? - if err := getObject(bkt, taskLocalStateKey, &ls); err != nil { + if err := s.codec.Get(bkt, taskLocalStateKey, &ls); err != nil { return fmt.Errorf("failed to read local task runner state: %v", err) } // Restore Task State - if err := getObject(bkt, taskStateKey, &ts); err != nil { + if err := s.codec.Get(bkt, taskStateKey, &ts); err != nil { return fmt.Errorf("failed to read task state: %v", err) } @@ -237,14 +249,15 @@ func (s *BoltStateDB) GetTaskRunnerState(allocID, taskName string) (*trstate.Loc // PutTaskRunnerLocalState stores TaskRunner's LocalState or returns an error. // It is up to the caller to serialize the state to bytes. -func (s *BoltStateDB) PutTaskRunnerLocalState(allocID, taskName string, buf []byte) error { +func (s *BoltStateDB) PutTaskRunnerLocalState(allocID, taskName string, val interface{}) error { return s.db.Update(func(tx *bolt.Tx) error { taskBkt, err := getTaskBucket(tx, allocID, taskName) if err != nil { return fmt.Errorf("failed to retrieve allocation bucket: %v", err) } - if err := putData(taskBkt, taskLocalStateKey, buf); err != nil { + path := strings.Join([]string{allocID, taskName, string(taskLocalStateKey)}, "-") + if err := s.codec.Put(taskBkt, path, taskLocalStateKey, val); err != nil { return fmt.Errorf("failed to write task_runner state: %v", err) } @@ -260,7 +273,8 @@ func (s *BoltStateDB) PutTaskState(allocID, taskName string, state *structs.Task return fmt.Errorf("failed to retrieve allocation bucket: %v", err) } - return putObject(taskBkt, taskStateKey, state) + path := strings.Join([]string{allocID, taskName, string(taskStateKey)}, "-") + return s.codec.Put(taskBkt, path, taskStateKey, state) }) } @@ -270,51 +284,6 @@ func (s *BoltStateDB) Close() error { return s.db.Close() } -func putObject(bkt *bolt.Bucket, key []byte, obj interface{}) error { - if !bkt.Writable() { - return fmt.Errorf("bucket must be writable") - } - - // Serialize the object - var buf bytes.Buffer - if err := codec.NewEncoder(&buf, structs.MsgpackHandle).Encode(obj); err != nil { - return fmt.Errorf("failed to encode passed object: %v", err) - } - - if err := bkt.Put(key, buf.Bytes()); err != nil { - return fmt.Errorf("failed to write data at key %v: %v", string(key), err) - } - - return nil -} - -func putData(bkt *bolt.Bucket, key, value []byte) error { - if !bkt.Writable() { - return fmt.Errorf("bucket must be writable") - } - - if err := bkt.Put(key, value); err != nil { - return fmt.Errorf("failed to write data at key %v: %v", string(key), err) - } - - return nil -} - -func getObject(bkt *bolt.Bucket, key []byte, obj interface{}) error { - // Get the data - data := bkt.Get(key) - if data == nil { - return fmt.Errorf("no data at key %v", string(key)) - } - - // Deserialize the object - if err := codec.NewDecoderBytes(data, structs.MsgpackHandle).Decode(obj); err != nil { - return fmt.Errorf("failed to decode data into passed object: %v", err) - } - - return nil -} - // getAllocationBucket returns the bucket used to persist state about a // particular allocation. If the root allocation bucket or the specific // allocation bucket doesn't exist, it will be created as long as the