Verified Commit c8c388d5 authored by Raphael Rösch's avatar Raphael Rösch Committed by GitLab
Browse files

feat(gitlaboauth2): support ephemeral ports in CallbackServer

Changelog: Improvements
parent 20c15357
Loading
Loading
Loading
Loading
+62 −7
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ import (
	_ "embed"
	"errors"
	"fmt"
	"net"
	"net/http"
	"net/url"
	"strings"
@@ -81,13 +82,41 @@ type BrowserFunc func(url string) error
//	}
//	fmt.Printf("Access token: %s\n", token.AccessToken)
func NewCallbackServer(config *oauth2.Config, addr string, browser BrowserFunc) *CallbackServer {
	cfg := *config
	return &CallbackServer{
		config:  config,
		config:  &cfg,
		addr:    addr,
		browser: browser,
	}
}

// NewLocalCallbackServer creates a new callback server that listens on localhost with an ephemeral port.
//
// Parameters:
//   - config: The OAuth2 configuration created with NewOAuth2Config
//   - browser: A function that opens URLs in the user's browser
//
// Returns:
//   - *CallbackServer: A configured callback server ready to handle OAuth2 flow on localhost with an ephemeral port
//
// Example usage:
//
//	config := gitlaboauth2.NewOAuth2Config("", "client-id", []string{"read_user"})
//	browserFunc := func(url string) error {
//		return exec.Command("open", url).Start() // macOS
//	}
//	server := gitlaboauth2.NewLocalCallbackServer(config, browserFunc)
//
//	ctx := context.Background()
//	token, err := server.GetToken(ctx)
//	if err != nil {
//		log.Fatal(err)
//	}
//	fmt.Printf("Access token: %s\n", token.AccessToken)
func NewLocalCallbackServer(config *oauth2.Config, browser BrowserFunc) *CallbackServer {
	return NewCallbackServer(config, "127.0.0.1:0", browser)
}

// GetToken performs the complete OAuth2 flow and returns an access token.
//
// This method orchestrates the entire OAuth2 authentication flow by:
@@ -140,21 +169,39 @@ func (s *CallbackServer) GetToken(ctx context.Context) (*oauth2.Token, error) {
	errorChan := make(chan error, 1)
	defer close(errorChan)

	state := oauth2.GenerateVerifier()
	verifier := oauth2.GenerateVerifier()
	authURL := s.config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier))

	u, err := url.Parse(s.config.RedirectURL)
	if err != nil {
		return nil, err
	}

	// Create the listener first so we know the actual bound port before
	// generating the auth URL. Passing port 0 (e.g. addr ":0") lets the OS
	// assign a free ephemeral port.
	ln, err := net.Listen("tcp", s.addr)
	if err != nil {
		return nil, fmt.Errorf("server failed: %w", err)
	}

	// When the redirect URL contains port 0, replace it with the actual port
	// that was assigned by the OS.
	if u.Port() == "0" {
		actualPort, err := listenerPort(ln)
		if err != nil {
			return nil, fmt.Errorf("failed to get listener port: %w", err)
		}
		u.Host = fmt.Sprintf("%s:%d", u.Hostname(), actualPort)
		s.config.RedirectURL = u.String()
	}

	state := oauth2.GenerateVerifier()
	verifier := oauth2.GenerateVerifier()
	authURL := s.config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier))

	// Set up HTTP server
	mux := http.NewServeMux()
	mux.HandleFunc("/"+strings.TrimPrefix(u.Path, "/"), s.callbackHandler(ctx, tokenChan, errorChan, state, verifier))

	s.server = &http.Server{
		Addr:         s.addr,
		Handler:      mux,
		ReadTimeout:  15 * time.Second,
		WriteTimeout: 15 * time.Second,
@@ -166,7 +213,7 @@ func (s *CallbackServer) GetToken(ctx context.Context) (*oauth2.Token, error) {

	// Start server
	wg.Go(func() {
		if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
		if err := s.server.Serve(ln); err != nil && err != http.ErrServerClosed {
			errorChan <- fmt.Errorf("server failed: %w", err)
		}
	})
@@ -194,6 +241,14 @@ func (s *CallbackServer) GetToken(ctx context.Context) (*oauth2.Token, error) {
	return token, nil
}

func listenerPort(ln net.Listener) (int, error) {
	tcpAddr, ok := ln.Addr().(*net.TCPAddr)
	if !ok {
		return 0, fmt.Errorf("listener is not TCP: %T", ln.Addr())
	}
	return tcpAddr.Port, nil
}

func (s *CallbackServer) callbackHandler(ctx context.Context, tokenChan chan *oauth2.Token, errorChan chan error, expectedState, verifier string) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		// Check for errors
+62 −0
Original line number Diff line number Diff line
@@ -241,6 +241,68 @@ func TestCallbackServer_GetToken(t *testing.T) { //nolint:paralleltest
		assert.Nil(t, token)
		assert.Contains(t, err.Error(), "server failed")
	})

	t.Run("ephemeral port assigned by OS", func(t *testing.T) { //nolint:paralleltest
		// GIVEN a mock OAuth2 provider
		mockProvider := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			if r.URL.Path == "/oauth/token" {
				w.Header().Set("Content-Type", "application/json")
				w.WriteHeader(http.StatusOK)
				fmt.Fprint(w, `{
					"access_token": "ephemeral-access-token",
					"token_type": "Bearer",
					"expires_in": 3600
				}`)
			}
		}))
		defer mockProvider.Close()

		// GIVEN a config with port 0 in the redirect URL so the OS picks the port
		config := &oauth2.Config{
			ClientID:    "test-client-id",
			RedirectURL: "http://localhost:0/auth/redirect",
			Endpoint: oauth2.Endpoint{
				TokenURL: mockProvider.URL + "/oauth/token",
			},
			Scopes: []string{"read_user"},
		}

		var capturedURL string
		browserFunc := func(authURL string) error {
			capturedURL = authURL
			// WHEN the browser opens, simulate the user completing the OAuth flow
			go func() {
				parsedURL, err := url.Parse(capturedURL)
				if err != nil {
					return
				}
				state := parsedURL.Query().Get("state")
				redirectURI := parsedURL.Query().Get("redirect_uri")

				callbackURL := fmt.Sprintf("%s?code=test-code&state=%s", redirectURI, state)
				client := &http.Client{Timeout: 5 * time.Second}
				resp, err := client.Get(callbackURL)
				if err == nil {
					resp.Body.Close()
				}
			}()
			return nil
		}

		// GIVEN a local callback server with an ephemeral port
		server := NewLocalCallbackServer(config, browserFunc)

		ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
		defer cancel()

		// WHEN GetToken is called
		token, err := server.GetToken(ctx)

		// THEN a token is returned
		require.NoError(t, err)
		assert.NotNil(t, token)
		assert.Equal(t, "ephemeral-access-token", token.AccessToken)
	})
}

func TestCallbackServer_CallbackHandler(t *testing.T) { //nolint:paralleltest