VictoriaMetrics/app/vmctl/auth/auth.go

223 lines
5.5 KiB
Go
Raw Normal View History

package auth
import (
"encoding/base64"
"fmt"
"net/http"
"strings"
"sync"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/fasttime"
)
// HTTPClientConfig represents http client config.
type HTTPClientConfig struct {
BasicAuth *BasicAuthConfig
BearerToken string
Headers string
}
// NewConfig creates auth config for the given hcc.
func (hcc *HTTPClientConfig) NewConfig() (*Config, error) {
opts := &Options{
BasicAuth: hcc.BasicAuth,
BearerToken: hcc.BearerToken,
Headers: hcc.Headers,
}
return opts.NewConfig()
}
// BasicAuthConfig represents basic auth config.
type BasicAuthConfig struct {
Username string
Password string
PasswordFile string
}
// ConfigOptions options which helps build Config
type ConfigOptions func(config *HTTPClientConfig)
// Generate returns Config based on the given params
func Generate(filterOptions ...ConfigOptions) (*Config, error) {
authCfg := &HTTPClientConfig{}
for _, option := range filterOptions {
option(authCfg)
}
return authCfg.NewConfig()
}
// WithBasicAuth returns AuthConfigOptions and initialized BasicAuthConfig based on given params
func WithBasicAuth(username, password string) ConfigOptions {
return func(config *HTTPClientConfig) {
if username != "" || password != "" {
config.BasicAuth = &BasicAuthConfig{
Username: username,
Password: password,
}
}
}
}
// WithBearer returns AuthConfigOptions and set BearerToken or BearerTokenFile based on given params
func WithBearer(token string) ConfigOptions {
return func(config *HTTPClientConfig) {
if token != "" {
config.BearerToken = token
}
}
}
// WithHeaders returns AuthConfigOptions and set Headers based on the given params
func WithHeaders(headers string) ConfigOptions {
return func(config *HTTPClientConfig) {
if headers != "" {
config.Headers = headers
}
}
}
// Config is auth config.
type Config struct {
getAuthHeader func() string
authHeaderLock sync.Mutex
authHeader string
authHeaderDeadline uint64
headers []keyValue
authDigest string
}
// SetHeaders sets the configured ac headers to req.
func (ac *Config) SetHeaders(req *http.Request, setAuthHeader bool) {
reqHeaders := req.Header
for _, h := range ac.headers {
reqHeaders.Set(h.key, h.value)
}
if setAuthHeader {
if ah := ac.GetAuthHeader(); ah != "" {
reqHeaders.Set("Authorization", ah)
}
}
}
// GetAuthHeader returns optional `Authorization: ...` http header.
func (ac *Config) GetAuthHeader() string {
f := ac.getAuthHeader
if f == nil {
return ""
}
ac.authHeaderLock.Lock()
defer ac.authHeaderLock.Unlock()
if fasttime.UnixTimestamp() > ac.authHeaderDeadline {
ac.authHeader = f()
// Cache the authHeader for a second.
ac.authHeaderDeadline = fasttime.UnixTimestamp() + 1
}
return ac.authHeader
}
type authContext struct {
// getAuthHeader must return <value> for 'Authorization: <value>' http request header
getAuthHeader func() string
// authDigest must contain the digest for the used authorization
// The digest must be changed whenever the original config changes.
authDigest string
}
func (ac *authContext) initFromBasicAuthConfig(ba *BasicAuthConfig) error {
if ba.Username == "" {
return fmt.Errorf("missing `username` in `basic_auth` section")
}
if ba.Password != "" {
ac.getAuthHeader = func() string {
// See https://en.wikipedia.org/wiki/Basic_access_authentication
token := ba.Username + ":" + ba.Password
token64 := base64.StdEncoding.EncodeToString([]byte(token))
return "Basic " + token64
}
ac.authDigest = fmt.Sprintf("basic(username=%q, password=%q)", ba.Username, ba.Password)
return nil
}
return nil
}
func (ac *authContext) initFromBearerToken(bearerToken string) error {
ac.getAuthHeader = func() string {
return "Bearer " + bearerToken
}
ac.authDigest = fmt.Sprintf("bearer(token=%q)", bearerToken)
return nil
}
// Options contain options, which must be passed to NewConfig.
type Options struct {
// BasicAuth contains optional BasicAuthConfig.
BasicAuth *BasicAuthConfig
// BearerToken contains optional bearer token.
BearerToken string
// Headers contains optional http request headers in the form 'Foo: bar'.
Headers string
}
// NewConfig creates auth config from the given opts.
func (opts *Options) NewConfig() (*Config, error) {
var ac authContext
if opts.BasicAuth != nil {
if ac.getAuthHeader != nil {
return nil, fmt.Errorf("cannot use both `authorization` and `basic_auth`")
}
if err := ac.initFromBasicAuthConfig(opts.BasicAuth); err != nil {
return nil, err
}
}
if opts.BearerToken != "" {
if ac.getAuthHeader != nil {
return nil, fmt.Errorf("cannot simultaneously use `authorization`, `basic_auth` and `bearer_token`")
}
if err := ac.initFromBearerToken(opts.BearerToken); err != nil {
return nil, err
}
}
headers, err := parseHeaders(opts.Headers)
if err != nil {
return nil, err
}
c := &Config{
getAuthHeader: ac.getAuthHeader,
headers: headers,
authDigest: ac.authDigest,
}
return c, nil
}
type keyValue struct {
key string
value string
}
func parseHeaders(headers string) ([]keyValue, error) {
if len(headers) == 0 {
return nil, nil
}
var headersSplitByDelimiter = strings.Split(headers, "^^")
kvs := make([]keyValue, len(headersSplitByDelimiter))
for i, h := range headersSplitByDelimiter {
n := strings.IndexByte(h, ':')
if n < 0 {
return nil, fmt.Errorf(`missing ':' in header %q; expecting "key: value" format`, h)
}
kv := &kvs[i]
kv.key = strings.TrimSpace(h[:n])
kv.value = strings.TrimSpace(h[n+1:])
}
return kvs, nil
}