diff --git a/app/proxy/proxy.go b/app/proxy/proxy.go index e9f101a..d9c3e77 100644 --- a/app/proxy/proxy.go +++ b/app/proxy/proxy.go @@ -106,7 +106,7 @@ func (h *Http) Run(ctx context.Context) error { h.headersHandler(h.ProxyHeaders), h.accessLogHandler(h.AccessLog), h.stdoutLogHandler(h.StdOutEnabled, logger.New(logger.Log(log.Default()), logger.Prefix("[INFO]")).Handler), - R.SizeLimit(h.MaxBodySize), + h.maxReqSizeHandler(h.MaxBodySize), h.gzipHandler(), ) @@ -335,6 +335,17 @@ func (h *Http) stdoutLogHandler(enable bool, lh func(next http.Handler) http.Han } } +func (h *Http) maxReqSizeHandler(maxSize int64) func(next http.Handler) http.Handler { + if maxSize <= 0 { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + }) + } + } + return R.SizeLimit(maxSize) +} + func (h *Http) makeHTTPServer(addr string, router http.Handler) *http.Server { return &http.Server{ Addr: addr, diff --git a/app/proxy/proxy_test.go b/app/proxy/proxy_test.go index 3d79c11..fcaa915 100644 --- a/app/proxy/proxy_test.go +++ b/app/proxy/proxy_test.go @@ -1,6 +1,7 @@ package proxy import ( + "bytes" "context" "fmt" "io" @@ -248,7 +249,70 @@ func TestHttp_DoWithAssetRules(t *testing.T) { assert.Equal(t, "", resp.Header.Get("h1")) assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control")) } +} +func TestHttp_DoLimitedReq(t *testing.T) { + port := rand.Intn(10000) + 40000 + h := Http{Timeouts: Timeouts{ResponseHeader: 200 * time.Millisecond}, Address: fmt.Sprintf("127.0.0.1:%d", port), + AccessLog: io.Discard, Signature: true, ProxyHeaders: []string{"hh1:vv1", "hh2:vv2"}, StdOutEnabled: true, + Reporter: &ErrorReporter{Nice: true}, MaxBodySize: 10} + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + ds := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("req: %v", r) + w.Header().Add("h1", "v1") + require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP")) + fmt.Fprintf(w, "response %s", r.URL.String()) + })) + + svc := discovery.NewService([]discovery.Provider{ + &provider.Static{Rules: []string{ + "localhost,^/api/(.*)," + ds.URL + "/123/$1,", + "127.0.0.1,^/api/(.*)," + ds.URL + "/567/$1,", + }, + }}, time.Millisecond*10) + + go func() { + _ = svc.Run(context.Background()) + }() + + time.Sleep(50 * time.Millisecond) + h.Matcher, h.Metrics = svc, mgmt.NewMetrics() + + go func() { + _ = h.Run(ctx) + }() + time.Sleep(10 * time.Millisecond) + + client := http.Client{} + + { + req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg")) + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + t.Logf("%+v", resp.Header) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "response /567/something", string(body)) + assert.Equal(t, "reproxy", resp.Header.Get("App-Name")) + assert.Equal(t, "v1", resp.Header.Get("h1")) + assert.Equal(t, "vv1", resp.Header.Get("hh1")) + assert.Equal(t, "vv2", resp.Header.Get("hh2")) + } + + { + req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg1234567")) + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusRequestEntityTooLarge, resp.StatusCode) + } } func TestHttp_toHttp(t *testing.T) {