From 2d5e5badcf4006c841df6ed4f05e2fef8500ab32 Mon Sep 17 00:00:00 2001 From: Aliaksandr Valialkin Date: Wed, 17 Apr 2024 16:46:27 +0200 Subject: [PATCH] app/vmauth: use lib/promauth for creating backend roundtripper This simplifies further maintenance and opens doors for additional config options supported by lib/promauth. For example, an ability to specify client TLS certificates. --- app/vmauth/auth_config.go | 14 +++--- app/vmauth/auth_config_test.go | 4 +- app/vmauth/main.go | 88 ++++++++-------------------------- 3 files changed, 30 insertions(+), 76 deletions(-) diff --git a/app/vmauth/auth_config.go b/app/vmauth/auth_config.go index 395351cdc..53b47b42b 100644 --- a/app/vmauth/auth_config.go +++ b/app/vmauth/auth_config.go @@ -81,7 +81,7 @@ type UserInfo struct { concurrencyLimitCh chan struct{} concurrencyLimitReached *metrics.Counter - httpTransport *http.Transport + rt http.RoundTripper requests *metrics.Counter backendErrors *metrics.Counter @@ -729,11 +729,11 @@ func parseAuthConfig(data []byte) (*AuthConfig, error) { return float64(len(ui.concurrencyLimitCh)) }) - tr, err := getTransport(ui.TLSInsecureSkipVerify, ui.TLSCAFile) + rt, err := newRoundTripper(ui.TLSInsecureSkipVerify, ui.TLSCAFile) if err != nil { - return nil, fmt.Errorf("cannot initialize HTTP transport: %w", err) + return nil, fmt.Errorf("cannot initialize HTTP RoundTripper: %w", err) } - ui.httpTransport = tr + ui.rt = rt } return ac, nil } @@ -777,11 +777,11 @@ func parseAuthConfigUsers(ac *AuthConfig) (map[string]*UserInfo, error) { return float64(len(ui.concurrencyLimitCh)) }) - tr, err := getTransport(ui.TLSInsecureSkipVerify, ui.TLSCAFile) + rt, err := newRoundTripper(ui.TLSInsecureSkipVerify, ui.TLSCAFile) if err != nil { - return nil, fmt.Errorf("cannot initialize HTTP transport: %w", err) + return nil, fmt.Errorf("cannot initialize HTTP RoundTripper: %w", err) } - ui.httpTransport = tr + ui.rt = rt for _, at := range ats { byAuthToken[at] = ui diff --git a/app/vmauth/auth_config_test.go b/app/vmauth/auth_config_test.go index 3ba719b40..8dd89f49c 100644 --- a/app/vmauth/auth_config_test.go +++ b/app/vmauth/auth_config_test.go @@ -578,11 +578,11 @@ unauthorized_user: } ui := m[getHTTPAuthBasicToken("foo", "bar")] - if !isSetBool(ui.TLSInsecureSkipVerify, true) || !ui.httpTransport.TLSClientConfig.InsecureSkipVerify { + if !isSetBool(ui.TLSInsecureSkipVerify, true) { t.Fatalf("unexpected TLSInsecureSkipVerify value for user foo") } - if !isSetBool(ac.UnauthorizedUser.TLSInsecureSkipVerify, false) || ac.UnauthorizedUser.httpTransport.TLSClientConfig.InsecureSkipVerify { + if !isSetBool(ac.UnauthorizedUser.TLSInsecureSkipVerify, false) { t.Fatalf("unexpected TLSInsecureSkipVerify value for unauthorized_user") } } diff --git a/app/vmauth/main.go b/app/vmauth/main.go index 979fac519..0743dc079 100644 --- a/app/vmauth/main.go +++ b/app/vmauth/main.go @@ -2,8 +2,6 @@ package main import ( "context" - "crypto/tls" - "crypto/x509" "errors" "flag" "fmt" @@ -22,14 +20,13 @@ import ( "github.com/VictoriaMetrics/VictoriaMetrics/lib/buildinfo" "github.com/VictoriaMetrics/VictoriaMetrics/lib/bytesutil" - "github.com/VictoriaMetrics/VictoriaMetrics/lib/encoding" "github.com/VictoriaMetrics/VictoriaMetrics/lib/envflag" "github.com/VictoriaMetrics/VictoriaMetrics/lib/flagutil" - "github.com/VictoriaMetrics/VictoriaMetrics/lib/fs/fscore" "github.com/VictoriaMetrics/VictoriaMetrics/lib/httpserver" "github.com/VictoriaMetrics/VictoriaMetrics/lib/logger" "github.com/VictoriaMetrics/VictoriaMetrics/lib/netutil" "github.com/VictoriaMetrics/VictoriaMetrics/lib/procutil" + "github.com/VictoriaMetrics/VictoriaMetrics/lib/promauth" "github.com/VictoriaMetrics/VictoriaMetrics/lib/pushmetrics" ) @@ -240,7 +237,7 @@ func tryProcessingRequest(w http.ResponseWriter, r *http.Request, targetURL *url req.Host = targetURL.Host } updateHeadersByConfig(req.Header, hc.RequestHeaders) - res, err := ui.httpTransport.RoundTrip(req) + res, err := ui.rt.RoundTrip(req) rtb, rtbOK := req.Body.(*readTrackingBody) if err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { @@ -392,50 +389,26 @@ var ( missingRouteRequests = metrics.NewCounter(`vmauth_http_request_errors_total{reason="missing_route"}`) ) -func getTransport(insecureSkipVerifyP *bool, caFile string) (*http.Transport, error) { - if insecureSkipVerifyP == nil { - insecureSkipVerifyP = backendTLSInsecureSkipVerify +func newRoundTripper(insecureSkipVerifyP *bool, caFileP string) (http.RoundTripper, error) { + insecureSkipVerify := *backendTLSInsecureSkipVerify + if p := insecureSkipVerifyP; p != nil { + insecureSkipVerify = *p } - insecureSkipVerify := *insecureSkipVerifyP - if caFile == "" { - caFile = *backendTLSCAFile + caFile := *backendTLSCAFile + if caFileP != "" { + caFile = caFileP + } + opts := &promauth.Options{ + TLSConfig: &promauth.TLSConfig{ + InsecureSkipVerify: insecureSkipVerify, + CAFile: caFile, + }, + } + cfg, err := opts.NewConfig() + if err != nil { + return nil, fmt.Errorf("cannot initialize promauth.Config: %w", err) } - bb := bbPool.Get() - defer bbPool.Put(bb) - - bb.B = appendTransportKey(bb.B[:0], insecureSkipVerify, caFile) - - transportMapLock.Lock() - defer transportMapLock.Unlock() - - tr := transportMap[string(bb.B)] - if tr == nil { - trLocal, err := newTransport(insecureSkipVerify, caFile) - if err != nil { - return nil, err - } - transportMap[string(bb.B)] = trLocal - tr = trLocal - } - - return tr, nil -} - -var ( - transportMap = make(map[string]*http.Transport) - transportMapLock sync.Mutex -) - -func appendTransportKey(dst []byte, insecureSkipVerify bool, caFile string) []byte { - dst = encoding.MarshalBool(dst, insecureSkipVerify) - dst = encoding.MarshalBytes(dst, bytesutil.ToUnsafeBytes(caFile)) - return dst -} - -var bbPool bytesutil.ByteBufferPool - -func newTransport(insecureSkipVerify bool, caFile string) (*http.Transport, error) { tr := http.DefaultTransport.(*http.Transport).Clone() tr.ResponseHeaderTimeout = *responseTimeout // Automatic compression must be disabled in order to fix https://github.com/VictoriaMetrics/VictoriaMetrics/issues/535 @@ -444,27 +417,8 @@ func newTransport(insecureSkipVerify bool, caFile string) (*http.Transport, erro if tr.MaxIdleConns != 0 && tr.MaxIdleConns < tr.MaxIdleConnsPerHost { tr.MaxIdleConns = tr.MaxIdleConnsPerHost } - tlsCfg := tr.TLSClientConfig - if tlsCfg == nil { - tlsCfg = &tls.Config{} - tr.TLSClientConfig = tlsCfg - } - if insecureSkipVerify || caFile != "" { - tlsCfg.ClientSessionCache = tls.NewLRUClientSessionCache(0) - tlsCfg.InsecureSkipVerify = insecureSkipVerify - if caFile != "" { - data, err := fscore.ReadFileOrHTTP(caFile) - if err != nil { - return nil, fmt.Errorf("cannot read tls_ca_file: %w", err) - } - rootCA := x509.NewCertPool() - if !rootCA.AppendCertsFromPEM(data) { - return nil, fmt.Errorf("cannot parse data read from tls_ca_file %q", caFile) - } - tlsCfg.RootCAs = rootCA - } - } - return tr, nil + rt := cfg.NewRoundTripper(tr) + return rt, nil } var (