From bbbd24dd536c8acd84c3209192eec59c37147d72 Mon Sep 17 00:00:00 2001 From: Umputun Date: Sat, 17 Apr 2021 13:11:10 -0500 Subject: [PATCH] fix missing url.Host forward --- app/main_test.go | 15 +++++++++++++-- app/proxy/proxy.go | 5 +++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/app/main_test.go b/app/main_test.go index 0ea87b5..aa65140 100644 --- a/app/main_test.go +++ b/app/main_test.go @@ -21,6 +21,7 @@ func Test_Main(t *testing.T) { port := chooseRandomUnusedPort() os.Args = []string{"test", "--static.enabled", "--static.rule=*,/svc1, https://httpbin.org/get,https://feedmaster.umputun.com/ping", + "--static.rule=*,/svc2/(.*), https://echo.umputun.com/$1,https://feedmaster.umputun.com/ping", "--dbg", "--logger.stdout", "--listen=127.0.0.1:" + strconv.Itoa(port), "--signature"} done := make(chan struct{}) @@ -62,11 +63,21 @@ func Test_Main(t *testing.T) { assert.Equal(t, 200, resp.StatusCode) body, err := ioutil.ReadAll(resp.Body) assert.NoError(t, err) - assert.Contains(t, string(body), `"Host": "127.0.0.1"`) + assert.Contains(t, string(body), `"Host": "httpbin.org"`) } { client := http.Client{Timeout: 10 * time.Second} - resp, err := client.Get(fmt.Sprintf("http://127.0.0.1:%d/bas", port)) + resp, err := client.Get(fmt.Sprintf("http://127.0.0.1:%d/svc2/test", port)) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, 200, resp.StatusCode) + body, err := ioutil.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Contains(t, string(body), `echo echo 123`) + } + { + client := http.Client{Timeout: 10 * time.Second} + resp, err := client.Get(fmt.Sprintf("http://127.0.0.1:%d/bad", port)) require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusBadGateway, resp.StatusCode) diff --git a/app/proxy/proxy.go b/app/proxy/proxy.go index 3b97998..5de5f18 100644 --- a/app/proxy/proxy.go +++ b/app/proxy/proxy.go @@ -151,11 +151,12 @@ func (h *Http) proxyHandler() http.HandlerFunc { Director: func(r *http.Request) { ctx := r.Context() uu := ctx.Value(contextKey("url")).(*url.URL) + r.Header.Add("X-Forwarded-Host", uu.Host) + r.Header.Set("X-Origin-Host", r.Host) r.URL.Path = uu.Path r.URL.Host = uu.Host r.URL.Scheme = uu.Scheme - r.Header.Add("X-Forwarded-Host", uu.Host) - r.Header.Add("X-Origin-Host", r.Host) + r.Host = uu.Host h.setXRealIP(r) }, Transport: &http.Transport{