diff --git a/app/proxy/proxy.go b/app/proxy/proxy.go index 1b92827..ce3a531 100644 --- a/app/proxy/proxy.go +++ b/app/proxy/proxy.go @@ -89,10 +89,11 @@ func (h *Http) Run(ctx context.Context) error { h.signatureHandler(), h.pingHandler, h.healthMiddleware, + // R.Headers(h.ProxyHeaders...), + h.headersHandler(h.ProxyHeaders), h.accessLogHandler(h.AccessLog), h.stdoutLogHandler(h.StdOutEnabled, logger.New(logger.Log(log.Default()), logger.Prefix("[INFO]")).Handler), R.SizeLimit(h.MaxBodySize), - R.Headers(h.ProxyHeaders...), h.gzipHandler(), ) @@ -238,6 +239,26 @@ func (h *Http) signatureHandler() func(next http.Handler) http.Handler { } } +func (h *Http) headersHandler(headers []string) func(next http.Handler) http.Handler { + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(h.ProxyHeaders) == 0 { + next.ServeHTTP(w, r) + return + } + for _, h := range headers { + elems := strings.Split(h, ":") + if len(elems) != 2 { + continue + } + w.Header().Set(strings.TrimSpace(elems[0]), strings.TrimSpace(elems[1])) + } + next.ServeHTTP(w, r) + }) + } +} + func (h *Http) accessLogHandler(wr io.Writer) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return handlers.CombinedLoggingHandler(wr, next) diff --git a/app/proxy/proxy_test.go b/app/proxy/proxy_test.go index 84c860a..2a18baa 100644 --- a/app/proxy/proxy_test.go +++ b/app/proxy/proxy_test.go @@ -21,7 +21,7 @@ import ( func TestHttp_Do(t *testing.T) { port := rand.Intn(10000) + 40000 h := Http{Timeouts: Timeouts{ResponseHeader: 200 * time.Millisecond}, Address: fmt.Sprintf("127.0.0.1:%d", port), - AccessLog: io.Discard, Signature: true} + AccessLog: io.Discard, Signature: true, ProxyHeaders: []string{"hh1:vv1", "hh2:vv2"}} ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() @@ -66,6 +66,8 @@ func TestHttp_Do(t *testing.T) { assert.Equal(t, "response /567/something", string(body)) assert.Equal(t, "reproxy", resp.Header.Get("App-Name")) assert.Equal(t, "v1", resp.Header.Get("h1")) + assert.Equal(t, "vv1", resp.Header.Get("hh1")) + assert.Equal(t, "vv2", resp.Header.Get("hh2")) } {