Loading gitlaboauth2/callback_server.go +62 −7 Original line number Diff line number Diff line Loading @@ -5,6 +5,7 @@ import ( _ "embed" "errors" "fmt" "net" "net/http" "net/url" "strings" Loading Loading @@ -81,13 +82,41 @@ type BrowserFunc func(url string) error // } // fmt.Printf("Access token: %s\n", token.AccessToken) func NewCallbackServer(config *oauth2.Config, addr string, browser BrowserFunc) *CallbackServer { cfg := *config return &CallbackServer{ config: config, config: &cfg, addr: addr, browser: browser, } } // NewLocalCallbackServer creates a new callback server that listens on localhost with an ephemeral port. // // Parameters: // - config: The OAuth2 configuration created with NewOAuth2Config // - browser: A function that opens URLs in the user's browser // // Returns: // - *CallbackServer: A configured callback server ready to handle OAuth2 flow on localhost with an ephemeral port // // Example usage: // // config := gitlaboauth2.NewOAuth2Config("", "client-id", []string{"read_user"}) // browserFunc := func(url string) error { // return exec.Command("open", url).Start() // macOS // } // server := gitlaboauth2.NewLocalCallbackServer(config, browserFunc) // // ctx := context.Background() // token, err := server.GetToken(ctx) // if err != nil { // log.Fatal(err) // } // fmt.Printf("Access token: %s\n", token.AccessToken) func NewLocalCallbackServer(config *oauth2.Config, browser BrowserFunc) *CallbackServer { return NewCallbackServer(config, "127.0.0.1:0", browser) } // GetToken performs the complete OAuth2 flow and returns an access token. // // This method orchestrates the entire OAuth2 authentication flow by: Loading Loading @@ -140,21 +169,39 @@ func (s *CallbackServer) GetToken(ctx context.Context) (*oauth2.Token, error) { errorChan := make(chan error, 1) defer close(errorChan) state := oauth2.GenerateVerifier() verifier := oauth2.GenerateVerifier() authURL := s.config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) u, err := url.Parse(s.config.RedirectURL) if err != nil { return nil, err } // Create the listener first so we know the actual bound port before // generating the auth URL. Passing port 0 (e.g. addr ":0") lets the OS // assign a free ephemeral port. ln, err := net.Listen("tcp", s.addr) if err != nil { return nil, fmt.Errorf("server failed: %w", err) } // When the redirect URL contains port 0, replace it with the actual port // that was assigned by the OS. if u.Port() == "0" { actualPort, err := listenerPort(ln) if err != nil { return nil, fmt.Errorf("failed to get listener port: %w", err) } u.Host = fmt.Sprintf("%s:%d", u.Hostname(), actualPort) s.config.RedirectURL = u.String() } state := oauth2.GenerateVerifier() verifier := oauth2.GenerateVerifier() authURL := s.config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) // Set up HTTP server mux := http.NewServeMux() mux.HandleFunc("/"+strings.TrimPrefix(u.Path, "/"), s.callbackHandler(ctx, tokenChan, errorChan, state, verifier)) s.server = &http.Server{ Addr: s.addr, Handler: mux, ReadTimeout: 15 * time.Second, WriteTimeout: 15 * time.Second, Loading @@ -166,7 +213,7 @@ func (s *CallbackServer) GetToken(ctx context.Context) (*oauth2.Token, error) { // Start server wg.Go(func() { if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := s.server.Serve(ln); err != nil && err != http.ErrServerClosed { errorChan <- fmt.Errorf("server failed: %w", err) } }) Loading Loading @@ -194,6 +241,14 @@ func (s *CallbackServer) GetToken(ctx context.Context) (*oauth2.Token, error) { return token, nil } func listenerPort(ln net.Listener) (int, error) { tcpAddr, ok := ln.Addr().(*net.TCPAddr) if !ok { return 0, fmt.Errorf("listener is not TCP: %T", ln.Addr()) } return tcpAddr.Port, nil } func (s *CallbackServer) callbackHandler(ctx context.Context, tokenChan chan *oauth2.Token, errorChan chan error, expectedState, verifier string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Check for errors Loading gitlaboauth2/callback_server_test.go +62 −0 Original line number Diff line number Diff line Loading @@ -241,6 +241,68 @@ func TestCallbackServer_GetToken(t *testing.T) { //nolint:paralleltest assert.Nil(t, token) assert.Contains(t, err.Error(), "server failed") }) t.Run("ephemeral port assigned by OS", func(t *testing.T) { //nolint:paralleltest // GIVEN a mock OAuth2 provider mockProvider := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/oauth/token" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) fmt.Fprint(w, `{ "access_token": "ephemeral-access-token", "token_type": "Bearer", "expires_in": 3600 }`) } })) defer mockProvider.Close() // GIVEN a config with port 0 in the redirect URL so the OS picks the port config := &oauth2.Config{ ClientID: "test-client-id", RedirectURL: "http://localhost:0/auth/redirect", Endpoint: oauth2.Endpoint{ TokenURL: mockProvider.URL + "/oauth/token", }, Scopes: []string{"read_user"}, } var capturedURL string browserFunc := func(authURL string) error { capturedURL = authURL // WHEN the browser opens, simulate the user completing the OAuth flow go func() { parsedURL, err := url.Parse(capturedURL) if err != nil { return } state := parsedURL.Query().Get("state") redirectURI := parsedURL.Query().Get("redirect_uri") callbackURL := fmt.Sprintf("%s?code=test-code&state=%s", redirectURI, state) client := &http.Client{Timeout: 5 * time.Second} resp, err := client.Get(callbackURL) if err == nil { resp.Body.Close() } }() return nil } // GIVEN a local callback server with an ephemeral port server := NewLocalCallbackServer(config, browserFunc) ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) defer cancel() // WHEN GetToken is called token, err := server.GetToken(ctx) // THEN a token is returned require.NoError(t, err) assert.NotNil(t, token) assert.Equal(t, "ephemeral-access-token", token.AccessToken) }) } func TestCallbackServer_CallbackHandler(t *testing.T) { //nolint:paralleltest Loading Loading
gitlaboauth2/callback_server.go +62 −7 Original line number Diff line number Diff line Loading @@ -5,6 +5,7 @@ import ( _ "embed" "errors" "fmt" "net" "net/http" "net/url" "strings" Loading Loading @@ -81,13 +82,41 @@ type BrowserFunc func(url string) error // } // fmt.Printf("Access token: %s\n", token.AccessToken) func NewCallbackServer(config *oauth2.Config, addr string, browser BrowserFunc) *CallbackServer { cfg := *config return &CallbackServer{ config: config, config: &cfg, addr: addr, browser: browser, } } // NewLocalCallbackServer creates a new callback server that listens on localhost with an ephemeral port. // // Parameters: // - config: The OAuth2 configuration created with NewOAuth2Config // - browser: A function that opens URLs in the user's browser // // Returns: // - *CallbackServer: A configured callback server ready to handle OAuth2 flow on localhost with an ephemeral port // // Example usage: // // config := gitlaboauth2.NewOAuth2Config("", "client-id", []string{"read_user"}) // browserFunc := func(url string) error { // return exec.Command("open", url).Start() // macOS // } // server := gitlaboauth2.NewLocalCallbackServer(config, browserFunc) // // ctx := context.Background() // token, err := server.GetToken(ctx) // if err != nil { // log.Fatal(err) // } // fmt.Printf("Access token: %s\n", token.AccessToken) func NewLocalCallbackServer(config *oauth2.Config, browser BrowserFunc) *CallbackServer { return NewCallbackServer(config, "127.0.0.1:0", browser) } // GetToken performs the complete OAuth2 flow and returns an access token. // // This method orchestrates the entire OAuth2 authentication flow by: Loading Loading @@ -140,21 +169,39 @@ func (s *CallbackServer) GetToken(ctx context.Context) (*oauth2.Token, error) { errorChan := make(chan error, 1) defer close(errorChan) state := oauth2.GenerateVerifier() verifier := oauth2.GenerateVerifier() authURL := s.config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) u, err := url.Parse(s.config.RedirectURL) if err != nil { return nil, err } // Create the listener first so we know the actual bound port before // generating the auth URL. Passing port 0 (e.g. addr ":0") lets the OS // assign a free ephemeral port. ln, err := net.Listen("tcp", s.addr) if err != nil { return nil, fmt.Errorf("server failed: %w", err) } // When the redirect URL contains port 0, replace it with the actual port // that was assigned by the OS. if u.Port() == "0" { actualPort, err := listenerPort(ln) if err != nil { return nil, fmt.Errorf("failed to get listener port: %w", err) } u.Host = fmt.Sprintf("%s:%d", u.Hostname(), actualPort) s.config.RedirectURL = u.String() } state := oauth2.GenerateVerifier() verifier := oauth2.GenerateVerifier() authURL := s.config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) // Set up HTTP server mux := http.NewServeMux() mux.HandleFunc("/"+strings.TrimPrefix(u.Path, "/"), s.callbackHandler(ctx, tokenChan, errorChan, state, verifier)) s.server = &http.Server{ Addr: s.addr, Handler: mux, ReadTimeout: 15 * time.Second, WriteTimeout: 15 * time.Second, Loading @@ -166,7 +213,7 @@ func (s *CallbackServer) GetToken(ctx context.Context) (*oauth2.Token, error) { // Start server wg.Go(func() { if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := s.server.Serve(ln); err != nil && err != http.ErrServerClosed { errorChan <- fmt.Errorf("server failed: %w", err) } }) Loading Loading @@ -194,6 +241,14 @@ func (s *CallbackServer) GetToken(ctx context.Context) (*oauth2.Token, error) { return token, nil } func listenerPort(ln net.Listener) (int, error) { tcpAddr, ok := ln.Addr().(*net.TCPAddr) if !ok { return 0, fmt.Errorf("listener is not TCP: %T", ln.Addr()) } return tcpAddr.Port, nil } func (s *CallbackServer) callbackHandler(ctx context.Context, tokenChan chan *oauth2.Token, errorChan chan error, expectedState, verifier string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Check for errors Loading
gitlaboauth2/callback_server_test.go +62 −0 Original line number Diff line number Diff line Loading @@ -241,6 +241,68 @@ func TestCallbackServer_GetToken(t *testing.T) { //nolint:paralleltest assert.Nil(t, token) assert.Contains(t, err.Error(), "server failed") }) t.Run("ephemeral port assigned by OS", func(t *testing.T) { //nolint:paralleltest // GIVEN a mock OAuth2 provider mockProvider := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/oauth/token" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) fmt.Fprint(w, `{ "access_token": "ephemeral-access-token", "token_type": "Bearer", "expires_in": 3600 }`) } })) defer mockProvider.Close() // GIVEN a config with port 0 in the redirect URL so the OS picks the port config := &oauth2.Config{ ClientID: "test-client-id", RedirectURL: "http://localhost:0/auth/redirect", Endpoint: oauth2.Endpoint{ TokenURL: mockProvider.URL + "/oauth/token", }, Scopes: []string{"read_user"}, } var capturedURL string browserFunc := func(authURL string) error { capturedURL = authURL // WHEN the browser opens, simulate the user completing the OAuth flow go func() { parsedURL, err := url.Parse(capturedURL) if err != nil { return } state := parsedURL.Query().Get("state") redirectURI := parsedURL.Query().Get("redirect_uri") callbackURL := fmt.Sprintf("%s?code=test-code&state=%s", redirectURI, state) client := &http.Client{Timeout: 5 * time.Second} resp, err := client.Get(callbackURL) if err == nil { resp.Body.Close() } }() return nil } // GIVEN a local callback server with an ephemeral port server := NewLocalCallbackServer(config, browserFunc) ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) defer cancel() // WHEN GetToken is called token, err := server.GetToken(ctx) // THEN a token is returned require.NoError(t, err) assert.NotNil(t, token) assert.Equal(t, "ephemeral-access-token", token.AccessToken) }) } func TestCallbackServer_CallbackHandler(t *testing.T) { //nolint:paralleltest Loading