implement simple on/off basic-auth for all resources

lint: err shadowing

extract htpasswd file load and add tests
This commit is contained in:
Umputun
2021-11-07 15:31:33 -06:00
parent 184d5ba87c
commit 8c59be3612
13 changed files with 1068 additions and 22 deletions

View File

@@ -30,11 +30,12 @@ import (
)
var opts struct {
Listen string `short:"l" long:"listen" env:"LISTEN" description:"listen on host:port (default: 0.0.0.0:8080/8443 under docker, 127.0.0.1:80/443 without)"`
MaxSize string `short:"m" long:"max" env:"MAX_SIZE" default:"64K" description:"max request size"`
GzipEnabled bool `short:"g" long:"gzip" env:"GZIP" description:"enable gz compression"`
ProxyHeaders []string `short:"x" long:"header" description:"outgoing proxy headers to add"` // env HEADER split in code to allow , inside ""
DropHeaders []string `long:"drop-header" env:"DROP_HEADERS" description:"incoming headers to drop" env-delim:","`
Listen string `short:"l" long:"listen" env:"LISTEN" description:"listen on host:port (default: 0.0.0.0:8080/8443 under docker, 127.0.0.1:80/443 without)"`
MaxSize string `short:"m" long:"max" env:"MAX_SIZE" default:"64K" description:"max request size"`
GzipEnabled bool `short:"g" long:"gzip" env:"GZIP" description:"enable gz compression"`
ProxyHeaders []string `short:"x" long:"header" description:"outgoing proxy headers to add"` // env HEADER split in code to allow , inside ""
DropHeaders []string `long:"drop-header" env:"DROP_HEADERS" description:"incoming headers to drop" env-delim:","`
AuthBasicHtpasswd string `long:"basic-htpasswd" env:"BASIC_HTPASSWD" description:"htpasswd file for basic auth"`
LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" default:"random"` //nolint
@@ -228,6 +229,11 @@ func run() error {
proxyHeaders = splitAtCommas(os.Getenv("HEADER")) // env value may have comma inside "", parsed separately
}
basicAuthAllowed, baErr := makeBasicAuth(opts.AuthBasicHtpasswd)
if baErr != nil {
return fmt.Errorf("failed to load basic auth: %w", baErr)
}
px := &proxy.Http{
Version: revision,
Matcher: svc,
@@ -256,21 +262,41 @@ func run() error {
ExpectContinue: opts.Timeouts.ExpectContinue,
ResponseHeader: opts.Timeouts.ResponseHeader,
},
Metrics: makeMetrics(ctx, svc),
Reporter: errReporter,
PluginConductor: makePluginConductor(ctx),
ThrottleSystem: opts.Throttle.System * 3,
ThrottleUser: opts.Throttle.User,
Metrics: makeMetrics(ctx, svc),
Reporter: errReporter,
PluginConductor: makePluginConductor(ctx),
ThrottleSystem: opts.Throttle.System * 3,
ThrottleUser: opts.Throttle.User,
BasicAuthEnabled: len(basicAuthAllowed) > 0,
BasicAuthAllowed: basicAuthAllowed,
}
err = px.Run(ctx)
if err != nil && err == http.ErrServerClosed {
log.Printf("[WARN] proxy server closed, %v", err) //nolint gocritic
log.Printf("[WARN] proxy server closed, %v", err) // nolint gocritic
return nil
}
return err
}
// makeBasicAuth returns a list of allowed basic auth users and password hashes.
// if no htpasswd file is specified, an empty list is returned.
func makeBasicAuth(htpasswdFile string) ([]string, error) {
var basicAuthAllowed []string
if htpasswdFile != "" {
data, err := ioutil.ReadFile(htpasswdFile)
if err != nil {
return nil, fmt.Errorf("failed to read htpasswd file %s: %w", htpasswdFile, err)
}
basicAuthAllowed = strings.Split(string(data), "\n")
for i, v := range basicAuthAllowed {
basicAuthAllowed[i] = strings.TrimSpace(v)
basicAuthAllowed[i] = strings.Replace(basicAuthAllowed[i], "\t", "", -1)
}
}
return basicAuthAllowed, nil
}
// make all providers. the order is matter, defines which provider will have priority in case of conflicting rules
// static first, file second and docker the last one
func makeProviders() ([]discovery.Provider, error) {

View File

@@ -376,3 +376,21 @@ func Test_splitAtCommas(t *testing.T) {
}
}
func Test_makeBasicAuth(t *testing.T) {
pf := `test:$2y$05$zMxDmK65SjcH2vJQNopVSO/nE8ngVLx65RoETyHpez7yTS/8CLEiW
test2:$2y$05$TLQqHh6VT4JxysdKGPOlJeSkkMsv.Ku/G45i7ssIm80XuouCrES12
bad bad`
fh, err := os.CreateTemp(os.TempDir(), "reproxy_auth_*")
require.NoError(t, err)
defer fh.Close()
n, err := fh.WriteString(pf)
require.Equal(t, len(pf), n)
res, err := makeBasicAuth(fh.Name())
require.NoError(t, err)
assert.Equal(t, 3, len(res))
assert.Equal(t, []string{"test:$2y$05$zMxDmK65SjcH2vJQNopVSO/nE8ngVLx65RoETyHpez7yTS/8CLEiW", "test2:$2y$05$TLQqHh6VT4JxysdKGPOlJeSkkMsv.Ku/G45i7ssIm80XuouCrES12", "bad bad"}, res)
}

