diff --git a/plugins/device/cmd/nvidia/device.go b/plugins/device/cmd/nvidia/device.go index 84b22e5da..2613d8e77 100644 --- a/plugins/device/cmd/nvidia/device.go +++ b/plugins/device/cmd/nvidia/device.go @@ -3,6 +3,7 @@ package nvidia import ( "context" "fmt" + "strings" "sync" "time" @@ -29,6 +30,11 @@ const ( notAvailable = "N/A" ) +const ( + // Nvidia-container-runtime environment variable names + nvidiaVisibleDevices = "NVIDIA_VISIBLE_DEVICES" +) + var ( // pluginInfo describes the plugin pluginInfo = &base.PluginInfoResponse{ @@ -149,9 +155,50 @@ func (d *NvidiaDevice) Fingerprint(ctx context.Context) (<-chan *device.Fingerpr return outCh, nil } -// Reserve returns information on how to mount the given devices. +type reservationError struct { + notExistingIDs []string +} + +func (e *reservationError) Error() string { + return fmt.Sprintf("unknown device IDs: %s", strings.Join(e.notExistingIDs, ",")) +} + +// Reserve returns information on how to mount given devices. +// Assumption is made that nomad server is responsible for correctness of +// GPU allocations, handling tricky cases such as double-allocation of single GPU func (d *NvidiaDevice) Reserve(deviceIDs []string) (*device.ContainerReservation, error) { - return nil, nil + if len(deviceIDs) == 0 { + return &device.ContainerReservation{}, nil + } + // Due to the asynchronous nature of NvidiaPlugin, there is a possibility + // of race condition + // + // Timeline: + // 1 - fingerprint reports that GPU with id "1" is present + // 2 - the following events happen at the same time: + // a) server decides to allocate GPU with id "1" + // b) fingerprint check reports that GPU with id "1" is no more present + // + // The latest and always valid version of fingerprinted ids are stored in + // d.devices map. To avoid this race condition an error is returned if + // any of provided deviceIDs is not found in d.devices map + d.deviceLock.RLock() + var notExistingIDs []string + for _, id := range deviceIDs { + if _, deviceIDExists := d.devices[id]; !deviceIDExists { + notExistingIDs = append(notExistingIDs, id) + } + } + d.deviceLock.RUnlock() + if len(notExistingIDs) != 0 { + return nil, &reservationError{notExistingIDs} + } + + return &device.ContainerReservation{ + Envs: map[string]string{ + nvidiaVisibleDevices: strings.Join(deviceIDs, ","), + }, + }, nil } // Stats streams statistics for the detected devices. diff --git a/plugins/device/cmd/nvidia/device_test.go b/plugins/device/cmd/nvidia/device_test.go index 43ef1f570..b1fa4b17a 100644 --- a/plugins/device/cmd/nvidia/device_test.go +++ b/plugins/device/cmd/nvidia/device_test.go @@ -1,7 +1,13 @@ package nvidia import ( + "testing" + "github.com/hashicorp/nomad/plugins/device/cmd/nvidia/nvml" + + hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/nomad/plugins/device" + "github.com/stretchr/testify/require" ) type MockNvmlClient struct { @@ -19,3 +25,91 @@ func (c *MockNvmlClient) GetFingerprintData() (*nvml.FingerprintData, error) { func (c *MockNvmlClient) GetStatsData() ([]*nvml.StatsData, error) { return c.StatsResponseReturned, c.StatsError } + +func TestReserve(t *testing.T) { + for _, testCase := range []struct { + Name string + ExpectedReservation *device.ContainerReservation + ExpectedError error + Device *NvidiaDevice + RequestedIDs []string + }{ + { + Name: "All RequestedIDs are not managed by Device", + ExpectedReservation: nil, + ExpectedError: &reservationError{[]string{ + "UUID1", + "UUID2", + "UUID3", + }}, + RequestedIDs: []string{ + "UUID1", + "UUID2", + "UUID3", + }, + Device: &NvidiaDevice{ + logger: hclog.NewNullLogger(), + }, + }, + { + Name: "Some RequestedIDs are not managed by Device", + ExpectedReservation: nil, + ExpectedError: &reservationError{[]string{ + "UUID1", + "UUID2", + }}, + RequestedIDs: []string{ + "UUID1", + "UUID2", + "UUID3", + }, + Device: &NvidiaDevice{ + devices: map[string]struct{}{ + "UUID3": {}, + }, + logger: hclog.NewNullLogger(), + }, + }, + { + Name: "All RequestedIDs are managed by Device", + ExpectedReservation: &device.ContainerReservation{ + Envs: map[string]string{ + nvidiaVisibleDevices: "UUID1,UUID2,UUID3", + }, + }, + ExpectedError: nil, + RequestedIDs: []string{ + "UUID1", + "UUID2", + "UUID3", + }, + Device: &NvidiaDevice{ + devices: map[string]struct{}{ + "UUID1": {}, + "UUID2": {}, + "UUID3": {}, + }, + logger: hclog.NewNullLogger(), + }, + }, + { + Name: "No IDs requested", + ExpectedReservation: &device.ContainerReservation{}, + ExpectedError: nil, + RequestedIDs: nil, + Device: &NvidiaDevice{ + devices: map[string]struct{}{ + "UUID1": {}, + "UUID2": {}, + "UUID3": {}, + }, + logger: hclog.NewNullLogger(), + }, + }, + } { + actualReservation, actualError := testCase.Device.Reserve(testCase.RequestedIDs) + req := require.New(t) + req.Equal(testCase.ExpectedReservation, actualReservation) + req.Equal(testCase.ExpectedError, actualError) + } +}