diff --git a/nomad/memdb/txn.go b/nomad/memdb/txn.go index 864f97582..ce9a72d6e 100644 --- a/nomad/memdb/txn.go +++ b/nomad/memdb/txn.go @@ -2,6 +2,7 @@ package memdb import ( "fmt" + "strings" "github.com/hashicorp/go-immutable-radix" ) @@ -278,26 +279,14 @@ func (txn *Txn) DeleteAll(table, index string, args ...interface{}) (int, error) // First is used to return the first matching object for // the given constraints on the index func (txn *Txn) First(table, index string, args ...interface{}) (interface{}, error) { - // Get the table schema - tableSchema, ok := txn.db.schema.Tables[table] - if !ok { - return nil, fmt.Errorf("invalid table '%s'", table) - } - - // Get the index schema - indexSchema, ok := tableSchema.Indexes[index] - if !ok { - return nil, fmt.Errorf("invalid index '%s'", index) - } - - // Get the exact match index - val, err := indexSchema.Indexer.FromArgs(args...) + // Get the index value + indexSchema, val, err := txn.getIndexValue(table, index, args...) if err != nil { - return nil, fmt.Errorf("index error: %v", err) + return nil, err } // Get the index itself - indexTxn := txn.readableIndex(table, index) + indexTxn := txn.readableIndex(table, indexSchema.Name) // Do an exact lookup if indexSchema.Unique { @@ -320,6 +309,58 @@ func (txn *Txn) First(table, index string, args ...interface{}) (interface{}, er return firstVal, nil } +// getIndexValue is used to get the IndexSchema and the value +// used to scan the index given the parameters. This handles prefix based +// scans when the index has the "_prefix" suffix. The index must support +// prefix iteration. +func (txn *Txn) getIndexValue(table, index string, args ...interface{}) (*IndexSchema, []byte, error) { + // Get the table schema + tableSchema, ok := txn.db.schema.Tables[table] + if !ok { + return nil, nil, fmt.Errorf("invalid table '%s'", table) + } + + // Check for a prefix scan + prefixScan := false + if strings.HasSuffix(index, "_prefix") { + index = strings.TrimSuffix(index, "_prefix") + prefixScan = true + } + + // Get the index schema + indexSchema, ok := tableSchema.Indexes[index] + if !ok { + return nil, nil, fmt.Errorf("invalid index '%s'", index) + } + + // Hot-path for when there are no arguments + if len(args) == 0 { + return indexSchema, nil, nil + } + + // Special case the prefix scanning + if prefixScan { + prefixIndexer, ok := indexSchema.Indexer.(PrefixIndexer) + if !ok { + return indexSchema, nil, + fmt.Errorf("index '%s' does not support prefix scanning", index) + } + + val, err := prefixIndexer.PrefixFromArgs(args...) + if err != nil { + return indexSchema, nil, fmt.Errorf("index error: %v", err) + } + return indexSchema, val, err + } + + // Get the exact match index + val, err := indexSchema.Indexer.FromArgs(args...) + if err != nil { + return indexSchema, nil, fmt.Errorf("index error: %v", err) + } + return indexSchema, val, err +} + // ResultIterator is used to iterate over a list of results // from a Get query on a table. type ResultIterator interface { @@ -329,30 +370,14 @@ type ResultIterator interface { // Get is used to construct a ResultIterator over all the // rows that match the given constraints of an index. func (txn *Txn) Get(table, index string, args ...interface{}) (ResultIterator, error) { - // Get the table schema - tableSchema, ok := txn.db.schema.Tables[table] - if !ok { - return nil, fmt.Errorf("invalid table '%s'", table) - } - - // Get the index schema - indexSchema, ok := tableSchema.Indexes[index] - if !ok { - return nil, fmt.Errorf("invalid index '%s'", index) - } - - // Get the exact match index if any arguments given - var val []byte - if len(args) > 0 { - var err error - val, err = indexSchema.Indexer.FromArgs(args...) - if err != nil { - return nil, fmt.Errorf("index error: %v", err) - } + // Get the index value to scan + indexSchema, val, err := txn.getIndexValue(table, index, args...) + if err != nil { + return nil, err } // Get the index itself - indexTxn := txn.readableIndex(table, index) + indexTxn := txn.readableIndex(table, indexSchema.Name) indexRoot := indexTxn.Root() // Collect all the objects by walking the prefix. This should obviously diff --git a/nomad/memdb/txn_test.go b/nomad/memdb/txn_test.go index 8fbd10924..5cd9803b1 100644 --- a/nomad/memdb/txn_test.go +++ b/nomad/memdb/txn_test.go @@ -427,3 +427,102 @@ func TestTxn_DeleteAll_Simple(t *testing.T) { t.Fatalf("bad: %#v", raw) } } + +func TestTxn_InsertGet_Prefix(t *testing.T) { + db := testDB(t) + txn := db.Txn(true) + + obj1 := &TestObject{ + ID: "my-cool-thing", + Foo: "foobarbaz", + } + obj2 := &TestObject{ + ID: "my-other-cool-thing", + Foo: "foozipzap", + } + + err := txn.Insert("main", obj1) + if err != nil { + t.Fatalf("err: %v", err) + } + err = txn.Insert("main", obj2) + if err != nil { + t.Fatalf("err: %v", err) + } + + checkResult := func(txn *Txn) { + // Attempt a row scan on the ID Prefix + result, err := txn.Get("main", "id_prefix") + if err != nil { + t.Fatalf("err: %v", err) + } + + if raw := result.Next(); raw != obj1 { + t.Fatalf("bad: %#v %#v", raw, obj1) + } + + if raw := result.Next(); raw != obj2 { + t.Fatalf("bad: %#v %#v", raw, obj2) + } + + if raw := result.Next(); raw != nil { + t.Fatalf("bad: %#v %#v", raw, nil) + } + + // Attempt a row scan on the ID with specific ID prefix + result, err = txn.Get("main", "id_prefix", "my-c") + if err != nil { + t.Fatalf("err: %v", err) + } + + if raw := result.Next(); raw != obj1 { + t.Fatalf("bad: %#v %#v", raw, obj1) + } + + if raw := result.Next(); raw != nil { + t.Fatalf("bad: %#v %#v", raw, nil) + } + + // Attempt a row scan secondary index + result, err = txn.Get("main", "foo_prefix", "foo") + if err != nil { + t.Fatalf("err: %v", err) + } + + if raw := result.Next(); raw != obj1 { + t.Fatalf("bad: %#v %#v", raw, obj1) + } + + if raw := result.Next(); raw != obj2 { + t.Fatalf("bad: %#v %#v", raw, obj2) + } + + if raw := result.Next(); raw != nil { + t.Fatalf("bad: %#v %#v", raw, nil) + } + + // Attempt a row scan secondary index, tigher prefix + result, err = txn.Get("main", "foo_prefix", "foob") + if err != nil { + t.Fatalf("err: %v", err) + } + + if raw := result.Next(); raw != obj1 { + t.Fatalf("bad: %#v %#v", raw, obj1) + } + + if raw := result.Next(); raw != nil { + t.Fatalf("bad: %#v %#v", raw, nil) + } + } + + // Check the results within the txn + checkResult(txn) + + // Commit and start a new read transaction + txn.Commit() + txn = db.Txn(false) + + // Check the results in a new txn + checkResult(txn) +}