Commit 9417155f authored by Amer Khaled's avatar Amer Khaled Committed by Patrick Rice
Browse files

feat: improve URL validation and error handling in client initialization

Changelog: Improvements
parent 7a08ced5
Loading
Loading
Loading
Loading
+15 −0
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ package gitlab

import (
	"errors"
	"log/slog"
	"net/http"
	"time"

@@ -150,6 +151,20 @@ func WithUserAgent(userAgent string) ClientOptionFunc {
	}
}

// WithURLWarningLogger sets a custom logger for URL validation warnings.
// By default, warnings are logged using slog.Default().
// Pass slog.New(slog.NewTextHandler(io.Discard, nil)) to disable warnings.
// TODO: Use slog.NewDiscardHandler() when we upgrade to Go 1.25+
func WithURLWarningLogger(logger *slog.Logger) ClientOptionFunc {
	return func(c *Client) error {
		if logger == nil {
			return errors.New("logger cannot be nil, use slog.New(slog.NewTextHandler(io.Discard, nil)) to discard warnings")
		}
		c.urlWarningLogger = logger
		return nil
	}
}

// WithCookieJar can be used to configure a cookie jar.
func WithCookieJar(jar http.CookieJar) ClientOptionFunc {
	return func(c *Client) error {
+80 −5
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ import (
	"errors"
	"fmt"
	"io"
	"log/slog"
	"maps"
	"math"
	"math/rand"
@@ -74,6 +75,21 @@ const (

var ErrNotFound = errors.New("404 Not Found")

// URLValidationError wraps URL parsing errors with helpful context
type URLValidationError struct {
	URL  string
	Err  error
	Hint string
}

func (e *URLValidationError) Error() string {
	msg := fmt.Sprintf("invalid base URL %q: %v", e.URL, e.Err)
	if e.Hint != "" {
		msg += fmt.Sprintf(" (hint: %s)", e.Hint)
	}
	return msg
}

// A Client manages communication with the GitLab API.
type Client struct {
	// HTTP client used to communicate with the API.
@@ -111,6 +127,9 @@ type Client struct {
	// which are used to decorate the http.Client#Transport value.
	interceptors []Interceptor

	// urlWarningLogger is used to print URL validation warnings
	urlWarningLogger *slog.Logger

	// User agent used when communicating with the GitLab API.
	UserAgent string

@@ -381,6 +400,7 @@ func NewAuthSourceClient(as AuthSource, options ...ClientOptionFunc) (*Client, e
	c := &Client{
		UserAgent:        userAgent,
		authSource:       as,
		urlWarningLogger: slog.Default(),
	}

	// Configure the HTTP client.
@@ -798,16 +818,71 @@ func (c *Client) BaseURL() *url.URL {
	return &u
}

// validateBaseURL checks for common real-world mistakes and returns them as errors.
// Returns the parsed URL if validation succeeds.
func validateBaseURL(baseURL string) (*url.URL, error) {
	if baseURL == "" {
		return nil, &URLValidationError{
			URL:  baseURL,
			Err:  errors.New("empty URL"),
			Hint: `provide a valid GitLab instance URL (e.g., "https://gitlab.com")`,
		}
	}

	if !strings.Contains(baseURL, "://") {
		return nil, &URLValidationError{
			URL:  baseURL,
			Err:  errors.New("missing scheme"),
			Hint: fmt.Sprintf(`try "https://%s"`, baseURL),
		}
	}

	parsedURL, err := url.Parse(baseURL)
	if err != nil {
		return nil, &URLValidationError{
			URL: baseURL,
			Err: err,
			Hint: `possible issues:
		  - missing hostname
		  - invalid characters/spaces
		  - invalid port (must be 1-65535)
		  - query parameters (?)
		  - fragments (#)
		  - invalid URL encoding`,
		}
	}

	if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
		return nil, &URLValidationError{
			URL:  baseURL,
			Err:  fmt.Errorf("unsupported scheme %q", parsedURL.Scheme),
			Hint: fmt.Sprintf(`GitLab API requires http or https (try "https://%s")`, parsedURL.Host),
		}
	}

	return parsedURL, nil
}

// setBaseURL sets the base URL for API requests to a custom endpoint.
func (c *Client) setBaseURL(urlStr string) error {
	// Make sure the given URL end with a slash
	// Make sure the given URL ends with a slash
	if !strings.HasSuffix(urlStr, "/") {
		urlStr += "/"
	}

	baseURL, err := url.Parse(urlStr)
	// Validate and parse
	baseURL, err := validateBaseURL(urlStr)
	if err != nil {
		return err
		// Log the validation warning
		c.urlWarningLogger.Warn("URL validation warning", "error", err)

		// Don't return the error - just warn and continue
		// Try to parse anyway as a fallback
		baseURL, err = url.Parse(urlStr)
		if err != nil {
			// If we really can't parse it, we have to give up
			return fmt.Errorf("failed to parse base URL: %w", err)
		}
	}

	if !strings.HasSuffix(baseURL.Path, apiVersionPath) {
+90 −0
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ import (
	"errors"
	"fmt"
	"io"
	"log/slog"
	"net"
	"net/http"
	"net/http/cookiejar"
@@ -1538,3 +1539,92 @@ func TestParseID(t *testing.T) {
		})
	}
}

func TestSetBaseURL(t *testing.T) {
	t.Parallel()

	tests := []struct {
		name        string
		input       string
		wantBaseURL string
	}{
		{
			name:        "valid HTTPS URL",
			input:       "https://gitlab.com",
			wantBaseURL: "https://gitlab.com/api/v4/",
		},
		{
			name:        "valid URL with custom path and port",
			input:       "https://git.company.com:8443/gitlab",
			wantBaseURL: "https://git.company.com:8443/gitlab/api/v4/",
		},
		{
			name:        "URL with trailing slash",
			input:       "https://gitlab.com/",
			wantBaseURL: "https://gitlab.com/api/v4/",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			t.Parallel()

			c := &Client{}
			err := c.setBaseURL(tt.input)

			require.NoError(t, err)
			require.NotNil(t, c.baseURL)
			assert.Equal(t, tt.wantBaseURL, c.baseURL.String())
		})
	}
}

func TestSetBaseURL_ValidationWarnings(t *testing.T) {
	t.Parallel()

	tests := []struct {
		name        string
		input       string
		expectError bool
	}{
		{
			name:        "empty URL logs warning but continues",
			input:       "",
			expectError: false,
		},
		{
			name:        "missing scheme logs warning but continues",
			input:       "gitlab.com",
			expectError: false,
		},
		{
			name:        "wrong scheme logs warning but continues",
			input:       "git://gitlab.com",
			expectError: false,
		},
		{
			name:        "unparseable URL returns error",
			input:       "://invalid",
			expectError: true,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			t.Parallel()

			c := &Client{
				// TODO: Use slog.NewDiscardHandler() when we upgrade to Go 1.25+
				urlWarningLogger: slog.New(slog.NewTextHandler(io.Discard, nil)),
			}

			err := c.setBaseURL(tt.input)

			if tt.expectError {
				assert.Error(t, err)
			} else {
				assert.NoError(t, err)
			}
		})
	}
}