mirror of
https://github.com/kemko/reproxy.git
synced 2026-01-01 15:55:49 +03:00
implement simple on/off basic-auth for all resources
lint: err shadowing extract htpasswd file load and add tests
This commit is contained in:
48
app/main.go
48
app/main.go
@@ -30,11 +30,12 @@ import (
|
||||
)
|
||||
|
||||
var opts struct {
|
||||
Listen string `short:"l" long:"listen" env:"LISTEN" description:"listen on host:port (default: 0.0.0.0:8080/8443 under docker, 127.0.0.1:80/443 without)"`
|
||||
MaxSize string `short:"m" long:"max" env:"MAX_SIZE" default:"64K" description:"max request size"`
|
||||
GzipEnabled bool `short:"g" long:"gzip" env:"GZIP" description:"enable gz compression"`
|
||||
ProxyHeaders []string `short:"x" long:"header" description:"outgoing proxy headers to add"` // env HEADER split in code to allow , inside ""
|
||||
DropHeaders []string `long:"drop-header" env:"DROP_HEADERS" description:"incoming headers to drop" env-delim:","`
|
||||
Listen string `short:"l" long:"listen" env:"LISTEN" description:"listen on host:port (default: 0.0.0.0:8080/8443 under docker, 127.0.0.1:80/443 without)"`
|
||||
MaxSize string `short:"m" long:"max" env:"MAX_SIZE" default:"64K" description:"max request size"`
|
||||
GzipEnabled bool `short:"g" long:"gzip" env:"GZIP" description:"enable gz compression"`
|
||||
ProxyHeaders []string `short:"x" long:"header" description:"outgoing proxy headers to add"` // env HEADER split in code to allow , inside ""
|
||||
DropHeaders []string `long:"drop-header" env:"DROP_HEADERS" description:"incoming headers to drop" env-delim:","`
|
||||
AuthBasicHtpasswd string `long:"basic-htpasswd" env:"BASIC_HTPASSWD" description:"htpasswd file for basic auth"`
|
||||
|
||||
LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" default:"random"` //nolint
|
||||
|
||||
@@ -228,6 +229,11 @@ func run() error {
|
||||
proxyHeaders = splitAtCommas(os.Getenv("HEADER")) // env value may have comma inside "", parsed separately
|
||||
}
|
||||
|
||||
basicAuthAllowed, baErr := makeBasicAuth(opts.AuthBasicHtpasswd)
|
||||
if baErr != nil {
|
||||
return fmt.Errorf("failed to load basic auth: %w", baErr)
|
||||
}
|
||||
|
||||
px := &proxy.Http{
|
||||
Version: revision,
|
||||
Matcher: svc,
|
||||
@@ -256,21 +262,41 @@ func run() error {
|
||||
ExpectContinue: opts.Timeouts.ExpectContinue,
|
||||
ResponseHeader: opts.Timeouts.ResponseHeader,
|
||||
},
|
||||
Metrics: makeMetrics(ctx, svc),
|
||||
Reporter: errReporter,
|
||||
PluginConductor: makePluginConductor(ctx),
|
||||
ThrottleSystem: opts.Throttle.System * 3,
|
||||
ThrottleUser: opts.Throttle.User,
|
||||
Metrics: makeMetrics(ctx, svc),
|
||||
Reporter: errReporter,
|
||||
PluginConductor: makePluginConductor(ctx),
|
||||
ThrottleSystem: opts.Throttle.System * 3,
|
||||
ThrottleUser: opts.Throttle.User,
|
||||
BasicAuthEnabled: len(basicAuthAllowed) > 0,
|
||||
BasicAuthAllowed: basicAuthAllowed,
|
||||
}
|
||||
|
||||
err = px.Run(ctx)
|
||||
if err != nil && err == http.ErrServerClosed {
|
||||
log.Printf("[WARN] proxy server closed, %v", err) //nolint gocritic
|
||||
log.Printf("[WARN] proxy server closed, %v", err) // nolint gocritic
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// makeBasicAuth returns a list of allowed basic auth users and password hashes.
|
||||
// if no htpasswd file is specified, an empty list is returned.
|
||||
func makeBasicAuth(htpasswdFile string) ([]string, error) {
|
||||
var basicAuthAllowed []string
|
||||
if htpasswdFile != "" {
|
||||
data, err := ioutil.ReadFile(htpasswdFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read htpasswd file %s: %w", htpasswdFile, err)
|
||||
}
|
||||
basicAuthAllowed = strings.Split(string(data), "\n")
|
||||
for i, v := range basicAuthAllowed {
|
||||
basicAuthAllowed[i] = strings.TrimSpace(v)
|
||||
basicAuthAllowed[i] = strings.Replace(basicAuthAllowed[i], "\t", "", -1)
|
||||
}
|
||||
}
|
||||
return basicAuthAllowed, nil
|
||||
}
|
||||
|
||||
// make all providers. the order is matter, defines which provider will have priority in case of conflicting rules
|
||||
// static first, file second and docker the last one
|
||||
func makeProviders() ([]discovery.Provider, error) {
|
||||
|
||||
@@ -376,3 +376,21 @@ func Test_splitAtCommas(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func Test_makeBasicAuth(t *testing.T) {
|
||||
pf := `test:$2y$05$zMxDmK65SjcH2vJQNopVSO/nE8ngVLx65RoETyHpez7yTS/8CLEiW
|
||||
test2:$2y$05$TLQqHh6VT4JxysdKGPOlJeSkkMsv.Ku/G45i7ssIm80XuouCrES12
|
||||
bad bad`
|
||||
|
||||
fh, err := os.CreateTemp(os.TempDir(), "reproxy_auth_*")
|
||||
require.NoError(t, err)
|
||||
defer fh.Close()
|
||||
|
||||
n, err := fh.WriteString(pf)
|
||||
require.Equal(t, len(pf), n)
|
||||
|
||||
res, err := makeBasicAuth(fh.Name())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, len(res))
|
||||
assert.Equal(t, []string{"test:$2y$05$zMxDmK65SjcH2vJQNopVSO/nE8ngVLx65RoETyHpez7yTS/8CLEiW", "test2:$2y$05$TLQqHh6VT4JxysdKGPOlJeSkkMsv.Ku/G45i7ssIm80XuouCrES12", "bad bad"}, res)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -10,6 +12,7 @@ import (
|
||||
log "github.com/go-pkgz/lgr"
|
||||
R "github.com/go-pkgz/rest"
|
||||
"github.com/gorilla/handlers"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/umputun/reproxy/app/discovery"
|
||||
)
|
||||
@@ -156,6 +159,57 @@ func limiterUserHandler(reqSec int) func(next http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// basicAuthHandler is a middleware that authenticates via basic auth, if enabled
|
||||
// allowed is a list of user:bcrypt(passwd) strings generated by `htpasswd -nbB user passwd`
|
||||
func basicAuthHandler(enabled bool, allowed []string) func(next http.Handler) http.Handler {
|
||||
if !enabled {
|
||||
return passThroughHandler
|
||||
}
|
||||
|
||||
unauthorized := func(w http.ResponseWriter) {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
return func(h http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
unauthorized(w)
|
||||
return
|
||||
}
|
||||
|
||||
passed := false
|
||||
for _, a := range allowed {
|
||||
alwElems := strings.Split(strings.TrimSpace(a), ":")
|
||||
if len(alwElems) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// hash to ensure constant time comparison not affected by username length
|
||||
usernameHash := sha256.Sum256([]byte(username))
|
||||
expectedUsernameHash := sha256.Sum256([]byte(alwElems[0]))
|
||||
|
||||
expectedPasswordHash := alwElems[1]
|
||||
userMatched := subtle.ConstantTimeCompare(usernameHash[:], expectedUsernameHash[:])
|
||||
passMatchErr := bcrypt.CompareHashAndPassword([]byte(expectedPasswordHash), []byte(password))
|
||||
if userMatched == 1 && passMatchErr == nil {
|
||||
passed = true // don't stop here, check all allowed to keep the overall time consistent
|
||||
}
|
||||
}
|
||||
|
||||
if !passed {
|
||||
unauthorized(w)
|
||||
return
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func passThroughHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(w, r)
|
||||
|
||||
@@ -186,3 +186,63 @@ func Test_limiterClientHandlerWithMatches(t *testing.T) {
|
||||
wg.Wait()
|
||||
assert.Equal(t, int32(20), atomic.LoadInt32(&passed))
|
||||
}
|
||||
|
||||
func TestHttp_basicAuthHandler(t *testing.T) {
|
||||
allowed := []string{
|
||||
"test:$2y$05$zMxDmK65SjcH2vJQNopVSO/nE8ngVLx65RoETyHpez7yTS/8CLEiW",
|
||||
"test2:$2y$05$TLQqHh6VT4JxysdKGPOlJeSkkMsv.Ku/G45i7ssIm80XuouCrES12 ",
|
||||
"bad bad",
|
||||
}
|
||||
|
||||
handler := basicAuthHandler(true, allowed)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Logf("req: %v", r)
|
||||
}))
|
||||
ts := httptest.NewServer(handler)
|
||||
|
||||
client := http.Client{}
|
||||
|
||||
tbl := []struct {
|
||||
reqFn func(r *http.Request)
|
||||
ok bool
|
||||
}{
|
||||
{func(r *http.Request) {}, false},
|
||||
{func(r *http.Request) { r.SetBasicAuth("test", "passwd") }, true},
|
||||
{func(r *http.Request) { r.SetBasicAuth("test", "passwdbad") }, false},
|
||||
{func(r *http.Request) { r.SetBasicAuth("test2", "passwd2") }, true},
|
||||
{func(r *http.Request) { r.SetBasicAuth("test2", "passwbad") }, false},
|
||||
{func(r *http.Request) { r.SetBasicAuth("testbad", "passwbad") }, false},
|
||||
}
|
||||
|
||||
for i, tt := range tbl {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", ts.URL, nil)
|
||||
require.NoError(t, err)
|
||||
tt.reqFn(req)
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
if tt.ok {
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
return
|
||||
}
|
||||
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
|
||||
handler = basicAuthHandler(false, allowed)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Logf("req: %v", r)
|
||||
}))
|
||||
ts2 := httptest.NewServer(handler)
|
||||
for i, tt := range tbl {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", ts2.URL, nil)
|
||||
require.NoError(t, err)
|
||||
tt.reqFn(req)
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -45,6 +45,9 @@ type Http struct { // nolint golint
|
||||
Reporter Reporter
|
||||
LBSelector func(len int) int
|
||||
|
||||
BasicAuthEnabled bool
|
||||
BasicAuthAllowed []string
|
||||
|
||||
ThrottleSystem int
|
||||
ThrottleUser int
|
||||
}
|
||||
@@ -111,17 +114,18 @@ func (h *Http) Run(ctx context.Context) error {
|
||||
}()
|
||||
|
||||
handler := R.Wrap(h.proxyHandler(),
|
||||
R.Recoverer(log.Default()), // recover on errors
|
||||
signatureHandler(h.Signature, h.Version), // send app signature
|
||||
h.pingHandler, // respond to /ping
|
||||
h.healthMiddleware, // respond to /health
|
||||
h.matchHandler, // set matched routes to context
|
||||
limiterSystemHandler(h.ThrottleSystem), // limit total requests/sec
|
||||
limiterUserHandler(h.ThrottleUser), // req/seq per user/route match
|
||||
h.mgmtHandler(), // handles /metrics and /routes for prometheus
|
||||
h.pluginHandler(), // prc to external plugins
|
||||
headersHandler(h.ProxyHeaders, h.DropHeader), // add response headers and delete some request headers
|
||||
accessLogHandler(h.AccessLog), // apache-format log file
|
||||
R.Recoverer(log.Default()), // recover on errors
|
||||
signatureHandler(h.Signature, h.Version), // send app signature
|
||||
h.pingHandler, // respond to /ping
|
||||
basicAuthHandler(h.BasicAuthEnabled, h.BasicAuthAllowed), // basic auth
|
||||
h.healthMiddleware, // respond to /health
|
||||
h.matchHandler, // set matched routes to context
|
||||
limiterSystemHandler(h.ThrottleSystem), // limit total requests/sec
|
||||
limiterUserHandler(h.ThrottleUser), // req/seq per user/route match
|
||||
h.mgmtHandler(), // handles /metrics and /routes for prometheus
|
||||
h.pluginHandler(), // prc to external plugins
|
||||
headersHandler(h.ProxyHeaders, h.DropHeader), // add response headers and delete some request headers
|
||||
accessLogHandler(h.AccessLog), // apache-format log file
|
||||
stdoutLogHandler(h.StdOutEnabled, logger.New(logger.Log(log.Default()), logger.Prefix("[INFO]")).Handler),
|
||||
maxReqSizeHandler(h.MaxBodySize), // limit request max size
|
||||
gzipHandler(h.GzEnabled), // gzip response
|
||||
|
||||
@@ -565,6 +565,84 @@ func TestHttp_health(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttp_withBasicAuth(t *testing.T) {
|
||||
port := rand.Intn(10000) + 40000
|
||||
h := Http{Timeouts: Timeouts{ResponseHeader: 200 * time.Millisecond}, Address: fmt.Sprintf("127.0.0.1:%d", port),
|
||||
AccessLog: io.Discard, Signature: true, ProxyHeaders: []string{"hh1:vv1", "hh2:vv2"}, StdOutEnabled: true,
|
||||
Reporter: &ErrorReporter{Nice: true}, BasicAuthEnabled: true, BasicAuthAllowed: []string{
|
||||
"test:$2y$05$zMxDmK65SjcH2vJQNopVSO/nE8ngVLx65RoETyHpez7yTS/8CLEiW",
|
||||
"test2:$2y$05$TLQqHh6VT4JxysdKGPOlJeSkkMsv.Ku/G45i7ssIm80XuouCrES12",
|
||||
}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
ds := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Logf("req: %v", r)
|
||||
w.Header().Add("h1", "v1")
|
||||
require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP"))
|
||||
fmt.Fprintf(w, "response %s", r.URL.String())
|
||||
}))
|
||||
|
||||
svc := discovery.NewService([]discovery.Provider{
|
||||
&provider.Static{Rules: []string{
|
||||
"localhost,^/api/(.*)," + ds.URL + "/123/$1,",
|
||||
"127.0.0.1,^/api/(.*)," + ds.URL + "/567/$1,",
|
||||
"*,/web,spa:testdata,",
|
||||
},
|
||||
}}, time.Millisecond*10)
|
||||
|
||||
go func() {
|
||||
_ = svc.Run(context.Background())
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
h.Matcher, h.Metrics = svc, mgmt.NewMetrics()
|
||||
|
||||
go func() {
|
||||
_ = h.Run(ctx)
|
||||
}()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
client := http.Client{}
|
||||
|
||||
{
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
|
||||
require.NoError(t, err)
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
|
||||
{
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
|
||||
req.SetBasicAuth("test", "badpasswd")
|
||||
require.NoError(t, err)
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
{
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
|
||||
req.SetBasicAuth("test", "passwd")
|
||||
require.NoError(t, err)
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
{
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
|
||||
req.SetBasicAuth("test2", "passwd2")
|
||||
require.NoError(t, err)
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttp_toHttp(t *testing.T) {
|
||||
|
||||
tbl := []struct {
|
||||
|
||||
Reference in New Issue
Block a user