tests: swap testify for test in more places (#20028)

* tests: swap testify for test in plugins/csi/client_test.go

* tests: swap testify for test in testutil/

* tests: swap testify for test in host_test.go

* tests: swap testify for test in plugin_test.go

* tests: swap testify for test in utils_test.go

* tests: swap testify for test in scheduler/

* tests: swap testify for test in parse_test.go

* tests: swap testify for test in attribute_test.go

* tests: swap testify for test in plugins/drivers/

* tests: swap testify for test in command/

* tests: fixup some test usages

* go: run go mod tidy

* windows: cpuset test only on linux
This commit is contained in:
Seth Hoenig
2024-02-29 12:11:35 -06:00
committed by GitHub
parent c2fe51bf11
commit 4d83733909
106 changed files with 1297 additions and 1567 deletions

View File

@@ -11,14 +11,13 @@ import (
"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/stretchr/testify/require"
"github.com/shoenig/test/must"
"github.com/zclconf/go-cty/cty"
"github.com/zclconf/go-cty/cty/msgpack"
)
func TestBasePlugin_PluginInfo_GRPC(t *testing.T) {
ci.Parallel(t)
require := require.New(t)
var (
apiVersions = []string{"v0.1.0", "v0.1.1"}
@@ -69,22 +68,20 @@ func TestBasePlugin_PluginInfo_GRPC(t *testing.T) {
}
resp, err := impl.PluginInfo()
require.NoError(err)
require.Equal(apiVersions, resp.PluginApiVersions)
require.Equal(pluginVersion, resp.PluginVersion)
require.Equal(pluginName, resp.Name)
require.Equal(PluginTypeDriver, resp.Type)
must.NoError(t, err)
must.Eq(t, apiVersions, resp.PluginApiVersions)
must.Eq(t, pluginVersion, resp.PluginVersion)
must.Eq(t, pluginName, resp.Name)
must.Eq(t, PluginTypeDriver, resp.Type)
// Swap the implementation to return an unknown type
mock.PluginInfoF = unknownType
_, err = impl.PluginInfo()
require.Error(err)
require.Contains(err.Error(), "unknown type")
must.ErrorContains(t, err, "unknown type")
}
func TestBasePlugin_ConfigSchema(t *testing.T) {
ci.Parallel(t)
require := require.New(t)
mock := &MockPlugin{
ConfigSchemaF: func() (*hclspec.Spec, error) {
@@ -99,23 +96,18 @@ func TestBasePlugin_ConfigSchema(t *testing.T) {
defer client.Close()
raw, err := client.Dispense(PluginTypeBase)
if err != nil {
t.Fatalf("err: %s", err)
}
must.NoError(t, err)
impl, ok := raw.(BasePlugin)
if !ok {
t.Fatalf("bad: %#v", raw)
}
must.True(t, ok)
specOut, err := impl.ConfigSchema()
require.NoError(err)
require.True(pb.Equal(TestSpec, specOut))
must.NoError(t, err)
must.True(t, pb.Equal(TestSpec, specOut))
}
func TestBasePlugin_SetConfig(t *testing.T) {
ci.Parallel(t)
require := require.New(t)
var receivedData []byte
mock := &MockPlugin{
@@ -138,29 +130,25 @@ func TestBasePlugin_SetConfig(t *testing.T) {
defer client.Close()
raw, err := client.Dispense(PluginTypeBase)
if err != nil {
t.Fatalf("err: %s", err)
}
must.NoError(t, err)
impl, ok := raw.(BasePlugin)
if !ok {
t.Fatalf("bad: %#v", raw)
}
must.True(t, ok)
config := cty.ObjectVal(map[string]cty.Value{
"foo": cty.StringVal("v1"),
"bar": cty.NumberIntVal(1337),
"baz": cty.BoolVal(true),
})
cdata, err := msgpack.Marshal(config, config.Type())
require.NoError(err)
require.NoError(impl.SetConfig(&Config{PluginConfig: cdata}))
require.Equal(cdata, receivedData)
must.NoError(t, err)
must.NoError(t, impl.SetConfig(&Config{PluginConfig: cdata}))
must.Eq(t, cdata, receivedData)
// Decode the value back
var actual TestConfig
require.NoError(structs.Decode(receivedData, &actual))
require.Equal("v1", actual.Foo)
require.EqualValues(1337, actual.Bar)
require.True(actual.Baz)
must.NoError(t, structs.Decode(receivedData, &actual))
must.Eq(t, "v1", actual.Foo)
must.Eq(t, 1337, actual.Bar)
must.True(t, actual.Baz)
}

View File

@@ -14,7 +14,6 @@ import (
csipbv1 "github.com/container-storage-interface/spec/lib/go/csi"
"github.com/golang/protobuf/ptypes/wrappers"
"github.com/shoenig/test/must"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
@@ -102,10 +101,10 @@ func TestClient_RPC_PluginProbe(t *testing.T) {
resp, err := client.PluginProbe(context.TODO())
if tc.ExpectedErr != nil {
require.EqualError(t, err, tc.ExpectedErr.Error())
must.EqError(t, err, tc.ExpectedErr.Error())
}
require.Equal(t, tc.ExpectedResponse, resp)
must.Eq(t, tc.ExpectedResponse, resp)
})
}
@@ -156,11 +155,11 @@ func TestClient_RPC_PluginInfo(t *testing.T) {
name, version, err := client.PluginGetInfo(context.TODO())
if tc.ExpectedErr != nil {
require.EqualError(t, err, tc.ExpectedErr.Error())
must.EqError(t, err, tc.ExpectedErr.Error())
}
require.Equal(t, tc.ExpectedResponseName, name)
require.Equal(t, tc.ExpectedResponseVersion, version)
must.Eq(t, tc.ExpectedResponseName, name)
must.Eq(t, tc.ExpectedResponseVersion, version)
})
}
@@ -223,10 +222,10 @@ func TestClient_RPC_PluginGetCapabilities(t *testing.T) {
resp, err := client.PluginGetCapabilities(context.TODO())
if tc.ExpectedErr != nil {
require.EqualError(t, err, tc.ExpectedErr.Error())
must.EqError(t, err, tc.ExpectedErr.Error())
}
require.Equal(t, tc.ExpectedResponse, resp)
must.Eq(t, tc.ExpectedResponse, resp)
})
}
}
@@ -323,10 +322,10 @@ func TestClient_RPC_ControllerGetCapabilities(t *testing.T) {
resp, err := client.ControllerGetCapabilities(context.TODO())
if tc.ExpectedErr != nil {
require.EqualError(t, err, tc.ExpectedErr.Error())
must.EqError(t, err, tc.ExpectedErr.Error())
}
require.Equal(t, tc.ExpectedResponse, resp)
must.Eq(t, tc.ExpectedResponse, resp)
})
}
}
@@ -383,10 +382,10 @@ func TestClient_RPC_NodeGetCapabilities(t *testing.T) {
resp, err := client.NodeGetCapabilities(context.TODO())
if tc.ExpectedErr != nil {
require.EqualError(t, err, tc.ExpectedErr.Error())
must.EqError(t, err, tc.ExpectedErr.Error())
}
require.Equal(t, tc.ExpectedResponse, resp)
must.Eq(t, tc.ExpectedResponse, resp)
})
}
}
@@ -450,10 +449,10 @@ func TestClient_RPC_ControllerPublishVolume(t *testing.T) {
resp, err := client.ControllerPublishVolume(context.TODO(), tc.Request)
if tc.ExpectedErr != nil {
require.EqualError(t, err, tc.ExpectedErr.Error())
must.EqError(t, err, tc.ExpectedErr.Error())
}
require.Equal(t, tc.ExpectedResponse, resp)
must.Eq(t, tc.ExpectedResponse, resp)
})
}
}
@@ -498,10 +497,10 @@ func TestClient_RPC_ControllerUnpublishVolume(t *testing.T) {
resp, err := client.ControllerUnpublishVolume(context.TODO(), tc.Request)
if tc.ExpectedErr != nil {
require.EqualError(t, err, tc.ExpectedErr.Error())
must.EqError(t, err, tc.ExpectedErr.Error())
}
require.Equal(t, tc.ExpectedResponse, resp)
must.Eq(t, tc.ExpectedResponse, resp)
})
}
}
@@ -723,9 +722,9 @@ func TestClient_RPC_ControllerValidateVolume(t *testing.T) {
err := client.ControllerValidateCapabilities(context.TODO(), req)
if tc.ExpectedErr != nil {
require.EqualError(t, err, tc.ExpectedErr.Error())
must.EqError(t, err, tc.ExpectedErr.Error())
} else {
require.NoError(t, err, tc.Name)
must.NoError(t, err, must.Sprint("name", tc.Name))
}
})
}
@@ -832,24 +831,24 @@ func TestClient_RPC_ControllerCreateVolume(t *testing.T) {
resp, err := client.ControllerCreateVolume(context.TODO(), req)
if tc.ExpectedErr != nil {
require.EqualError(t, err, tc.ExpectedErr.Error())
must.EqError(t, err, tc.ExpectedErr.Error())
return
}
require.NoError(t, err, tc.Name)
must.NoError(t, err, must.Sprint("name", tc.Name))
if tc.Response == nil {
require.Nil(t, resp)
must.Nil(t, resp)
return
}
if tc.CapacityRange != nil {
require.Greater(t, resp.Volume.CapacityBytes, int64(0))
must.Greater(t, 0, resp.Volume.CapacityBytes)
}
if tc.ContentSource != nil {
require.Equal(t, tc.ContentSource.CloneID, resp.Volume.ContentSource.CloneID)
require.Equal(t, tc.ContentSource.SnapshotID, resp.Volume.ContentSource.SnapshotID)
must.Eq(t, tc.ContentSource.CloneID, resp.Volume.ContentSource.CloneID)
must.Eq(t, tc.ContentSource.SnapshotID, resp.Volume.ContentSource.SnapshotID)
}
if tc.Response != nil && tc.Response.Volume != nil {
require.Len(t, resp.Volume.AccessibleTopology, 1)
require.Equal(t,
must.SliceLen(t, 1, resp.Volume.AccessibleTopology)
must.Eq(t,
req.AccessibilityRequirements.Requisite[0].Segments,
resp.Volume.AccessibleTopology[0].Segments,
)
@@ -894,10 +893,10 @@ func TestClient_RPC_ControllerDeleteVolume(t *testing.T) {
cc.NextErr = tc.ResponseErr
err := client.ControllerDeleteVolume(context.TODO(), tc.Request)
if tc.ExpectedErr != nil {
require.EqualError(t, err, tc.ExpectedErr.Error())
must.EqError(t, err, tc.ExpectedErr.Error())
return
}
require.NoError(t, err, tc.Name)
must.NoError(t, err, must.Sprint("name", tc.Name))
})
}
}
@@ -987,11 +986,11 @@ func TestClient_RPC_ControllerListVolume(t *testing.T) {
resp, err := client.ControllerListVolumes(context.TODO(), tc.Request)
if tc.ExpectedErr != nil {
require.EqualError(t, err, tc.ExpectedErr.Error())
must.EqError(t, err, tc.ExpectedErr.Error())
return
}
require.NoError(t, err, tc.Name)
require.NotNil(t, resp)
must.NoError(t, err, must.Sprint("name", tc.Name))
must.NotNil(t, resp)
})
}
@@ -1054,11 +1053,11 @@ func TestClient_RPC_ControllerCreateSnapshot(t *testing.T) {
// from protobuf to our struct
resp, err := client.ControllerCreateSnapshot(context.TODO(), tc.Request)
if tc.ExpectedErr != nil {
require.EqualError(t, err, tc.ExpectedErr.Error())
must.EqError(t, err, tc.ExpectedErr.Error())
} else {
require.NoError(t, err, tc.Name)
require.NotZero(t, resp.Snapshot.CreateTime)
require.Equal(t, now.Second(), time.Unix(resp.Snapshot.CreateTime, 0).Second())
must.NoError(t, err, must.Sprint("name", tc.Name))
must.Positive(t, resp.Snapshot.CreateTime)
must.Eq(t, now.Second(), time.Unix(resp.Snapshot.CreateTime, 0).Second())
}
})
}
@@ -1099,10 +1098,10 @@ func TestClient_RPC_ControllerDeleteSnapshot(t *testing.T) {
cc.NextErr = tc.ResponseErr
err := client.ControllerDeleteSnapshot(context.TODO(), tc.Request)
if tc.ExpectedErr != nil {
require.EqualError(t, err, tc.ExpectedErr.Error())
must.EqError(t, err, tc.ExpectedErr.Error())
return
}
require.NoError(t, err, tc.Name)
must.NoError(t, err, must.Sprint("name", tc.Name))
})
}
}
@@ -1162,14 +1161,14 @@ func TestClient_RPC_ControllerListSnapshots(t *testing.T) {
resp, err := client.ControllerListSnapshots(context.TODO(), tc.Request)
if tc.ExpectedErr != nil {
require.EqualError(t, err, tc.ExpectedErr.Error())
must.EqError(t, err, tc.ExpectedErr.Error())
return
}
require.NoError(t, err, tc.Name)
require.NotNil(t, resp)
require.Len(t, resp.Entries, 1)
require.NotZero(t, resp.Entries[0].Snapshot.CreateTime)
require.Equal(t, now.Second(),
must.NoError(t, err, must.Sprint("name", tc.Name))
must.NotNil(t, resp)
must.Len(t, 1, resp.Entries)
must.Positive(t, resp.Entries[0].Snapshot.CreateTime)
must.Eq(t, now.Second(),
time.Unix(resp.Entries[0].Snapshot.CreateTime, 0).Second())
})
}
@@ -1359,9 +1358,9 @@ func TestClient_RPC_NodeStageVolume(t *testing.T) {
VolumeCapability: &VolumeCapability{},
})
if tc.ExpectedErr != nil {
require.EqualError(t, err, tc.ExpectedErr.Error())
must.EqError(t, err, tc.ExpectedErr.Error())
} else {
require.Nil(t, err)
must.NoError(t, err)
}
})
}
@@ -1398,9 +1397,9 @@ func TestClient_RPC_NodeUnstageVolume(t *testing.T) {
err := client.NodeUnstageVolume(context.TODO(), "foo", "/foo")
if tc.ExpectedErr != nil {
require.EqualError(t, err, tc.ExpectedErr.Error())
must.EqError(t, err, tc.ExpectedErr.Error())
} else {
require.Nil(t, err)
must.NoError(t, err)
}
})
}
@@ -1456,9 +1455,9 @@ func TestClient_RPC_NodePublishVolume(t *testing.T) {
err := client.NodePublishVolume(context.TODO(), tc.Request)
if tc.ExpectedErr != nil {
require.EqualError(t, err, tc.ExpectedErr.Error())
must.EqError(t, err, tc.ExpectedErr.Error())
} else {
require.Nil(t, err)
must.NoError(t, err)
}
})
}
@@ -1511,9 +1510,9 @@ func TestClient_RPC_NodeUnpublishVolume(t *testing.T) {
err := client.NodeUnpublishVolume(context.TODO(), tc.ExternalID, tc.TargetPath)
if tc.ExpectedErr != nil {
require.EqualError(t, err, tc.ExpectedErr.Error())
must.EqError(t, err, tc.ExpectedErr.Error())
} else {
require.Nil(t, err)
must.NoError(t, err)
}
})
}

