agent: set content type header explicitly (#24489)

This PR addresses an XSS vulnerability where Nomad agents wouldn't explicitly
set content type headers for error responses.
This commit is contained in:
Piotr Kazmierczak
2024-11-20 10:18:30 +01:00
committed by GitHub
parent 11bba3dbcd
commit 9c5078f151
3 changed files with 43 additions and 42 deletions

3
.changelog/24489.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:security
security: Explicitly set 'Content-Type' header to mitigate XSS vulnerability
```

View File

@@ -58,6 +58,9 @@ const (
// MissingRequestID is a placeholder if we cannot retrieve a request
// UUID from context
MissingRequestID = "<missing request id>"
contentTypeHeader = "Content-Type"
plainContentType = "text/plain; charset=utf-8"
)
var (
@@ -743,6 +746,7 @@ func (s *HTTPServer) wrap(handler func(resp http.ResponseWriter, req *http.Reque
}
}
resp.Header().Set(contentTypeHeader, plainContentType)
resp.WriteHeader(code)
resp.Write([]byte(errMsg))
if isAPIClientError(code) {
@@ -801,6 +805,7 @@ func (s *HTTPServer) wrapNonJSON(handler func(resp http.ResponseWriter, req *htt
// Check for an error
if err != nil {
code, errMsg := errCodeFromHandler(err)
resp.Header().Set(contentTypeHeader, plainContentType)
resp.WriteHeader(code)
resp.Write([]byte(errMsg))
if isAPIClientError(code) {
@@ -810,7 +815,6 @@ func (s *HTTPServer) wrapNonJSON(handler func(resp http.ResponseWriter, req *htt
}
return
}
// write response
if obj != nil {
resp.Write(obj)
@@ -884,6 +888,7 @@ func parseWait(resp http.ResponseWriter, req *http.Request, b *structs.QueryOpti
if wait := query.Get("wait"); wait != "" {
dur, err := time.ParseDuration(wait)
if err != nil {
resp.Header().Set(contentTypeHeader, plainContentType)
resp.WriteHeader(http.StatusBadRequest)
resp.Write([]byte("Invalid wait time"))
return true
@@ -893,6 +898,7 @@ func parseWait(resp http.ResponseWriter, req *http.Request, b *structs.QueryOpti
if idx := query.Get("index"); idx != "" {
index, err := strconv.ParseUint(idx, 10, 64)
if err != nil {
resp.Header().Set(contentTypeHeader, plainContentType)
resp.WriteHeader(http.StatusBadRequest)
resp.Write([]byte("Invalid index"))
return true
@@ -913,6 +919,7 @@ func parseConsistency(resp http.ResponseWriter, req *http.Request, b *structs.Qu
staleQuery, err := strconv.ParseBool(staleVal[0])
if err != nil {
errMsg := "Expect `true` or `false` for `stale` query string parameter"
resp.Header().Set(contentTypeHeader, plainContentType)
resp.WriteHeader(http.StatusBadRequest)
resp.Write([]byte(errMsg))
return CodedError(http.StatusBadRequest, errMsg)
@@ -1037,6 +1044,7 @@ func parsePagination(resp http.ResponseWriter, req *http.Request, b *structs.Que
perPage, err := strconv.ParseInt(rawPerPage, 10, 32)
if err != nil {
errMsg := "Expect a number for `per_page` query string parameter"
resp.Header().Set(contentTypeHeader, plainContentType)
resp.WriteHeader(http.StatusBadRequest)
resp.Write([]byte(errMsg))
return CodedError(http.StatusBadRequest, errMsg)
@@ -1158,6 +1166,7 @@ func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request)
reply := structs.ACLWhoAmIResponse{}
if a.srv.parse(resp, req, &args.Region, &args.QueryOptions) {
// Error parsing request, 400
resp.Header().Set(contentTypeHeader, plainContentType)
resp.WriteHeader(http.StatusBadRequest)
resp.Write([]byte(http.StatusText(http.StatusBadRequest)))
return
@@ -1165,6 +1174,7 @@ func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request)
if args.AuthToken == "" {
// 401 instead of 403 since no token was present.
resp.Header().Set(contentTypeHeader, plainContentType)
resp.WriteHeader(http.StatusUnauthorized)
resp.Write([]byte(http.StatusText(http.StatusUnauthorized)))
return
@@ -1175,12 +1185,14 @@ func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request)
// credentials, so convert it to a Forbidden response code.
if strings.HasSuffix(err.Error(), structs.ErrPermissionDenied.Error()) {
a.srv.logger.Debug("Failed to authenticated Task API request", "method", req.Method, "url", req.URL)
resp.Header().Set(contentTypeHeader, plainContentType)
resp.WriteHeader(http.StatusForbidden)
resp.Write([]byte(http.StatusText(http.StatusForbidden)))
return
}
a.srv.logger.Error("error authenticating built API request", "error", err, "url", req.URL, "method", req.Method)
resp.Header().Set(contentTypeHeader, plainContentType)
resp.WriteHeader(http.StatusInternalServerError)
resp.Write([]byte("Server error authenticating request\n"))
return
@@ -1189,6 +1201,7 @@ func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request)
// Require an acl token or workload identity
if reply.Identity == nil || (reply.Identity.ACLToken == nil && reply.Identity.Claims == nil) {
a.srv.logger.Debug("Failed to authenticated Task API request", "method", req.Method, "url", req.URL)
resp.Header().Set(contentTypeHeader, plainContentType)
resp.WriteHeader(http.StatusForbidden)
resp.Write([]byte(http.StatusText(http.StatusForbidden)))
return

View File

@@ -275,7 +275,8 @@ func TestWrapNonJSON(t *testing.T) {
s.Server.wrapNonJSON(handler)(resp, req)
respBody, _ := io.ReadAll(resp.Body)
require.Equal(t, respBody, []byte("test response"))
must.Eq(t, respBody, []byte("test response"))
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
}
@@ -298,8 +299,9 @@ func TestWrapNonJSON_Error(t *testing.T) {
req, _ := http.NewRequest(http.MethodGet, "/v1/kv/key", nil)
s.Server.wrapNonJSON(handlerRPCErr)(resp, req)
respBody, _ := io.ReadAll(resp.Body)
require.Equal(t, []byte("not found"), respBody)
require.Equal(t, 404, resp.Code)
must.Eq(t, []byte("not found"), respBody)
must.Eq(t, 404, resp.Code)
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
}
// CodedError
@@ -308,8 +310,9 @@ func TestWrapNonJSON_Error(t *testing.T) {
req, _ := http.NewRequest(http.MethodGet, "/v1/kv/key", nil)
s.Server.wrapNonJSON(handlerCodedErr)(resp, req)
respBody, _ := io.ReadAll(resp.Body)
require.Equal(t, []byte("unprocessable"), respBody)
require.Equal(t, 422, resp.Code)
must.Eq(t, []byte("unprocessable"), respBody)
must.Eq(t, 422, resp.Code)
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
}
}
@@ -381,7 +384,8 @@ func TestPermissionDenied(t *testing.T) {
req, _ := http.NewRequest(http.MethodGet, "/v1/job/foo", nil)
s.Server.wrap(handler)(resp, req)
assert.Equal(t, resp.Code, 403)
must.Eq(t, resp.Code, 403)
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
}
// When remote RPC is used the errors have "rpc error: " prependend
@@ -393,7 +397,8 @@ func TestPermissionDenied(t *testing.T) {
req, _ := http.NewRequest(http.MethodGet, "/v1/job/foo", nil)
s.Server.wrap(handler)(resp, req)
assert.Equal(t, resp.Code, 403)
must.Eq(t, resp.Code, 403)
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
}
}
@@ -411,7 +416,8 @@ func TestTokenNotFound(t *testing.T) {
urlStr := "/v1/job/foo"
req, _ := http.NewRequest(http.MethodGet, urlStr, nil)
s.Server.wrap(handler)(resp, req)
assert.Equal(t, resp.Code, 403)
must.Eq(t, resp.Code, 403)
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
}
func TestParseWait(t *testing.T) {
@@ -421,20 +427,11 @@ func TestParseWait(t *testing.T) {
req, err := http.NewRequest(http.MethodGet,
"/v1/catalog/nodes?wait=60s&index=1000", nil)
if err != nil {
t.Fatalf("err: %v", err)
}
must.NoError(t, err)
if d := parseWait(resp, req, &b); d {
t.Fatalf("unexpected done")
}
if b.MinQueryIndex != 1000 {
t.Fatalf("Bad: %v", b)
}
if b.MaxQueryTime != 60*time.Second {
t.Fatalf("Bad: %v", b)
}
must.False(t, parseWait(resp, req, &b))
must.Eq(t, b.MinQueryIndex, 1000)
must.Eq(t, b.MaxQueryTime, 60*time.Second)
}
func TestParseWait_InvalidTime(t *testing.T) {
@@ -444,17 +441,11 @@ func TestParseWait_InvalidTime(t *testing.T) {
req, err := http.NewRequest(http.MethodGet,
"/v1/catalog/nodes?wait=60foo&index=1000", nil)
if err != nil {
t.Fatalf("err: %v", err)
}
must.NoError(t, err)
if d := parseWait(resp, req, &b); !d {
t.Fatalf("expected done")
}
if resp.Code != 400 {
t.Fatalf("bad code: %v", resp.Code)
}
must.True(t, parseWait(resp, req, &b))
must.Eq(t, resp.Code, 400)
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
}
func TestParseWait_InvalidIndex(t *testing.T) {
@@ -464,17 +455,11 @@ func TestParseWait_InvalidIndex(t *testing.T) {
req, err := http.NewRequest(http.MethodGet,
"/v1/catalog/nodes?wait=60s&index=foo", nil)
if err != nil {
t.Fatalf("err: %v", err)
}
must.NoError(t, err)
if d := parseWait(resp, req, &b); !d {
t.Fatalf("expected done")
}
if resp.Code != 400 {
t.Fatalf("bad code: %v", resp.Code)
}
must.True(t, parseWait(resp, req, &b))
must.Eq(t, resp.Code, 400)
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
}
func TestParseConsistency(t *testing.T) {