package netutil

import (
	"bytes"
	"io"
	"net"
	"reflect"
	"testing"
)

func TestParseProxyProtocolSuccess(t *testing.T) {
	f := func(body, wantTail []byte, wantAddr net.Addr) {
		t.Helper()
		r := bytes.NewBuffer(body)
		gotAddr, err := readProxyProto(r)
		if err != nil {
			t.Fatalf("unexpected error: %s", err)
		}
		if !reflect.DeepEqual(gotAddr, wantAddr) {
			t.Fatalf("ip not match, got: %v, want: %v", gotAddr, wantAddr)
		}
		gotTail, err := io.ReadAll(r)
		if err != nil {
			t.Fatalf("cannot read tail: %s", err)
		}
		if !bytes.Equal(gotTail, wantTail) {
			t.Fatalf("unexpected tail after parsing proxy protocol\ngot:\n%q\nwant:\n%q", gotTail, wantTail)
		}
	}
	// LOCAL addr
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x20, 0x11, 0x00, 0x0C,
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0}, nil,
		nil)
	// ipv4
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C,
		// ip data srcid,dstip,srcport,dstport
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0}, nil,
		&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 80})
	// ipv4 with payload
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C,
		// ip data
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0,
		// some payload
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0,
	}, []byte{0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0},
		&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 80})
	// ipv6
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x21, 0x00, 0x24,
		// src and dst ipv6
		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
		// ports
		0, 80, 0, 0}, nil,
		&net.TCPAddr{IP: net.ParseIP("::1"), Port: 80})
}

func TestParseProxyProtocolFail(t *testing.T) {
	f := func(body []byte) {
		t.Helper()
		r := bytes.NewBuffer(body)
		gotAddr, err := readProxyProto(r)
		if err == nil {
			t.Fatalf("expected error at input %v", body)
		}
		if gotAddr != nil {
			t.Fatalf("expected ip to be nil, got: %v", gotAddr)
		}
	}
	// too short protocol prefix
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A})
	// broken protocol prefix
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21})
	// invalid header
	f([]byte{0x0D, 0x1A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C})
	// invalid version
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x31, 0x11, 0x00, 0x0C})
	// too long block
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0xff, 0x0C})
	// missing bytes in address
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C,
		// ip data srcid,dstip,srcport
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80})
	// too short address length
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x08,
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0})
	// unsupported family
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x31, 0x00, 0x0C,
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0})
	// unsupported command
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x22, 0x11, 0x00, 0x0C,
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0})
	// mismatch ipv6 and ipv4
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x21, 0x00, 0x0C,
		// ip data srcid,dstip,srcport
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0})
	// ipv4 udp isn't supported
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x12, 0x00, 0x0C,
		// ip data srcid,dstip,srcport,dstport
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0})
	// ipv6 udp isn't supported
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x22, 0x00, 0x24,
		// src and dst ipv6
		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
		// ports
		0, 80, 0, 0})
}