mirror of
https://github.com/kemko/nomad.git
synced 2026-01-01 16:05:42 +03:00
agent: set content type header explicitly (#24489)
This PR addresses an XSS vulnerability where Nomad agents wouldn't explicitly set content type headers for error responses.
This commit is contained in:
committed by
GitHub
parent
11bba3dbcd
commit
9c5078f151
3
.changelog/24489.txt
Normal file
3
.changelog/24489.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
```release-note:security
|
||||
security: Explicitly set 'Content-Type' header to mitigate XSS vulnerability
|
||||
```
|
||||
@@ -58,6 +58,9 @@ const (
|
||||
// MissingRequestID is a placeholder if we cannot retrieve a request
|
||||
// UUID from context
|
||||
MissingRequestID = "<missing request id>"
|
||||
|
||||
contentTypeHeader = "Content-Type"
|
||||
plainContentType = "text/plain; charset=utf-8"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -743,6 +746,7 @@ func (s *HTTPServer) wrap(handler func(resp http.ResponseWriter, req *http.Reque
|
||||
}
|
||||
}
|
||||
|
||||
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||
resp.WriteHeader(code)
|
||||
resp.Write([]byte(errMsg))
|
||||
if isAPIClientError(code) {
|
||||
@@ -801,6 +805,7 @@ func (s *HTTPServer) wrapNonJSON(handler func(resp http.ResponseWriter, req *htt
|
||||
// Check for an error
|
||||
if err != nil {
|
||||
code, errMsg := errCodeFromHandler(err)
|
||||
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||
resp.WriteHeader(code)
|
||||
resp.Write([]byte(errMsg))
|
||||
if isAPIClientError(code) {
|
||||
@@ -810,7 +815,6 @@ func (s *HTTPServer) wrapNonJSON(handler func(resp http.ResponseWriter, req *htt
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// write response
|
||||
if obj != nil {
|
||||
resp.Write(obj)
|
||||
@@ -884,6 +888,7 @@ func parseWait(resp http.ResponseWriter, req *http.Request, b *structs.QueryOpti
|
||||
if wait := query.Get("wait"); wait != "" {
|
||||
dur, err := time.ParseDuration(wait)
|
||||
if err != nil {
|
||||
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
resp.Write([]byte("Invalid wait time"))
|
||||
return true
|
||||
@@ -893,6 +898,7 @@ func parseWait(resp http.ResponseWriter, req *http.Request, b *structs.QueryOpti
|
||||
if idx := query.Get("index"); idx != "" {
|
||||
index, err := strconv.ParseUint(idx, 10, 64)
|
||||
if err != nil {
|
||||
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
resp.Write([]byte("Invalid index"))
|
||||
return true
|
||||
@@ -913,6 +919,7 @@ func parseConsistency(resp http.ResponseWriter, req *http.Request, b *structs.Qu
|
||||
staleQuery, err := strconv.ParseBool(staleVal[0])
|
||||
if err != nil {
|
||||
errMsg := "Expect `true` or `false` for `stale` query string parameter"
|
||||
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
resp.Write([]byte(errMsg))
|
||||
return CodedError(http.StatusBadRequest, errMsg)
|
||||
@@ -1037,6 +1044,7 @@ func parsePagination(resp http.ResponseWriter, req *http.Request, b *structs.Que
|
||||
perPage, err := strconv.ParseInt(rawPerPage, 10, 32)
|
||||
if err != nil {
|
||||
errMsg := "Expect a number for `per_page` query string parameter"
|
||||
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
resp.Write([]byte(errMsg))
|
||||
return CodedError(http.StatusBadRequest, errMsg)
|
||||
@@ -1158,6 +1166,7 @@ func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request)
|
||||
reply := structs.ACLWhoAmIResponse{}
|
||||
if a.srv.parse(resp, req, &args.Region, &args.QueryOptions) {
|
||||
// Error parsing request, 400
|
||||
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
resp.Write([]byte(http.StatusText(http.StatusBadRequest)))
|
||||
return
|
||||
@@ -1165,6 +1174,7 @@ func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request)
|
||||
|
||||
if args.AuthToken == "" {
|
||||
// 401 instead of 403 since no token was present.
|
||||
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||
resp.WriteHeader(http.StatusUnauthorized)
|
||||
resp.Write([]byte(http.StatusText(http.StatusUnauthorized)))
|
||||
return
|
||||
@@ -1175,12 +1185,14 @@ func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request)
|
||||
// credentials, so convert it to a Forbidden response code.
|
||||
if strings.HasSuffix(err.Error(), structs.ErrPermissionDenied.Error()) {
|
||||
a.srv.logger.Debug("Failed to authenticated Task API request", "method", req.Method, "url", req.URL)
|
||||
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||
resp.WriteHeader(http.StatusForbidden)
|
||||
resp.Write([]byte(http.StatusText(http.StatusForbidden)))
|
||||
return
|
||||
}
|
||||
|
||||
a.srv.logger.Error("error authenticating built API request", "error", err, "url", req.URL, "method", req.Method)
|
||||
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||
resp.WriteHeader(http.StatusInternalServerError)
|
||||
resp.Write([]byte("Server error authenticating request\n"))
|
||||
return
|
||||
@@ -1189,6 +1201,7 @@ func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request)
|
||||
// Require an acl token or workload identity
|
||||
if reply.Identity == nil || (reply.Identity.ACLToken == nil && reply.Identity.Claims == nil) {
|
||||
a.srv.logger.Debug("Failed to authenticated Task API request", "method", req.Method, "url", req.URL)
|
||||
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||
resp.WriteHeader(http.StatusForbidden)
|
||||
resp.Write([]byte(http.StatusText(http.StatusForbidden)))
|
||||
return
|
||||
|
||||
@@ -275,7 +275,8 @@ func TestWrapNonJSON(t *testing.T) {
|
||||
s.Server.wrapNonJSON(handler)(resp, req)
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
require.Equal(t, respBody, []byte("test response"))
|
||||
must.Eq(t, respBody, []byte("test response"))
|
||||
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
|
||||
|
||||
}
|
||||
|
||||
@@ -298,8 +299,9 @@ func TestWrapNonJSON_Error(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodGet, "/v1/kv/key", nil)
|
||||
s.Server.wrapNonJSON(handlerRPCErr)(resp, req)
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
require.Equal(t, []byte("not found"), respBody)
|
||||
require.Equal(t, 404, resp.Code)
|
||||
must.Eq(t, []byte("not found"), respBody)
|
||||
must.Eq(t, 404, resp.Code)
|
||||
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
|
||||
}
|
||||
|
||||
// CodedError
|
||||
@@ -308,8 +310,9 @@ func TestWrapNonJSON_Error(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodGet, "/v1/kv/key", nil)
|
||||
s.Server.wrapNonJSON(handlerCodedErr)(resp, req)
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
require.Equal(t, []byte("unprocessable"), respBody)
|
||||
require.Equal(t, 422, resp.Code)
|
||||
must.Eq(t, []byte("unprocessable"), respBody)
|
||||
must.Eq(t, 422, resp.Code)
|
||||
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
|
||||
}
|
||||
|
||||
}
|
||||
@@ -381,7 +384,8 @@ func TestPermissionDenied(t *testing.T) {
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/v1/job/foo", nil)
|
||||
s.Server.wrap(handler)(resp, req)
|
||||
assert.Equal(t, resp.Code, 403)
|
||||
must.Eq(t, resp.Code, 403)
|
||||
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
|
||||
}
|
||||
|
||||
// When remote RPC is used the errors have "rpc error: " prependend
|
||||
@@ -393,7 +397,8 @@ func TestPermissionDenied(t *testing.T) {
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/v1/job/foo", nil)
|
||||
s.Server.wrap(handler)(resp, req)
|
||||
assert.Equal(t, resp.Code, 403)
|
||||
must.Eq(t, resp.Code, 403)
|
||||
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -411,7 +416,8 @@ func TestTokenNotFound(t *testing.T) {
|
||||
urlStr := "/v1/job/foo"
|
||||
req, _ := http.NewRequest(http.MethodGet, urlStr, nil)
|
||||
s.Server.wrap(handler)(resp, req)
|
||||
assert.Equal(t, resp.Code, 403)
|
||||
must.Eq(t, resp.Code, 403)
|
||||
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
|
||||
}
|
||||
|
||||
func TestParseWait(t *testing.T) {
|
||||
@@ -421,20 +427,11 @@ func TestParseWait(t *testing.T) {
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet,
|
||||
"/v1/catalog/nodes?wait=60s&index=1000", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
must.NoError(t, err)
|
||||
|
||||
if d := parseWait(resp, req, &b); d {
|
||||
t.Fatalf("unexpected done")
|
||||
}
|
||||
|
||||
if b.MinQueryIndex != 1000 {
|
||||
t.Fatalf("Bad: %v", b)
|
||||
}
|
||||
if b.MaxQueryTime != 60*time.Second {
|
||||
t.Fatalf("Bad: %v", b)
|
||||
}
|
||||
must.False(t, parseWait(resp, req, &b))
|
||||
must.Eq(t, b.MinQueryIndex, 1000)
|
||||
must.Eq(t, b.MaxQueryTime, 60*time.Second)
|
||||
}
|
||||
|
||||
func TestParseWait_InvalidTime(t *testing.T) {
|
||||
@@ -444,17 +441,11 @@ func TestParseWait_InvalidTime(t *testing.T) {
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet,
|
||||
"/v1/catalog/nodes?wait=60foo&index=1000", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
must.NoError(t, err)
|
||||
|
||||
if d := parseWait(resp, req, &b); !d {
|
||||
t.Fatalf("expected done")
|
||||
}
|
||||
|
||||
if resp.Code != 400 {
|
||||
t.Fatalf("bad code: %v", resp.Code)
|
||||
}
|
||||
must.True(t, parseWait(resp, req, &b))
|
||||
must.Eq(t, resp.Code, 400)
|
||||
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
|
||||
}
|
||||
|
||||
func TestParseWait_InvalidIndex(t *testing.T) {
|
||||
@@ -464,17 +455,11 @@ func TestParseWait_InvalidIndex(t *testing.T) {
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet,
|
||||
"/v1/catalog/nodes?wait=60s&index=foo", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
must.NoError(t, err)
|
||||
|
||||
if d := parseWait(resp, req, &b); !d {
|
||||
t.Fatalf("expected done")
|
||||
}
|
||||
|
||||
if resp.Code != 400 {
|
||||
t.Fatalf("bad code: %v", resp.Code)
|
||||
}
|
||||
must.True(t, parseWait(resp, req, &b))
|
||||
must.Eq(t, resp.Code, 400)
|
||||
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
|
||||
}
|
||||
|
||||
func TestParseConsistency(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user