mirror of
https://github.com/kemko/nomad.git
synced 2026-01-01 16:05:42 +03:00
agent: set content type header explicitly (#24489)
This PR addresses an XSS vulnerability where Nomad agents wouldn't explicitly set content type headers for error responses.
This commit is contained in:
committed by
GitHub
parent
11bba3dbcd
commit
9c5078f151
3
.changelog/24489.txt
Normal file
3
.changelog/24489.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
```release-note:security
|
||||||
|
security: Explicitly set 'Content-Type' header to mitigate XSS vulnerability
|
||||||
|
```
|
||||||
@@ -58,6 +58,9 @@ const (
|
|||||||
// MissingRequestID is a placeholder if we cannot retrieve a request
|
// MissingRequestID is a placeholder if we cannot retrieve a request
|
||||||
// UUID from context
|
// UUID from context
|
||||||
MissingRequestID = "<missing request id>"
|
MissingRequestID = "<missing request id>"
|
||||||
|
|
||||||
|
contentTypeHeader = "Content-Type"
|
||||||
|
plainContentType = "text/plain; charset=utf-8"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -743,6 +746,7 @@ func (s *HTTPServer) wrap(handler func(resp http.ResponseWriter, req *http.Reque
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||||
resp.WriteHeader(code)
|
resp.WriteHeader(code)
|
||||||
resp.Write([]byte(errMsg))
|
resp.Write([]byte(errMsg))
|
||||||
if isAPIClientError(code) {
|
if isAPIClientError(code) {
|
||||||
@@ -801,6 +805,7 @@ func (s *HTTPServer) wrapNonJSON(handler func(resp http.ResponseWriter, req *htt
|
|||||||
// Check for an error
|
// Check for an error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
code, errMsg := errCodeFromHandler(err)
|
code, errMsg := errCodeFromHandler(err)
|
||||||
|
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||||
resp.WriteHeader(code)
|
resp.WriteHeader(code)
|
||||||
resp.Write([]byte(errMsg))
|
resp.Write([]byte(errMsg))
|
||||||
if isAPIClientError(code) {
|
if isAPIClientError(code) {
|
||||||
@@ -810,7 +815,6 @@ func (s *HTTPServer) wrapNonJSON(handler func(resp http.ResponseWriter, req *htt
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// write response
|
// write response
|
||||||
if obj != nil {
|
if obj != nil {
|
||||||
resp.Write(obj)
|
resp.Write(obj)
|
||||||
@@ -884,6 +888,7 @@ func parseWait(resp http.ResponseWriter, req *http.Request, b *structs.QueryOpti
|
|||||||
if wait := query.Get("wait"); wait != "" {
|
if wait := query.Get("wait"); wait != "" {
|
||||||
dur, err := time.ParseDuration(wait)
|
dur, err := time.ParseDuration(wait)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||||
resp.WriteHeader(http.StatusBadRequest)
|
resp.WriteHeader(http.StatusBadRequest)
|
||||||
resp.Write([]byte("Invalid wait time"))
|
resp.Write([]byte("Invalid wait time"))
|
||||||
return true
|
return true
|
||||||
@@ -893,6 +898,7 @@ func parseWait(resp http.ResponseWriter, req *http.Request, b *structs.QueryOpti
|
|||||||
if idx := query.Get("index"); idx != "" {
|
if idx := query.Get("index"); idx != "" {
|
||||||
index, err := strconv.ParseUint(idx, 10, 64)
|
index, err := strconv.ParseUint(idx, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||||
resp.WriteHeader(http.StatusBadRequest)
|
resp.WriteHeader(http.StatusBadRequest)
|
||||||
resp.Write([]byte("Invalid index"))
|
resp.Write([]byte("Invalid index"))
|
||||||
return true
|
return true
|
||||||
@@ -913,6 +919,7 @@ func parseConsistency(resp http.ResponseWriter, req *http.Request, b *structs.Qu
|
|||||||
staleQuery, err := strconv.ParseBool(staleVal[0])
|
staleQuery, err := strconv.ParseBool(staleVal[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := "Expect `true` or `false` for `stale` query string parameter"
|
errMsg := "Expect `true` or `false` for `stale` query string parameter"
|
||||||
|
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||||
resp.WriteHeader(http.StatusBadRequest)
|
resp.WriteHeader(http.StatusBadRequest)
|
||||||
resp.Write([]byte(errMsg))
|
resp.Write([]byte(errMsg))
|
||||||
return CodedError(http.StatusBadRequest, errMsg)
|
return CodedError(http.StatusBadRequest, errMsg)
|
||||||
@@ -1037,6 +1044,7 @@ func parsePagination(resp http.ResponseWriter, req *http.Request, b *structs.Que
|
|||||||
perPage, err := strconv.ParseInt(rawPerPage, 10, 32)
|
perPage, err := strconv.ParseInt(rawPerPage, 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := "Expect a number for `per_page` query string parameter"
|
errMsg := "Expect a number for `per_page` query string parameter"
|
||||||
|
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||||
resp.WriteHeader(http.StatusBadRequest)
|
resp.WriteHeader(http.StatusBadRequest)
|
||||||
resp.Write([]byte(errMsg))
|
resp.Write([]byte(errMsg))
|
||||||
return CodedError(http.StatusBadRequest, errMsg)
|
return CodedError(http.StatusBadRequest, errMsg)
|
||||||
@@ -1158,6 +1166,7 @@ func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request)
|
|||||||
reply := structs.ACLWhoAmIResponse{}
|
reply := structs.ACLWhoAmIResponse{}
|
||||||
if a.srv.parse(resp, req, &args.Region, &args.QueryOptions) {
|
if a.srv.parse(resp, req, &args.Region, &args.QueryOptions) {
|
||||||
// Error parsing request, 400
|
// Error parsing request, 400
|
||||||
|
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||||
resp.WriteHeader(http.StatusBadRequest)
|
resp.WriteHeader(http.StatusBadRequest)
|
||||||
resp.Write([]byte(http.StatusText(http.StatusBadRequest)))
|
resp.Write([]byte(http.StatusText(http.StatusBadRequest)))
|
||||||
return
|
return
|
||||||
@@ -1165,6 +1174,7 @@ func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request)
|
|||||||
|
|
||||||
if args.AuthToken == "" {
|
if args.AuthToken == "" {
|
||||||
// 401 instead of 403 since no token was present.
|
// 401 instead of 403 since no token was present.
|
||||||
|
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||||
resp.WriteHeader(http.StatusUnauthorized)
|
resp.WriteHeader(http.StatusUnauthorized)
|
||||||
resp.Write([]byte(http.StatusText(http.StatusUnauthorized)))
|
resp.Write([]byte(http.StatusText(http.StatusUnauthorized)))
|
||||||
return
|
return
|
||||||
@@ -1175,12 +1185,14 @@ func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request)
|
|||||||
// credentials, so convert it to a Forbidden response code.
|
// credentials, so convert it to a Forbidden response code.
|
||||||
if strings.HasSuffix(err.Error(), structs.ErrPermissionDenied.Error()) {
|
if strings.HasSuffix(err.Error(), structs.ErrPermissionDenied.Error()) {
|
||||||
a.srv.logger.Debug("Failed to authenticated Task API request", "method", req.Method, "url", req.URL)
|
a.srv.logger.Debug("Failed to authenticated Task API request", "method", req.Method, "url", req.URL)
|
||||||
|
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||||
resp.WriteHeader(http.StatusForbidden)
|
resp.WriteHeader(http.StatusForbidden)
|
||||||
resp.Write([]byte(http.StatusText(http.StatusForbidden)))
|
resp.Write([]byte(http.StatusText(http.StatusForbidden)))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
a.srv.logger.Error("error authenticating built API request", "error", err, "url", req.URL, "method", req.Method)
|
a.srv.logger.Error("error authenticating built API request", "error", err, "url", req.URL, "method", req.Method)
|
||||||
|
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||||
resp.WriteHeader(http.StatusInternalServerError)
|
resp.WriteHeader(http.StatusInternalServerError)
|
||||||
resp.Write([]byte("Server error authenticating request\n"))
|
resp.Write([]byte("Server error authenticating request\n"))
|
||||||
return
|
return
|
||||||
@@ -1189,6 +1201,7 @@ func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request)
|
|||||||
// Require an acl token or workload identity
|
// Require an acl token or workload identity
|
||||||
if reply.Identity == nil || (reply.Identity.ACLToken == nil && reply.Identity.Claims == nil) {
|
if reply.Identity == nil || (reply.Identity.ACLToken == nil && reply.Identity.Claims == nil) {
|
||||||
a.srv.logger.Debug("Failed to authenticated Task API request", "method", req.Method, "url", req.URL)
|
a.srv.logger.Debug("Failed to authenticated Task API request", "method", req.Method, "url", req.URL)
|
||||||
|
resp.Header().Set(contentTypeHeader, plainContentType)
|
||||||
resp.WriteHeader(http.StatusForbidden)
|
resp.WriteHeader(http.StatusForbidden)
|
||||||
resp.Write([]byte(http.StatusText(http.StatusForbidden)))
|
resp.Write([]byte(http.StatusText(http.StatusForbidden)))
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -275,7 +275,8 @@ func TestWrapNonJSON(t *testing.T) {
|
|||||||
s.Server.wrapNonJSON(handler)(resp, req)
|
s.Server.wrapNonJSON(handler)(resp, req)
|
||||||
|
|
||||||
respBody, _ := io.ReadAll(resp.Body)
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
require.Equal(t, respBody, []byte("test response"))
|
must.Eq(t, respBody, []byte("test response"))
|
||||||
|
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -298,8 +299,9 @@ func TestWrapNonJSON_Error(t *testing.T) {
|
|||||||
req, _ := http.NewRequest(http.MethodGet, "/v1/kv/key", nil)
|
req, _ := http.NewRequest(http.MethodGet, "/v1/kv/key", nil)
|
||||||
s.Server.wrapNonJSON(handlerRPCErr)(resp, req)
|
s.Server.wrapNonJSON(handlerRPCErr)(resp, req)
|
||||||
respBody, _ := io.ReadAll(resp.Body)
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
require.Equal(t, []byte("not found"), respBody)
|
must.Eq(t, []byte("not found"), respBody)
|
||||||
require.Equal(t, 404, resp.Code)
|
must.Eq(t, 404, resp.Code)
|
||||||
|
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CodedError
|
// CodedError
|
||||||
@@ -308,8 +310,9 @@ func TestWrapNonJSON_Error(t *testing.T) {
|
|||||||
req, _ := http.NewRequest(http.MethodGet, "/v1/kv/key", nil)
|
req, _ := http.NewRequest(http.MethodGet, "/v1/kv/key", nil)
|
||||||
s.Server.wrapNonJSON(handlerCodedErr)(resp, req)
|
s.Server.wrapNonJSON(handlerCodedErr)(resp, req)
|
||||||
respBody, _ := io.ReadAll(resp.Body)
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
require.Equal(t, []byte("unprocessable"), respBody)
|
must.Eq(t, []byte("unprocessable"), respBody)
|
||||||
require.Equal(t, 422, resp.Code)
|
must.Eq(t, 422, resp.Code)
|
||||||
|
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -381,7 +384,8 @@ func TestPermissionDenied(t *testing.T) {
|
|||||||
|
|
||||||
req, _ := http.NewRequest(http.MethodGet, "/v1/job/foo", nil)
|
req, _ := http.NewRequest(http.MethodGet, "/v1/job/foo", nil)
|
||||||
s.Server.wrap(handler)(resp, req)
|
s.Server.wrap(handler)(resp, req)
|
||||||
assert.Equal(t, resp.Code, 403)
|
must.Eq(t, resp.Code, 403)
|
||||||
|
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
|
||||||
}
|
}
|
||||||
|
|
||||||
// When remote RPC is used the errors have "rpc error: " prependend
|
// When remote RPC is used the errors have "rpc error: " prependend
|
||||||
@@ -393,7 +397,8 @@ func TestPermissionDenied(t *testing.T) {
|
|||||||
|
|
||||||
req, _ := http.NewRequest(http.MethodGet, "/v1/job/foo", nil)
|
req, _ := http.NewRequest(http.MethodGet, "/v1/job/foo", nil)
|
||||||
s.Server.wrap(handler)(resp, req)
|
s.Server.wrap(handler)(resp, req)
|
||||||
assert.Equal(t, resp.Code, 403)
|
must.Eq(t, resp.Code, 403)
|
||||||
|
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -411,7 +416,8 @@ func TestTokenNotFound(t *testing.T) {
|
|||||||
urlStr := "/v1/job/foo"
|
urlStr := "/v1/job/foo"
|
||||||
req, _ := http.NewRequest(http.MethodGet, urlStr, nil)
|
req, _ := http.NewRequest(http.MethodGet, urlStr, nil)
|
||||||
s.Server.wrap(handler)(resp, req)
|
s.Server.wrap(handler)(resp, req)
|
||||||
assert.Equal(t, resp.Code, 403)
|
must.Eq(t, resp.Code, 403)
|
||||||
|
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseWait(t *testing.T) {
|
func TestParseWait(t *testing.T) {
|
||||||
@@ -421,20 +427,11 @@ func TestParseWait(t *testing.T) {
|
|||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet,
|
req, err := http.NewRequest(http.MethodGet,
|
||||||
"/v1/catalog/nodes?wait=60s&index=1000", nil)
|
"/v1/catalog/nodes?wait=60s&index=1000", nil)
|
||||||
if err != nil {
|
must.NoError(t, err)
|
||||||
t.Fatalf("err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if d := parseWait(resp, req, &b); d {
|
must.False(t, parseWait(resp, req, &b))
|
||||||
t.Fatalf("unexpected done")
|
must.Eq(t, b.MinQueryIndex, 1000)
|
||||||
}
|
must.Eq(t, b.MaxQueryTime, 60*time.Second)
|
||||||
|
|
||||||
if b.MinQueryIndex != 1000 {
|
|
||||||
t.Fatalf("Bad: %v", b)
|
|
||||||
}
|
|
||||||
if b.MaxQueryTime != 60*time.Second {
|
|
||||||
t.Fatalf("Bad: %v", b)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseWait_InvalidTime(t *testing.T) {
|
func TestParseWait_InvalidTime(t *testing.T) {
|
||||||
@@ -444,17 +441,11 @@ func TestParseWait_InvalidTime(t *testing.T) {
|
|||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet,
|
req, err := http.NewRequest(http.MethodGet,
|
||||||
"/v1/catalog/nodes?wait=60foo&index=1000", nil)
|
"/v1/catalog/nodes?wait=60foo&index=1000", nil)
|
||||||
if err != nil {
|
must.NoError(t, err)
|
||||||
t.Fatalf("err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if d := parseWait(resp, req, &b); !d {
|
must.True(t, parseWait(resp, req, &b))
|
||||||
t.Fatalf("expected done")
|
must.Eq(t, resp.Code, 400)
|
||||||
}
|
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
|
||||||
|
|
||||||
if resp.Code != 400 {
|
|
||||||
t.Fatalf("bad code: %v", resp.Code)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseWait_InvalidIndex(t *testing.T) {
|
func TestParseWait_InvalidIndex(t *testing.T) {
|
||||||
@@ -464,17 +455,11 @@ func TestParseWait_InvalidIndex(t *testing.T) {
|
|||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet,
|
req, err := http.NewRequest(http.MethodGet,
|
||||||
"/v1/catalog/nodes?wait=60s&index=foo", nil)
|
"/v1/catalog/nodes?wait=60s&index=foo", nil)
|
||||||
if err != nil {
|
must.NoError(t, err)
|
||||||
t.Fatalf("err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if d := parseWait(resp, req, &b); !d {
|
must.True(t, parseWait(resp, req, &b))
|
||||||
t.Fatalf("expected done")
|
must.Eq(t, resp.Code, 400)
|
||||||
}
|
must.Eq(t, resp.Header().Get(contentTypeHeader), plainContentType)
|
||||||
|
|
||||||
if resp.Code != 400 {
|
|
||||||
t.Fatalf("bad code: %v", resp.Code)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseConsistency(t *testing.T) {
|
func TestParseConsistency(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user