From 920813677cd0eee29be05e05c29b4f8f55439bb6 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Sun, 7 Jun 2015 19:38:01 -0500 Subject: [PATCH] memdb: insert and first working --- nomad/memdb/index.go | 59 +++++++++++------ nomad/memdb/index_test.go | 48 +++++++++++--- nomad/memdb/memdb.go | 14 ++-- nomad/memdb/schema.go | 21 ++++-- nomad/memdb/schema_test.go | 8 ++- nomad/memdb/txn.go | 129 +++++++++++++++++++++++++++++++++++++ nomad/memdb/txn_test.go | 20 ++++++ 7 files changed, 257 insertions(+), 42 deletions(-) diff --git a/nomad/memdb/index.go b/nomad/memdb/index.go index 774cdb895..3d70bfc79 100644 --- a/nomad/memdb/index.go +++ b/nomad/memdb/index.go @@ -8,25 +8,42 @@ import ( // StringFieldIndex is used to extract a field from an object // using reflection and builds an index on that field. -func StringFieldIndex(field string, lowercase bool) IndexerFunc { - return func(obj interface{}) (bool, []byte, error) { - v := reflect.ValueOf(obj) - v = reflect.Indirect(v) // Derefence the pointer if any - - fv := v.FieldByName(field) - if !fv.IsValid() { - return false, nil, - fmt.Errorf("field '%s' for %#v is invalid", field, obj) - } - - val := fv.String() - if val == "" { - return false, nil, nil - } - - if lowercase { - val = strings.ToLower(val) - } - return true, []byte(val), nil - } +type StringFieldIndex struct { + Field string + Lowercase bool +} + +func (s *StringFieldIndex) FromObject(obj interface{}) (bool, []byte, error) { + v := reflect.ValueOf(obj) + v = reflect.Indirect(v) // Derefence the pointer if any + + fv := v.FieldByName(s.Field) + if !fv.IsValid() { + return false, nil, + fmt.Errorf("field '%s' for %#v is invalid", s.Field, obj) + } + + val := fv.String() + if val == "" { + return false, nil, nil + } + + if s.Lowercase { + val = strings.ToLower(val) + } + return true, []byte(val), nil +} + +func (s *StringFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("must provide only a single argument") + } + arg, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("argument must be a string: %#v", args[0]) + } + if s.Lowercase { + arg = strings.ToLower(arg) + } + return []byte(arg), nil } diff --git a/nomad/memdb/index_test.go b/nomad/memdb/index_test.go index 24d825a4b..41d755df8 100644 --- a/nomad/memdb/index_test.go +++ b/nomad/memdb/index_test.go @@ -20,11 +20,11 @@ func testObj() *TestObject { return obj } -func TestStringFieldIndex(t *testing.T) { +func TestStringFieldIndex_FromObject(t *testing.T) { obj := testObj() - indexer := StringFieldIndex("Foo", false) + indexer := StringFieldIndex{"Foo", false} - ok, val, err := indexer(obj) + ok, val, err := indexer.FromObject(obj) if err != nil { t.Fatalf("err: %v", err) } @@ -35,8 +35,8 @@ func TestStringFieldIndex(t *testing.T) { t.Fatalf("should be ok") } - lower := StringFieldIndex("Foo", true) - ok, val, err = lower(obj) + lower := StringFieldIndex{"Foo", true} + ok, val, err = lower.FromObject(obj) if err != nil { t.Fatalf("err: %v", err) } @@ -47,14 +47,14 @@ func TestStringFieldIndex(t *testing.T) { t.Fatalf("should be ok") } - badField := StringFieldIndex("NA", true) - ok, val, err = badField(obj) + badField := StringFieldIndex{"NA", true} + ok, val, err = badField.FromObject(obj) if err == nil { t.Fatalf("should get error") } - emptyField := StringFieldIndex("Empty", true) - ok, val, err = emptyField(obj) + emptyField := StringFieldIndex{"Empty", true} + ok, val, err = emptyField.FromObject(obj) if err != nil { t.Fatalf("err: %v", err) } @@ -62,3 +62,33 @@ func TestStringFieldIndex(t *testing.T) { t.Fatalf("should not ok") } } + +func TestStringFieldIndex_FromArgs(t *testing.T) { + indexer := StringFieldIndex{"Foo", false} + _, err := indexer.FromArgs() + if err == nil { + t.Fatalf("should get err") + } + + _, err = indexer.FromArgs(42) + if err == nil { + t.Fatalf("should get err") + } + + val, err := indexer.FromArgs("foo") + if err != nil { + t.Fatalf("err: %v", err) + } + if string(val) != "foo" { + t.Fatalf("foo") + } + + lower := StringFieldIndex{"Foo", true} + val, err = lower.FromArgs("Foo") + if err != nil { + t.Fatalf("err: %v", err) + } + if string(val) != "foo" { + t.Fatalf("foo") + } +} diff --git a/nomad/memdb/memdb.go b/nomad/memdb/memdb.go index a8c46b0c0..4cd47d30f 100644 --- a/nomad/memdb/memdb.go +++ b/nomad/memdb/memdb.go @@ -52,13 +52,17 @@ func (db *MemDB) Txn(write bool) *Txn { // initialize is used to setup the DB for use after creation func (db *MemDB) initialize() error { - for _, tableSchema := range db.schema.Tables { - table := iradix.New() - for _, indexSchema := range tableSchema.Indexes { + for tName, tableSchema := range db.schema.Tables { + for iName, _ := range tableSchema.Indexes { index := iradix.New() - table, _, _ = table.Insert([]byte(indexSchema.Name), index) + path := indexPath(tName, iName) + db.root, _, _ = db.root.Insert(path, index) } - db.root, _, _ = db.root.Insert([]byte(tableSchema.Name), table) } return nil } + +// indexPath returns the path from the root to the given table index +func indexPath(table, index string) []byte { + return []byte(table + "." + index) +} diff --git a/nomad/memdb/schema.go b/nomad/memdb/schema.go index cd8ec313d..454f848d6 100644 --- a/nomad/memdb/schema.go +++ b/nomad/memdb/schema.go @@ -40,6 +40,12 @@ func (s *TableSchema) Validate() error { if len(s.Indexes) == 0 { return fmt.Errorf("missing table schemas for '%s'", s.Name) } + if _, ok := s.Indexes["id"]; !ok { + return fmt.Errorf("must have id index") + } + if !s.Indexes["id"].Unique { + return fmt.Errorf("id index must be unique") + } for name, index := range s.Indexes { if name != index.Name { return fmt.Errorf("index name mis-match for '%s'", name) @@ -51,16 +57,23 @@ func (s *TableSchema) Validate() error { return nil } -// IndexerFunc is used to extract an index value from an -// object or to indicate that the index value is missing. -type IndexerFunc func(interface{}) (bool, []byte, error) +// Indexer is an interface used for defining indexes +type Indexer interface { + // FromObject is used to extract an index value from an + // object or to indicate that the index value is missing. + FromObject(raw interface{}) (bool, []byte, error) + + // ExactFromArgs is used to build an exact index lookup + // based on arguments + FromArgs(args ...interface{}) ([]byte, error) +} // IndexSchema contains the schema for an index type IndexSchema struct { Name string AllowMissing bool Unique bool - Indexer IndexerFunc + Indexer Indexer } func (s *IndexSchema) Validate() error { diff --git a/nomad/memdb/schema_test.go b/nomad/memdb/schema_test.go index 73e13e36d..139a00e00 100644 --- a/nomad/memdb/schema_test.go +++ b/nomad/memdb/schema_test.go @@ -10,7 +10,8 @@ func testValidSchema() *DBSchema { Indexes: map[string]*IndexSchema{ "id": &IndexSchema{ Name: "id", - Indexer: StringFieldIndex("ID", false), + Unique: true, + Indexer: &StringFieldIndex{Field: "ID"}, }, }, }, @@ -60,7 +61,8 @@ func TestTableSchema_Validate(t *testing.T) { Indexes: map[string]*IndexSchema{ "id": &IndexSchema{ Name: "id", - Indexer: StringFieldIndex("ID", true), + Unique: true, + Indexer: &StringFieldIndex{Field: "ID", Lowercase: true}, }, }, } @@ -83,7 +85,7 @@ func TestIndexSchema_Validate(t *testing.T) { t.Fatalf("should not validate, no indexer") } - s.Indexer = StringFieldIndex("Foo", false) + s.Indexer = &StringFieldIndex{Field: "Foo", Lowercase: false} err = s.Validate() if err != nil { t.Fatalf("should validate: %v", err) diff --git a/nomad/memdb/txn.go b/nomad/memdb/txn.go index 371be50fa..4d967b28f 100644 --- a/nomad/memdb/txn.go +++ b/nomad/memdb/txn.go @@ -6,11 +6,65 @@ import ( "github.com/hashicorp/go-immutable-radix" ) +// tableIndex is a tuple of (Table, Index) used for lookups +type tableIndex struct { + Table string + Index string +} + // Txn is a transaction against a MemDB. This can be a read or write transaction. type Txn struct { db *MemDB write bool rootTxn *iradix.Txn + + modified map[tableIndex]*iradix.Txn +} + +// readableIndex returns a transaction usable for reading the given +// index in a table. If a write transaction is in progress, we may need +// to use an existing modified txn. +func (txn *Txn) readableIndex(table, index string) *iradix.Txn { + // Look for existing transaction + if txn.write && txn.modified != nil { + key := tableIndex{table, index} + exist, ok := txn.modified[key] + if ok { + return exist + } + } + + // Create a read transaction + path := indexPath(table, index) + raw, _ := txn.rootTxn.Get(path) + indexRoot := toTree(raw) + indexTxn := indexRoot.Txn() + return indexTxn +} + +// writableIndex returns a transaction usable for modifying the +// given index in a table. +func (txn *Txn) writableIndex(table, index string) *iradix.Txn { + if txn.modified == nil { + txn.modified = make(map[tableIndex]*iradix.Txn) + } + + // Look for existing transaction + key := tableIndex{table, index} + exist, ok := txn.modified[key] + if ok { + return exist + } + + // Start a new transaction + path := indexPath(table, index) + raw, _ := txn.rootTxn.Get(path) + indexRoot := toTree(raw) + indexTxn := indexRoot.Txn() + + // Keep this open for the duration of the txn + txn.modified[key] = indexTxn + return indexTxn } // Abort is used to cancel this transaction. This is a noop for read transactions. @@ -27,6 +81,7 @@ func (txn *Txn) Abort() { // Clear the txn txn.rootTxn = nil + txn.modified = nil // Release the writer lock since this is invalid txn.db.writer.Unlock() @@ -44,11 +99,18 @@ func (txn *Txn) Commit() { return } + // Commit each sub-transaction scoped to (table, index) + for key, subTxn := range txn.modified { + path := indexPath(key.Table, key.Index) + txn.rootTxn.Insert(path, subTxn.Commit()) + } + // Update the root of the DB txn.db.root = txn.rootTxn.Commit() // Clear the txn txn.rootTxn = nil + txn.modified = nil // Release the writer lock since this is invalid txn.db.writer.Unlock() @@ -59,6 +121,33 @@ func (txn *Txn) Insert(table string, obj interface{}) error { if !txn.write { return fmt.Errorf("cannot insert in read-only transaction") } + + // Get the table schema + tableSchema, ok := txn.db.schema.Tables[table] + if !ok { + return fmt.Errorf("invalid table '%s'", table) + } + + // Lookup the object by ID first + // TODO: Handle delete if existing (update) + + for name, indexSchema := range tableSchema.Indexes { + ok, val, err := indexSchema.Indexer.FromObject(obj) + if err != nil { + return fmt.Errorf("failed to build index '%s': %v", name, err) + } + if !ok { + if indexSchema.AllowMissing { + continue + } else { + return fmt.Errorf("missing value for index '%s'", name) + } + } + + // TODO: Handle non-unique index + indexTxn := txn.writableIndex(table, name) + indexTxn.Insert(val, obj) + } return nil } @@ -69,6 +158,38 @@ func (txn *Txn) Delete(table, index string, args ...interface{}) error { return nil } +func (txn *Txn) First(table, index string, args ...interface{}) (interface{}, error) { + // Get the table schema + tableSchema, ok := txn.db.schema.Tables[table] + if !ok { + return nil, fmt.Errorf("invalid table '%s'", table) + } + + // Get the index schema + indexSchema, ok := tableSchema.Indexes[index] + if !ok { + return nil, fmt.Errorf("invalid index '%s'", index) + } + + // Get the exact match index + val, err := indexSchema.Indexer.FromArgs(args...) + if err != nil { + return nil, fmt.Errorf("index error: %v", err) + } + + // Get the index itself + indexTxn := txn.readableIndex(table, index) + + // Do an exact lookup + obj, ok := indexTxn.Get(val) + if !ok { + return nil, nil + } + + // TODO: handle non-unique index + return obj, nil +} + type ResultIterator interface { Next() interface{} } @@ -76,3 +197,11 @@ type ResultIterator interface { func (txn *Txn) Get(table, index string, args ...interface{}) (ResultIterator, error) { return nil, nil } + +// toTree is used to do a fast assertion of type in cases +// where it is known to avoid the overhead of reflection +func toTree(raw interface{}) *iradix.Tree { + return raw.(*iradix.Tree) + // TODO: Fix this + //return (*iradix.Tree)(raw.(unsafe.Pointer)) +} diff --git a/nomad/memdb/txn_test.go b/nomad/memdb/txn_test.go index 000da73bf..2eeab866e 100644 --- a/nomad/memdb/txn_test.go +++ b/nomad/memdb/txn_test.go @@ -36,3 +36,23 @@ func TestTxn_Write_AbortCommit(t *testing.T) { txn.Abort() txn.Abort() } + +func TestTxn_Insert_First(t *testing.T) { + db := testDB(t) + txn := db.Txn(true) + + obj := testObj() + err := txn.Insert("main", obj) + if err != nil { + t.Fatalf("err: %v", err) + } + + raw, err := txn.First("main", "id", obj.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + + if raw != obj { + t.Fatalf("bad: %#v %#v", raw, obj) + } +}