mirror of
https://github.com/VictoriaMetrics/VictoriaMetrics.git
synced 2024-12-24 19:30:06 +01:00
126 lines
3.1 KiB
Go
126 lines
3.1 KiB
Go
|
package netutil
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
"sync/atomic"
|
||
|
|
||
|
"github.com/VictoriaMetrics/metrics"
|
||
|
)
|
||
|
|
||
|
// NewStatDialFunc returns dialer function that supports DNS SRV records and registers stats metrics for conns.
|
||
|
func NewStatDialFunc(metricPrefix string) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||
|
return func(ctx context.Context, _, addr string) (net.Conn, error) {
|
||
|
sc := &statDialConn{
|
||
|
dialsTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_dials_total`, metricPrefix)),
|
||
|
dialErrors: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_dial_errors_total`, metricPrefix)),
|
||
|
conns: metrics.GetOrCreateGauge(fmt.Sprintf(`%s_conns`, metricPrefix), nil),
|
||
|
|
||
|
readsTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_reads_total`, metricPrefix)),
|
||
|
writesTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_writes_total`, metricPrefix)),
|
||
|
readErrorsTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_read_errors_total`, metricPrefix)),
|
||
|
writeErrorsTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_write_errors_total`, metricPrefix)),
|
||
|
bytesReadTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_bytes_read_total`, metricPrefix)),
|
||
|
bytesWrittenTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_bytes_written_total`, metricPrefix)),
|
||
|
}
|
||
|
|
||
|
network := GetTCPNetwork()
|
||
|
conn, err := DialMaybeSRV(ctx, network, addr)
|
||
|
sc.dialsTotal.Inc()
|
||
|
if err != nil {
|
||
|
sc.dialErrors.Inc()
|
||
|
if !TCP6Enabled() && !isTCPv4Addr(addr) {
|
||
|
err = fmt.Errorf("%w; try -enableTCP6 command-line flag for dialing ipv6 addresses", err)
|
||
|
}
|
||
|
return nil, err
|
||
|
}
|
||
|
sc.Conn = conn
|
||
|
sc.conns.Inc()
|
||
|
return sc, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type statDialConn struct {
|
||
|
closed atomic.Int32
|
||
|
net.Conn
|
||
|
|
||
|
dialsTotal *metrics.Counter
|
||
|
dialErrors *metrics.Counter
|
||
|
conns *metrics.Gauge
|
||
|
|
||
|
readsTotal *metrics.Counter
|
||
|
writesTotal *metrics.Counter
|
||
|
readErrorsTotal *metrics.Counter
|
||
|
writeErrorsTotal *metrics.Counter
|
||
|
bytesReadTotal *metrics.Counter
|
||
|
bytesWrittenTotal *metrics.Counter
|
||
|
}
|
||
|
|
||
|
func (sc *statDialConn) Read(p []byte) (int, error) {
|
||
|
n, err := sc.Conn.Read(p)
|
||
|
sc.readsTotal.Inc()
|
||
|
if err != nil {
|
||
|
sc.readErrorsTotal.Inc()
|
||
|
}
|
||
|
sc.bytesReadTotal.Add(n)
|
||
|
return n, err
|
||
|
}
|
||
|
|
||
|
func (sc *statDialConn) Write(p []byte) (int, error) {
|
||
|
n, err := sc.Conn.Write(p)
|
||
|
sc.writesTotal.Inc()
|
||
|
if err != nil {
|
||
|
sc.writeErrorsTotal.Inc()
|
||
|
}
|
||
|
sc.bytesWrittenTotal.Add(n)
|
||
|
return n, err
|
||
|
}
|
||
|
|
||
|
func (sc *statDialConn) Close() error {
|
||
|
err := sc.Conn.Close()
|
||
|
if sc.closed.Add(1) == 1 {
|
||
|
sc.conns.Dec()
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func isTCPv4Addr(addr string) bool {
|
||
|
s := addr
|
||
|
for i := 0; i < 3; i++ {
|
||
|
n := strings.IndexByte(s, '.')
|
||
|
if n < 0 {
|
||
|
return false
|
||
|
}
|
||
|
if !isUint8NumString(s[:n]) {
|
||
|
return false
|
||
|
}
|
||
|
s = s[n+1:]
|
||
|
}
|
||
|
n := strings.IndexByte(s, ':')
|
||
|
if n < 0 {
|
||
|
return false
|
||
|
}
|
||
|
if !isUint8NumString(s[:n]) {
|
||
|
return false
|
||
|
}
|
||
|
s = s[n+1:]
|
||
|
|
||
|
// Verify TCP port
|
||
|
n, err := strconv.Atoi(s)
|
||
|
if err != nil {
|
||
|
return false
|
||
|
}
|
||
|
return n >= 0 && n < (1<<16)
|
||
|
}
|
||
|
|
||
|
func isUint8NumString(s string) bool {
|
||
|
n, err := strconv.Atoi(s)
|
||
|
if err != nil {
|
||
|
return false
|
||
|
}
|
||
|
return n >= 0 && n < (1<<8)
|
||
|
}
|