diff --git a/api/util_test.go b/api/util_test.go index c5932d4c0..0e179b307 100644 --- a/api/util_test.go +++ b/api/util_test.go @@ -119,6 +119,10 @@ func testQuotaSpec() *QuotaSpec { RegionLimit: &Resources{ CPU: pointerOf(2000), MemoryMB: pointerOf(2000), + Devices: []*RequestedDevice{{ + Name: "nvidia/gpu/1080ti", + Count: pointerOf(uint64(2)), + }}, }, }, }, diff --git a/command/quota_apply.go b/command/quota_apply.go index a7aef1e66..6d17a90e2 100644 --- a/command/quota_apply.go +++ b/command/quota_apply.go @@ -270,6 +270,7 @@ func parseQuotaResource(result *api.Resources, list *ast.ObjectList) error { "cpu", "memory", "memory_max", + "devices", } if err := helper.CheckHCLKeys(listVal, valid); err != nil { return multierror.Prefix(err, "resources ->") @@ -280,9 +281,46 @@ func parseQuotaResource(result *api.Resources, list *ast.ObjectList) error { return err } + // Manually parse + delete(m, "devices") + if err := mapstructure.WeakDecode(m, result); err != nil { return err } + // Parse devices + if o := listVal.Filter("device"); len(o.Items) > 0 { + result.Devices = make([]*api.RequestedDevice, 0) + if err := parseDeviceResource(&result.Devices, o); err != nil { + return multierror.Prefix(err, "devices ->") + } + } + + return nil +} + +func parseDeviceResource(result *[]*api.RequestedDevice, list *ast.ObjectList) error { + for _, o := range list.Elem().Items { + // Check for invalid keys + valid := []string{ + "name", + "count", + } + if err := helper.CheckHCLKeys(o.Val, valid); err != nil { + return err + } + + var m map[string]interface{} + if err := hcl.DecodeObject(&m, o.Val); err != nil { + return err + } + + var device api.RequestedDevice + if err := mapstructure.WeakDecode(m, &device); err != nil { + return err + } + + *result = append(*result, &device) + } return nil }