View File

@@ -8,7 +8,6 @@ import (
"fmt"
"io"
"os"
"reflect"
"regexp"
"runtime"
"strings"
@@ -20,7 +19,7 @@ import (
"github.com/hashicorp/nomad/plugins/drivers"
dproto "github.com/hashicorp/nomad/plugins/drivers/proto"
"github.com/hashicorp/nomad/testutil"
"github.com/stretchr/testify/require"
"github.com/shoenig/test/must"
)
func ExecTaskStreamingConformanceTests(t *testing.T, driver *DriverHarness, taskID string) {
@@ -121,30 +120,29 @@ func TestExecTaskStreamingBasicResponses(t *testing.T, driver *DriverHarness, ta
result := execTask(t, driver, taskID, c.Command, c.Tty, c.Stdin)
require.Equal(t, c.ExitCode, result.exitCode)
must.Eq(t, c.ExitCode, result.exitCode)
switch s := c.Stdout.(type) {
case string:
require.Equal(t, s, result.stdout)
must.Eq(t, s, result.stdout)
case *regexp.Regexp:
require.Regexp(t, s, result.stdout)
must.RegexMatch(t, s, result.stdout)
case nil:
require.Empty(t, result.stdout)
must.Eq(t, "", result.stdout)
default:
require.Fail(t, "unexpected stdout type", "found %v (%v), but expected string or regexp", s, reflect.TypeOf(s))
t.Fatal("unexpected type")
}
switch s := c.Stderr.(type) {
case string:
require.Equal(t, s, result.stderr)
must.Eq(t, s, result.stderr)
case *regexp.Regexp:
require.Regexp(t, s, result.stderr)
must.RegexMatch(t, s, result.stderr)
case nil:
require.Empty(t, result.stderr)
must.Eq(t, "", result.stderr)
default:
require.Fail(t, "unexpected stderr type", "found %v (%v), but expected string or regexp", s, reflect.TypeOf(s))
t.Fatal("unexpected type")
}
})
}
}
@@ -154,7 +152,7 @@ func TestExecTaskStreamingBasicResponses(t *testing.T, driver *DriverHarness, ta
func TestExecFSIsolation(t *testing.T, driver *DriverHarness, taskID string) {
t.Run("isolation", func(t *testing.T) {
caps, err := driver.Capabilities()
require.NoError(t, err)
must.NoError(t, err)
isolated := (caps.FSIsolation != drivers.FSIsolationNone)
@@ -164,7 +162,7 @@ func TestExecFSIsolation(t *testing.T, driver *DriverHarness, taskID string) {
w := execTask(t, driver, taskID,
fmt.Sprintf(`FILE=$(mktemp); echo "$FILE"; echo %q >> "${FILE}"`, text),
false, "")
require.Zero(t, w.exitCode)
must.Zero(t, w.exitCode)
tempfile := strings.TrimSpace(w.stdout)
if !isolated {
@@ -176,26 +174,26 @@ func TestExecFSIsolation(t *testing.T, driver *DriverHarness, taskID string) {
// read from host
b, err := os.ReadFile(tempfile)
if !isolated {
require.NoError(t, err)
require.Equal(t, text, strings.TrimSpace(string(b)))
must.NoError(t, err)
must.Eq(t, text, strings.TrimSpace(string(b)))
} else {
require.Error(t, err)
require.True(t, os.IsNotExist(err))
must.Error(t, err)
must.True(t, os.IsNotExist(err))
}
// read should succeed from task again
r := execTask(t, driver, taskID,
fmt.Sprintf("cat %q", tempfile),
false, "")
require.Zero(t, r.exitCode)
require.Equal(t, text, strings.TrimSpace(r.stdout))
must.Zero(t, r.exitCode)
must.Eq(t, text, strings.TrimSpace(r.stdout))
// we always run in a cgroup - testing freezer cgroup
r = execTask(t, driver, taskID,
"cat /proc/self/cgroup",
false, "",
)
require.Zero(t, r.exitCode)
must.Zero(t, r.exitCode)
switch cgroupslib.GetMode() {
@@ -214,7 +212,7 @@ func TestExecFSIsolation(t *testing.T, driver *DriverHarness, taskID string) {
}
}
if !ok {
require.Fail(t, "unexpected freezer cgroup", "expected freezer to be /nomad/ or /docker/, but found:\n%s", r.stdout)
t.Fatal("unexpected freezer cgroup")
}
case cgroupslib.CG2:
info, _ := driver.PluginInfo()
@@ -225,7 +223,7 @@ func TestExecFSIsolation(t *testing.T, driver *DriverHarness, taskID string) {
t.Skip("/proc/self/cgroup not useful in docker cgroups.v2")
}
// e.g. 0::/testing.slice/5bdbd6c2-8aba-3ab2-728b-0ff3a81727a9.sleep.scope
require.True(t, strings.HasSuffix(strings.TrimSpace(r.stdout), ".scope"), "actual stdout %q", r.stdout)
must.True(t, strings.HasSuffix(strings.TrimSpace(r.stdout), ".scope"), must.Sprintf("actual stdout %q", r.stdout))
}
})
}
@@ -249,27 +247,27 @@ func execTask(t *testing.T, driver *DriverHarness, taskID string, cmd string, tt
isRaw = true
err := raw.ExecTaskStreamingRaw(ctx, taskID,
command, tty, stream)
require.NoError(t, err)
must.NoError(t, err)
} else if d, ok := driver.impl.(drivers.ExecTaskStreamingDriver); ok {
execOpts, errCh := drivers.StreamToExecOptions(ctx, command, tty, stream)
r, err := d.ExecTaskStreaming(ctx, taskID, execOpts)
require.NoError(t, err)
must.NoError(t, err)
select {
case err := <-errCh:
require.NoError(t, err)
must.NoError(t, err)
default:
// all good
}
exitCode = r.ExitCode
} else {
require.Fail(t, "driver does not support exec")
t.Fatal("driver does not support exec")
}
result := stream.currentResult()
require.NoError(t, result.err)
must.NoError(t, result.err)
if !isRaw {
result.exitCode = exitCode

View File

@@ -25,7 +25,7 @@ import (
"github.com/hashicorp/nomad/plugins/drivers"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
testing "github.com/mitchellh/go-testing-interface"
"github.com/stretchr/testify/require"
"github.com/shoenig/test/must"
)
type DriverHarness struct {
@@ -55,7 +55,7 @@ func NewDriverHarness(t testing.T, d drivers.DriverPlugin) *DriverHarness {
)
raw, err := client.Dispense(base.PluginTypeDriver)
require.NoError(t, err, "failed to dispense plugin")
must.NoError(t, err)
dClient := raw.(drivers.DriverPlugin)
return &DriverHarness{
@@ -80,21 +80,21 @@ func (h *DriverHarness) Kill() {
// between tests.
func (h *DriverHarness) MkAllocDir(t *drivers.TaskConfig, enableLogs bool) func() {
dir, err := os.MkdirTemp("", "nomad_driver_harness-")
require.NoError(h.t, err)
must.NoError(h.t, err)
allocDir := allocdir.NewAllocDir(h.logger, dir, t.AllocID)
require.NoError(h.t, allocDir.Build())
must.NoError(h.t, allocDir.Build())
t.AllocDir = allocDir.AllocDir
taskDir := allocDir.NewTaskDir(t.Name)
caps, err := h.Capabilities()
require.NoError(h.t, err)
must.NoError(h.t, err)
fsi := caps.FSIsolation
h.logger.Trace("FS isolation", "fsi", fsi)
require.NoError(h.t, taskDir.Build(fsi == drivers.FSIsolationChroot, ci.TinyChroot))
must.NoError(h.t, taskDir.Build(fsi == drivers.FSIsolationChroot, ci.TinyChroot))
task := &structs.Task{
Name: t.Name,
@@ -142,7 +142,7 @@ func (h *DriverHarness) MkAllocDir(t *drivers.TaskConfig, enableLogs bool) func(
MaxFiles: 10,
MaxFileSizeMB: 10,
})
require.NoError(h.t, err)
must.NoError(h.t, err)
return func() {
lm.Stop()

View File

@@ -16,7 +16,7 @@ import (
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/drivers"
pstructs "github.com/hashicorp/nomad/plugins/shared/structs"
"github.com/stretchr/testify/require"
"github.com/shoenig/test/must"
)
var _ drivers.DriverPlugin = (*MockDriver)(nil)
@@ -34,8 +34,8 @@ func TestDriverHarness(t *testing.T) {
harness := NewDriverHarness(t, d)
defer harness.Kill()
actual, _, err := harness.StartTask(&drivers.TaskConfig{})
require.NoError(t, err)
require.Equal(t, handle.Config.Name, actual.Config.Name)
must.NoError(t, err)
must.Eq(t, handle.Config.Name, actual.Config.Name)
}
type testDriverState struct {
@@ -45,7 +45,6 @@ type testDriverState struct {
func TestBaseDriver_Fingerprint(t *testing.T) {
ci.Parallel(t)
require := require.New(t)
fingerprints := []*drivers.Fingerprint{
{
@@ -81,7 +80,7 @@ func TestBaseDriver_Fingerprint(t *testing.T) {
defer harness.Kill()
ch, err := harness.Fingerprint(context.Background())
require.NoError(err)
must.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
@@ -89,25 +88,24 @@ func TestBaseDriver_Fingerprint(t *testing.T) {
defer wg.Done()
select {
case f := <-ch:
require.Exactly(f, fingerprints[0])
must.Eq(t, f, fingerprints[0])
case <-time.After(1 * time.Second):
require.Fail("did not receive fingerprint[0]")
t.Fatal("did not receive fingerprint[0]")
}
select {
case f := <-ch:
require.Exactly(f, fingerprints[1])
must.Eq(t, f, fingerprints[1])
case <-time.After(1 * time.Second):
require.Fail("did not receive fingerprint[1]")
t.Fatal("did not receive fingerprint[1]")
}
}()
require.False(complete.Load().(bool))
must.False(t, complete.Load().(bool))
wg.Wait()
require.True(complete.Load().(bool))
must.True(t, complete.Load().(bool))
}
func TestBaseDriver_RecoverTask(t *testing.T) {
ci.Parallel(t)
require := require.New(t)
// build driver state and encode it into proto msg
state := testDriverState{Pid: 1, Log: "foo"}
@@ -119,8 +117,8 @@ func TestBaseDriver_RecoverTask(t *testing.T) {
impl := &MockDriver{
RecoverTaskF: func(h *drivers.TaskHandle) error {
var actual testDriverState
require.NoError(h.GetDriverState(&actual))
require.Equal(state, actual)
must.NoError(t, h.GetDriverState(&actual))
must.Eq(t, state, actual)
return nil
},
}
@@ -132,12 +130,11 @@ func TestBaseDriver_RecoverTask(t *testing.T) {
DriverState: buf.Bytes(),
}
err := harness.RecoverTask(handle)
require.NoError(err)
must.NoError(t, err)
}
func TestBaseDriver_StartTask(t *testing.T) {
ci.Parallel(t)
require := require.New(t)
cfg := &drivers.TaskConfig{
ID: "foo",
@@ -157,19 +154,18 @@ func TestBaseDriver_StartTask(t *testing.T) {
harness := NewDriverHarness(t, impl)
defer harness.Kill()
resp, _, err := harness.StartTask(cfg)
require.NoError(err)
require.Equal(cfg.ID, resp.Config.ID)
require.Equal(handle.State, resp.State)
must.NoError(t, err)
must.Eq(t, cfg.ID, resp.Config.ID)
must.Eq(t, handle.State, resp.State)
var actualState testDriverState
require.NoError(resp.GetDriverState(&actualState))
require.Equal(*state, actualState)
must.NoError(t, resp.GetDriverState(&actualState))
must.Eq(t, *state, actualState)
}
func TestBaseDriver_WaitTask(t *testing.T) {
ci.Parallel(t)
require := require.New(t)
result := &drivers.ExitResult{ExitCode: 1, Signal: 9}
@@ -194,20 +190,19 @@ func TestBaseDriver_WaitTask(t *testing.T) {
go func() {
defer wg.Done()
ch, err := harness.WaitTask(context.TODO(), "foo")
require.NoError(err)
must.NoError(t, err)
actualResult := <-ch
finished = true
require.Exactly(result, actualResult)
must.Eq(t, result, actualResult)
}()
require.False(finished)
must.False(t, finished)
close(signalTask)
wg.Wait()
require.True(finished)
must.True(t, finished)
}
func TestBaseDriver_TaskEvents(t *testing.T) {
ci.Parallel(t)
require := require.New(t)
now := time.Now().UTC().Truncate(time.Millisecond)
events := []*drivers.TaskEvent{
@@ -254,14 +249,14 @@ func TestBaseDriver_TaskEvents(t *testing.T) {
defer harness.Kill()
ch, err := harness.TaskEvents(context.Background())
require.NoError(err)
must.NoError(t, err)
for _, event := range events {
select {
case actual := <-ch:
require.Exactly(actual, event)
must.Eq(t, actual, event)
case <-time.After(500 * time.Millisecond):
require.Fail("failed to receive event")
t.Fatal("failed to receive event")
}
}
@@ -291,6 +286,6 @@ func TestBaseDriver_Capabilities(t *testing.T) {
defer harness.Kill()
caps, err := harness.Capabilities()
require.NoError(t, err)
require.Equal(t, capabilities, caps)
must.NoError(t, err)
must.Eq(t, capabilities, caps)
}

View File

@@ -9,8 +9,7 @@ import (
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/drivers/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/shoenig/test/must"
)
func TestResourceUsageRoundTrip(t *testing.T) {
@@ -36,8 +35,7 @@ func TestResourceUsageRoundTrip(t *testing.T) {
}
parsed := resourceUsageFromProto(resourceUsageToProto(input))
require.EqualValues(t, parsed, input)
must.Eq(t, parsed, input)
}
func TestTaskConfigRoundTrip(t *testing.T) {
@@ -109,8 +107,7 @@ func TestTaskConfigRoundTrip(t *testing.T) {
}
parsed := taskConfigFromProto(taskConfigToProto(input))
require.EqualValues(t, input, parsed)
must.Eq(t, input, parsed)
}
@@ -140,7 +137,7 @@ func Test_networkCreateRequestFromProto(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actualOutput := networkCreateRequestFromProto(tc.inputPB)
assert.Equal(t, tc.expectedOutput, actualOutput, tc.name)
must.Eq(t, tc.expectedOutput, actualOutput)
})
}
}

View File

@@ -8,7 +8,7 @@ import (
"testing"
"github.com/hashicorp/nomad/helper/pointer"
"github.com/stretchr/testify/require"
"github.com/shoenig/test/must"
)
func TestAttribute_Validate(t *testing.T) {
@@ -77,7 +77,7 @@ func TestAttribute_Validate(t *testing.T) {
for _, c := range cases {
t.Run(c.Input.GoString(), func(t *testing.T) {
if err := c.Input.Validate(); err != nil && !c.Fail {
require.NoError(t, err)
must.NoError(t, err)
}
})
}
@@ -538,7 +538,7 @@ func testComparison(t *testing.T, cases []*compareTestCase) {
if !ok && !c.NotComparable {
t.Fatal("should be comparable")
} else if ok {
require.Equal(t, c.Expected, v)
must.Eq(t, c.Expected, v)
}
})
}
@@ -662,8 +662,8 @@ func TestAttribute_ParseAndValidate(t *testing.T) {
for _, c := range cases {
t.Run(c.Input, func(t *testing.T) {
a := ParseAttribute(c.Input)
require.Equal(t, c.Expected, a)
require.NoError(t, a.Validate())
must.Eq(t, c.Expected, a)
must.NoError(t, a.Validate())
})
}
}