View File

@@ -1,6 +1,8 @@
package proxy
import (
"crypto/sha256"
"crypto/subtle"
"io"
"net/http"
"strings"
@@ -10,6 +12,7 @@ import (
log "github.com/go-pkgz/lgr"
R "github.com/go-pkgz/rest"
"github.com/gorilla/handlers"
"golang.org/x/crypto/bcrypt"
"github.com/umputun/reproxy/app/discovery"
)
@@ -156,6 +159,57 @@ func limiterUserHandler(reqSec int) func(next http.Handler) http.Handler {
}
}
// basicAuthHandler is a middleware that authenticates via basic auth, if enabled
// allowed is a list of user:bcrypt(passwd) strings generated by `htpasswd -nbB user passwd`
func basicAuthHandler(enabled bool, allowed []string) func(next http.Handler) http.Handler {
if !enabled {
return passThroughHandler
}
unauthorized := func(w http.ResponseWriter) {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
w.WriteHeader(http.StatusUnauthorized)
}
return func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if !ok {
unauthorized(w)
return
}
passed := false
for _, a := range allowed {
alwElems := strings.Split(strings.TrimSpace(a), ":")
if len(alwElems) != 2 {
continue
}
// hash to ensure constant time comparison not affected by username length
usernameHash := sha256.Sum256([]byte(username))
expectedUsernameHash := sha256.Sum256([]byte(alwElems[0]))
expectedPasswordHash := alwElems[1]
userMatched := subtle.ConstantTimeCompare(usernameHash[:], expectedUsernameHash[:])
passMatchErr := bcrypt.CompareHashAndPassword([]byte(expectedPasswordHash), []byte(password))
if userMatched == 1 && passMatchErr == nil {
passed = true // don't stop here, check all allowed to keep the overall time consistent
}
}
if !passed {
unauthorized(w)
return
}
h.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}
func passThroughHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)

View File

@@ -186,3 +186,63 @@ func Test_limiterClientHandlerWithMatches(t *testing.T) {
wg.Wait()
assert.Equal(t, int32(20), atomic.LoadInt32(&passed))
}
func TestHttp_basicAuthHandler(t *testing.T) {
allowed := []string{
"test:$2y$05$zMxDmK65SjcH2vJQNopVSO/nE8ngVLx65RoETyHpez7yTS/8CLEiW",
"test2:$2y$05$TLQqHh6VT4JxysdKGPOlJeSkkMsv.Ku/G45i7ssIm80XuouCrES12 ",
"bad bad",
}
handler := basicAuthHandler(true, allowed)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("req: %v", r)
}))
ts := httptest.NewServer(handler)
client := http.Client{}
tbl := []struct {
reqFn func(r *http.Request)
ok bool
}{
{func(r *http.Request) {}, false},
{func(r *http.Request) { r.SetBasicAuth("test", "passwd") }, true},
{func(r *http.Request) { r.SetBasicAuth("test", "passwdbad") }, false},
{func(r *http.Request) { r.SetBasicAuth("test2", "passwd2") }, true},
{func(r *http.Request) { r.SetBasicAuth("test2", "passwbad") }, false},
{func(r *http.Request) { r.SetBasicAuth("testbad", "passwbad") }, false},
}
for i, tt := range tbl {
t.Run(strconv.Itoa(i), func(t *testing.T) {
req, err := http.NewRequest("GET", ts.URL, nil)
require.NoError(t, err)
tt.reqFn(req)
resp, err := client.Do(req)
require.NoError(t, err)
resp.Body.Close()
if tt.ok {
require.Equal(t, http.StatusOK, resp.StatusCode)
return
}
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
})
}
handler = basicAuthHandler(false, allowed)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("req: %v", r)
}))
ts2 := httptest.NewServer(handler)
for i, tt := range tbl {
t.Run(strconv.Itoa(i), func(t *testing.T) {
req, err := http.NewRequest("GET", ts2.URL, nil)
require.NoError(t, err)
tt.reqFn(req)
resp, err := client.Do(req)
require.NoError(t, err)
resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
})
}
}

