package httpserver import ( "encoding/json" "net/http" "net/http/httptest" "strings" "testing" ) func TestGetQuotedRemoteAddr(t *testing.T) { f := func(remoteAddr, xForwardedFor, expectedAddr string) { t.Helper() req := &http.Request{ RemoteAddr: remoteAddr, } if xForwardedFor != "" { req.Header = map[string][]string{ "X-Forwarded-For": {xForwardedFor}, } } addr := GetQuotedRemoteAddr(req) if addr != expectedAddr { t.Fatalf("unexpected remote addr;\ngot\n%s\nwant\n%s", addr, expectedAddr) } // Verify that the addr can be unmarshaled as JSON string var s string if err := json.Unmarshal([]byte(addr), &s); err != nil { t.Fatalf("cannot unmarshal addr: %s", err) } } f("1.2.3.4", "", `"1.2.3.4"`) f("1.2.3.4", "foo.bar", `"1.2.3.4, X-Forwarded-For: foo.bar"`) f("1.2\n\"3.4", "foo\nb\"ar", `"1.2\n\"3.4, X-Forwarded-For: foo\nb\"ar"`) } func TestBasicAuthMetrics(t *testing.T) { origUsername := *httpAuthUsername origPasswd := httpAuthPassword.Get() defer func() { if err := httpAuthPassword.Set(origPasswd); err != nil { t.Fatalf("unexpected error: %s", err) } *httpAuthUsername = origUsername }() f := func(user, pass string, expCode int) { t.Helper() req := httptest.NewRequest(http.MethodGet, "/metrics", nil) req.SetBasicAuth(user, pass) w := httptest.NewRecorder() CheckBasicAuth(w, req) res := w.Result() _ = res.Body.Close() if expCode != res.StatusCode { t.Fatalf("wanted status code: %d, got: %d\n", res.StatusCode, expCode) } } *httpAuthUsername = "test" if err := httpAuthPassword.Set("pass"); err != nil { t.Fatalf("unexpected error: %s", err) } f("test", "pass", 200) f("test", "wrong", 401) f("wrong", "pass", 401) f("wrong", "wrong", 401) *httpAuthUsername = "" if err := httpAuthPassword.Set(""); err != nil { t.Fatalf("unexpected error: %s", err) } f("test", "pass", 200) f("test", "wrong", 200) f("wrong", "pass", 200) f("wrong", "wrong", 200) } func TestAuthKeyMetrics(t *testing.T) { origUsername := *httpAuthUsername origPasswd := httpAuthPassword.Get() defer func() { if err := httpAuthPassword.Set(origPasswd); err != nil { t.Fatalf("unexpected error: %s", err) } *httpAuthUsername = origUsername }() tstWithAuthKey := func(key string, expCode int) { t.Helper() req := httptest.NewRequest(http.MethodPost, "/metrics", strings.NewReader("authKey="+key)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded;param=value") w := httptest.NewRecorder() CheckAuthFlag(w, req, "rightKey", "metricsAuthkey") res := w.Result() defer res.Body.Close() if expCode != res.StatusCode { t.Fatalf("Unexpected status code: %d, Expected code is: %d\n", res.StatusCode, expCode) } } tstWithAuthKey("rightKey", 200) tstWithAuthKey("wrongKey", 401) tstWithOutAuthKey := func(user, pass string, expCode int) { t.Helper() req := httptest.NewRequest(http.MethodGet, "/metrics", nil) req.SetBasicAuth(user, pass) w := httptest.NewRecorder() CheckAuthFlag(w, req, "", "metricsAuthkey") res := w.Result() _ = res.Body.Close() if expCode != res.StatusCode { t.Fatalf("wanted status code: %d, got: %d\n", res.StatusCode, expCode) } } *httpAuthUsername = "test" if err := httpAuthPassword.Set("pass"); err != nil { t.Fatalf("unexpected error: %s", err) } tstWithOutAuthKey("test", "pass", 200) tstWithOutAuthKey("test", "wrong", 401) tstWithOutAuthKey("wrong", "pass", 401) tstWithOutAuthKey("wrong", "wrong", 401) } func TestHandlerWrapper(t *testing.T) { const hstsHeader = "foo" const frameOptionsHeader = "bar" const cspHeader = "baz" *headerHSTS = hstsHeader *headerFrameOptions = frameOptionsHeader *headerCSP = cspHeader defer func() { *headerHSTS = "" *headerFrameOptions = "" *headerCSP = "" }() req, _ := http.NewRequest("GET", "/health", nil) srv := &server{s: &http.Server{}} w := &httptest.ResponseRecorder{} handlerWrapper(srv, w, req, func(_ http.ResponseWriter, _ *http.Request) bool { return true }) h := w.Header() if got := h.Get("Strict-Transport-Security"); got != hstsHeader { t.Fatalf("unexpected HSTS header; got %q; want %q", got, hstsHeader) } if got := h.Get("X-Frame-Options"); got != frameOptionsHeader { t.Fatalf("unexpected X-Frame-Options header; got %q; want %q", got, frameOptionsHeader) } if got := h.Get("Content-Security-Policy"); got != cspHeader { t.Fatalf("unexpected CSP header; got %q; want %q", got, cspHeader) } }