diff --git a/.changelog/24489.txt b/.changelog/24489.txt new file mode 100644 index 000000000..0357b81db --- /dev/null +++ b/.changelog/24489.txt @@ -0,0 +1,3 @@ +```release-note:security +security: Explicitly set 'Content-Type' header to mitigate XSS vulnerability +``` diff --git a/command/agent/http.go b/command/agent/http.go index 23d5cef15..3f4db49d6 100644 --- a/command/agent/http.go +++ b/command/agent/http.go @@ -58,6 +58,9 @@ const ( // MissingRequestID is a placeholder if we cannot retrieve a request // UUID from context MissingRequestID = "" + + contentTypeHeader = "Content-Type" + plainContentType = "text/plain; charset=utf-8" ) var ( @@ -743,6 +746,7 @@ func (s *HTTPServer) wrap(handler func(resp http.ResponseWriter, req *http.Reque } } + resp.Header().Set(contentTypeHeader, plainContentType) resp.WriteHeader(code) resp.Write([]byte(errMsg)) if isAPIClientError(code) { @@ -801,6 +805,7 @@ func (s *HTTPServer) wrapNonJSON(handler func(resp http.ResponseWriter, req *htt // Check for an error if err != nil { code, errMsg := errCodeFromHandler(err) + resp.Header().Set(contentTypeHeader, plainContentType) resp.WriteHeader(code) resp.Write([]byte(errMsg)) if isAPIClientError(code) { @@ -810,7 +815,6 @@ func (s *HTTPServer) wrapNonJSON(handler func(resp http.ResponseWriter, req *htt } return } - // write response if obj != nil { resp.Write(obj) @@ -884,6 +888,7 @@ func parseWait(resp http.ResponseWriter, req *http.Request, b *structs.QueryOpti if wait := query.Get("wait"); wait != "" { dur, err := time.ParseDuration(wait) if err != nil { + resp.Header().Set(contentTypeHeader, plainContentType) resp.WriteHeader(http.StatusBadRequest) resp.Write([]byte("Invalid wait time")) return true @@ -893,6 +898,7 @@ func parseWait(resp http.ResponseWriter, req *http.Request, b *structs.QueryOpti if idx := query.Get("index"); idx != "" { index, err := strconv.ParseUint(idx, 10, 64) if err != nil { + resp.Header().Set(contentTypeHeader, plainContentType) resp.WriteHeader(http.StatusBadRequest) resp.Write([]byte("Invalid index")) return true @@ -913,6 +919,7 @@ func parseConsistency(resp http.ResponseWriter, req *http.Request, b *structs.Qu staleQuery, err := strconv.ParseBool(staleVal[0]) if err != nil { errMsg := "Expect `true` or `false` for `stale` query string parameter" + resp.Header().Set(contentTypeHeader, plainContentType) resp.WriteHeader(http.StatusBadRequest) resp.Write([]byte(errMsg)) return CodedError(http.StatusBadRequest, errMsg) @@ -1037,6 +1044,7 @@ func parsePagination(resp http.ResponseWriter, req *http.Request, b *structs.Que perPage, err := strconv.ParseInt(rawPerPage, 10, 32) if err != nil { errMsg := "Expect a number for `per_page` query string parameter" + resp.Header().Set(contentTypeHeader, plainContentType) resp.WriteHeader(http.StatusBadRequest) resp.Write([]byte(errMsg)) return CodedError(http.StatusBadRequest, errMsg) @@ -1158,6 +1166,7 @@ func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request) reply := structs.ACLWhoAmIResponse{} if a.srv.parse(resp, req, &args.Region, &args.QueryOptions) { // Error parsing request, 400 + resp.Header().Set(contentTypeHeader, plainContentType) resp.WriteHeader(http.StatusBadRequest) resp.Write([]byte(http.StatusText(http.StatusBadRequest))) return @@ -1165,6 +1174,7 @@ func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request) if args.AuthToken == "" { // 401 instead of 403 since no token was present. + resp.Header().Set(contentTypeHeader, plainContentType) resp.WriteHeader(http.StatusUnauthorized) resp.Write([]byte(http.StatusText(http.StatusUnauthorized))) return @@ -1175,12 +1185,14 @@ func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request) // credentials, so convert it to a Forbidden response code. if strings.HasSuffix(err.Error(), structs.ErrPermissionDenied.Error()) { a.srv.logger.Debug("Failed to authenticated Task API request", "method", req.Method, "url", req.URL) + resp.Header().Set(contentTypeHeader, plainContentType) resp.WriteHeader(http.StatusForbidden) resp.Write([]byte(http.StatusText(http.StatusForbidden))) return } a.srv.logger.Error("error authenticating built API request", "error", err, "url", req.URL, "method", req.Method) + resp.Header().Set(contentTypeHeader, plainContentType) resp.WriteHeader(http.StatusInternalServerError) resp.Write([]byte("Server error authenticating request\n")) return @@ -1189,6 +1201,7 @@ func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request) // Require an acl token or workload identity if reply.Identity == nil || (reply.Identity.ACLToken == nil && reply.Identity.Claims == nil) { a.srv.logger.Debug("Failed to authenticated Task API request", "method", req.Method, "url", req.URL) + resp.Header().Set(contentTypeHeader, plainContentType) resp.WriteHeader(http.StatusForbidden) resp.Write([]byte(http.StatusText(http.StatusForbidden))) return diff --git a/command/agent/http_test.go b/command/agent/http_test.go index a79fed7e5..7ecc12086 100644 --- a/command/agent/http_test.go +++ b/command/agent/http_test.go @@ -275,7 +275,8 @@ func TestWrapNonJSON(t *testing.T) { s.Server.wrapNonJSON(handler)(resp, req) respBody, _ := io.ReadAll(resp.Body) - require.Equal(t, respBody, []byte("test response")) + must.Eq(t, respBody, []byte("test response")) + must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType) } @@ -298,8 +299,9 @@ func TestWrapNonJSON_Error(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "/v1/kv/key", nil) s.Server.wrapNonJSON(handlerRPCErr)(resp, req) respBody, _ := io.ReadAll(resp.Body) - require.Equal(t, []byte("not found"), respBody) - require.Equal(t, 404, resp.Code) + must.Eq(t, []byte("not found"), respBody) + must.Eq(t, 404, resp.Code) + must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType) } // CodedError @@ -308,8 +310,9 @@ func TestWrapNonJSON_Error(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "/v1/kv/key", nil) s.Server.wrapNonJSON(handlerCodedErr)(resp, req) respBody, _ := io.ReadAll(resp.Body) - require.Equal(t, []byte("unprocessable"), respBody) - require.Equal(t, 422, resp.Code) + must.Eq(t, []byte("unprocessable"), respBody) + must.Eq(t, 422, resp.Code) + must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType) } } @@ -381,7 +384,8 @@ func TestPermissionDenied(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "/v1/job/foo", nil) s.Server.wrap(handler)(resp, req) - assert.Equal(t, resp.Code, 403) + must.Eq(t, resp.Code, 403) + must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType) } // When remote RPC is used the errors have "rpc error: " prependend @@ -393,7 +397,8 @@ func TestPermissionDenied(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "/v1/job/foo", nil) s.Server.wrap(handler)(resp, req) - assert.Equal(t, resp.Code, 403) + must.Eq(t, resp.Code, 403) + must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType) } } @@ -411,7 +416,8 @@ func TestTokenNotFound(t *testing.T) { urlStr := "/v1/job/foo" req, _ := http.NewRequest(http.MethodGet, urlStr, nil) s.Server.wrap(handler)(resp, req) - assert.Equal(t, resp.Code, 403) + must.Eq(t, resp.Code, 403) + must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType) } func TestParseWait(t *testing.T) { @@ -421,20 +427,11 @@ func TestParseWait(t *testing.T) { req, err := http.NewRequest(http.MethodGet, "/v1/catalog/nodes?wait=60s&index=1000", nil) - if err != nil { - t.Fatalf("err: %v", err) - } + must.NoError(t, err) - if d := parseWait(resp, req, &b); d { - t.Fatalf("unexpected done") - } - - if b.MinQueryIndex != 1000 { - t.Fatalf("Bad: %v", b) - } - if b.MaxQueryTime != 60*time.Second { - t.Fatalf("Bad: %v", b) - } + must.False(t, parseWait(resp, req, &b)) + must.Eq(t, b.MinQueryIndex, 1000) + must.Eq(t, b.MaxQueryTime, 60*time.Second) } func TestParseWait_InvalidTime(t *testing.T) { @@ -444,17 +441,11 @@ func TestParseWait_InvalidTime(t *testing.T) { req, err := http.NewRequest(http.MethodGet, "/v1/catalog/nodes?wait=60foo&index=1000", nil) - if err != nil { - t.Fatalf("err: %v", err) - } + must.NoError(t, err) - if d := parseWait(resp, req, &b); !d { - t.Fatalf("expected done") - } - - if resp.Code != 400 { - t.Fatalf("bad code: %v", resp.Code) - } + must.True(t, parseWait(resp, req, &b)) + must.Eq(t, resp.Code, 400) + must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType) } func TestParseWait_InvalidIndex(t *testing.T) { @@ -464,17 +455,11 @@ func TestParseWait_InvalidIndex(t *testing.T) { req, err := http.NewRequest(http.MethodGet, "/v1/catalog/nodes?wait=60s&index=foo", nil) - if err != nil { - t.Fatalf("err: %v", err) - } + must.NoError(t, err) - if d := parseWait(resp, req, &b); !d { - t.Fatalf("expected done") - } - - if resp.Code != 400 { - t.Fatalf("bad code: %v", resp.Code) - } + must.True(t, parseWait(resp, req, &b)) + must.Eq(t, resp.Code, 400) + must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType) } func TestParseConsistency(t *testing.T) {