diff --git a/app/discovery/discovery.go b/app/discovery/discovery.go index 15df868..86136b0 100644 --- a/app/discovery/discovery.go +++ b/app/discovery/discovery.go @@ -37,6 +37,7 @@ type URLMapper struct { PingURL string MatchType MatchType RedirectType RedirectType + OnlyFromIPs []string AssetsLocation string // local FS root location AssetsWebRoot string // web root location @@ -484,16 +485,6 @@ func (s *Service) mergeEvents(ctx context.Context, chs ...<-chan ProviderID) <-c return out } -// Contains checks if the input string (e) in the given slice -func Contains(e string, s []string) bool { - for _, a := range s { - if a == e { - return true - } - } - return false -} - // IsAlive indicates whether mapper destination is alive func (m URLMapper) IsAlive() bool { return !m.dead @@ -515,3 +506,24 @@ func (m URLMapper) ping() (string, error) { return "", err } + +// Contains checks if the input string (e) in the given slice +func Contains(e string, s []string) bool { + for _, a := range s { + if a == e { + return true + } + } + return false +} + +// ParseOnlyFrom parses comma separated list of IPs +func ParseOnlyFrom(s string) (res []string) { + if s == "" { + return []string{} + } + for _, v := range strings.Split(s, ",") { + res = append(res, strings.TrimSpace(v)) + } + return res +} diff --git a/app/discovery/discovery_test.go b/app/discovery/discovery_test.go index d6f46b1..4bb0de8 100644 --- a/app/discovery/discovery_test.go +++ b/app/discovery/discovery_test.go @@ -39,7 +39,7 @@ func TestService_Run(t *testing.T) { ListFunc: func() ([]URLMapper, error) { return []URLMapper{ {Server: "localhost", SrcMatch: *regexp.MustCompile("/api/svc3/xyz"), - Dst: "http://127.0.0.3:8080/blah3/xyz", ProviderID: PIDocker}, + Dst: "http://127.0.0.3:8080/blah3/xyz", ProviderID: PIDocker, OnlyFromIPs: []string{"127.0.0.1"}}, }, nil }, } @@ -66,6 +66,7 @@ func TestService_Run(t *testing.T) { assert.Equal(t, "localhost", mappers[0].Server) assert.Equal(t, "/api/svc3/xyz", mappers[0].SrcMatch.String()) assert.Equal(t, "http://127.0.0.3:8080/blah3/xyz", mappers[0].Dst) + assert.Equal(t, []string{"127.0.0.1"}, mappers[0].OnlyFromIPs) assert.Equal(t, 1, len(p1.EventsCalls())) assert.Equal(t, 1, len(p2.EventsCalls())) @@ -104,7 +105,8 @@ func TestService_Match(t *testing.T) { }, ListFunc: func() ([]URLMapper, error) { return []URLMapper{ - {SrcMatch: *regexp.MustCompile("/api/svc3/xyz"), Dst: "http://127.0.0.3:8080/blah3/xyz", ProviderID: PIDocker}, + {SrcMatch: *regexp.MustCompile("/api/svc3/xyz"), Dst: "http://127.0.0.3:8080/blah3/xyz", + OnlyFromIPs: []string{"127.0.0.1", "192.168.1.0/24"}, ProviderID: PIDocker}, {SrcMatch: *regexp.MustCompile("/web"), Dst: "/var/web", ProviderID: PIDocker, MatchType: MTStatic, AssetsWebRoot: "/web", AssetsLocation: "/var/web"}, {SrcMatch: *regexp.MustCompile("/www/"), Dst: "/var/web", ProviderID: PIDocker, MatchType: MTStatic, @@ -131,9 +133,11 @@ func TestService_Match(t *testing.T) { res Matches }{ {"example.com", "/api/svc3/xyz/something", Matches{MTProxy, []MatchedRoute{ - {Destination: "http://127.0.0.3:8080/blah3/xyz/something", Alive: true}}}}, + {Destination: "http://127.0.0.3:8080/blah3/xyz/something", Alive: true, + Mapper: URLMapper{OnlyFromIPs: []string{"127.0.0.1", "192.168.1.0/24"}}}}}}, {"example.com", "/api/svc3/xyz", Matches{MTProxy, []MatchedRoute{{ - Destination: "http://127.0.0.3:8080/blah3/xyz", Alive: true}}}}, + Destination: "http://127.0.0.3:8080/blah3/xyz", Alive: true, + Mapper: URLMapper{OnlyFromIPs: []string{"127.0.0.1", "192.168.1.0/24"}}}}}}, {"abc.example.com", "/api/svc1/1234", Matches{MTProxy, []MatchedRoute{ {Destination: "http://127.0.0.1:8080/blah1/1234", Alive: true}}}}, {"zzz.example.com", "/aaa/api/svc1/1234", Matches{MTProxy, nil}}, @@ -167,6 +171,7 @@ func TestService_Match(t *testing.T) { for i := 0; i < len(res.Routes); i++ { assert.Equal(t, tt.res.Routes[i].Alive, res.Routes[i].Alive) assert.Equal(t, tt.res.Routes[i].Destination, res.Routes[i].Destination) + assert.Equal(t, tt.res.Routes[i].Mapper.OnlyFromIPs, res.Routes[i].Mapper.OnlyFromIPs) } assert.Equal(t, tt.res.MatchType, res.MatchType) }) @@ -608,3 +613,39 @@ func TestCheckHealth(t *testing.T) { assert.NoError(t, res[ts.URL]) assert.NoError(t, res[ts2.URL]) } + +func TestParseOnlyFrom(t *testing.T) { + tbl := []struct { + name string + input string + expected []string + }{ + { + name: "empty string", + input: "", + expected: []string{}, + }, + { + name: "single IP", + input: "192.168.1.1", + expected: []string{"192.168.1.1"}, + }, + { + name: "multiple IPs", + input: "192.168.1.1, 192.168.1.2, 192.168.1.3, 10.0.0.0/16", + expected: []string{"192.168.1.1", "192.168.1.2", "192.168.1.3", "10.0.0.0/16"}, + }, + { + name: "multiple IPs with extra spaces", + input: " 192.168.1.1 , 192.168.1.2 , 192.168.1.3 ", + expected: []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}, + }, + } + + for _, tt := range tbl { + t.Run(tt.name, func(t *testing.T) { + result := ParseOnlyFrom(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/app/discovery/provider/consulcatalog/consulcatalog.go b/app/discovery/provider/consulcatalog/consulcatalog.go index 258681e..6beda13 100644 --- a/app/discovery/provider/consulcatalog/consulcatalog.go +++ b/app/discovery/provider/consulcatalog/consulcatalog.go @@ -3,12 +3,13 @@ package consulcatalog import ( "context" "fmt" - "github.com/umputun/reproxy/app/discovery" "log" "regexp" "sort" "strings" "time" + + "github.com/umputun/reproxy/app/discovery" ) //go:generate moq -out consul_client_mock.go -skip-ensure -fmt goimports . ConsulClient @@ -139,6 +140,7 @@ func (cc *ConsulCatalog) List() ([]discovery.URLMapper, error) { destURL := fmt.Sprintf("http://%s:%d/$1", c.ServiceAddress, c.ServicePort) pingURL := fmt.Sprintf("http://%s:%d/ping", c.ServiceAddress, c.ServicePort) server := "*" + onlyFrom := []string{} if v, ok := c.Labels["reproxy.enabled"]; ok && (v == "true" || v == "yes" || v == "1") { enabled = true @@ -159,6 +161,10 @@ func (cc *ConsulCatalog) List() ([]discovery.URLMapper, error) { server = v } + if v, ok := c.Labels["reproxy.remote"]; ok { + onlyFrom = discovery.ParseOnlyFrom(v) + } + if v, ok := c.Labels["reproxy.ping"]; ok { enabled = true pingURL = fmt.Sprintf("http://%s:%d%s", c.ServiceAddress, c.ServicePort, v) @@ -177,7 +183,7 @@ func (cc *ConsulCatalog) List() ([]discovery.URLMapper, error) { // server label may have multiple, comma separated servers for _, srv := range strings.Split(server, ",") { res = append(res, discovery.URLMapper{Server: strings.TrimSpace(srv), SrcMatch: *srcRegex, Dst: destURL, - PingURL: pingURL, ProviderID: discovery.PIConsulCatalog}) + OnlyFromIPs: onlyFrom, PingURL: pingURL, ProviderID: discovery.PIConsulCatalog}) } } diff --git a/app/discovery/provider/consulcatalog/consulcatalog_test.go b/app/discovery/provider/consulcatalog/consulcatalog_test.go index 7e05610..e5c973d 100644 --- a/app/discovery/provider/consulcatalog/consulcatalog_test.go +++ b/app/discovery/provider/consulcatalog/consulcatalog_test.go @@ -3,12 +3,14 @@ package consulcatalog import ( "context" "fmt" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/umputun/reproxy/app/discovery" "sort" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/umputun/reproxy/app/discovery" ) func TestNew(t *testing.T) { @@ -62,7 +64,8 @@ func TestConsulCatalog_List(t *testing.T) { ServiceAddress: "addr3", ServicePort: 3000, Labels: map[string]string{"reproxy.route": "^/api/123/(.*)", "reproxy.dest": "/blah/$1", - "reproxy.server": "example.com,domain.com", "reproxy.ping": "/ping", "reproxy.enabled": "yes"}, + "reproxy.server": "example.com,domain.com", "reproxy.ping": "/ping", + "reproxy.enabled": "yes", "reproxy.remote": "127.0.0.1, 192.168.1.0/24"}, }, { ServiceID: "id4", @@ -91,21 +94,25 @@ func TestConsulCatalog_List(t *testing.T) { assert.Equal(t, "http://addr3:3000/blah/$1", res[0].Dst) assert.Equal(t, "example.com", res[0].Server) assert.Equal(t, "http://addr3:3000/ping", res[0].PingURL) + assert.Equal(t, []string{"127.0.0.1", "192.168.1.0/24"}, res[0].OnlyFromIPs) assert.Equal(t, "^/api/123/(.*)", res[1].SrcMatch.String()) assert.Equal(t, "http://addr3:3000/blah/$1", res[1].Dst) assert.Equal(t, "domain.com", res[1].Server) assert.Equal(t, "http://addr3:3000/ping", res[1].PingURL) + assert.Equal(t, []string{"127.0.0.1", "192.168.1.0/24"}, res[1].OnlyFromIPs) assert.Equal(t, "^/(.*)", res[2].SrcMatch.String()) assert.Equal(t, "http://addr44:4000/$1", res[2].Dst) assert.Equal(t, "http://addr44:4000/ping", res[2].PingURL) assert.Equal(t, "*", res[2].Server) + assert.Equal(t, []string{}, res[2].OnlyFromIPs) assert.Equal(t, "^/(.*)", res[3].SrcMatch.String()) assert.Equal(t, "http://addr2:2000/$1", res[3].Dst) assert.Equal(t, "http://addr2:2000/ping", res[3].PingURL) assert.Equal(t, "*", res[3].Server) + assert.Equal(t, []string{}, res[3].OnlyFromIPs) } func TestConsulCatalog_serviceListWasChanged(t *testing.T) { diff --git a/app/discovery/provider/docker.go b/app/discovery/provider/docker.go index d7354b9..6ced47d 100644 --- a/app/discovery/provider/docker.go +++ b/app/discovery/provider/docker.go @@ -103,6 +103,7 @@ func (d *Docker) parseContainerInfo(c containerInfo) (res []discovery.URLMapper) // defaults destURL, pingURL, server := fmt.Sprintf("http://%s:%d/$1", c.IP, port), fmt.Sprintf("http://%s:%d/ping", c.IP, port), "*" assetsWebRoot, assetsLocation, assetsSPA := "", "", false + onlyFrom := []string{} if d.AutoAPI && n == 0 { enabled = true @@ -133,6 +134,10 @@ func (d *Docker) parseContainerInfo(c containerInfo) (res []discovery.URLMapper) server = v } + if v, ok := d.labelN(c.Labels, n, "remote"); ok { + onlyFrom = discovery.ParseOnlyFrom(v) + } + if v, ok := d.labelN(c.Labels, n, "ping"); ok { enabled = true if strings.HasPrefix(v, "http://") || strings.HasPrefix(v, "https://") { @@ -171,7 +176,7 @@ func (d *Docker) parseContainerInfo(c containerInfo) (res []discovery.URLMapper) // docker server label may have multiple, comma separated servers for _, srv := range strings.Split(server, ",") { mp := discovery.URLMapper{Server: strings.TrimSpace(srv), SrcMatch: *srcRegex, Dst: destURL, - PingURL: pingURL, ProviderID: discovery.PIDocker, MatchType: discovery.MTProxy} + PingURL: pingURL, OnlyFromIPs: onlyFrom, ProviderID: discovery.PIDocker, MatchType: discovery.MTProxy} // for assets we add the second proxy mapping only if explicitly requested if assetsWebRoot != "" && explicit { diff --git a/app/discovery/provider/docker_test.go b/app/discovery/provider/docker_test.go index 8295145..953a7b6 100644 --- a/app/discovery/provider/docker_test.go +++ b/app/discovery/provider/docker_test.go @@ -30,7 +30,7 @@ func TestDocker_List(t *testing.T) { { Name: "c1", State: "running", IP: "127.0.0.2", Ports: []int{12345}, Labels: map[string]string{"reproxy.route": "^/api/123/(.*)", "reproxy.dest": "/blah/$1", - "reproxy.server": "example.com", "reproxy.ping": "/ping"}, + "reproxy.server": "example.com", "reproxy.ping": "/ping", "reproxy.remote": "192.168.1.0/24, 127.0.0.1"}, }, { Name: "c1", State: "running", IP: "127.0.0.21", Ports: []int{12345}, @@ -64,21 +64,25 @@ func TestDocker_List(t *testing.T) { assert.Equal(t, "http://127.0.0.2:12345/blah/$1", res[0].Dst) assert.Equal(t, "example.com", res[0].Server) assert.Equal(t, "http://127.0.0.2:12345/ping", res[0].PingURL) + assert.Equal(t, []string{"192.168.1.0/24", "127.0.0.1"}, res[0].OnlyFromIPs) assert.Equal(t, "^/api/90/(.*)", res[1].SrcMatch.String()) assert.Equal(t, "http://example.com/blah/$1", res[1].Dst) assert.Equal(t, "https://example.com//ping", res[1].PingURL) assert.Equal(t, "example.com", res[1].Server) + assert.Equal(t, []string{}, res[1].OnlyFromIPs) assert.Equal(t, "^/c2/(.*)", res[2].SrcMatch.String()) assert.Equal(t, "http://127.0.0.3:12346/$1", res[2].Dst) assert.Equal(t, "http://127.0.0.3:12346/ping", res[2].PingURL) assert.Equal(t, "*", res[2].Server) + assert.Equal(t, []string{}, res[2].OnlyFromIPs) assert.Equal(t, "^/a/(.*)", res[3].SrcMatch.String()) assert.Equal(t, "http://127.0.0.2:12348/a/$1", res[3].Dst) assert.Equal(t, "http://127.0.0.2:12348/ping", res[3].PingURL) assert.Equal(t, "example.com", res[3].Server) + assert.Equal(t, []string{}, res[3].OnlyFromIPs) } func TestDocker_ListMulti(t *testing.T) { diff --git a/app/discovery/provider/file.go b/app/discovery/provider/file.go index 41fa485..c78bed5 100644 --- a/app/discovery/provider/file.go +++ b/app/discovery/provider/file.go @@ -84,6 +84,7 @@ func (d *File) List() (res []discovery.URLMapper, err error) { Ping string `yaml:"ping"` AssetsEnabled bool `yaml:"assets"` AssetsSPA bool `yaml:"spa"` + OnlyFrom string `yaml:"remote"` } fh, err := os.Open(d.FileName) if err != nil { @@ -106,12 +107,13 @@ func (d *File) List() (res []discovery.URLMapper, err error) { srv = "*" } mapper := discovery.URLMapper{ - Server: srv, - SrcMatch: *rx, - Dst: f.Dest, - PingURL: f.Ping, - ProviderID: discovery.PIFile, - MatchType: discovery.MTProxy, + Server: srv, + SrcMatch: *rx, + Dst: f.Dest, + PingURL: f.Ping, + ProviderID: discovery.PIFile, + MatchType: discovery.MTProxy, + OnlyFromIPs: discovery.ParseOnlyFrom(f.OnlyFrom), } if f.AssetsEnabled || f.AssetsSPA { mapper.MatchType = discovery.MTStatic diff --git a/app/discovery/provider/file_test.go b/app/discovery/provider/file_test.go index 4832abd..4f1ac2c 100644 --- a/app/discovery/provider/file_test.go +++ b/app/discovery/provider/file_test.go @@ -113,18 +113,21 @@ func TestFile_List(t *testing.T) { assert.Equal(t, "", res[0].PingURL) assert.Equal(t, "srv.example.com", res[0].Server) assert.Equal(t, discovery.MTProxy, res[0].MatchType) + assert.Equal(t, []string{}, res[0].OnlyFromIPs) assert.Equal(t, "^/api/svc1/(.*)", res[1].SrcMatch.String()) assert.Equal(t, "http://127.0.0.1:8080/blah1/$1", res[1].Dst) assert.Equal(t, "", res[1].PingURL) assert.Equal(t, "*", res[1].Server) assert.Equal(t, discovery.MTProxy, res[1].MatchType) + assert.Equal(t, []string{}, res[0].OnlyFromIPs) assert.Equal(t, "/api/svc3/xyz", res[2].SrcMatch.String()) assert.Equal(t, "http://127.0.0.3:8080/blah3/xyz", res[2].Dst) assert.Equal(t, "http://127.0.0.3:8080/ping", res[2].PingURL) assert.Equal(t, "*", res[2].Server) assert.Equal(t, discovery.MTProxy, res[2].MatchType) + assert.Equal(t, []string{}, res[0].OnlyFromIPs) assert.Equal(t, "/web/", res[3].SrcMatch.String()) assert.Equal(t, "/var/web", res[3].Dst) @@ -132,6 +135,7 @@ func TestFile_List(t *testing.T) { assert.Equal(t, "*", res[3].Server) assert.Equal(t, discovery.MTStatic, res[3].MatchType) assert.Equal(t, false, res[3].AssetsSPA) + assert.Equal(t, []string{"192.168.1.0/24", "124.0.0.1"}, res[3].OnlyFromIPs) assert.Equal(t, "/web2/", res[4].SrcMatch.String()) assert.Equal(t, "/var/web2", res[4].Dst) @@ -139,4 +143,5 @@ func TestFile_List(t *testing.T) { assert.Equal(t, "*", res[4].Server) assert.Equal(t, discovery.MTStatic, res[4].MatchType) assert.Equal(t, true, res[4].AssetsSPA) + assert.Equal(t, []string{}, res[0].OnlyFromIPs) } diff --git a/app/discovery/provider/static.go b/app/discovery/provider/static.go index 01f5a83..1e0a68e 100644 --- a/app/discovery/provider/static.go +++ b/app/discovery/provider/static.go @@ -9,7 +9,7 @@ import ( "github.com/umputun/reproxy/app/discovery" ) -// Static provider, rules are server,from,to +// Static provider, rules are server,source_url,destination[,ping] type Static struct { Rules []string // each rule is 4 elements comma separated - server,source_url,destination,ping } diff --git a/app/discovery/provider/testdata/config.yml b/app/discovery/provider/testdata/config.yml index d0662bb..85e618f 100644 --- a/app/discovery/provider/testdata/config.yml +++ b/app/discovery/provider/testdata/config.yml @@ -1,7 +1,7 @@ default: - {route: "^/api/svc1/(.*)", dest: "http://127.0.0.1:8080/blah1/$1"} - {route: "/api/svc3/xyz", dest: "http://127.0.0.3:8080/blah3/xyz", "ping": "http://127.0.0.3:8080/ping"} - - {route: "/web/", dest: "/var/web", "assets": yes} + - {route: "/web/", dest: "/var/web", "assets": yes, "remote": "192.168.1.0/24, 124.0.0.1"} - {route: "/web2/", dest: "/var/web2", "spa": yes} srv.example.com: - {route: "^/api/svc2/(.*)", dest: "http://127.0.0.2:8080/blah2/$1/abc"} diff --git a/app/main.go b/app/main.go index 5190731..cb078f9 100644 --- a/app/main.go +++ b/app/main.go @@ -29,14 +29,14 @@ import ( ) var opts struct { - Listen string `short:"l" long:"listen" env:"LISTEN" description:"listen on host:port (default: 0.0.0.0:8080/8443 under docker, 127.0.0.1:80/443 without)"` - MaxSize string `short:"m" long:"max" env:"MAX_SIZE" default:"64K" description:"max request size"` - GzipEnabled bool `short:"g" long:"gzip" env:"GZIP" description:"enable gz compression"` - ProxyHeaders []string `short:"x" long:"header" description:"outgoing proxy headers to add"` // env HEADER split in code to allow , inside "" - DropHeaders []string `long:"drop-header" env:"DROP_HEADERS" description:"incoming headers to drop" env-delim:","` - AuthBasicHtpasswd string `long:"basic-htpasswd" env:"BASIC_HTPASSWD" description:"htpasswd file for basic auth"` - - LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" default:"random"` // nolint + Listen string `short:"l" long:"listen" env:"LISTEN" description:"listen on host:port (default: 0.0.0.0:8080/8443 under docker, 127.0.0.1:80/443 without)"` + MaxSize string `short:"m" long:"max" env:"MAX_SIZE" default:"64K" description:"max request size"` + GzipEnabled bool `short:"g" long:"gzip" env:"GZIP" description:"enable gz compression"` + ProxyHeaders []string `short:"x" long:"header" description:"outgoing proxy headers to add"` // env HEADER split in code to allow , inside "" + DropHeaders []string `long:"drop-header" env:"DROP_HEADERS" description:"incoming headers to drop" env-delim:","` + AuthBasicHtpasswd string `long:"basic-htpasswd" env:"BASIC_HTPASSWD" description:"htpasswd file for basic auth"` + RemoteLookupHeaders bool `long:"remote-lookup-headers" env:"REMOTE_LOOKUP_HEADERS" description:"enable remote lookup headers"` + LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" default:"random"` // nolint SSL struct { Type string `long:"type" env:"TYPE" description:"ssl (auto) support" choice:"none" choice:"static" choice:"auto" default:"none"` // nolint @@ -273,10 +273,11 @@ func run() error { ThrottleUser: opts.Throttle.User, BasicAuthEnabled: len(basicAuthAllowed) > 0, BasicAuthAllowed: basicAuthAllowed, + OnlyFrom: makeOnlyFromMiddleware(), } err = px.Run(ctx) - if err != nil && err == http.ErrServerClosed { + if err != nil && errors.Is(err, http.ErrServerClosed) { log.Printf("[WARN] proxy server closed, %v", err) // nolint gocritic return nil } @@ -424,6 +425,13 @@ func makeLBSelector() func(len int) int { } } +func makeOnlyFromMiddleware() *proxy.OnlyFrom { + if opts.RemoteLookupHeaders { + return proxy.NewOnlyFrom(proxy.OFRealIP, proxy.OFForwarded, proxy.OFRemoteAddr) + } + return proxy.NewOnlyFrom(proxy.OFRemoteAddr) +} + func makeErrorReporter() (proxy.Reporter, error) { result := &proxy.ErrorReporter{ Nice: opts.ErrorReport.Enabled, diff --git a/app/proxy/handlers_test.go b/app/proxy/handlers_test.go index 2e7391a..9c7fc1d 100644 --- a/app/proxy/handlers_test.go +++ b/app/proxy/handlers_test.go @@ -244,5 +244,4 @@ func TestHttp_basicAuthHandler(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) }) } - } diff --git a/app/proxy/only_from.go b/app/proxy/only_from.go new file mode 100644 index 0000000..dc337ca --- /dev/null +++ b/app/proxy/only_from.go @@ -0,0 +1,149 @@ +package proxy + +import ( + "bytes" + "net" + "net/http" + "strings" + + "github.com/umputun/reproxy/app/discovery" +) + +// OnlyFrom implements middleware to allow access for a limited list of source IPs. +type OnlyFrom struct { + lookups []OFLookup +} + +// OFLookup defines lookup method for source IP. +type OFLookup string + +// enum of possible lookup methods +const ( + OFRemoteAddr OFLookup = "remote-addr" + OFRealIP OFLookup = "real-ip" + OFForwarded OFLookup = "forwarded" +) + +// NewOnlyFrom creates OnlyFrom middleware with given lookup methods. +func NewOnlyFrom(lookups ...OFLookup) *OnlyFrom { + return &OnlyFrom{lookups: lookups} +} + +// Handler implements middleware interface. +func (o *OnlyFrom) Handler(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + var allowedIPs []string + reqCtx := r.Context() + if reqCtx.Value(ctxMatch) != nil { // route match detected by matchHandler + match := reqCtx.Value(ctxMatch).(discovery.MatchedRoute) + allowedIPs = match.Mapper.OnlyFromIPs + } + if len(allowedIPs) == 0 { + // no restrictions if no ips defined + next.ServeHTTP(w, r) + return + } + + realIP := o.realIP(o.lookups, r) + if realIP != "" && o.matchRemoteIP(realIP, allowedIPs) { + next.ServeHTTP(w, r) + return + } + w.WriteHeader(http.StatusForbidden) + } + return http.HandlerFunc(fn) +} + +func (o *OnlyFrom) realIP(ipLookups []OFLookup, r *http.Request) string { + realIP := r.Header.Get("X-Real-IP") + forwardedFor := r.Header.Get("X-Forwarded-For") + + for _, lookup := range ipLookups { + + if lookup == OFRemoteAddr { + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr // can't parse, return as is + } + return ip + } + + if lookup == OFForwarded && forwardedFor != "" { + // X-Forwarded-For is potentially a list of addresses separated with "," + // The left-most being the original client, and each successive proxy that passed the request + // adding the IP address where it received the request from. + // In case if the original IP is a private behind a proxy, we need to get the first public IP from the list + return preferPublicIP(strings.Split(forwardedFor, ",")) + } + + if lookup == OFRealIP && realIP != "" { + return realIP + } + } + + return "" // we can't get real ip +} + +// matchRemoteIP returns true if request's ip matches any of ips in the list of allowedIPs. +// allowedIPs can be defined as IP (like 192.168.1.12) or CIDR (192.168.0.0/16) +func (o *OnlyFrom) matchRemoteIP(remoteIP string, allowedIPs []string) bool { + for _, allowedIP := range allowedIPs { + // check for ip prefix or CIDR + if _, cidrnet, err := net.ParseCIDR(allowedIP); err == nil { + if cidrnet.Contains(net.ParseIP(remoteIP)) { + return true + } + } + // check for ip match + if remoteIP == allowedIP { + return true + } + } + return false +} + +// preferPublicIP returns first public IP from the list of IPs +// if no public IP found, returns first IP from the list +func preferPublicIP(ips []string) string { + for _, ip := range ips { + ip = strings.TrimSpace(ip) + if net.ParseIP(ip).IsGlobalUnicast() && !isPrivateSubnet(net.ParseIP(ip)) { + return ip + } + } + return strings.TrimSpace(ips[0]) +} + +type ipRange struct { + start net.IP + end net.IP +} + +var privateRanges = []ipRange{ + {start: net.ParseIP("10.0.0.0"), end: net.ParseIP("10.255.255.255")}, + {start: net.ParseIP("100.64.0.0"), end: net.ParseIP("100.127.255.255")}, + {start: net.ParseIP("172.16.0.0"), end: net.ParseIP("172.31.255.255")}, + {start: net.ParseIP("192.0.0.0"), end: net.ParseIP("192.0.0.255")}, + {start: net.ParseIP("192.168.0.0"), end: net.ParseIP("192.168.255.255")}, + {start: net.ParseIP("198.18.0.0"), end: net.ParseIP("198.19.255.255")}, + {start: net.ParseIP("::1"), end: net.ParseIP("::1")}, + {start: net.ParseIP("fc00::"), end: net.ParseIP("fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")}, + {start: net.ParseIP("fe80::"), end: net.ParseIP("febf:ffff:ffff:ffff:ffff:ffff:ffff:ffff")}, +} + +// isPrivateSubnet - check to see if this ip is in a private subnet +func isPrivateSubnet(ipAddress net.IP) bool { + inRange := func(r ipRange, ipAddress net.IP) bool { + // ensure the IPs are in the same format for comparison + ipAddress = ipAddress.To16() + r.start = r.start.To16() + r.end = r.end.To16() + return bytes.Compare(ipAddress, r.start) >= 0 && bytes.Compare(ipAddress, r.end) <= 0 + } + for _, r := range privateRanges { + if inRange(r, ipAddress) { + return true + } + } + return false +} diff --git a/app/proxy/only_from_test.go b/app/proxy/only_from_test.go new file mode 100644 index 0000000..37a8156 --- /dev/null +++ b/app/proxy/only_from_test.go @@ -0,0 +1,144 @@ +package proxy + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/umputun/reproxy/app/discovery" +) + +func TestOnlyFrom_Handler(t *testing.T) { + tbl := []struct { + name string + lookups []OFLookup + allowedIPs []string + remoteAddr string + realIP string + forwardedFor string + expectedStatusCode int + }{ + { + name: "allowed IP", + lookups: []OFLookup{OFRemoteAddr}, + allowedIPs: []string{"192.168.1.1"}, + remoteAddr: "192.168.1.1:1234", + expectedStatusCode: http.StatusOK, + }, + { + name: "disallowed IP", + lookups: []OFLookup{OFRemoteAddr}, + allowedIPs: []string{"192.168.1.1"}, + remoteAddr: "192.168.1.2:1234", + expectedStatusCode: http.StatusForbidden, + }, + { + name: "no restrictions", + lookups: []OFLookup{OFRemoteAddr}, + allowedIPs: []string{}, + remoteAddr: "192.168.1.2:1234", + expectedStatusCode: http.StatusOK, + }, + { + name: "allowed IP with RealIP lookup", + lookups: []OFLookup{OFRealIP}, + allowedIPs: []string{"192.168.1.1"}, + realIP: "192.168.1.1", + expectedStatusCode: http.StatusOK, + }, + { + name: "disallowed IP with RealIP lookup", + lookups: []OFLookup{OFRealIP}, + allowedIPs: []string{"192.168.1.1"}, + realIP: "192.168.1.2", + expectedStatusCode: http.StatusForbidden, + }, + { + name: "allowed IP with Forwarded lookup", + lookups: []OFLookup{OFForwarded}, + allowedIPs: []string{"192.168.1.1"}, + forwardedFor: "192.168.1.1", + expectedStatusCode: http.StatusOK, + }, + { + name: "allowed IP with Forwarded lookup, mix private and public IPs", + lookups: []OFLookup{OFForwarded}, + allowedIPs: []string{"8.8.8.8"}, + forwardedFor: "192.168.1.1, 10.0.0.5, 8.8.8.8, 10.10.10.10", + expectedStatusCode: http.StatusOK, + }, + { + name: "disallowed IP with Forwarded lookup", + lookups: []OFLookup{OFForwarded}, + allowedIPs: []string{"192.168.1.1"}, + forwardedFor: "192.168.1.2", + expectedStatusCode: http.StatusForbidden, + }, + { + name: "multiple lookups, allowed IP", + lookups: []OFLookup{OFRemoteAddr, OFRealIP}, + allowedIPs: []string{"192.168.1.1", "192.168.1.2"}, + remoteAddr: "192.168.1.2:1234", + realIP: "192.168.1.1", + expectedStatusCode: http.StatusOK, + }, + { + name: "multiple lookups, disallowed IP", + lookups: []OFLookup{OFRemoteAddr, OFRealIP}, + allowedIPs: []string{"192.168.1.1", "192.168.1.2"}, + remoteAddr: "192.168.1.3:1234", + realIP: "192.168.1.3", + expectedStatusCode: http.StatusForbidden, + }, + { + name: "CIDR block, allowed IP", + lookups: []OFLookup{OFRemoteAddr}, + allowedIPs: []string{"192.168.1.0/24"}, + remoteAddr: "192.168.1.2:1234", + expectedStatusCode: http.StatusOK, + }, + { + name: "CIDR block, disallowed IP", + lookups: []OFLookup{OFRemoteAddr}, + allowedIPs: []string{"192.168.1.0/24"}, + remoteAddr: "192.168.2.2:1234", + expectedStatusCode: http.StatusForbidden, + }, + { + name: "invalid remote address format", + lookups: []OFLookup{OFRemoteAddr}, + allowedIPs: []string{"192.168.1.1"}, + remoteAddr: "invalid_format", + expectedStatusCode: http.StatusForbidden, + }, + { + name: "empty remote address", + lookups: []OFLookup{OFRemoteAddr}, + allowedIPs: []string{"192.168.1.1"}, + remoteAddr: "", + expectedStatusCode: http.StatusForbidden, + }, + } + + for _, tt := range tbl { + t.Run(tt.name, func(t *testing.T) { + onlyFrom := NewOnlyFrom(tt.lookups...) + handler := onlyFrom.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + req := httptest.NewRequest("GET", "http://example.com/foo", http.NoBody) + req.RemoteAddr = tt.remoteAddr + req.Header.Set("X-Real-IP", tt.realIP) + req.Header.Set("X-Forwarded-For", tt.forwardedFor) + req = req.WithContext(context.WithValue(req.Context(), + ctxMatch, discovery.MatchedRoute{Mapper: discovery.URLMapper{OnlyFromIPs: tt.allowedIPs}})) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, tt.expectedStatusCode, rr.Code) + }) + } +} diff --git a/app/proxy/proxy.go b/app/proxy/proxy.go index 281d4b9..24d6ed7 100644 --- a/app/proxy/proxy.go +++ b/app/proxy/proxy.go @@ -28,27 +28,27 @@ import ( // Http is a proxy server for both http and https type Http struct { // nolint golint Matcher - Address string - AssetsLocation string - AssetsWebRoot string - Assets404 string - AssetsSPA bool - MaxBodySize int64 - GzEnabled bool - ProxyHeaders []string - DropHeader []string - SSLConfig SSLConfig - Version string - AccessLog io.Writer - StdOutEnabled bool - Signature bool - Timeouts Timeouts - CacheControl MiddlewareProvider - Metrics MiddlewareProvider - PluginConductor MiddlewareProvider - Reporter Reporter - LBSelector func(len int) int - + Address string + AssetsLocation string + AssetsWebRoot string + Assets404 string + AssetsSPA bool + MaxBodySize int64 + GzEnabled bool + ProxyHeaders []string + DropHeader []string + SSLConfig SSLConfig + Version string + AccessLog io.Writer + StdOutEnabled bool + Signature bool + Timeouts Timeouts + CacheControl MiddlewareProvider + Metrics MiddlewareProvider + PluginConductor MiddlewareProvider + Reporter Reporter + LBSelector func(len int) int + OnlyFrom *OnlyFrom BasicAuthEnabled bool BasicAuthAllowed []string @@ -121,18 +121,19 @@ func (h *Http) Run(ctx context.Context) error { }() handler := R.Wrap(h.proxyHandler(), - R.Recoverer(log.Default()), // recover on errors - signatureHandler(h.Signature, h.Version), // send app signature - h.pingHandler, // respond to /ping + R.Recoverer(log.Default()), // recover on errors + signatureHandler(h.Signature, h.Version), // send app signature + h.OnlyFrom.Handler, // limit source (remote) IPs if defined + h.pingHandler, // respond to /ping basicAuthHandler(h.BasicAuthEnabled, h.BasicAuthAllowed), // basic auth - h.healthMiddleware, // respond to /health - h.matchHandler, // set matched routes to context - limiterSystemHandler(h.ThrottleSystem), // limit total requests/sec - limiterUserHandler(h.ThrottleUser), // req/seq per user/route match - h.mgmtHandler(), // handles /metrics and /routes for prometheus - h.pluginHandler(), // prc to external plugins - headersHandler(h.ProxyHeaders, h.DropHeader), // add response headers and delete some request headers - accessLogHandler(h.AccessLog), // apache-format log file + h.healthMiddleware, // respond to /health + h.matchHandler, // set matched routes to context + limiterSystemHandler(h.ThrottleSystem), // limit total requests/sec + limiterUserHandler(h.ThrottleUser), // req/seq per user/route match + h.mgmtHandler(), // handles /metrics and /routes for prometheus + h.pluginHandler(), // prc to external plugins + headersHandler(h.ProxyHeaders, h.DropHeader), // add response headers and delete some request headers + accessLogHandler(h.AccessLog), // apache-format log file stdoutLogHandler(h.StdOutEnabled, logger.New(logger.Log(log.Default()), logger.Prefix("[INFO]")).Handler), maxReqSizeHandler(h.MaxBodySize), // limit request max size gzipHandler(h.GzEnabled), // gzip response @@ -400,22 +401,22 @@ func (h *Http) makeHTTPServer(addr string, router http.Handler) *http.Server { } func (h *Http) setXRealIP(r *http.Request) { - - remoteIP := r.Header.Get("X-Forwarded-For") - if remoteIP == "" { - remoteIP = r.RemoteAddr - } - - ip, _, err := net.SplitHostPort(remoteIP) - if err != nil { + if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" { + // use the left-most non-private client IP address + // if there is no any non-private IP address, use the left-most address + r.Header.Set("X-Real-IP", preferPublicIP(strings.Split(forwarded, ","))) return } + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return + } userIP := net.ParseIP(ip) if userIP == nil { return } - r.Header.Add("X-Real-IP", ip) + r.Header.Set("X-Real-IP", ip) } // discoveredServers gets the list of servers discovered by providers. diff --git a/app/proxy/proxy_test.go b/app/proxy/proxy_test.go index 0154432..4d916b1 100644 --- a/app/proxy/proxy_test.go +++ b/app/proxy/proxy_test.go @@ -34,6 +34,7 @@ func TestHttp_Do(t *testing.T) { t.Logf("req: %v", r) w.Header().Add("h1", "v1") require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP")) + require.Equal(t, "127.0.0.1", r.Header.Get("X-Forwarded-For")) fmt.Fprintf(w, "response %s", r.URL.String()) })) @@ -59,7 +60,7 @@ func TestHttp_Do(t *testing.T) { client := http.Client{} - { + t.Run("to 127.0.0.1, good", func(t *testing.T) { req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody) require.NoError(t, err) resp, err := client.Do(req) @@ -75,9 +76,9 @@ func TestHttp_Do(t *testing.T) { assert.Equal(t, "v1", resp.Header.Get("h1")) assert.Equal(t, "vv1", resp.Header.Get("hh1")) assert.Equal(t, "vv2", resp.Header.Get("hh2")) - } + }) - { + t.Run("to localhost, good", func(t *testing.T) { resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/api/something") require.NoError(t, err) defer resp.Body.Close() @@ -89,9 +90,9 @@ func TestHttp_Do(t *testing.T) { assert.Equal(t, "response /123/something", string(body)) assert.Equal(t, "reproxy", resp.Header.Get("App-Name")) assert.Equal(t, "v1", resp.Header.Get("h1")) - } + }) - { + t.Run("bad gateway", func(t *testing.T) { resp, err := client.Get("http://127.0.0.1:" + strconv.Itoa(port) + "/bad/something") require.NoError(t, err) defer resp.Body.Close() @@ -100,9 +101,9 @@ func TestHttp_Do(t *testing.T) { require.NoError(t, err) assert.Contains(t, string(b), "Sorry for the inconvenience") assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type")) - } + }) - { + t.Run("url encode", func(t *testing.T) { resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/api/test%20%25%20and%20&,%20and%20other%20characters%20@%28%29%5E%21") require.NoError(t, err) defer resp.Body.Close() @@ -114,7 +115,7 @@ func TestHttp_Do(t *testing.T) { assert.Equal(t, "response /123/test%20%25%20and%20&,%20and%20other%20characters%20@%28%29%5E%21", string(body)) assert.Equal(t, "reproxy", resp.Header.Get("App-Name")) assert.Equal(t, "v1", resp.Header.Get("h1")) - } + }) } func TestHttp_DoWithAssets(t *testing.T) { @@ -153,7 +154,7 @@ func TestHttp_DoWithAssets(t *testing.T) { client := http.Client{} - { + t.Run("api call", func(t *testing.T) { req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody) require.NoError(t, err) resp, err := client.Do(req) @@ -167,9 +168,9 @@ func TestHttp_DoWithAssets(t *testing.T) { assert.Equal(t, "response /567/something", string(body)) assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "v1", resp.Header.Get("h1")) - } + }) - { + t.Run("static call, good", func(t *testing.T) { resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html") require.NoError(t, err) defer resp.Body.Close() @@ -182,9 +183,9 @@ func TestHttp_DoWithAssets(t *testing.T) { assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "", resp.Header.Get("h1")) assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control")) - } + }) - { + t.Run("static call, bad", func(t *testing.T) { resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html") require.NoError(t, err) defer resp.Body.Close() @@ -192,9 +193,9 @@ func TestHttp_DoWithAssets(t *testing.T) { body, err := io.ReadAll(resp.Body) require.NoError(t, err) assert.Equal(t, "404 page not found\n", string(body)) - } + }) - { + t.Run("bad url", func(t *testing.T) { resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/svcbad") require.NoError(t, err) defer resp.Body.Close() @@ -203,7 +204,7 @@ func TestHttp_DoWithAssets(t *testing.T) { require.NoError(t, err) assert.Contains(t, string(body), "Server error") assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type")) - } + }) } func TestHttp_DoWithAssetsCustom404(t *testing.T) { @@ -243,7 +244,7 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) { client := http.Client{} - { + t.Run("api call, found", func(t *testing.T) { req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody) require.NoError(t, err) resp, err := client.Do(req) @@ -257,9 +258,9 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) { assert.Equal(t, "response /567/something", string(body)) assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "v1", resp.Header.Get("h1")) - } + }) - { + t.Run("static call, found", func(t *testing.T) { resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html") require.NoError(t, err) defer resp.Body.Close() @@ -272,9 +273,9 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) { assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "", resp.Header.Get("h1")) assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control")) - } + }) - { + t.Run("static call, not found", func(t *testing.T) { resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html") require.NoError(t, err) defer resp.Body.Close() @@ -284,9 +285,9 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) { assert.Equal(t, "not found! blah blah blah\nthere is no spoon", string(body)) t.Logf("%+v", resp.Header) assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type")) - } + }) - { + t.Run("another static call, not found", func(t *testing.T) { resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad2.html") require.NoError(t, err) defer resp.Body.Close() @@ -296,7 +297,7 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) { assert.Equal(t, "not found! blah blah blah\nthere is no spoon", string(body)) t.Logf("%+v", resp.Header) assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type")) - } + }) } func TestHttp_DoWithSpaAssets(t *testing.T) { @@ -336,7 +337,7 @@ func TestHttp_DoWithSpaAssets(t *testing.T) { client := http.Client{} - { + t.Run("api call, good", func(t *testing.T) { req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody) require.NoError(t, err) resp, err := client.Do(req) @@ -350,9 +351,9 @@ func TestHttp_DoWithSpaAssets(t *testing.T) { assert.Equal(t, "response /567/something", string(body)) assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "v1", resp.Header.Get("h1")) - } + }) - { + t.Run("static call, good", func(t *testing.T) { resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html") require.NoError(t, err) defer resp.Body.Close() @@ -365,9 +366,9 @@ func TestHttp_DoWithSpaAssets(t *testing.T) { assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "", resp.Header.Get("h1")) assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control")) - } + }) - { + t.Run("static call, not found server index", func(t *testing.T) { resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html") require.NoError(t, err) defer resp.Body.Close() @@ -380,9 +381,9 @@ func TestHttp_DoWithSpaAssets(t *testing.T) { assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "", resp.Header.Get("h1")) assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control")) - } + }) - { + t.Run("static call, bad url", func(t *testing.T) { resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/svcbad") require.NoError(t, err) defer resp.Body.Close() @@ -391,7 +392,7 @@ func TestHttp_DoWithSpaAssets(t *testing.T) { require.NoError(t, err) assert.Contains(t, string(body), "Server error") assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type")) - } + }) } func TestHttp_DoWithAssetRules(t *testing.T) { @@ -715,16 +716,16 @@ func TestHttp_withBasicAuth(t *testing.T) { client := http.Client{} - { + t.Run("no auth", func(t *testing.T) { req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg")) require.NoError(t, err) resp, err := client.Do(req) require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - } + }) - { + t.Run("bad auth", func(t *testing.T) { req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg")) req.SetBasicAuth("test", "badpasswd") require.NoError(t, err) @@ -732,8 +733,9 @@ func TestHttp_withBasicAuth(t *testing.T) { require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - } - { + }) + + t.Run("good auth", func(t *testing.T) { req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg")) req.SetBasicAuth("test", "passwd") require.NoError(t, err) @@ -741,8 +743,9 @@ func TestHttp_withBasicAuth(t *testing.T) { require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) - } - { + }) + + t.Run("good auth 2", func(t *testing.T) { req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg")) req.SetBasicAuth("test2", "passwd2") require.NoError(t, err) @@ -750,7 +753,7 @@ func TestHttp_withBasicAuth(t *testing.T) { require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) - } + }) } func TestHttp_toHttp(t *testing.T) { @@ -766,9 +769,9 @@ func TestHttp_toHttp(t *testing.T) { } h := Http{} - for i, tt := range tbl { + for _, tt := range tbl { tt := tt - t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Run(tt.addr, func(t *testing.T) { assert.Equal(t, tt.res, h.toHTTP(tt.addr, tt.port)) }) } @@ -791,8 +794,8 @@ func TestHttp_isAssetRequest(t *testing.T) { {"/static/", "/tmp", "", false}, } - for i, tt := range tbl { - t.Run(strconv.Itoa(i), func(t *testing.T) { + for _, tt := range tbl { + t.Run(tt.req, func(t *testing.T) { h := Http{AssetsLocation: tt.assetsLocation, AssetsWebRoot: tt.assetsWebRoot} r, err := http.NewRequest("GET", tt.req, http.NoBody) require.NoError(t, err) @@ -803,56 +806,61 @@ func TestHttp_isAssetRequest(t *testing.T) { } func TestHttp_matchHandler(t *testing.T) { - tbl := []struct { + name string matches discovery.Matches res string ok bool }{ - { - discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ + name: "all alive destinations", + matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ {Destination: "dest1", Alive: true}, {Destination: "dest2", Alive: true}, {Destination: "dest3", Alive: true}, }}, - "dest1", true, + res: "dest1", ok: true, }, { - discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ + name: "second alive destination", + matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ {Destination: "dest1", Alive: false}, {Destination: "dest2", Alive: true}, {Destination: "dest3", Alive: false}, }}, - "dest2", true, + res: "dest2", ok: true, }, { - discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ + name: "one dead destination", + matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ {Destination: "dest1", Alive: false}, {Destination: "dest2", Alive: true}, {Destination: "dest3", Alive: true}, }}, - "dest2", true, + res: "dest2", ok: true, }, { - discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ + name: "last alive destination", + matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ {Destination: "dest1", Alive: false}, {Destination: "dest2", Alive: false}, {Destination: "dest3", Alive: true}, }}, - "dest3", true, + res: "dest3", ok: true, }, { - discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ + name: "all dead destinations", + matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ {Destination: "dest1", Alive: false}, {Destination: "dest2", Alive: false}, {Destination: "dest3", Alive: false}, }}, - "", false, + res: "", ok: false, }, { - discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{}}, "", false, + name: "no destinations", + matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{}}, res: "", ok: false, }, } @@ -864,9 +872,8 @@ func TestHttp_matchHandler(t *testing.T) { } client := http.Client{} - for i, tt := range tbl { - t.Run(strconv.Itoa(i), func(t *testing.T) { - + for _, tt := range tbl { + t.Run(tt.name, func(t *testing.T) { h := Http{Matcher: matcherMock, LBSelector: func(len int) int { return 0 }} handler := h.matchHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Logf("req: %+v", r) @@ -893,7 +900,6 @@ func TestHttp_matchHandler(t *testing.T) { } func TestHttp_discoveredServers(t *testing.T) { - calls := 0 m := &MatcherMock{ServersFunc: func() []string { defer func() { calls++ }()