diff --git a/app/discovery/discovery.go b/app/discovery/discovery.go index 593300b..945549e 100644 --- a/app/discovery/discovery.go +++ b/app/discovery/discovery.go @@ -1,6 +1,6 @@ // Package discovery provides a common interface for all providers and Match to // transform source to destination URL. -// Do func starts event loop checking all providers and retrieving lists of rules. +// Run func starts event loop checking all providers and retrieving lists of rules. // All lists combined into a merged one. package discovery @@ -47,9 +47,9 @@ func NewService(providers []Provider) *Service { return &Service{providers: providers} } -// Do runs blocking loop getting events from all providers +// Run runs blocking loop getting events from all providers // and updating mappers on each event -func (s *Service) Do(ctx context.Context) error { +func (s *Service) Run(ctx context.Context) error { var evChs []<-chan struct{} for _, p := range s.providers { diff --git a/app/discovery/discovery_test.go b/app/discovery/discovery_test.go index a8ed5dd..15f4f41 100644 --- a/app/discovery/discovery_test.go +++ b/app/discovery/discovery_test.go @@ -46,7 +46,7 @@ func TestService_Do(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() - err := svc.Do(ctx) + err := svc.Run(ctx) require.Error(t, err) assert.Equal(t, context.DeadlineExceeded, err) assert.Equal(t, 3, len(svc.mappers)) @@ -100,7 +100,7 @@ func TestService_Match(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() - err := svc.Do(ctx) + err := svc.Run(ctx) require.Error(t, err) assert.Equal(t, context.DeadlineExceeded, err) assert.Equal(t, 3, len(svc.mappers)) diff --git a/app/main.go b/app/main.go index 42ba2ba..4bc03ba 100644 --- a/app/main.go +++ b/app/main.go @@ -3,7 +3,6 @@ package main import ( "context" "fmt" - "log" "os" "os/signal" "runtime" @@ -11,7 +10,7 @@ import ( "time" docker "github.com/fsouza/go-dockerclient" - "github.com/go-pkgz/lgr" + log "github.com/go-pkgz/lgr" "github.com/pkg/errors" "github.com/umputun/go-flags" @@ -27,6 +26,14 @@ var opts struct { GzipEnabled bool `short:"g" long:"gzip" env:"GZIP" description:"enable gz compression"` ProxyHeaders []string `short:"x" long:"header" env:"HEADER" description:"proxy headers"` + SSL struct { + Type string `long:"type" env:"TYPE" description:"ssl (auto) support" choice:"none" choice:"static" choice:"auto" default:"none"` //nolint + Cert string `long:"cert" env:"CERT" description:"path to cert.pem file"` + Key string `long:"key" env:"KEY" description:"path to key.pem file"` + ACMELocation string `long:"acme-location" env:"ACME_LOCATION" description:"dir where certificates will be stored by autocert manager" default:"./var/acme"` + ACMEEmail string `long:"acme-email" env:"ACME_EMAIL" description:"admin email for certificate notifications"` + } `group:"ssl" namespace:"ssl" env-namespace:"SSL"` + Assets struct { Location string `short:"a" long:"location" env:"LOCATION" default:"" description:"assets location"` WebRoot string `long:"root" env:"ROOT" default:"/" description:"assets web root"` @@ -54,6 +61,16 @@ var opts struct { Dbg bool `long:"dbg" env:"DEBUG" description:"debug mode"` } +// SSLGroup defines options group for server ssl params +type SSLGroup struct { + Type string `long:"type" env:"TYPE" description:"ssl (auto) support" choice:"none" choice:"static" choice:"auto" default:"none"` //nolint + Port int `long:"port" env:"PORT" description:"port number for https server" default:"8443"` + Cert string `long:"cert" env:"CERT" description:"path to cert.pem file"` + Key string `long:"key" env:"KEY" description:"path to key.pem file"` + ACMELocation string `long:"acme-location" env:"ACME_LOCATION" description:"dir where certificates will be stored by autocert manager" default:"./var/acme"` + ACMEEmail string `long:"acme-email" env:"ACME_EMAIL" description:"admin email for certificate notifications"` +} + var revision = "unknown" func main() { @@ -84,11 +101,16 @@ func main() { svc := discovery.NewService(providers) go func() { - if err := svc.Do(context.Background()); err != nil { + if err := svc.Run(context.Background()); err != nil { log.Fatalf("[ERROR] discovery failed, %v", err) } }() + sslConfig, err := makeSSLConfig() + if err != nil { + log.Fatalf("[ERROR] failed to make config of ssl server params, %v", err) + } + px := &proxy.Http{ Version: revision, Matcher: svc, @@ -98,8 +120,9 @@ func main() { AssetsLocation: opts.Assets.Location, AssetsWebRoot: opts.Assets.WebRoot, GzEnabled: opts.GzipEnabled, + SSLConfig: sslConfig, } - if err := px.Do(context.Background()); err != nil { + if err := px.Run(context.Background()); err != nil { log.Fatalf("[ERROR] proxy server failed, %v", err) } } @@ -133,13 +156,34 @@ func makeProviders() ([]discovery.Provider, error) { return res, nil } -func setupLog(dbg bool) { - - logOpts := []lgr.Option{lgr.Msec, lgr.LevelBraces, lgr.StackTraceOnError} - if dbg { - logOpts = []lgr.Option{lgr.Debug, lgr.CallerFile, lgr.CallerFunc, lgr.Msec, lgr.LevelBraces, lgr.StackTraceOnError} +func makeSSLConfig() (config proxy.SSLConfig, err error) { + switch opts.SSL.Type { + case "none": + config.SSLMode = proxy.SSLNone + case "static": + if opts.SSL.Cert == "" { + return config, errors.New("path to cert.pem is required") + } + if opts.SSL.Key == "" { + return config, errors.New("path to key.pem is required") + } + config.SSLMode = proxy.SSLStatic + config.Cert = opts.SSL.Cert + config.Key = opts.SSL.Key + case "auto": + config.SSLMode = proxy.SSLAuto + config.ACMELocation = opts.SSL.ACMELocation + config.ACMEEmail = opts.SSL.ACMEEmail } - lgr.SetupStdLogger(logOpts...) + return config, err +} + +func setupLog(dbg bool) { + if dbg { + log.Setup(log.Debug, log.CallerFile, log.CallerFunc, log.Msec, log.LevelBraces) + return + } + log.Setup(log.Msec, log.LevelBraces) } func catchSignal() { diff --git a/app/proxy/middleware/gzip_test.go b/app/proxy/middleware/gzip_test.go deleted file mode 100644 index 6b5c68b..0000000 --- a/app/proxy/middleware/gzip_test.go +++ /dev/null @@ -1,62 +0,0 @@ -package middleware - -import ( - "bytes" - "compress/gzip" - "io/ioutil" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestGzip(t *testing.T) { - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := w.Write([]byte("Lorem Ipsum is simply dummy text of the printing and typesetting industry. " + - "Lorem Ipsum has been the industry’s standard dummy text ever since the 1500s, when an unknown printer took " + - "a galley of type and scrambled it to make a type specimen book. It has survived not only five centuries," + - " but also the leap into electronic typesetting, remaining essentially unchanged. It was popularised" + - " in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, " + - "and more recently with desktop publishing software like Aldus PageMaker including versions of Lorem Ipsum.")) - require.NoError(t, err) - }) - ts := httptest.NewServer(Gzip(handler)) - defer ts.Close() - - client := http.Client{} - - { - req, err := http.NewRequest("GET", ts.URL+"/something", nil) - require.NoError(t, err) - req.Header.Set("Accept-Encoding", "gzip") - resp, err := client.Do(req) - require.NoError(t, err) - assert.Equal(t, 200, resp.StatusCode) - defer resp.Body.Close() - b, err := ioutil.ReadAll(resp.Body) - assert.NoError(t, err) - assert.Equal(t, 355, len(b), "compressed size") - - gzr, err := gzip.NewReader(bytes.NewBuffer(b)) - require.NoError(t, err) - b, err = ioutil.ReadAll(gzr) - require.NoError(t, err) - assert.True(t, strings.HasPrefix(string(b), "Lorem Ipsum"), string(b)) - } - { - req, err := http.NewRequest("GET", ts.URL+"/something", nil) - require.NoError(t, err) - resp, err := client.Do(req) - require.Nil(t, err) - assert.Equal(t, 200, resp.StatusCode) - defer resp.Body.Close() - b, err := ioutil.ReadAll(resp.Body) - assert.NoError(t, err) - assert.Equal(t, 576, len(b), "uncompressed size") - - } - -} diff --git a/app/proxy/middleware/headers.go b/app/proxy/middleware/headers.go deleted file mode 100644 index f8c8218..0000000 --- a/app/proxy/middleware/headers.go +++ /dev/null @@ -1,25 +0,0 @@ -package middleware - -import ( - "net/http" - "strings" -) - -// Headers middleware adds headers to request -func Headers(headers ...string) func(http.Handler) http.Handler { - - return func(h http.Handler) http.Handler { - - fn := func(w http.ResponseWriter, r *http.Request) { - for _, h := range headers { - elems := strings.Split(h, ":") - if len(elems) != 2 { - continue - } - r.Header.Set(strings.TrimSpace(elems[0]), strings.TrimSpace(elems[1])) - } - h.ServeHTTP(w, r) - } - return http.HandlerFunc(fn) - } -} diff --git a/app/proxy/middleware/headers_test.go b/app/proxy/middleware/headers_test.go deleted file mode 100644 index 140efd3..0000000 --- a/app/proxy/middleware/headers_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package middleware - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestHeaders(t *testing.T) { - req := httptest.NewRequest("GET", "/something", nil) - w := httptest.NewRecorder() - - h := Headers("h1:v1", "bad", "h2:v2")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - h.ServeHTTP(w, req) - resp := w.Result() - assert.Equal(t, http.StatusOK, resp.StatusCode) - t.Logf("%+v", req.Header) - assert.Equal(t, "v1", req.Header.Get("h1")) - assert.Equal(t, "v2", req.Header.Get("h2")) - assert.Equal(t, 2, len(req.Header)) -} diff --git a/app/proxy/proxy.go b/app/proxy/proxy.go index 49a7ec1..cfc4e66 100644 --- a/app/proxy/proxy.go +++ b/app/proxy/proxy.go @@ -11,9 +11,9 @@ import ( "github.com/go-pkgz/lgr" log "github.com/go-pkgz/lgr" "github.com/go-pkgz/rest" + R "github.com/go-pkgz/rest" "github.com/go-pkgz/rest/logger" - - "github.com/umputun/docker-proxy/app/proxy/middleware" + "github.com/pkg/errors" ) type Http struct { @@ -33,40 +33,12 @@ type Matcher interface { Match(srv, src string) (string, bool) } -func (h *Http) Do(ctx context.Context) error { - log.Printf("[INFO] run proxy on %s", h.Address) - if h.AssetsLocation != "" { - log.Printf("[DEBUG] assets file server enabled for %s", h.AssetsLocation) - } - - httpServer := &http.Server{ - Addr: h.Address, - Handler: h.wrap(h.proxyHandler(), - rest.Recoverer(lgr.Default()), - rest.AppInfo("dpx", "umputun", h.Version), - rest.Ping, - logger.New(logger.Prefix("[DEBUG] PROXY")).Handler, - rest.SizeLimit(h.MaxBodySize), - middleware.Headers(h.ProxyHeaders...), - h.gzipHandler(), - ), - ReadHeaderTimeout: 5 * time.Second, - WriteTimeout: 120 * time.Second, - IdleTimeout: 30 * time.Second, - } - - go func() { - <-ctx.Done() - if err := httpServer.Close(); err != nil { - log.Printf("[ERROR] failed to close proxy server, %v", err) - } - }() - - return httpServer.ListenAndServe() -} - // Run the lister and request's router, activate rest server -func (h *Http) Run(ctx context.Context) { +func (h *Http) Run(ctx context.Context) error { + + if h.AssetsLocation != "" { + log.Printf("[DEBUG] assets file server enabled for %s, webroot %s", h.AssetsLocation, h.AssetsWebRoot) + } var httpServer, httpsServer *http.Server @@ -84,24 +56,23 @@ func (h *Http) Run(ctx context.Context) { } }() - handler := h.wrap(h.proxyHandler(), - rest.Recoverer(lgr.Default()), - rest.AppInfo("dpx", "umputun", h.Version), - rest.Ping, + handler := R.Wrap(h.proxyHandler(), + R.Recoverer(lgr.Default()), + R.AppInfo("dpx", "umputun", h.Version), + R.Ping, logger.New(logger.Prefix("[DEBUG] PROXY")).Handler, - rest.SizeLimit(h.MaxBodySize), - middleware.Headers(h.ProxyHeaders...), + R.SizeLimit(h.MaxBodySize), + R.Headers(h.ProxyHeaders...), h.gzipHandler(), ) switch h.SSLConfig.SSLMode { - case None: + case SSLNone: log.Printf("[INFO] activate http proxy server on %s", h.Address) httpServer = h.makeHTTPServer(h.Address, handler) httpServer.ErrorLog = log.ToStdLogger(log.Default(), "WARN") - err := httpServer.ListenAndServe() - log.Printf("[WARN] http server terminated, %s", err) - case Static: + return httpServer.ListenAndServe() + case SSLStatic: log.Printf("[INFO] activate https server in 'static' mode on %s", h.Address) httpsServer = h.makeHTTPSServer(h.Address, handler) @@ -115,9 +86,8 @@ func (h *Http) Run(ctx context.Context) { err := httpServer.ListenAndServe() log.Printf("[WARN] http redirect server terminated, %s", err) }() - err := httpServer.ListenAndServeTLS(h.SSLConfig.Cert, h.SSLConfig.Key) - log.Printf("[WARN] https server terminated, %s", err) - case Auto: + return httpServer.ListenAndServeTLS(h.SSLConfig.Cert, h.SSLConfig.Key) + case SSLAuto: log.Printf("[INFO] activate https server in 'auto' mode on %s", h.Address) m := h.makeAutocertManager() @@ -133,9 +103,9 @@ func (h *Http) Run(ctx context.Context) { log.Printf("[WARN] http challenge server terminated, %s", err) }() - err := httpsServer.ListenAndServeTLS("", "") - log.Printf("[WARN] https server terminated, %s", err) + return httpsServer.ListenAndServeTLS("", "") } + return errors.Errorf("unknown SSL type %v", h.SSLConfig.SSLMode) } func (h *Http) toHttp(address string) string { @@ -143,24 +113,15 @@ func (h *Http) toHttp(address string) string { } func (h *Http) gzipHandler() func(next http.Handler) http.Handler { - gzHandler := func(next http.Handler) http.Handler { + if h.GzEnabled { + return R.Gzip + } + + return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next.ServeHTTP(w, r) }) } - if h.GzEnabled { - gzHandler = middleware.Gzip - } - return gzHandler -} - -// wrap convert a list of middlewares to nested calls, in reversed order -func (h *Http) wrap(p http.Handler, mws ...func(http.Handler) http.Handler) http.Handler { - res := p - for i := len(mws) - 1; i >= 0; i-- { - res = mws[i](res) - } - return res } func (h *Http) proxyHandler() http.HandlerFunc { diff --git a/app/proxy/proxy_test.go b/app/proxy/proxy_test.go index 227f5d6..bcc9fbd 100644 --- a/app/proxy/proxy_test.go +++ b/app/proxy/proxy_test.go @@ -37,12 +37,12 @@ func TestHttp_Do(t *testing.T) { }}) go func() { - svc.Do(context.Background()) + svc.Run(context.Background()) }() h.Matcher = svc go func() { - h.Do(ctx) + h.Run(ctx) }() time.Sleep(10 * time.Millisecond) diff --git a/app/proxy/ssl.go b/app/proxy/ssl.go index ebe8153..8315c07 100644 --- a/app/proxy/ssl.go +++ b/app/proxy/ssl.go @@ -16,14 +16,14 @@ import ( type sslMode int8 const ( - // None defines to run http server only - None sslMode = iota + // SSLNone defines to run http server only + SSLNone sslMode = iota - // Static defines to run both https and http server. Redirect http to https - Static + // SSLStatic defines to run both https and http server. Redirect http to https + SSLStatic - // Auto defines to run both https and http server. Redirect http to https. Https server with autocert support - Auto + // SSLAuto defines to run both https and http server. Redirect http to https. Https server with autocert support + SSLAuto ) // SSLConfig holds all ssl params for rest server @@ -31,7 +31,6 @@ type SSLConfig struct { SSLMode sslMode Cert string Key string - Port int ACMELocation string ACMEEmail string FQDNs []string @@ -41,7 +40,7 @@ type SSLConfig struct { // with default middlewares. Used in 'static' ssl mode. func (h *Http) httpToHTTPSRouter() http.Handler { log.Printf("[DEBUG] create https-to-http redirect routes") - return h.wrap(h.redirectHandler(), R.Recoverer(log.Default())) + return R.Wrap(h.redirectHandler(), R.Recoverer(log.Default())) } // httpChallengeRouter creates new router which performs ACME "http-01" challenge response @@ -50,7 +49,7 @@ func (h *Http) httpToHTTPSRouter() http.Handler { // Used in 'auto' ssl mode. func (h *Http) httpChallengeRouter(m *autocert.Manager) http.Handler { log.Printf("[DEBUG] create http-challenge routes") - return h.wrap(m.HTTPHandler(h.redirectHandler()), R.Recoverer(log.Default())) + return R.Wrap(m.HTTPHandler(h.redirectHandler()), R.Recoverer(log.Default())) } func (h *Http) redirectHandler() http.Handler {