package netutil

import (
	"context"
	"fmt"
	"net"
	"strconv"
	"strings"
	"sync/atomic"

	"github.com/VictoriaMetrics/metrics"
)

// NewStatDialFuncWithDial returns dialer function that registers stats metrics for conns.
func NewStatDialFuncWithDial(metricPrefix string, dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)) func(ctx context.Context, network, addr string) (net.Conn, error) {
	return newStatDialFunc(metricPrefix, dialFunc)
}

// NewStatDialFunc returns dialer function that supports DNS SRV records and registers stats metrics for conns.
func NewStatDialFunc(metricPrefix string) func(ctx context.Context, network, addr string) (net.Conn, error) {
	return newStatDialFunc(metricPrefix, DialMaybeSRV)
}

func newStatDialFunc(metricPrefix string, dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)) func(ctx context.Context, network, addr string) (net.Conn, error) {
	return func(ctx context.Context, _, addr string) (net.Conn, error) {
		sc := &statDialConn{
			dialsTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_dials_total`, metricPrefix)),
			dialErrors: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_dial_errors_total`, metricPrefix)),
			conns:      metrics.GetOrCreateGauge(fmt.Sprintf(`%s_conns`, metricPrefix), nil),

			readsTotal:        metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_reads_total`, metricPrefix)),
			writesTotal:       metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_writes_total`, metricPrefix)),
			readErrorsTotal:   metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_read_errors_total`, metricPrefix)),
			writeErrorsTotal:  metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_write_errors_total`, metricPrefix)),
			bytesReadTotal:    metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_bytes_read_total`, metricPrefix)),
			bytesWrittenTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_bytes_written_total`, metricPrefix)),
		}

		network := GetTCPNetwork()
		conn, err := dialFunc(ctx, network, addr)
		sc.dialsTotal.Inc()
		if err != nil {
			sc.dialErrors.Inc()
			if !TCP6Enabled() && !isTCPv4Addr(addr) {
				err = fmt.Errorf("%w; try -enableTCP6 command-line flag for dialing ipv6 addresses", err)
			}
			return nil, err
		}
		sc.Conn = conn
		sc.conns.Inc()
		return sc, nil
	}
}

type statDialConn struct {
	closed atomic.Int32
	net.Conn

	dialsTotal *metrics.Counter
	dialErrors *metrics.Counter
	conns      *metrics.Gauge

	readsTotal        *metrics.Counter
	writesTotal       *metrics.Counter
	readErrorsTotal   *metrics.Counter
	writeErrorsTotal  *metrics.Counter
	bytesReadTotal    *metrics.Counter
	bytesWrittenTotal *metrics.Counter
}

func (sc *statDialConn) Read(p []byte) (int, error) {
	n, err := sc.Conn.Read(p)
	sc.readsTotal.Inc()
	if err != nil {
		sc.readErrorsTotal.Inc()
	}
	sc.bytesReadTotal.Add(n)
	return n, err
}

func (sc *statDialConn) Write(p []byte) (int, error) {
	n, err := sc.Conn.Write(p)
	sc.writesTotal.Inc()
	if err != nil {
		sc.writeErrorsTotal.Inc()
	}
	sc.bytesWrittenTotal.Add(n)
	return n, err
}

func (sc *statDialConn) Close() error {
	err := sc.Conn.Close()
	if sc.closed.Add(1) == 1 {
		sc.conns.Dec()
	}
	return err
}

func isTCPv4Addr(addr string) bool {
	s := addr
	for i := 0; i < 3; i++ {
		n := strings.IndexByte(s, '.')
		if n < 0 {
			return false
		}
		if !isUint8NumString(s[:n]) {
			return false
		}
		s = s[n+1:]
	}
	n := strings.IndexByte(s, ':')
	if n < 0 {
		return false
	}
	if !isUint8NumString(s[:n]) {
		return false
	}
	s = s[n+1:]

	// Verify TCP port
	n, err := strconv.Atoi(s)
	if err != nil {
		return false
	}
	return n >= 0 && n < (1<<16)
}

func isUint8NumString(s string) bool {
	n, err := strconv.Atoi(s)
	if err != nil {
		return false
	}
	return n >= 0 && n < (1<<8)
}