diff --git a/collector/fixtures/proc/net/tcpstat b/collector/fixtures/proc/net/tcpstat deleted file mode 100644 index 352c00bb..00000000 --- a/collector/fixtures/proc/net/tcpstat +++ /dev/null @@ -1,3 +0,0 @@ - sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode - 0: 00000000:0016 00000000:0000 0A 00000015:00000000 00:00000000 00000000 0 0 2740 1 ffff88003d3af3c0 100 0 0 10 0 - 1: 0F02000A:0016 0202000A:8B6B 01 00000015:00000001 02:000AC99B 00000000 0 0 3652 4 ffff88003d3ae040 21 4 31 47 46 diff --git a/collector/tcpstat_linux.go b/collector/tcpstat_linux.go index 47c3f3e0..99e33bc6 100644 --- a/collector/tcpstat_linux.go +++ b/collector/tcpstat_linux.go @@ -18,13 +18,12 @@ package collector import ( "fmt" - "io" - "io/ioutil" "os" - "strconv" - "strings" + "syscall" + "unsafe" "github.com/go-kit/log" + "github.com/mdlayher/netlink" "github.com/prometheus/client_golang/prometheus" ) @@ -80,16 +79,64 @@ func NewTCPStatCollector(logger log.Logger) (Collector, error) { }, nil } +// InetDiagSockID (inet_diag_sockid) contains the socket identity. +// https://github.com/torvalds/linux/blob/v4.0/include/uapi/linux/inet_diag.h#L13 +type InetDiagSockID struct { + SourcePort [2]byte + DestPort [2]byte + SourceIP [4][4]byte + DestIP [4][4]byte + Interface uint32 + Cookie [2]uint32 +} + +// InetDiagReqV2 (inet_diag_req_v2) is used to request diagnostic data. +// https://github.com/torvalds/linux/blob/v4.0/include/uapi/linux/inet_diag.h#L37 +type InetDiagReqV2 struct { + Family uint8 + Protocol uint8 + Ext uint8 + Pad uint8 + States uint32 + ID InetDiagSockID +} + +const sizeOfDiagRequest = 0x38 + +func (req *InetDiagReqV2) Serialize() []byte { + return (*(*[sizeOfDiagRequest]byte)(unsafe.Pointer(req)))[:] +} + +func (req *InetDiagReqV2) Len() int { + return sizeOfDiagRequest +} + +type InetDiagMsg struct { + Family uint8 + State uint8 + Timer uint8 + Retrans uint8 + ID InetDiagSockID + Expires uint32 + RQueue uint32 + WQueue uint32 + UID uint32 + Inode uint32 +} + +func parseInetDiagMsg(b []byte) *InetDiagMsg { + return (*InetDiagMsg)(unsafe.Pointer(&b[0])) +} + func (c *tcpStatCollector) Update(ch chan<- prometheus.Metric) error { - tcpStats, err := getTCPStats(procFilePath("net/tcp")) + tcpStats, err := getTCPStats(syscall.AF_INET) if err != nil { return fmt.Errorf("couldn't get tcpstats: %w", err) } // if enabled ipv6 system - tcp6File := procFilePath("net/tcp6") - if _, hasIPv6 := os.Stat(tcp6File); hasIPv6 == nil { - tcp6Stats, err := getTCPStats(tcp6File) + if _, hasIPv6 := os.Stat(procFilePath("net/tcp6")); hasIPv6 == nil { + tcp6Stats, err := getTCPStats(syscall.AF_INET6) if err != nil { return fmt.Errorf("couldn't get tcp6stats: %w", err) } @@ -102,59 +149,51 @@ func (c *tcpStatCollector) Update(ch chan<- prometheus.Metric) error { for st, value := range tcpStats { ch <- c.desc.mustNewConstMetric(value, st.String()) } + return nil } -func getTCPStats(statsFile string) (map[tcpConnectionState]float64, error) { - file, err := os.Open(statsFile) +func getTCPStats(family uint8) (map[tcpConnectionState]float64, error) { + const TCPFAll = 0xFFF + const InetDiagInfo = 2 + const SockDiagByFamily = 20 + + conn, err := netlink.Dial(syscall.NETLINK_INET_DIAG, nil) + if err != nil { + return nil, fmt.Errorf("couldn't connect netlink: %w", err) + } + defer conn.Close() + + msg := netlink.Message{ + Header: netlink.Header{ + Type: SockDiagByFamily, + Flags: syscall.NLM_F_REQUEST | syscall.NLM_F_DUMP, + }, + Data: (&InetDiagReqV2{ + Family: family, + Protocol: syscall.IPPROTO_TCP, + States: TCPFAll, + Ext: 0 | 1<<(InetDiagInfo-1), + }).Serialize(), + } + + messages, err := conn.Execute(msg) if err != nil { return nil, err } - defer file.Close() - return parseTCPStats(file) + return parseTCPStats(messages) } -func parseTCPStats(r io.Reader) (map[tcpConnectionState]float64, error) { +func parseTCPStats(msgs []netlink.Message) (map[tcpConnectionState]float64, error) { tcpStats := map[tcpConnectionState]float64{} - contents, err := ioutil.ReadAll(r) - if err != nil { - return nil, err - } - for _, line := range strings.Split(string(contents), "\n")[1:] { - parts := strings.Fields(line) - if len(parts) == 0 { - continue - } - if len(parts) < 5 { - return nil, fmt.Errorf("invalid TCP stats line: %q", line) - } - - qu := strings.Split(parts[4], ":") - if len(qu) < 2 { - return nil, fmt.Errorf("cannot parse tx_queues and rx_queues: %q", line) - } - - tx, err := strconv.ParseUint(qu[0], 16, 64) - if err != nil { - return nil, err - } - tcpStats[tcpConnectionState(tcpTxQueuedBytes)] += float64(tx) - - rx, err := strconv.ParseUint(qu[1], 16, 64) - if err != nil { - return nil, err - } - tcpStats[tcpConnectionState(tcpRxQueuedBytes)] += float64(rx) - - st, err := strconv.ParseInt(parts[3], 16, 8) - if err != nil { - return nil, err - } - - tcpStats[tcpConnectionState(st)]++ + for _, m := range msgs { + msg := parseInetDiagMsg(m.Data) + tcpStats[tcpTxQueuedBytes] += float64(msg.WQueue) + tcpStats[tcpRxQueuedBytes] += float64(msg.RQueue) + tcpStats[tcpConnectionState(msg.State)]++ } return tcpStats, nil diff --git a/collector/tcpstat_linux_test.go b/collector/tcpstat_linux_test.go index b609b846..37dc1eee 100644 --- a/collector/tcpstat_linux_test.go +++ b/collector/tcpstat_linux_test.go @@ -14,66 +14,56 @@ package collector import ( - "os" - "strings" + "bytes" + "encoding/binary" + "syscall" "testing" + + "github.com/mdlayher/netlink" ) -func Test_parseTCPStatsError(t *testing.T) { - tests := []struct { - name string - in string - }{ - { - name: "too few fields", - in: "sl local_address\n 0: 00000000:0016", - }, - { - name: "missing colon in tx-rx field", - in: "sl local_address rem_address st tx_queue rx_queue\n" + - " 1: 0F02000A:0016 0202000A:8B6B 01 0000000000000001", - }, - { - name: "tx parsing issue", - in: "sl local_address rem_address st tx_queue rx_queue\n" + - " 1: 0F02000A:0016 0202000A:8B6B 01 0000000x:00000001", - }, - { - name: "rx parsing issue", - in: "sl local_address rem_address st tx_queue rx_queue\n" + - " 1: 0F02000A:0016 0202000A:8B6B 01 00000000:0000000x", - }, - { - name: "state parsing issue", - in: "sl local_address rem_address st tx_queue rx_queue\n" + - " 1: 0F02000A:0016 0202000A:8B6B 0H 00000000:00000001", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if _, err := parseTCPStats(strings.NewReader(tt.in)); err == nil { - t.Fatal("expected an error, but none occurred") - } - }) - } -} - -func TestTCPStat(t *testing.T) { - - noFile, _ := os.Open("follow the white rabbit") - defer noFile.Close() - - if _, err := parseTCPStats(noFile); err == nil { - t.Fatal("expected an error, but none occurred") +func Test_parseTCPStats(t *testing.T) { + encode := func(m InetDiagMsg) []byte { + var buf bytes.Buffer + err := binary.Write(&buf, binary.LittleEndian, m) + if err != nil { + panic(err) + } + return buf.Bytes() } - file, err := os.Open("fixtures/proc/net/tcpstat") - if err != nil { - t.Fatal(err) + msg := []netlink.Message{ + { + Data: encode(InetDiagMsg{ + Family: syscall.AF_INET, + State: uint8(tcpEstablished), + Timer: 0, + Retrans: 0, + ID: InetDiagSockID{}, + Expires: 0, + RQueue: 11, + WQueue: 21, + UID: 0, + Inode: 0, + }), + }, + { + Data: encode(InetDiagMsg{ + Family: syscall.AF_INET, + State: uint8(tcpListen), + Timer: 0, + Retrans: 0, + ID: InetDiagSockID{}, + Expires: 0, + RQueue: 11, + WQueue: 21, + UID: 0, + Inode: 0, + }), + }, } - defer file.Close() - tcpStats, err := parseTCPStats(file) + tcpStats, err := parseTCPStats(msg) if err != nil { t.Fatal(err) } @@ -89,35 +79,8 @@ func TestTCPStat(t *testing.T) { if want, got := 42, int(tcpStats[tcpTxQueuedBytes]); want != got { t.Errorf("want tcpstat number of bytes in tx queue %d, got %d", want, got) } - if want, got := 1, int(tcpStats[tcpRxQueuedBytes]); want != got { + if want, got := 22, int(tcpStats[tcpRxQueuedBytes]); want != got { t.Errorf("want tcpstat number of bytes in rx queue %d, got %d", want, got) } } - -func Test_getTCPStats(t *testing.T) { - type args struct { - statsFile string - } - tests := []struct { - name string - args args - wantErr bool - }{ - { - name: "file not found", - args: args{statsFile: "somewhere over the rainbow"}, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := getTCPStats(tt.args.statsFile) - if (err != nil) != tt.wantErr { - t.Errorf("getTCPStats() error = %v, wantErr %v", err, tt.wantErr) - return - } - // other cases are covered by TestTCPStat() - }) - } -} diff --git a/go.mod b/go.mod index a78dc8ff..d1be25f0 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/jsimonetti/rtnetlink v1.1.1 github.com/lufia/iostat v1.2.1 github.com/mattn/go-xmlrpc v0.0.3 + github.com/mdlayher/netlink v1.6.0 github.com/mdlayher/wifi v0.0.0-20220320220353-954ff73a19a5 github.com/prometheus/client_golang v1.12.1 github.com/prometheus/client_model v0.2.0