diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index a2a8f4efd..1f4f533aa 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -3467,7 +3467,13 @@ func validateServices(t *Task) error { // Ensure that services don't ask for non-existent ports and their names are // unique. - servicePorts := make(map[string][]string) + servicePorts := make(map[string]map[string]struct{}) + addServicePort := func(label, service string) { + if _, ok := servicePorts[label]; !ok { + servicePorts[label] = map[string]struct{}{} + } + servicePorts[label][service] = struct{}{} + } knownServices := make(map[string]struct{}) for i, service := range t.Services { if err := service.Validate(); err != nil { @@ -3488,20 +3494,58 @@ func validateServices(t *Task) error { _, err := strconv.Atoi(service.PortLabel) if err != nil { // Not a numeric port label, add it to list to check - servicePorts[service.PortLabel] = append(servicePorts[service.PortLabel], service.Name) + addServicePort(service.PortLabel, service.Name) } } else { - servicePorts[service.PortLabel] = append(servicePorts[service.PortLabel], service.Name) + addServicePort(service.PortLabel, service.Name) } } - // Ensure that check names are unique. + // Ensure that check names are unique and have valid ports knownChecks := make(map[string]struct{}) for _, check := range service.Checks { if _, ok := knownChecks[check.Name]; ok { mErr.Errors = append(mErr.Errors, fmt.Errorf("check %q is duplicate", check.Name)) } knownChecks[check.Name] = struct{}{} + + if !check.RequiresPort() { + // No need to continue validating check if it doesn't need a port + continue + } + + effectivePort := check.PortLabel + if effectivePort == "" { + // Inherits from service + effectivePort = service.PortLabel + } + + if effectivePort == "" { + mErr.Errors = append(mErr.Errors, fmt.Errorf("check %q is missing a port", check.Name)) + continue + } + + isNumeric := false + portNumber, err := strconv.Atoi(effectivePort) + if err == nil { + isNumeric = true + } + + // Numeric ports are fine for address_mode = "driver" + if check.AddressMode == "driver" && isNumeric { + if portNumber <= 0 { + mErr.Errors = append(mErr.Errors, fmt.Errorf("check %q has invalid numeric port %d", check.Name, portNumber)) + } + continue + } + + if isNumeric { + mErr.Errors = append(mErr.Errors, fmt.Errorf(`check %q cannot use a numeric port %d without setting address_mode="driver"`, check.Name, portNumber)) + continue + } + + // PortLabel must exist, report errors by its parent service + addServicePort(effectivePort, service.Name) } } @@ -3520,7 +3564,14 @@ func validateServices(t *Task) error { for servicePort, services := range servicePorts { _, ok := portLabels[servicePort] if !ok { - joined := strings.Join(services, ", ") + names := make([]string, 0, len(services)) + for name := range services { + names = append(names, name) + } + + // Keep order deterministic + sort.Strings(names) + joined := strings.Join(names, ", ") err := fmt.Errorf("port label %q referenced by services %v does not exist", servicePort, joined) mErr.Errors = append(mErr.Errors, err) } diff --git a/nomad/structs/structs_test.go b/nomad/structs/structs_test.go index 3840b9cf4..07db00439 100644 --- a/nomad/structs/structs_test.go +++ b/nomad/structs/structs_test.go @@ -1180,6 +1180,102 @@ func TestTask_Validate_Service_Check(t *testing.T) { } } +// TestTask_Validate_Service_Check_AddressMode asserts that checks do not +// inherit address mode but do inherit ports. +func TestTask_Validate_Service_Check_AddressMode(t *testing.T) { + task := &Task{ + Resources: &Resources{ + Networks: []*NetworkResource{ + { + DynamicPorts: []Port{ + { + Label: "http", + Value: 9999, + }, + }, + }, + }, + }, + Services: []*Service{ + { + Name: "invalid-driver", + PortLabel: "80", + AddressMode: "host", + }, + { + Name: "http-driver", + PortLabel: "80", + AddressMode: "driver", + Checks: []*ServiceCheck{ + { + // Should fail + Name: "invalid-check-1", + Type: "tcp", + Interval: time.Second, + Timeout: time.Second, + }, + { + // Should fail + Name: "invalid-check-2", + Type: "tcp", + PortLabel: "80", + Interval: time.Second, + Timeout: time.Second, + }, + { + // Should fail + Name: "invalid-check-3", + Type: "tcp", + PortLabel: "missing-port-label", + Interval: time.Second, + Timeout: time.Second, + }, + { + // Should pass + Name: "valid-script-check", + Type: "script", + Command: "ok", + Interval: time.Second, + Timeout: time.Second, + }, + { + // Should pass + Name: "valid-host-check", + Type: "tcp", + PortLabel: "http", + Interval: time.Second, + Timeout: time.Second, + }, + { + // Should pass + Name: "valid-driver-check", + Type: "tcp", + AddressMode: "driver", + Interval: time.Second, + Timeout: time.Second, + }, + }, + }, + }, + } + err := validateServices(task) + if err == nil { + t.Fatalf("expected errors but task validated successfully") + } + errs := err.(*multierror.Error).Errors + if expected := 4; len(errs) != expected { + for i, err := range errs { + t.Logf("errs[%d] -> %s", i, err) + } + t.Fatalf("expected %d errors but found %d", expected, len(errs)) + } + + assert.Contains(t, errs[0].Error(), `check "invalid-check-1" cannot use a numeric port`) + assert.Contains(t, errs[1].Error(), `check "invalid-check-2" cannot use a numeric port`) + assert.Contains(t, errs[2].Error(), `port label "80" referenced`) + assert.Contains(t, errs[3].Error(), `port label "missing-port-label" referenced`) +} + func TestTask_Validate_Service_Check_CheckRestart(t *testing.T) { invalidCheckRestart := &CheckRestart{ Limit: -1,