mirror of
https://github.com/VictoriaMetrics/VictoriaMetrics.git
synced 2024-12-20 23:46:23 +01:00
223 lines
5.5 KiB
Go
223 lines
5.5 KiB
Go
|
package auth
|
||
|
|
||
|
import (
|
||
|
"encoding/base64"
|
||
|
"fmt"
|
||
|
"net/http"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
|
||
|
"github.com/VictoriaMetrics/VictoriaMetrics/lib/fasttime"
|
||
|
)
|
||
|
|
||
|
// HTTPClientConfig represents http client config.
|
||
|
type HTTPClientConfig struct {
|
||
|
BasicAuth *BasicAuthConfig
|
||
|
BearerToken string
|
||
|
Headers string
|
||
|
}
|
||
|
|
||
|
// NewConfig creates auth config for the given hcc.
|
||
|
func (hcc *HTTPClientConfig) NewConfig() (*Config, error) {
|
||
|
opts := &Options{
|
||
|
BasicAuth: hcc.BasicAuth,
|
||
|
BearerToken: hcc.BearerToken,
|
||
|
Headers: hcc.Headers,
|
||
|
}
|
||
|
return opts.NewConfig()
|
||
|
}
|
||
|
|
||
|
// BasicAuthConfig represents basic auth config.
|
||
|
type BasicAuthConfig struct {
|
||
|
Username string
|
||
|
Password string
|
||
|
PasswordFile string
|
||
|
}
|
||
|
|
||
|
// ConfigOptions options which helps build Config
|
||
|
type ConfigOptions func(config *HTTPClientConfig)
|
||
|
|
||
|
// Generate returns Config based on the given params
|
||
|
func Generate(filterOptions ...ConfigOptions) (*Config, error) {
|
||
|
authCfg := &HTTPClientConfig{}
|
||
|
for _, option := range filterOptions {
|
||
|
option(authCfg)
|
||
|
}
|
||
|
|
||
|
return authCfg.NewConfig()
|
||
|
}
|
||
|
|
||
|
// WithBasicAuth returns AuthConfigOptions and initialized BasicAuthConfig based on given params
|
||
|
func WithBasicAuth(username, password string) ConfigOptions {
|
||
|
return func(config *HTTPClientConfig) {
|
||
|
if username != "" || password != "" {
|
||
|
config.BasicAuth = &BasicAuthConfig{
|
||
|
Username: username,
|
||
|
Password: password,
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// WithBearer returns AuthConfigOptions and set BearerToken or BearerTokenFile based on given params
|
||
|
func WithBearer(token string) ConfigOptions {
|
||
|
return func(config *HTTPClientConfig) {
|
||
|
if token != "" {
|
||
|
config.BearerToken = token
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// WithHeaders returns AuthConfigOptions and set Headers based on the given params
|
||
|
func WithHeaders(headers string) ConfigOptions {
|
||
|
return func(config *HTTPClientConfig) {
|
||
|
if headers != "" {
|
||
|
config.Headers = headers
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Config is auth config.
|
||
|
type Config struct {
|
||
|
getAuthHeader func() string
|
||
|
authHeaderLock sync.Mutex
|
||
|
authHeader string
|
||
|
authHeaderDeadline uint64
|
||
|
|
||
|
headers []keyValue
|
||
|
|
||
|
authDigest string
|
||
|
}
|
||
|
|
||
|
// SetHeaders sets the configured ac headers to req.
|
||
|
func (ac *Config) SetHeaders(req *http.Request, setAuthHeader bool) {
|
||
|
reqHeaders := req.Header
|
||
|
for _, h := range ac.headers {
|
||
|
reqHeaders.Set(h.key, h.value)
|
||
|
}
|
||
|
if setAuthHeader {
|
||
|
if ah := ac.GetAuthHeader(); ah != "" {
|
||
|
reqHeaders.Set("Authorization", ah)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// GetAuthHeader returns optional `Authorization: ...` http header.
|
||
|
func (ac *Config) GetAuthHeader() string {
|
||
|
f := ac.getAuthHeader
|
||
|
if f == nil {
|
||
|
return ""
|
||
|
}
|
||
|
ac.authHeaderLock.Lock()
|
||
|
defer ac.authHeaderLock.Unlock()
|
||
|
if fasttime.UnixTimestamp() > ac.authHeaderDeadline {
|
||
|
ac.authHeader = f()
|
||
|
// Cache the authHeader for a second.
|
||
|
ac.authHeaderDeadline = fasttime.UnixTimestamp() + 1
|
||
|
}
|
||
|
return ac.authHeader
|
||
|
}
|
||
|
|
||
|
type authContext struct {
|
||
|
// getAuthHeader must return <value> for 'Authorization: <value>' http request header
|
||
|
getAuthHeader func() string
|
||
|
|
||
|
// authDigest must contain the digest for the used authorization
|
||
|
// The digest must be changed whenever the original config changes.
|
||
|
authDigest string
|
||
|
}
|
||
|
|
||
|
func (ac *authContext) initFromBasicAuthConfig(ba *BasicAuthConfig) error {
|
||
|
if ba.Username == "" {
|
||
|
return fmt.Errorf("missing `username` in `basic_auth` section")
|
||
|
}
|
||
|
if ba.Password != "" {
|
||
|
ac.getAuthHeader = func() string {
|
||
|
// See https://en.wikipedia.org/wiki/Basic_access_authentication
|
||
|
token := ba.Username + ":" + ba.Password
|
||
|
token64 := base64.StdEncoding.EncodeToString([]byte(token))
|
||
|
return "Basic " + token64
|
||
|
}
|
||
|
ac.authDigest = fmt.Sprintf("basic(username=%q, password=%q)", ba.Username, ba.Password)
|
||
|
return nil
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (ac *authContext) initFromBearerToken(bearerToken string) error {
|
||
|
ac.getAuthHeader = func() string {
|
||
|
return "Bearer " + bearerToken
|
||
|
}
|
||
|
ac.authDigest = fmt.Sprintf("bearer(token=%q)", bearerToken)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Options contain options, which must be passed to NewConfig.
|
||
|
type Options struct {
|
||
|
// BasicAuth contains optional BasicAuthConfig.
|
||
|
BasicAuth *BasicAuthConfig
|
||
|
|
||
|
// BearerToken contains optional bearer token.
|
||
|
BearerToken string
|
||
|
|
||
|
// Headers contains optional http request headers in the form 'Foo: bar'.
|
||
|
Headers string
|
||
|
}
|
||
|
|
||
|
// NewConfig creates auth config from the given opts.
|
||
|
func (opts *Options) NewConfig() (*Config, error) {
|
||
|
var ac authContext
|
||
|
if opts.BasicAuth != nil {
|
||
|
if ac.getAuthHeader != nil {
|
||
|
return nil, fmt.Errorf("cannot use both `authorization` and `basic_auth`")
|
||
|
}
|
||
|
if err := ac.initFromBasicAuthConfig(opts.BasicAuth); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
if opts.BearerToken != "" {
|
||
|
if ac.getAuthHeader != nil {
|
||
|
return nil, fmt.Errorf("cannot simultaneously use `authorization`, `basic_auth` and `bearer_token`")
|
||
|
}
|
||
|
if err := ac.initFromBearerToken(opts.BearerToken); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
headers, err := parseHeaders(opts.Headers)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
c := &Config{
|
||
|
getAuthHeader: ac.getAuthHeader,
|
||
|
headers: headers,
|
||
|
authDigest: ac.authDigest,
|
||
|
}
|
||
|
return c, nil
|
||
|
}
|
||
|
|
||
|
type keyValue struct {
|
||
|
key string
|
||
|
value string
|
||
|
}
|
||
|
|
||
|
func parseHeaders(headers string) ([]keyValue, error) {
|
||
|
if len(headers) == 0 {
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
var headersSplitByDelimiter = strings.Split(headers, "^^")
|
||
|
|
||
|
kvs := make([]keyValue, len(headersSplitByDelimiter))
|
||
|
for i, h := range headersSplitByDelimiter {
|
||
|
n := strings.IndexByte(h, ':')
|
||
|
if n < 0 {
|
||
|
return nil, fmt.Errorf(`missing ':' in header %q; expecting "key: value" format`, h)
|
||
|
}
|
||
|
kv := &kvs[i]
|
||
|
kv.key = strings.TrimSpace(h[:n])
|
||
|
kv.value = strings.TrimSpace(h[n+1:])
|
||
|
}
|
||
|
return kvs, nil
|
||
|
}
|