lib/netutil: move creation of GetCertificate callback into a separate function

This improves code readability a bit
This commit is contained in:
Aliaksandr Valialkin 2024-04-17 22:10:40 +02:00
parent 8412219781
commit bd454f5063
No known key found for this signature in database
GPG Key ID: 52C003EE2BCDB9EB

View File

@ -12,44 +12,46 @@ import (
// GetServerTLSConfig returns TLS config for the server. // GetServerTLSConfig returns TLS config for the server.
func GetServerTLSConfig(tlsCertFile, tlsKeyFile, tlsMinVersion string, tlsCipherSuites []string) (*tls.Config, error) { func GetServerTLSConfig(tlsCertFile, tlsKeyFile, tlsMinVersion string, tlsCipherSuites []string) (*tls.Config, error) {
var certLock sync.Mutex minVersion, err := ParseTLSVersion(tlsMinVersion)
var certDeadline uint64
var cert *tls.Certificate
c, err := tls.LoadX509KeyPair(tlsCertFile, tlsKeyFile)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot load TLS cert from certFile=%q, keyFile=%q: %w", tlsCertFile, tlsKeyFile, err) return nil, fmt.Errorf("cannnot use TLS min version from tlsMinVersion=%q. Supported TLS versions (TLS10, TLS11, TLS12, TLS13): %w", tlsMinVersion, err)
} }
cipherSuites, err := cipherSuitesFromNames(tlsCipherSuites) cipherSuites, err := cipherSuitesFromNames(tlsCipherSuites)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot use TLS cipher suites from tlsCipherSuites=%q: %w", tlsCipherSuites, err) return nil, fmt.Errorf("cannot use TLS cipher suites from tlsCipherSuites=%q: %w", tlsCipherSuites, err)
} }
minVersion, err := ParseTLSVersion(tlsMinVersion)
if err != nil {
return nil, fmt.Errorf("cannnot use TLS min version from tlsMinVersion=%q. Supported TLS versions (TLS10, TLS11, TLS12, TLS13): %w", tlsMinVersion, err)
}
cert = &c
cfg := &tls.Config{ cfg := &tls.Config{
MinVersion: minVersion, MinVersion: minVersion,
// Do not set MaxVersion, since this has no sense from security PoV. // Do not set MaxVersion, since this has no sense from security PoV.
// This can only result in lower security level if improperly set. // This can only result in lower security level if improperly set.
GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
certLock.Lock()
defer certLock.Unlock()
if fasttime.UnixTimestamp() > certDeadline {
c, err = tls.LoadX509KeyPair(tlsCertFile, tlsKeyFile)
if err != nil {
return nil, fmt.Errorf("cannot load TLS cert from certFile=%q, keyFile=%q: %w", tlsCertFile, tlsKeyFile, err)
}
certDeadline = fasttime.UnixTimestamp() + 1
cert = &c
}
return cert, nil
},
CipherSuites: cipherSuites, CipherSuites: cipherSuites,
} }
cfg.GetCertificate = newGetCertificateFunc(tlsCertFile, tlsKeyFile)
return cfg, nil return cfg, nil
} }
func newGetCertificateFunc(tlsCertFile, tlsKeyFile string) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
var certLock sync.Mutex
var certDeadline uint64
var cert *tls.Certificate
return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
certLock.Lock()
defer certLock.Unlock()
if fasttime.UnixTimestamp() > certDeadline {
c, err := tls.LoadX509KeyPair(tlsCertFile, tlsKeyFile)
if err != nil {
return nil, fmt.Errorf("cannot load TLS cert from certFile=%q, keyFile=%q: %w", tlsCertFile, tlsKeyFile, err)
}
certDeadline = fasttime.UnixTimestamp() + 1
cert = &c
}
return cert, nil
}
}
func cipherSuitesFromNames(cipherSuiteNames []string) ([]uint16, error) { func cipherSuitesFromNames(cipherSuiteNames []string) ([]uint16, error) {
if len(cipherSuiteNames) == 0 { if len(cipherSuiteNames) == 0 {
return nil, nil return nil, nil