diff --git a/nomad/state/state_store_service_regisration_test.go b/nomad/state/state_store_service_regisration_test.go new file mode 100644 index 000000000..84ca6c93d --- /dev/null +++ b/nomad/state/state_store_service_regisration_test.go @@ -0,0 +1,619 @@ +package state + +import ( + "strconv" + "testing" + + "github.com/hashicorp/go-memdb" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/stretchr/testify/require" +) + +func TestStateStore_UpsertServiceRegistrations(t *testing.T) { + t.Parallel() + testState := testStateStore(t) + + // SubTest Marker: This ensures new service registrations are inserted as + // expected with their correct indexes, along with an update to the index + // table. + services := mock.ServiceRegistrations() + insertIndex := uint64(20) + + // Perform the initial upsert of service registrations. + err := testState.UpsertServiceRegistrations(structs.MsgTypeTestSetup, insertIndex, services) + require.NoError(t, err) + + // Check that the index for the table was modified as expected. + initialIndex, err := testState.Index(TableServiceRegistrations) + require.NoError(t, err) + require.Equal(t, insertIndex, initialIndex) + + // List all the service registrations in the table, so we can perform a + // number of tests on the return array. + ws := memdb.NewWatchSet() + iter, err := testState.GetServiceRegistrations(ws) + require.NoError(t, err) + + // Count how many table entries we have, to ensure it is the expected + // number. + var count int + + for raw := iter.Next(); raw != nil; raw = iter.Next() { + count++ + + // Ensure the create and modify indexes are populated correctly. + serviceReg := raw.(*structs.ServiceRegistration) + require.Equal(t, insertIndex, serviceReg.CreateIndex, "incorrect create index", serviceReg.ID) + require.Equal(t, insertIndex, serviceReg.ModifyIndex, "incorrect modify index", serviceReg.ID) + } + require.Equal(t, 2, count, "incorrect number of service registrations found") + + // SubTest Marker: This section attempts to upsert the exact same service + // registrations without any modification. In this case, the index table + // should not be updated, indicating no write actually happened due to + // equality checking. + reInsertIndex := uint64(30) + require.NoError(t, testState.UpsertServiceRegistrations(structs.MsgTypeTestSetup, reInsertIndex, services)) + reInsertActualIndex, err := testState.Index(TableServiceRegistrations) + require.NoError(t, err) + require.Equal(t, insertIndex, reInsertActualIndex, "index should not have changed") + + // SubTest Marker: This section modifies a single one of the previously + // inserted service registrations and performs an upsert. This ensures the + // index table is modified correctly and that each service registration is + // updated, or not, as expected. + service1Update := services[0].Copy() + service1Update.Tags = []string{"modified"} + services1Update := []*structs.ServiceRegistration{service1Update} + + update1Index := uint64(40) + require.NoError(t, testState.UpsertServiceRegistrations(structs.MsgTypeTestSetup, update1Index, services1Update)) + + // Check that the index for the table was modified as expected. + updateActualIndex, err := testState.Index(TableServiceRegistrations) + require.NoError(t, err) + require.Equal(t, update1Index, updateActualIndex, "index should have changed") + + // Get the service registrations from the table. + iter, err = testState.GetServiceRegistrations(ws) + require.NoError(t, err) + + // Iterate all the stored registrations and assert they are as expected. + for raw := iter.Next(); raw != nil; raw = iter.Next() { + serviceReg := raw.(*structs.ServiceRegistration) + + var expectedModifyIndex uint64 + + switch serviceReg.ID { + case service1Update.ID: + expectedModifyIndex = update1Index + case services[1].ID: + expectedModifyIndex = insertIndex + default: + t.Errorf("unknown service registration found: %s", serviceReg.ID) + continue + } + require.Equal(t, insertIndex, serviceReg.CreateIndex, "incorrect create index", serviceReg.ID) + require.Equal(t, expectedModifyIndex, serviceReg.ModifyIndex, "incorrect modify index", serviceReg.ID) + } + + // SubTest Marker: Here we modify the second registration but send an + // upsert request that includes this and the already modified registration. + service2Update := services[1].Copy() + service2Update.Tags = []string{"modified"} + services2Update := []*structs.ServiceRegistration{service1Update, service2Update} + + update2Index := uint64(50) + require.NoError(t, testState.UpsertServiceRegistrations(structs.MsgTypeTestSetup, update2Index, services2Update)) + + // Check that the index for the table was modified as expected. + update2ActualIndex, err := testState.Index(TableServiceRegistrations) + require.NoError(t, err) + require.Equal(t, update2Index, update2ActualIndex, "index should have changed") + + // Get the service registrations from the table. + iter, err = testState.GetServiceRegistrations(ws) + require.NoError(t, err) + + // Iterate all the stored registrations and assert they are as expected. + for raw := iter.Next(); raw != nil; raw = iter.Next() { + serviceReg := raw.(*structs.ServiceRegistration) + + var ( + expectedModifyIndex uint64 + expectedServiceReg *structs.ServiceRegistration + ) + + switch serviceReg.ID { + case service2Update.ID: + expectedModifyIndex = update2Index + expectedServiceReg = service2Update + case service1Update.ID: + expectedModifyIndex = update1Index + expectedServiceReg = service1Update + default: + t.Errorf("unknown service registration found: %s", serviceReg.ID) + continue + } + require.Equal(t, insertIndex, serviceReg.CreateIndex, "incorrect create index", serviceReg.ID) + require.Equal(t, expectedModifyIndex, serviceReg.ModifyIndex, "incorrect modify index", serviceReg.ID) + require.True(t, expectedServiceReg.Equals(serviceReg)) + } +} + +func TestStateStore_DeleteServiceRegistrationByID(t *testing.T) { + t.Parallel() + testState := testStateStore(t) + + // Generate some test services that we will use and modify throughout. + services := mock.ServiceRegistrations() + + // SubTest Marker: This section attempts to delete a service registration + // by an ID that does not exist. This is easy to perform here as the state + // is empty. + initialIndex := uint64(10) + err := testState.DeleteServiceRegistrationByID( + structs.MsgTypeTestSetup, initialIndex, services[0].Namespace, services[0].ID) + require.EqualError(t, err, "service registration not found") + + actualInitialIndex, err := testState.Index(TableServiceRegistrations) + require.NoError(t, err) + require.Equal(t, uint64(0), actualInitialIndex, "index should not have changed") + + // SubTest Marker: This section upserts two registrations, deletes one, + // then ensure the remaining is left as expected. + require.NoError(t, testState.UpsertServiceRegistrations(structs.MsgTypeTestSetup, initialIndex, services)) + + // Perform the delete. + delete1Index := uint64(20) + require.NoError(t, testState.DeleteServiceRegistrationByID( + structs.MsgTypeTestSetup, delete1Index, services[0].Namespace, services[0].ID)) + + // Check that the index for the table was modified as expected. + actualDelete1Index, err := testState.Index(TableServiceRegistrations) + require.NoError(t, err) + require.Equal(t, delete1Index, actualDelete1Index, "index should have changed") + + ws := memdb.NewWatchSet() + + // Get the service registrations from the table. + iter, err := testState.GetServiceRegistrations(ws) + require.NoError(t, err) + + var delete1Count int + + // Iterate all the stored registrations and assert we have the expected + // number. + for raw := iter.Next(); raw != nil; raw = iter.Next() { + delete1Count++ + } + require.Equal(t, 1, delete1Count, "unexpected number of registrations in table") + + // SubTest Marker: Delete the remaining registration and ensure all indexes + // are updated as expected and the table is empty. + delete2Index := uint64(30) + require.NoError(t, testState.DeleteServiceRegistrationByID( + structs.MsgTypeTestSetup, delete2Index, services[1].Namespace, services[1].ID)) + + // Check that the index for the table was modified as expected. + actualDelete2Index, err := testState.Index(TableServiceRegistrations) + require.NoError(t, err) + require.Equal(t, delete2Index, actualDelete2Index, "index should have changed") + + // Get the service registrations from the table. + iter, err = testState.GetServiceRegistrations(ws) + require.NoError(t, err) + + var delete2Count int + + // Iterate all the stored registrations and assert we have the expected + // number. + for raw := iter.Next(); raw != nil; raw = iter.Next() { + delete2Count++ + } + require.Equal(t, 0, delete2Count, "unexpected number of registrations in table") +} + +func TestStateStore_DeleteServiceRegistrationByNodeID(t *testing.T) { + t.Parallel() + testState := testStateStore(t) + + // Generate some test services that we will use and modify throughout. + services := mock.ServiceRegistrations() + + // SubTest Marker: This section attempts to delete a service registration + // by a nodeID that does not exist. This is easy to perform here as the + // state is empty. + initialIndex := uint64(10) + require.NoError(t, + testState.DeleteServiceRegistrationByNodeID(structs.MsgTypeTestSetup, initialIndex, services[0].NodeID)) + + actualInitialIndex, err := testState.Index(TableServiceRegistrations) + require.NoError(t, err) + require.Equal(t, uint64(0), actualInitialIndex, "index should not have changed") + + // SubTest Marker: This section upserts two registrations then deletes one + // by using the nodeID. + require.NoError(t, testState.UpsertServiceRegistrations(structs.MsgTypeTestSetup, initialIndex, services)) + + // Perform the delete. + delete1Index := uint64(20) + require.NoError(t, testState.DeleteServiceRegistrationByNodeID( + structs.MsgTypeTestSetup, delete1Index, services[0].NodeID)) + + // Check that the index for the table was modified as expected. + actualDelete1Index, err := testState.Index(TableServiceRegistrations) + require.NoError(t, err) + require.Equal(t, delete1Index, actualDelete1Index, "index should have changed") + + ws := memdb.NewWatchSet() + + // Get the service registrations from the table. + iter, err := testState.GetServiceRegistrations(ws) + require.NoError(t, err) + + var delete1Count int + + // Iterate all the stored registrations and assert we have the expected + // number. + for raw := iter.Next(); raw != nil; raw = iter.Next() { + delete1Count++ + } + require.Equal(t, 1, delete1Count, "unexpected number of registrations in table") + + // SubTest Marker: Add multiple service registrations for a single nodeID + // then delete these via the nodeID. + delete2NodeID := services[1].NodeID + var delete2NodeServices []*structs.ServiceRegistration + + for i := 0; i < 4; i++ { + iString := strconv.Itoa(i) + delete2NodeServices = append(delete2NodeServices, &structs.ServiceRegistration{ + ID: "_nomad-task-ca60e901-675a-0ab2-2e57-2f3b05fdc540-group-api-countdash-api-http-" + iString, + ServiceName: "countdash-api-" + iString, + Namespace: "platform", + NodeID: delete2NodeID, + Datacenter: "dc2", + JobID: "countdash-api-" + iString, + AllocID: "ca60e901-675a-0ab2-2e57-2f3b05fdc54" + iString, + Tags: []string{"bar"}, + Address: "192.168.200.200", + Port: 27500 + i, + }) + } + + // Upsert the new service registrations. + delete2UpsertIndex := uint64(30) + require.NoError(t, + testState.UpsertServiceRegistrations(structs.MsgTypeTestSetup, delete2UpsertIndex, delete2NodeServices)) + + delete2Index := uint64(40) + require.NoError(t, testState.DeleteServiceRegistrationByNodeID( + structs.MsgTypeTestSetup, delete2Index, delete2NodeID)) + + // Check that the index for the table was modified as expected. + actualDelete2Index, err := testState.Index(TableServiceRegistrations) + require.NoError(t, err) + require.Equal(t, delete2Index, actualDelete2Index, "index should have changed") + + // Get the service registrations from the table. + iter, err = testState.GetServiceRegistrations(ws) + require.NoError(t, err) + + var delete2Count int + + // Iterate all the stored registrations and assert we have the expected + // number. + for raw := iter.Next(); raw != nil; raw = iter.Next() { + delete2Count++ + } + require.Equal(t, 0, delete2Count, "unexpected number of registrations in table") +} + +func TestStateStore_GetServiceRegistrations(t *testing.T) { + t.Parallel() + testState := testStateStore(t) + + // Generate some test services and upsert them. + services := mock.ServiceRegistrations() + initialIndex := uint64(10) + require.NoError(t, testState.UpsertServiceRegistrations(structs.MsgTypeTestSetup, initialIndex, services)) + + // Read the service registrations and check the objects. + ws := memdb.NewWatchSet() + iter, err := testState.GetServiceRegistrations(ws) + require.NoError(t, err) + + var count int + + for raw := iter.Next(); raw != nil; raw = iter.Next() { + count++ + + serviceReg := raw.(*structs.ServiceRegistration) + require.Equal(t, initialIndex, serviceReg.CreateIndex, "incorrect create index", serviceReg.ID) + require.Equal(t, initialIndex, serviceReg.ModifyIndex, "incorrect modify index", serviceReg.ID) + + switch serviceReg.ID { + case services[0].ID: + require.Equal(t, services[0], serviceReg) + case services[1].ID: + require.Equal(t, services[1], serviceReg) + default: + t.Errorf("unknown service registration found: %s", serviceReg.ID) + } + } + require.Equal(t, 2, count) +} + +func TestStateStore_GetServiceRegistrationsByNamespace(t *testing.T) { + t.Parallel() + testState := testStateStore(t) + + // Generate some test services and upsert them. + services := mock.ServiceRegistrations() + initialIndex := uint64(10) + require.NoError(t, testState.UpsertServiceRegistrations(structs.MsgTypeTestSetup, initialIndex, services)) + + // Look up services using the namespace of the first service. + ws := memdb.NewWatchSet() + iter, err := testState.GetServiceRegistrationsByNamespace(ws, services[0].Namespace) + require.NoError(t, err) + + var count1 int + + for raw := iter.Next(); raw != nil; raw = iter.Next() { + count1++ + serviceReg := raw.(*structs.ServiceRegistration) + require.Equal(t, initialIndex, serviceReg.CreateIndex, "incorrect create index", serviceReg.ID) + require.Equal(t, initialIndex, serviceReg.ModifyIndex, "incorrect modify index", serviceReg.ID) + require.Equal(t, services[0].Namespace, serviceReg.Namespace) + } + require.Equal(t, 1, count1) + + // Look up services using the namespace of the second service. + iter, err = testState.GetServiceRegistrationsByNamespace(ws, services[1].Namespace) + require.NoError(t, err) + + var count2 int + + for raw := iter.Next(); raw != nil; raw = iter.Next() { + count2++ + serviceReg := raw.(*structs.ServiceRegistration) + require.Equal(t, initialIndex, serviceReg.CreateIndex, "incorrect create index", serviceReg.ID) + require.Equal(t, initialIndex, serviceReg.ModifyIndex, "incorrect modify index", serviceReg.ID) + require.Equal(t, services[1].Namespace, serviceReg.Namespace) + } + require.Equal(t, 1, count2) + + // Look up services using a namespace that shouldn't contain any + // registrations. + iter, err = testState.GetServiceRegistrationsByNamespace(ws, "pony-club") + require.NoError(t, err) + + var count3 int + + for raw := iter.Next(); raw != nil; raw = iter.Next() { + count3++ + } + require.Equal(t, 0, count3) +} + +func TestStateStore_GetServiceRegistrationByName(t *testing.T) { + t.Parallel() + testState := testStateStore(t) + + // Generate some test services and upsert them. + services := mock.ServiceRegistrations() + initialIndex := uint64(10) + require.NoError(t, testState.UpsertServiceRegistrations(structs.MsgTypeTestSetup, initialIndex, services)) + + // Try reading a service by a name that shouldn't exist. + ws := memdb.NewWatchSet() + iter, err := testState.GetServiceRegistrationByName(ws, "default", "pony-glitter-api") + require.NoError(t, err) + + var count1 int + for raw := iter.Next(); raw != nil; raw = iter.Next() { + count1++ + } + require.Equal(t, 0, count1) + + // Read one of the known service registrations. + expectedReg := services[1].Copy() + + iter, err = testState.GetServiceRegistrationByName(ws, expectedReg.Namespace, expectedReg.ServiceName) + require.NoError(t, err) + + var count2 int + + for raw := iter.Next(); raw != nil; raw = iter.Next() { + count2++ + serviceReg := raw.(*structs.ServiceRegistration) + require.Equal(t, expectedReg.ServiceName, serviceReg.ServiceName) + require.Equal(t, expectedReg.Namespace, serviceReg.Namespace) + } + require.Equal(t, 1, count2) + + // Create a bunch of additional services whose name and namespace match + // that of expectedReg. + var newServices []*structs.ServiceRegistration + + for i := 0; i < 4; i++ { + iString := strconv.Itoa(i) + newServices = append(newServices, &structs.ServiceRegistration{ + ID: "_nomad-task-ca60e901-675a-0ab2-2e57-2f3b05fdc540-group-api-countdash-api-http-" + iString, + ServiceName: expectedReg.ServiceName, + Namespace: expectedReg.Namespace, + NodeID: "2873cf75-42e5-7c45-ca1c-415f3e18be3d", + Datacenter: "dc1", + JobID: expectedReg.JobID, + AllocID: "ca60e901-675a-0ab2-2e57-2f3b05fdc54" + iString, + Tags: []string{"bar"}, + Address: "192.168.200.200", + Port: 27500 + i, + }) + } + + updateIndex := uint64(20) + require.NoError(t, testState.UpsertServiceRegistrations(structs.MsgTypeTestSetup, updateIndex, newServices)) + + iter, err = testState.GetServiceRegistrationByName(ws, expectedReg.Namespace, expectedReg.ServiceName) + require.NoError(t, err) + + var count3 int + + for raw := iter.Next(); raw != nil; raw = iter.Next() { + count3++ + serviceReg := raw.(*structs.ServiceRegistration) + require.Equal(t, expectedReg.ServiceName, serviceReg.ServiceName) + require.Equal(t, expectedReg.Namespace, serviceReg.Namespace) + } + require.Equal(t, 5, count3) +} + +func TestStateStore_GetServiceRegistrationByID(t *testing.T) { + t.Parallel() + testState := testStateStore(t) + + // Generate some test services and upsert them. + services := mock.ServiceRegistrations() + initialIndex := uint64(10) + require.NoError(t, testState.UpsertServiceRegistrations(structs.MsgTypeTestSetup, initialIndex, services)) + + ws := memdb.NewWatchSet() + + // Try reading a service by an ID that shouldn't exist. + serviceReg, err := testState.GetServiceRegistrationByID(ws, "default", "pony-glitter-sparkles") + require.NoError(t, err) + require.Nil(t, serviceReg) + + // Read the two services that we should find. + serviceReg, err = testState.GetServiceRegistrationByID(ws, services[0].Namespace, services[0].ID) + require.NoError(t, err) + require.Equal(t, services[0], serviceReg) + + serviceReg, err = testState.GetServiceRegistrationByID(ws, services[1].Namespace, services[1].ID) + require.NoError(t, err) + require.Equal(t, services[1], serviceReg) +} + +func TestStateStore_GetServiceRegistrationsByAllocID(t *testing.T) { + t.Parallel() + testState := testStateStore(t) + + // Generate some test services and upsert them. + services := mock.ServiceRegistrations() + initialIndex := uint64(10) + require.NoError(t, testState.UpsertServiceRegistrations(structs.MsgTypeTestSetup, initialIndex, services)) + + ws := memdb.NewWatchSet() + + // Try reading services by an allocation that doesn't have any + // registrations. + iter, err := testState.GetServiceRegistrationsByAllocID(ws, "4eed3c6d-6bf1-60d6-040a-e347accae6c4") + require.NoError(t, err) + + var count1 int + for raw := iter.Next(); raw != nil; raw = iter.Next() { + count1++ + } + require.Equal(t, 0, count1) + + // Read the two allocations that we should find. + iter, err = testState.GetServiceRegistrationsByAllocID(ws, services[0].AllocID) + + var count2 int + for raw := iter.Next(); raw != nil; raw = iter.Next() { + count2++ + serviceReg := raw.(*structs.ServiceRegistration) + require.Equal(t, services[0].AllocID, serviceReg.AllocID) + } + require.Equal(t, 1, count2) + + iter, err = testState.GetServiceRegistrationsByAllocID(ws, services[1].AllocID) + + var count3 int + for raw := iter.Next(); raw != nil; raw = iter.Next() { + count3++ + serviceReg := raw.(*structs.ServiceRegistration) + require.Equal(t, services[1].AllocID, serviceReg.AllocID) + } + require.Equal(t, 1, count3) +} + +func TestStateStore_GetServiceRegistrationsByJobID(t *testing.T) { + t.Parallel() + testState := testStateStore(t) + + // Generate some test services and upsert them. + services := mock.ServiceRegistrations() + initialIndex := uint64(10) + require.NoError(t, testState.UpsertServiceRegistrations(structs.MsgTypeTestSetup, initialIndex, services)) + + ws := memdb.NewWatchSet() + + // Perform a query against a job that shouldn't have any registrations. + iter, err := testState.GetServiceRegistrationsByJobID(ws, "default", "tamagotchi") + require.NoError(t, err) + + var count1 int + for raw := iter.Next(); raw != nil; raw = iter.Next() { + count1++ + } + require.Equal(t, 0, count1) + + // Look up services using the namespace and jobID of the first service. + iter, err = testState.GetServiceRegistrationsByJobID(ws, services[0].Namespace, services[0].JobID) + require.NoError(t, err) + + var outputList1 []*structs.ServiceRegistration + + for raw := iter.Next(); raw != nil; raw = iter.Next() { + serviceReg := raw.(*structs.ServiceRegistration) + require.Equal(t, initialIndex, serviceReg.CreateIndex, "incorrect create index", serviceReg.ID) + require.Equal(t, initialIndex, serviceReg.ModifyIndex, "incorrect modify index", serviceReg.ID) + outputList1 = append(outputList1, serviceReg) + } + require.ElementsMatch(t, outputList1, []*structs.ServiceRegistration{services[0]}) + + // Look up services using the namespace and jobID of the second service. + iter, err = testState.GetServiceRegistrationsByJobID(ws, services[1].Namespace, services[1].JobID) + require.NoError(t, err) + + var outputList2 []*structs.ServiceRegistration + + for raw := iter.Next(); raw != nil; raw = iter.Next() { + serviceReg := raw.(*structs.ServiceRegistration) + require.Equal(t, initialIndex, serviceReg.CreateIndex, "incorrect create index", serviceReg.ID) + require.Equal(t, initialIndex, serviceReg.ModifyIndex, "incorrect modify index", serviceReg.ID) + outputList2 = append(outputList2, serviceReg) + } + require.ElementsMatch(t, outputList2, []*structs.ServiceRegistration{services[1]}) +} + +func TestStateStore_GetServiceRegistrationsByNodeID(t *testing.T) { + t.Parallel() + testState := testStateStore(t) + + // Generate some test services and upsert them. + services := mock.ServiceRegistrations() + initialIndex := uint64(10) + require.NoError(t, testState.UpsertServiceRegistrations(structs.MsgTypeTestSetup, initialIndex, services)) + + ws := memdb.NewWatchSet() + + // Perform a query against a node that shouldn't have any registrations. + serviceRegs, err := testState.GetServiceRegistrationsByNodeID(ws, "4eed3c6d-6bf1-60d6-040a-e347accae6c4") + require.NoError(t, err) + require.Len(t, serviceRegs, 0) + + // Read the two nodes that we should find entries for. + serviceRegs, err = testState.GetServiceRegistrationsByNodeID(ws, services[0].NodeID) + require.NoError(t, err) + require.Len(t, serviceRegs, 1) + + serviceRegs, err = testState.GetServiceRegistrationsByNodeID(ws, services[1].NodeID) + require.NoError(t, err) + require.Len(t, serviceRegs, 1) +} diff --git a/nomad/state/state_store_service_registration.go b/nomad/state/state_store_service_registration.go new file mode 100644 index 000000000..1db2d1d28 --- /dev/null +++ b/nomad/state/state_store_service_registration.go @@ -0,0 +1,271 @@ +package state + +import ( + "errors" + "fmt" + + "github.com/hashicorp/go-memdb" + "github.com/hashicorp/nomad/nomad/structs" +) + +// UpsertServiceRegistrations is used to insert a number of service +// registrations into the state store. It uses a single write transaction for +// efficiency, however, any error means no entries will be committed. +func (s *StateStore) UpsertServiceRegistrations( + msgType structs.MessageType, index uint64, services []*structs.ServiceRegistration) error { + + // Grab a write transaction, so we can use this across all service inserts. + txn := s.db.WriteTxnMsgT(msgType, index) + defer txn.Abort() + + // updated tracks whether any inserts have been made. This allows us to + // skip updating the index table if we do not need to. + var updated bool + + // Iterate the array of services. In the event of a single error, all + // inserts fail via the txn.Abort() defer. + for _, service := range services { + serviceUpdated, err := s.upsertServiceRegistrationTxn(index, txn, service) + if err != nil { + return err + } + // Ensure we track whether any inserts have been made. + updated = updated || serviceUpdated + } + + // If we did not perform any inserts, exit early. + if !updated { + return nil + } + + // Perform the index table update to mark the new inserts. + if err := txn.Insert(tableIndex, &IndexEntry{TableServiceRegistrations, index}); err != nil { + return fmt.Errorf("index update failed: %v", err) + } + + return txn.Commit() +} + +// upsertServiceRegistrationTxn inserts a single service registration into the +// state store using the provided write transaction. It is the responsibility +// of the caller to update the index table. +func (s *StateStore) upsertServiceRegistrationTxn( + index uint64, txn *txn, service *structs.ServiceRegistration) (bool, error) { + + existing, err := txn.First(TableServiceRegistrations, indexID, service.Namespace, service.ID) + if err != nil { + return false, fmt.Errorf("service registration lookup failed: %v", err) + } + + // Set up the indexes correctly to ensure existing indexes are maintained. + if existing != nil { + exist := existing.(*structs.ServiceRegistration) + if exist.Equals(service) { + return false, nil + } + service.CreateIndex = exist.CreateIndex + service.ModifyIndex = index + } else { + service.CreateIndex = index + service.ModifyIndex = index + } + + // Insert the service registration into the table. + if err := txn.Insert(TableServiceRegistrations, service); err != nil { + return false, fmt.Errorf("service registration insert failed: %v", err) + } + return true, nil +} + +// DeleteServiceRegistrationByID is responsible for deleting a single service +// registration based on it's ID and namespace. If the service registration is +// not found within state, an error will be returned. +func (s *StateStore) DeleteServiceRegistrationByID( + msgType structs.MessageType, index uint64, namespace, id string) error { + + txn := s.db.WriteTxnMsgT(msgType, index) + defer txn.Abort() + + if err := s.deleteServiceRegistrationByIDTxn(index, txn, namespace, id); err != nil { + return err + } + return txn.Commit() +} + +func (s *StateStore) deleteServiceRegistrationByIDTxn( + index uint64, txn *txn, namespace, id string) error { + + // Lookup the service registration by its ID and namespace. This is a + // unique index and therefore there will be a maximum of one entry. + existing, err := txn.First(TableServiceRegistrations, indexID, namespace, id) + if err != nil { + return fmt.Errorf("service registration lookup failed: %v", err) + } + if existing == nil { + return errors.New("service registration not found") + } + + // Delete the existing entry from the table. + if err := txn.Delete(TableServiceRegistrations, existing); err != nil { + return fmt.Errorf("service registration deletion failed: %v", err) + } + + // Update the index table to indicate an update has occurred. + if err := txn.Insert(tableIndex, &IndexEntry{TableServiceRegistrations, index}); err != nil { + return fmt.Errorf("index update failed: %v", err) + } + return nil +} + +// DeleteServiceRegistrationByNodeID deletes all service registrations that +// belong on a single node. If there are no registrations tied to the nodeID, +// the call will noop without an error. +func (s *StateStore) DeleteServiceRegistrationByNodeID( + msgType structs.MessageType, index uint64, nodeID string) error { + + txn := s.db.WriteTxnMsgT(msgType, index) + defer txn.Abort() + + num, err := txn.DeleteAll(TableServiceRegistrations, indexNodeID, nodeID) + if err != nil { + return fmt.Errorf("deleting service registrations failed: %v", err) + } + + // If we did not delete any entries, do not update the index table. + // Otherwise, update the table with the latest index. + switch num { + case 0: + return nil + default: + if err := txn.Insert(tableIndex, &IndexEntry{TableServiceRegistrations, index}); err != nil { + return fmt.Errorf("index update failed: %v", err) + } + } + + return txn.Commit() +} + +// GetServiceRegistrations returns an iterator that contains all service +// registrations stored within state. This is primarily useful when performing +// listings which use the namespace wildcard operator. The caller is +// responsible for ensuring ACL access is confirmed, or filtering is performed +// before responding. +func (s *StateStore) GetServiceRegistrations(ws memdb.WatchSet) (memdb.ResultIterator, error) { + txn := s.db.ReadTxn() + + // Walk the entire table. + iter, err := txn.Get(TableServiceRegistrations, indexID) + if err != nil { + return nil, fmt.Errorf("service registration lookup failed: %v", err) + } + ws.Add(iter.WatchCh()) + return iter, nil +} + +// GetServiceRegistrationsByNamespace returns an iterator that contains all +// registrations belonging to the provided namespace. +func (s *StateStore) GetServiceRegistrationsByNamespace( + ws memdb.WatchSet, namespace string) (memdb.ResultIterator, error) { + txn := s.db.ReadTxn() + + // Walk the entire table. + iter, err := txn.Get(TableServiceRegistrations, indexID+"_prefix", namespace, "") + if err != nil { + return nil, fmt.Errorf("service registration lookup failed: %v", err) + } + ws.Add(iter.WatchCh()) + + return iter, nil +} + +// GetServiceRegistrationByName returns an iterator that contains all service +// registrations whose namespace and name match the input parameters. This func +// therefore represents how to identify a single, collection of services that +// are logically grouped together. +func (s *StateStore) GetServiceRegistrationByName( + ws memdb.WatchSet, namespace, name string) (memdb.ResultIterator, error) { + + txn := s.db.ReadTxn() + + iter, err := txn.Get(TableServiceRegistrations, indexServiceName, namespace, name) + if err != nil { + return nil, fmt.Errorf("service registration lookup failed: %v", err) + } + ws.Add(iter.WatchCh()) + + return iter, nil +} + +// GetServiceRegistrationByID returns a single registration. The registration +// will be nil, if no matching entry was found; it is the responsibility of the +// caller to check for this. +func (s *StateStore) GetServiceRegistrationByID( + ws memdb.WatchSet, namespace, id string) (*structs.ServiceRegistration, error) { + + txn := s.db.ReadTxn() + + watchCh, existing, err := txn.FirstWatch(TableServiceRegistrations, indexID, namespace, id) + if err != nil { + return nil, fmt.Errorf("service registration lookup failed: %v", err) + } + ws.Add(watchCh) + + if existing != nil { + return existing.(*structs.ServiceRegistration), nil + } + return nil, nil +} + +// GetServiceRegistrationsByAllocID returns an iterator containing all the +// service registrations corresponding to a single allocation. +func (s *StateStore) GetServiceRegistrationsByAllocID( + ws memdb.WatchSet, allocID string) (memdb.ResultIterator, error) { + + txn := s.db.ReadTxn() + + iter, err := txn.Get(TableServiceRegistrations, indexAllocID, allocID) + if err != nil { + return nil, fmt.Errorf("service registration lookup failed: %v", err) + } + ws.Add(iter.WatchCh()) + + return iter, nil +} + +// GetServiceRegistrationsByJobID returns an iterator containing all the +// service registrations corresponding to a single job. +func (s *StateStore) GetServiceRegistrationsByJobID( + ws memdb.WatchSet, namespace, jobID string) (memdb.ResultIterator, error) { + + txn := s.db.ReadTxn() + + iter, err := txn.Get(TableServiceRegistrations, indexJob, namespace, jobID) + if err != nil { + return nil, fmt.Errorf("service registration lookup failed: %v", err) + } + ws.Add(iter.WatchCh()) + + return iter, nil +} + +// GetServiceRegistrationsByNodeID identifies all service registrations tied to +// the specified nodeID. This is useful for performing an in-memory lookup in +// order to avoid calling DeleteServiceRegistrationByNodeID via a Raft message. +func (s *StateStore) GetServiceRegistrationsByNodeID( + ws memdb.WatchSet, nodeID string) ([]*structs.ServiceRegistration, error) { + + txn := s.db.ReadTxn() + + iter, err := txn.Get(TableServiceRegistrations, indexNodeID, nodeID) + if err != nil { + return nil, fmt.Errorf("service registration lookup failed: %v", err) + } + ws.Add(iter.WatchCh()) + + var result []*structs.ServiceRegistration + for raw := iter.Next(); raw != nil; raw = iter.Next() { + result = append(result, raw.(*structs.ServiceRegistration)) + } + + return result, nil +}