View File

@@ -45,6 +45,9 @@ type Http struct { // nolint golint
Reporter Reporter
LBSelector func(len int) int
BasicAuthEnabled bool
BasicAuthAllowed []string
ThrottleSystem int
ThrottleUser int
}
@@ -111,17 +114,18 @@ func (h *Http) Run(ctx context.Context) error {
}()
handler := R.Wrap(h.proxyHandler(),
R.Recoverer(log.Default()), // recover on errors
signatureHandler(h.Signature, h.Version), // send app signature
h.pingHandler, // respond to /ping
h.healthMiddleware, // respond to /health
h.matchHandler, // set matched routes to context
limiterSystemHandler(h.ThrottleSystem), // limit total requests/sec
limiterUserHandler(h.ThrottleUser), // req/seq per user/route match
h.mgmtHandler(), // handles /metrics and /routes for prometheus
h.pluginHandler(), // prc to external plugins
headersHandler(h.ProxyHeaders, h.DropHeader), // add response headers and delete some request headers
accessLogHandler(h.AccessLog), // apache-format log file
R.Recoverer(log.Default()), // recover on errors
signatureHandler(h.Signature, h.Version), // send app signature
h.pingHandler, // respond to /ping
basicAuthHandler(h.BasicAuthEnabled, h.BasicAuthAllowed), // basic auth
h.healthMiddleware, // respond to /health
h.matchHandler, // set matched routes to context
limiterSystemHandler(h.ThrottleSystem), // limit total requests/sec
limiterUserHandler(h.ThrottleUser), // req/seq per user/route match
h.mgmtHandler(), // handles /metrics and /routes for prometheus
h.pluginHandler(), // prc to external plugins
headersHandler(h.ProxyHeaders, h.DropHeader), // add response headers and delete some request headers
accessLogHandler(h.AccessLog), // apache-format log file
stdoutLogHandler(h.StdOutEnabled, logger.New(logger.Log(log.Default()), logger.Prefix("[INFO]")).Handler),
maxReqSizeHandler(h.MaxBodySize), // limit request max size
gzipHandler(h.GzEnabled), // gzip response

View File

@@ -565,6 +565,84 @@ func TestHttp_health(t *testing.T) {
}
}
func TestHttp_withBasicAuth(t *testing.T) {
port := rand.Intn(10000) + 40000
h := Http{Timeouts: Timeouts{ResponseHeader: 200 * time.Millisecond}, Address: fmt.Sprintf("127.0.0.1:%d", port),
AccessLog: io.Discard, Signature: true, ProxyHeaders: []string{"hh1:vv1", "hh2:vv2"}, StdOutEnabled: true,
Reporter: &ErrorReporter{Nice: true}, BasicAuthEnabled: true, BasicAuthAllowed: []string{
"test:$2y$05$zMxDmK65SjcH2vJQNopVSO/nE8ngVLx65RoETyHpez7yTS/8CLEiW",
"test2:$2y$05$TLQqHh6VT4JxysdKGPOlJeSkkMsv.Ku/G45i7ssIm80XuouCrES12",
}}
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
ds := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("req: %v", r)
w.Header().Add("h1", "v1")
require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP"))
fmt.Fprintf(w, "response %s", r.URL.String())
}))
svc := discovery.NewService([]discovery.Provider{
&provider.Static{Rules: []string{
"localhost,^/api/(.*)," + ds.URL + "/123/$1,",
"127.0.0.1,^/api/(.*)," + ds.URL + "/567/$1,",
"*,/web,spa:testdata,",
},
}}, time.Millisecond*10)
go func() {
_ = svc.Run(context.Background())
}()
time.Sleep(50 * time.Millisecond)
h.Matcher, h.Metrics = svc, mgmt.NewMetrics()
go func() {
_ = h.Run(ctx)
}()
time.Sleep(10 * time.Millisecond)
client := http.Client{}
{
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
}
{
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
req.SetBasicAuth("test", "badpasswd")
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
}
{
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
req.SetBasicAuth("test", "passwd")
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
{
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
req.SetBasicAuth("test2", "passwd2")
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
}
func TestHttp_toHttp(t *testing.T) {
tbl := []struct {