add OverrideHeadersIn and OverrideHeadersOut support

This is an attempt to address #108

Instead of dedicated support of header`s removal it allows to return a flag indicating plugin's full control over headers. With this flag set, the conductor won't mix response headers with originals but rather will count on a plugin to provide all the headers.
This commit is contained in:
Umputun
2021-09-07 00:23:47 -05:00
parent 506ded3ad4
commit c7a2308267
4 changed files with 31 additions and 14 deletions

View File

@@ -87,6 +87,20 @@ func (c *Conductor) Run(ctx context.Context) error {
// Failed plugin calls ignored. Status code from any plugin may stop the chain of calls if not 200. This is needed
// to allow plugins like auth which has to terminate request in some cases.
func (c *Conductor) Middleware(next http.Handler) http.Handler {
setHeaders := func(src, alt http.Header, overrideHeaders bool) {
if overrideHeaders {
for k := range src {
src.Del(k)
}
}
for k, vv := range alt {
for _, v := range vv {
src.Add(k, v)
}
}
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c.lock.RLock()
@@ -101,16 +115,10 @@ func (c *Conductor) Middleware(next http.Handler) http.Handler {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
for k, vv := range reply.HeadersIn {
for _, v := range vv {
r.Header.Add(k, v)
}
}
for k, vv := range reply.HeadersOut {
for _, v := range vv {
w.Header().Add(k, v)
}
}
setHeaders(r.Header, reply.HeadersIn, reply.OverrideHeadersIn)
setHeaders(w.Header(), reply.HeadersOut, reply.OverrideHeadersOut)
if reply.StatusCode >= 400 {
c.lock.RUnlock()
http.Error(w, http.StatusText(reply.StatusCode), reply.StatusCode)

View File

@@ -223,6 +223,7 @@ func TestConductor_Middleware(t *testing.T) {
reply.(*lib.Response).StatusCode = 200
reply.(*lib.Response).HeadersOut = map[string][]string{}
reply.(*lib.Response).HeadersOut.Set("k11", "v11")
reply.(*lib.Response).OverrideHeadersOut = true
}
if serviceMethod == "Test1.Mw3" {
t.Fatal("shouldn't be called")
@@ -285,9 +286,10 @@ func TestConductor_Middleware(t *testing.T) {
}))
h.ServeHTTP(w, rr)
assert.Equal(t, 200, w.Result().StatusCode)
assert.Equal(t, "v1", w.Result().Header.Get("k1"))
assert.Equal(t, "", w.Result().Header.Get("k1"))
assert.Equal(t, "v2", w.Result().Header.Get("k2"))
assert.Equal(t, "v21", rr.Header.Get("k21"))
assert.Equal(t, "v11", w.Result().Header.Get("k11"))
t.Logf("req: %+v", rr)
t.Logf("resp: %+v", w.Result())
}