diff --git a/.changelog/24991.txt b/.changelog/24991.txt new file mode 100644 index 000000000..6e9a190a3 --- /dev/null +++ b/.changelog/24991.txt @@ -0,0 +1,3 @@ +```release-note:bug +docker: Fixed a bug that prevented image_pull_timeout from being applied +``` diff --git a/drivers/docker/coordinator.go b/drivers/docker/coordinator.go index 6f013775a..d0a1b38c0 100644 --- a/drivers/docker/coordinator.go +++ b/drivers/docker/coordinator.go @@ -43,9 +43,14 @@ func newPullFuture() *pullFuture { } } -// wait waits till the future has a result -func (p *pullFuture) wait() *pullFuture { - <-p.waitCh +// wait waits till the future has a result or the context is canceled +func (p *pullFuture) wait(ctx context.Context) *pullFuture { + select { + case <-ctx.Done(): + p.err = fmt.Errorf("wait aborted: %w", ctx.Err()) + case <-p.waitCh: + // all good + } return p } @@ -80,6 +85,7 @@ func noopLogEventFn(string, map[string]string) {} // dockerCoordinatorConfig is used to configure the Docker coordinator. type dockerCoordinatorConfig struct { + // ctx should be the driver context to handle shutdowns ctx context.Context // logger is the logger the coordinator should use @@ -153,10 +159,11 @@ func (d *dockerCoordinator) PullImage(image string, authOptions *registry.AuthCo d.pullFutures[image] = future go d.pullImageImpl(image, authOptions, pullTimeout, pullActivityTimeout, future) } + // We unlock while we wait since this can take a while d.imageLock.Unlock() - // We unlock while we wait since this can take a while - id, user, err := future.wait().result() + // passing driver context here to stop waiting at driver shutdown + id, user, err := future.wait(d.ctx).result() d.imageLock.Lock() defer d.imageLock.Unlock() @@ -182,7 +189,8 @@ func (d *dockerCoordinator) pullImageImpl(imageID string, authOptions *registry. defer d.clearPullLogger(imageID) // Parse the repo and tag repo, tag := parseDockerImage(imageID) - ctx, cancel := context.WithTimeout(context.Background(), pullTimeout) + + pullCtx, cancel := context.WithTimeout(d.ctx, pullTimeout) defer cancel() pm := newImageProgressManager(imageID, cancel, pullActivityTimeout, d.handlePullInactivity, @@ -196,11 +204,11 @@ func (d *dockerCoordinator) pullImageImpl(imageID string, authOptions *registry. } pullOptions := image.PullOptions{RegistryAuth: auth.Auth} - reader, err := d.client.ImagePull(d.ctx, dockerImageRef(repo, tag), pullOptions) + reader, err := d.client.ImagePull(pullCtx, dockerImageRef(repo, tag), pullOptions) - if errors.Is(ctx.Err(), context.DeadlineExceeded) { + if errors.Is(err, context.DeadlineExceeded) { d.logger.Error("timeout pulling container", "image_ref", dockerImageRef(repo, tag)) - future.set("", "", recoverablePullError(ctx.Err(), imageID)) + future.set("", "", recoverablePullError(err, imageID)) return } diff --git a/drivers/docker/coordinator_test.go b/drivers/docker/coordinator_test.go index 92f6843af..850d1d0e6 100644 --- a/drivers/docker/coordinator_test.go +++ b/drivers/docker/coordinator_test.go @@ -41,7 +41,11 @@ func newMockImageClient(idToName map[string]string, pullDelay time.Duration) *mo } func (m *mockImageClient) ImagePull(ctx context.Context, refStr string, opts image.PullOptions) (io.ReadCloser, error) { - time.Sleep(m.pullDelay) + select { + case <-ctx.Done(): + return nil, fmt.Errorf("mockImageClient.ImagePull aborted: %w", ctx.Err()) + case <-time.After(m.pullDelay): + } m.lock.Lock() defer m.lock.Unlock() m.pulled[refStr]++ @@ -361,11 +365,73 @@ func TestDockerCoordinator_PullImage_ProgressError(t *testing.T) { } coordinator := newDockerCoordinator(config) - // this error should get set() on the future by pullImageImpl(), - // then returned by PullImage() readErr := errors.New("a bad bad thing happened") mock.pullReader = &readErrorer{readErr: readErr} _, _, err := coordinator.PullImage("foo", nil, uuid.Generate(), nil, timeout, timeout) must.ErrorIs(t, err, readErr) } + +func TestDockerCoordinator_PullImage_Timeouts(t *testing.T) { + ci.Parallel(t) + + cases := []struct { + name string + driverTimeout time.Duration // used in driver context to simulate driver/agent shutdown + pullTimeout time.Duration // user provided `image_pull_timeout` + pullDelay time.Duration // mock delay - how long it "actually" takes to pull the image + expectErr string + }{ + { + name: "pull completes", + pullDelay: 10 * time.Millisecond, + pullTimeout: 200 * time.Millisecond, + driverTimeout: 400 * time.Millisecond, + expectErr: "", + }, + { + name: "pull timeout", + pullDelay: 400 * time.Millisecond, + pullTimeout: 10 * time.Millisecond, + driverTimeout: 200 * time.Millisecond, + expectErr: "mockImageClient.ImagePull aborted", + }, + { + name: "driver shutdown", + pullDelay: 400 * time.Millisecond, + pullTimeout: 200 * time.Millisecond, + driverTimeout: 10 * time.Millisecond, + expectErr: "wait aborted", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + driverCtx, cancel := context.WithTimeout(context.Background(), tc.driverTimeout) + defer cancel() + + mapping := map[string]string{"foo:v1": "foo"} + mock := newMockImageClient(mapping, tc.pullDelay) + config := &dockerCoordinatorConfig{ + ctx: driverCtx, + logger: testlog.HCLogger(t), + cleanup: true, + client: mock, + removeDelay: 1 * time.Millisecond, + } + coordinator := newDockerCoordinator(config) + progressTimeout := 10 * time.Millisecond // does not apply here + + id, _, err := coordinator.PullImage("foo:v1", nil, uuid.Generate(), nil, + tc.pullTimeout, progressTimeout) + + if tc.expectErr == "" { + must.NoError(t, err) + must.Eq(t, "foo", id) + } else { + must.ErrorIs(t, err, context.DeadlineExceeded) + must.ErrorContains(t, err, tc.expectErr) + } + }) + } +}