app/vmauth: use lib/promauth for creating backend roundtripper

This simplifies further maintenance and opens doors for additional config options
supported by lib/promauth. For example, an ability to specify client TLS certificates.
This commit is contained in:
Aliaksandr Valialkin 2024-04-17 16:46:27 +02:00
parent 21a9b1f920
commit 2d5e5badcf
No known key found for this signature in database
GPG Key ID: 52C003EE2BCDB9EB
3 changed files with 30 additions and 76 deletions

View File

@ -81,7 +81,7 @@ type UserInfo struct {
concurrencyLimitCh chan struct{}
concurrencyLimitReached *metrics.Counter
httpTransport *http.Transport
rt http.RoundTripper
requests *metrics.Counter
backendErrors *metrics.Counter
@ -729,11 +729,11 @@ func parseAuthConfig(data []byte) (*AuthConfig, error) {
return float64(len(ui.concurrencyLimitCh))
})
tr, err := getTransport(ui.TLSInsecureSkipVerify, ui.TLSCAFile)
rt, err := newRoundTripper(ui.TLSInsecureSkipVerify, ui.TLSCAFile)
if err != nil {
return nil, fmt.Errorf("cannot initialize HTTP transport: %w", err)
return nil, fmt.Errorf("cannot initialize HTTP RoundTripper: %w", err)
}
ui.httpTransport = tr
ui.rt = rt
}
return ac, nil
}
@ -777,11 +777,11 @@ func parseAuthConfigUsers(ac *AuthConfig) (map[string]*UserInfo, error) {
return float64(len(ui.concurrencyLimitCh))
})
tr, err := getTransport(ui.TLSInsecureSkipVerify, ui.TLSCAFile)
rt, err := newRoundTripper(ui.TLSInsecureSkipVerify, ui.TLSCAFile)
if err != nil {
return nil, fmt.Errorf("cannot initialize HTTP transport: %w", err)
return nil, fmt.Errorf("cannot initialize HTTP RoundTripper: %w", err)
}
ui.httpTransport = tr
ui.rt = rt
for _, at := range ats {
byAuthToken[at] = ui

View File

@ -578,11 +578,11 @@ unauthorized_user:
}
ui := m[getHTTPAuthBasicToken("foo", "bar")]
if !isSetBool(ui.TLSInsecureSkipVerify, true) || !ui.httpTransport.TLSClientConfig.InsecureSkipVerify {
if !isSetBool(ui.TLSInsecureSkipVerify, true) {
t.Fatalf("unexpected TLSInsecureSkipVerify value for user foo")
}
if !isSetBool(ac.UnauthorizedUser.TLSInsecureSkipVerify, false) || ac.UnauthorizedUser.httpTransport.TLSClientConfig.InsecureSkipVerify {
if !isSetBool(ac.UnauthorizedUser.TLSInsecureSkipVerify, false) {
t.Fatalf("unexpected TLSInsecureSkipVerify value for unauthorized_user")
}
}

View File

@ -2,8 +2,6 @@ package main
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"flag"
"fmt"
@ -22,14 +20,13 @@ import (
"github.com/VictoriaMetrics/VictoriaMetrics/lib/buildinfo"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/bytesutil"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/encoding"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/envflag"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/flagutil"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/fs/fscore"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/httpserver"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/logger"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/netutil"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/procutil"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/promauth"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/pushmetrics"
)
@ -240,7 +237,7 @@ func tryProcessingRequest(w http.ResponseWriter, r *http.Request, targetURL *url
req.Host = targetURL.Host
}
updateHeadersByConfig(req.Header, hc.RequestHeaders)
res, err := ui.httpTransport.RoundTrip(req)
res, err := ui.rt.RoundTrip(req)
rtb, rtbOK := req.Body.(*readTrackingBody)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
@ -392,50 +389,26 @@ var (
missingRouteRequests = metrics.NewCounter(`vmauth_http_request_errors_total{reason="missing_route"}`)
)
func getTransport(insecureSkipVerifyP *bool, caFile string) (*http.Transport, error) {
if insecureSkipVerifyP == nil {
insecureSkipVerifyP = backendTLSInsecureSkipVerify
func newRoundTripper(insecureSkipVerifyP *bool, caFileP string) (http.RoundTripper, error) {
insecureSkipVerify := *backendTLSInsecureSkipVerify
if p := insecureSkipVerifyP; p != nil {
insecureSkipVerify = *p
}
insecureSkipVerify := *insecureSkipVerifyP
if caFile == "" {
caFile = *backendTLSCAFile
caFile := *backendTLSCAFile
if caFileP != "" {
caFile = caFileP
}
opts := &promauth.Options{
TLSConfig: &promauth.TLSConfig{
InsecureSkipVerify: insecureSkipVerify,
CAFile: caFile,
},
}
cfg, err := opts.NewConfig()
if err != nil {
return nil, fmt.Errorf("cannot initialize promauth.Config: %w", err)
}
bb := bbPool.Get()
defer bbPool.Put(bb)
bb.B = appendTransportKey(bb.B[:0], insecureSkipVerify, caFile)
transportMapLock.Lock()
defer transportMapLock.Unlock()
tr := transportMap[string(bb.B)]
if tr == nil {
trLocal, err := newTransport(insecureSkipVerify, caFile)
if err != nil {
return nil, err
}
transportMap[string(bb.B)] = trLocal
tr = trLocal
}
return tr, nil
}
var (
transportMap = make(map[string]*http.Transport)
transportMapLock sync.Mutex
)
func appendTransportKey(dst []byte, insecureSkipVerify bool, caFile string) []byte {
dst = encoding.MarshalBool(dst, insecureSkipVerify)
dst = encoding.MarshalBytes(dst, bytesutil.ToUnsafeBytes(caFile))
return dst
}
var bbPool bytesutil.ByteBufferPool
func newTransport(insecureSkipVerify bool, caFile string) (*http.Transport, error) {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.ResponseHeaderTimeout = *responseTimeout
// Automatic compression must be disabled in order to fix https://github.com/VictoriaMetrics/VictoriaMetrics/issues/535
@ -444,27 +417,8 @@ func newTransport(insecureSkipVerify bool, caFile string) (*http.Transport, erro
if tr.MaxIdleConns != 0 && tr.MaxIdleConns < tr.MaxIdleConnsPerHost {
tr.MaxIdleConns = tr.MaxIdleConnsPerHost
}
tlsCfg := tr.TLSClientConfig
if tlsCfg == nil {
tlsCfg = &tls.Config{}
tr.TLSClientConfig = tlsCfg
}
if insecureSkipVerify || caFile != "" {
tlsCfg.ClientSessionCache = tls.NewLRUClientSessionCache(0)
tlsCfg.InsecureSkipVerify = insecureSkipVerify
if caFile != "" {
data, err := fscore.ReadFileOrHTTP(caFile)
if err != nil {
return nil, fmt.Errorf("cannot read tls_ca_file: %w", err)
}
rootCA := x509.NewCertPool()
if !rootCA.AppendCertsFromPEM(data) {
return nil, fmt.Errorf("cannot parse data read from tls_ca_file %q", caFile)
}
tlsCfg.RootCAs = rootCA
}
}
return tr, nil
rt := cfg.NewRoundTripper(tr)
return rt, nil
}
var (