Skip to content
Snippets Groups Projects

Redirect user based on DefaultDomainRedirect for the given Pages project

Merged Naman Jagdish Gala requested to merge ngala/default-domain-redirect into master
All threads resolved!
package defaultdomainredirect
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestMiddlewareRedirectsToDefaultDomain(t *testing.T) {
finalHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := NewMiddleware(finalHandler)
tests := []struct {
name string
requestHost string
requestPath string
defaultDomain string
prefix string
expectedStatusCode int
expectedLocation string
}{
{
name: "redirect to default domain",
requestHost: "example.com",
requestPath: "/path/to/resource",
defaultDomain: "https://default.example.com",
expectedStatusCode: http.StatusPermanentRedirect,
expectedLocation: "https://default.example.com/path/to/resource",
},
{
name: "no redirect if already on default domain",
requestHost: "default.example.com",
requestPath: "/path/to/resource",
defaultDomain: "https://default.example.com",
expectedStatusCode: http.StatusOK,
},
{
name: "no redirect if defaultDomainRedirect is empty",
requestHost: "example.com",
requestPath: "/path/to/resource",
defaultDomain: "",
expectedStatusCode: http.StatusOK,
},
{
name: "invalid default domain",
requestHost: "example.com",
requestPath: "/path/to/resource",
defaultDomain: "://invalid-url",
expectedStatusCode: http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
getDefaultDomainRedirectAndPrefixFunc = func(r *http.Request) (string, string) {
return tc.defaultDomain, ""
}
r := httptest.NewRequest(http.MethodGet, "http://"+tc.requestHost+tc.requestPath, nil)
r.Host = tc.requestHost
rec := httptest.NewRecorder()
middleware.ServeHTTP(rec, r)
require.Equal(t, tc.expectedStatusCode, rec.Code)
if tc.expectedStatusCode == http.StatusPermanentRedirect {
require.Equal(t, tc.expectedLocation, rec.Header().Get("Location"))
} else {
require.Empty(t, rec.Header().Get("Location"))
}
})
}
}
Loading