package netutil

import (
	"crypto/tls"
	"fmt"
	"strconv"
	"strings"
	"sync"

	"github.com/VictoriaMetrics/VictoriaMetrics/lib/fasttime"
)

// GetServerTLSConfig returns TLS config for the server.
func GetServerTLSConfig(tlsCertFile, tlsKeyFile, tlsMinVersion string, tlsCipherSuites []string) (*tls.Config, error) {
	minVersion, err := ParseTLSVersion(tlsMinVersion)
	if err != nil {
		return nil, fmt.Errorf("cannnot use TLS min version from tlsMinVersion=%q. Supported TLS versions (TLS10, TLS11, TLS12, TLS13): %w", tlsMinVersion, err)
	}
	cipherSuites, err := cipherSuitesFromNames(tlsCipherSuites)
	if err != nil {
		return nil, fmt.Errorf("cannot use TLS cipher suites from tlsCipherSuites=%q: %w", tlsCipherSuites, err)
	}
	cfg := &tls.Config{
		MinVersion: minVersion,

		// Do not set MaxVersion, since this has no sense from security PoV.
		// This can only result in lower security level if improperly set.

		CipherSuites: cipherSuites,
	}

	cfg.GetCertificate = newGetCertificateFunc(tlsCertFile, tlsKeyFile)
	return cfg, nil
}

func newGetCertificateFunc(tlsCertFile, tlsKeyFile string) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
	var certLock sync.Mutex
	var certDeadline uint64
	var cert *tls.Certificate
	return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
		certLock.Lock()
		defer certLock.Unlock()
		if fasttime.UnixTimestamp() > certDeadline {
			c, err := tls.LoadX509KeyPair(tlsCertFile, tlsKeyFile)
			if err != nil {
				return nil, fmt.Errorf("cannot load TLS cert from certFile=%q, keyFile=%q: %w", tlsCertFile, tlsKeyFile, err)
			}
			certDeadline = fasttime.UnixTimestamp() + 1
			cert = &c
		}
		return cert, nil
	}
}

func cipherSuitesFromNames(cipherSuiteNames []string) ([]uint16, error) {
	if len(cipherSuiteNames) == 0 {
		return nil, nil
	}

	css := tls.CipherSuites()

	cssByName := make(map[string]uint16, len(css))
	for _, cs := range css {
		cssByName[strings.ToLower(cs.Name)] = cs.ID
	}

	cssByID := make(map[uint16]bool, len(css))
	for _, cs := range css {
		cssByID[cs.ID] = true
	}

	cipherSuites := make([]uint16, 0, len(cipherSuiteNames))
	for _, name := range cipherSuiteNames {
		id, ok := cssByName[strings.ToLower(name)]
		if !ok {
			// Try searching by ID
			idKey, err := strconv.ParseUint(name, 0, 16)
			if err != nil || !cssByID[uint16(idKey)] {
				return nil, fmt.Errorf("unsupported TLS cipher suite name: %s", name)
			}
			id = uint16(idKey)
		}
		cipherSuites = append(cipherSuites, id)
	}
	return cipherSuites, nil
}

// ParseTLSVersion returns tls version from the given string s.
func ParseTLSVersion(s string) (uint16, error) {
	switch strings.ToUpper(s) {
	case "":
		// Special case - use default TLS version provided by tls package.
		return 0, nil
	case "TLS13":
		return tls.VersionTLS13, nil
	case "TLS12":
		return tls.VersionTLS12, nil
	case "TLS11":
		return tls.VersionTLS11, nil
	case "TLS10":
		return tls.VersionTLS10, nil
	default:
		return 0, fmt.Errorf("unsupported TLS version %q", s)
	}
}