VictoriaMetrics/lib/netutil/tls.go
jackyin e5d279bb71
lib/netutil: validate TLS cert and key files immediately (#6621)
Validate files specified via `-tlsKeyFile` and `-tlsCertFile` cmd-line flags on the process start-up. Previously, validation happened on the first connection accepted by HTTP server.

https://github.com/VictoriaMetrics/VictoriaMetrics/issues/6608

---------

Co-authored-by: hagen1778 <roman@victoriametrics.com>
2024-07-29 13:58:53 +02:00

111 lines
3.1 KiB
Go

package netutil
import (
"crypto/tls"
"fmt"
"strconv"
"strings"
"sync"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/fasttime"
)
// GetServerTLSConfig returns TLS config for the server.
func GetServerTLSConfig(tlsCertFile, tlsKeyFile, tlsMinVersion string, tlsCipherSuites []string) (*tls.Config, error) {
_, err := tls.LoadX509KeyPair(tlsCertFile, tlsKeyFile)
if err != nil {
return nil, fmt.Errorf("cannot load TLS certificate and key files: %w", err)
}
minVersion, err := ParseTLSVersion(tlsMinVersion)
if err != nil {
return nil, fmt.Errorf("cannnot use TLS min version from tlsMinVersion=%q. Supported TLS versions (TLS10, TLS11, TLS12, TLS13): %w", tlsMinVersion, err)
}
cipherSuites, err := cipherSuitesFromNames(tlsCipherSuites)
if err != nil {
return nil, fmt.Errorf("cannot use TLS cipher suites from tlsCipherSuites=%q: %w", tlsCipherSuites, err)
}
cfg := &tls.Config{
MinVersion: minVersion,
// Do not set MaxVersion, since this has no sense from security PoV.
// This can only result in lower security level if improperly set.
CipherSuites: cipherSuites,
}
cfg.GetCertificate = newGetCertificateFunc(tlsCertFile, tlsKeyFile)
return cfg, nil
}
func newGetCertificateFunc(tlsCertFile, tlsKeyFile string) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
var certLock sync.Mutex
var certDeadline uint64
var cert *tls.Certificate
return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
certLock.Lock()
defer certLock.Unlock()
if fasttime.UnixTimestamp() > certDeadline {
c, err := tls.LoadX509KeyPair(tlsCertFile, tlsKeyFile)
if err != nil {
return nil, fmt.Errorf("cannot load TLS cert from certFile=%q, keyFile=%q: %w", tlsCertFile, tlsKeyFile, err)
}
certDeadline = fasttime.UnixTimestamp() + 1
cert = &c
}
return cert, nil
}
}
func cipherSuitesFromNames(cipherSuiteNames []string) ([]uint16, error) {
if len(cipherSuiteNames) == 0 {
return nil, nil
}
css := tls.CipherSuites()
cssByName := make(map[string]uint16, len(css))
for _, cs := range css {
cssByName[strings.ToLower(cs.Name)] = cs.ID
}
cssByID := make(map[uint16]bool, len(css))
for _, cs := range css {
cssByID[cs.ID] = true
}
cipherSuites := make([]uint16, 0, len(cipherSuiteNames))
for _, name := range cipherSuiteNames {
id, ok := cssByName[strings.ToLower(name)]
if !ok {
// Try searching by ID
idKey, err := strconv.ParseUint(name, 0, 16)
if err != nil || !cssByID[uint16(idKey)] {
return nil, fmt.Errorf("unsupported TLS cipher suite name: %s", name)
}
id = uint16(idKey)
}
cipherSuites = append(cipherSuites, id)
}
return cipherSuites, nil
}
// ParseTLSVersion returns tls version from the given string s.
func ParseTLSVersion(s string) (uint16, error) {
switch strings.ToUpper(s) {
case "":
// Special case - use default TLS version provided by tls package.
return 0, nil
case "TLS13":
return tls.VersionTLS13, nil
case "TLS12":
return tls.VersionTLS12, nil
case "TLS11":
return tls.VersionTLS11, nil
case "TLS10":
return tls.VersionTLS10, nil
default:
return 0, fmt.Errorf("unsupported TLS version %q", s)
}
}