// // cors_test.go // Copyright (C) 2022 tiglog // // Distributed under terms of the MIT license. // package cors_test import ( "context" "net/http" "net/http/httptest" "testing" "git.hexq.cn/tiglog/golib/gtest" "git.hexq.cn/tiglog/golib/gweb/cors" "github.com/gin-gonic/gin" ) func newTestRouter(origins []string) *gin.Engine { gin.SetMode(gin.TestMode) router := gin.New() router.Use(cors.NewCors(origins)) router.GET("/", func(c *gin.Context) { c.String(http.StatusOK, "get") }) router.POST("/", func(c *gin.Context) { c.String(http.StatusOK, "post") }) router.PATCH("/", func(c *gin.Context) { c.String(http.StatusOK, "patch") }) return router } func performRequest(r http.Handler, method, origin string) *httptest.ResponseRecorder { return performRequestWithHeaders(r, method, origin, http.Header{}) } func performRequestWithHeaders(r http.Handler, method, origin string, header http.Header) *httptest.ResponseRecorder { req, _ := http.NewRequestWithContext(context.Background(), method, "/", nil) // From go/net/http/request.go: // For incoming requests, the Host header is promoted to the // Request.Host field and removed from the Header map. req.Host = header.Get("Host") header.Del("Host") if len(origin) > 0 { header.Set("Origin", origin) } req.Header = header w := httptest.NewRecorder() r.ServeHTTP(w, req) return w } func TestPassesAllowOrigins(t *testing.T) { router := newTestRouter([]string{"http://google.com"}) // no CORS request, origin == "" w := performRequest(router, "GET", "") gtest.Equal(t, "get", w.Body.String()) gtest.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) gtest.Empty(t, w.Header().Get("Access-Control-Allow-Credentials")) gtest.Empty(t, w.Header().Get("Access-Control-Expose-Headers")) // no CORS request, origin == host h := http.Header{} h.Set("Host", "facebook.com") w = performRequestWithHeaders(router, "GET", "http://facebook.com", h) gtest.Equal(t, "get", w.Body.String()) gtest.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) gtest.Empty(t, w.Header().Get("Access-Control-Allow-Credentials")) gtest.Empty(t, w.Header().Get("Access-Control-Expose-Headers")) // allowed CORS request w = performRequest(router, "GET", "http://google.com") gtest.Equal(t, "get", w.Body.String()) gtest.Equal(t, "http://google.com", w.Header().Get("Access-Control-Allow-Origin")) gtest.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials")) // deny CORS request w = performRequest(router, "GET", "https://google.com") // gtest.Equal(t, http.StatusForbidden, w.Code) gtest.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) gtest.Empty(t, w.Header().Get("Access-Control-Allow-Credentials")) gtest.Empty(t, w.Header().Get("Access-Control-Expose-Headers")) // allowed CORS prefligh request w = performRequest(router, "OPTIONS", "http://google.com") gtest.Equal(t, http.StatusNoContent, w.Code) gtest.Equal(t, "http://google.com", w.Header().Get("Access-Control-Allow-Origin")) gtest.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials")) // gtest.Equal(t, "GET,POST,PUT,HEAD", w.Header().Get("Access-Control-Allow-Methods")) // deny CORS prefligh request w = performRequest(router, "OPTIONS", "http://example.com") // gtest.Equal(t, http.StatusForbidden, w.Code) gtest.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) gtest.Empty(t, w.Header().Get("Access-Control-Allow-Credentials")) gtest.Empty(t, w.Header().Get("Access-Control-Allow-Methods")) gtest.Empty(t, w.Header().Get("Access-Control-Allow-Headers")) gtest.Empty(t, w.Header().Get("Access-Control-Max-Age")) }