Verified Commit 83615e91 authored by Elliot Forbes's avatar Elliot Forbes 2️⃣ Committed by GitLab
Browse files

feat(v2/httpserver): merge full access logger from v2/http into httpserver

parent 78591307
Loading
Loading
Loading
Loading
+22 −0
Original line number Diff line number Diff line
// Package correlation provides correlation ID propagation via context.
package correlation

import "context"

type ctxKey int

const keyCorrelationID ctxKey = iota

// ExtractFromContext returns the correlation ID stored in ctx, or an empty
// string if none is present.
func ExtractFromContext(ctx context.Context) string {
	if v, ok := ctx.Value(keyCorrelationID).(string); ok {
		return v
	}
	return ""
}

// InjectToContext returns a copy of ctx with the correlation ID set.
func InjectToContext(ctx context.Context, id string) context.Context {
	return context.WithValue(ctx, keyCorrelationID, id)
}
+28 −0
Original line number Diff line number Diff line
package correlation

import (
	"context"
	"testing"

	"github.com/stretchr/testify/assert"
)

func TestExtractFromContext_returnsID(t *testing.T) {
	ctx := InjectToContext(context.Background(), "abc-123")
	assert.Equal(t, "abc-123", ExtractFromContext(ctx))
}

func TestExtractFromContext_emptyWhenMissing(t *testing.T) {
	assert.Equal(t, "", ExtractFromContext(context.Background()))
}

func TestExtractFromContext_emptyStringIsValid(t *testing.T) {
	ctx := InjectToContext(context.Background(), "")
	assert.Equal(t, "", ExtractFromContext(ctx))
}

func TestInjectToContext_overwritesPrevious(t *testing.T) {
	ctx := InjectToContext(context.Background(), "first")
	ctx = InjectToContext(ctx, "second")
	assert.Equal(t, "second", ExtractFromContext(ctx))
}
+40 −0
Original line number Diff line number Diff line
@@ -55,6 +55,46 @@ const (
	// (e.g. "0.0.0.0:8080" or "127.0.0.1:9090").
	TCPAddress = "tcp_address"

	// HTTPURI - a string field that captures the request URI of an HTTP
	// request, including the path and query string with sensitive parameters
	// masked (e.g. "?password=[FILTERED]"). Use [HTTPURL] for clean URLs
	// without query strings.
	HTTPURI = "uri"

	// HTTPHost - a string field that captures the HTTP Host header of a
	// request (e.g. "api.gitlab.com").
	HTTPHost = "host"

	// HTTPProto - a string field that captures the HTTP protocol version of a
	// request (e.g. "HTTP/1.1", "HTTP/2.0").
	HTTPProto = "proto"

	// RemoteAddr - a string field that captures the raw remote socket address
	// of a connection in "host:port" format (e.g. "10.0.0.1:54321"). Use
	// [RemoteIP] when only the IP address is needed.
	RemoteAddr = "remote_addr"

	// HTTPReferrer - a string field that captures the Referer header of an
	// HTTP request with sensitive query parameters masked.
	HTTPReferrer = "referrer"

	// HTTPUserAgent - a string field that captures the User-Agent header of
	// an HTTP request.
	HTTPUserAgent = "user_agent"

	// WrittenBytes - an int64 field that captures the number of bytes written
	// to the HTTP response body.
	WrittenBytes = "written_bytes"

	// ContentType - a string field that captures the Content-Type of an HTTP
	// response (e.g. "application/json").
	ContentType = "content_type"

	// TTFBS - a float64 field that captures the time to first byte of an HTTP
	// response in seconds. Measures the duration from when the request was
	// received to when the first byte of the response was written.
	TTFBS = "ttfb_s"

	// New fields being added to this section should have
	// the appropriate doc comments added above. These
	// should clearly indicate what the intended use of the
+147 −0
Original line number Diff line number Diff line
package httpserver

import (
	"bufio"
	"fmt"
	"log/slog"
	"net"
	"net/http"
	"strings"
	"time"

	"gitlab.com/gitlab-org/labkit/v2/correlation"
	"gitlab.com/gitlab-org/labkit/v2/log"
)

// AccessLogger returns HTTP middleware that emits a structured access log
// entry for each request. If logger is nil, slog.Default() is used.
func AccessLogger(next http.Handler, logger *slog.Logger) http.Handler {
	if logger == nil {
		logger = slog.Default()
	}

	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		lrw := &loggingResponseWriter{rw: w, started: time.Now()}

		var wrapped http.ResponseWriter = lrw
		if _, ok := w.(http.Hijacker); ok {
			wrapped = &hijackingResponseWriter{loggingResponseWriter: lrw}
		}

		var completed bool
		defer func() {
			// Default to 200 only when the handler returned normally without
			// calling WriteHeader. On panics, completed is false and status
			// remains 0, indicating abnormal termination.
			if completed && !lrw.wroteHeader {
				lrw.status = http.StatusOK
			}

			remoteAddr := r.RemoteAddr
			remoteIP, _, err := net.SplitHostPort(remoteAddr)
			if err != nil {
				remoteIP = remoteAddr
			}

			// Use X-Forwarded-For when present to extract the real client
			// IP from behind load balancers and reverse proxies.
			clientIP := remoteIP
			if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
				if ip := firstXForwardedFor(xff); ip != "" {
					clientIP = ip
				}
			}

			dur := time.Since(lrw.started)

			logger.LogAttrs(r.Context(), slog.LevelInfo, "access",
				log.CorrelationID(correlation.ExtractFromContext(r.Context())),
				log.HTTPMethod(r.Method),
				log.HTTPURI(maskURL(r.RequestURI)),
				log.HTTPStatusCode(lrw.status),
				log.DurationS(dur),
				slog.String("system", "http"),
				log.HTTPHost(r.Host),
				log.HTTPProto(r.Proto),
				log.RemoteAddr(remoteAddr),
				log.RemoteIP(clientIP),
				log.HTTPReferrer(maskURL(r.Referer())),
				log.HTTPUserAgent(r.UserAgent()),
				log.WrittenBytes(lrw.written),
				log.ContentType(lrw.contentType),
				log.TTFBS(lrw.ttfb),
			)
		}()

		next.ServeHTTP(wrapped, r)
		completed = true
	})
}

// firstXForwardedFor returns the first (leftmost) IP from an
// X-Forwarded-For header value, which represents the original client.
func firstXForwardedFor(xff string) string {
	ip, _, _ := strings.Cut(xff, ",")
	ip = strings.TrimSpace(ip)
	// Validate it looks like an IP address.
	if ip == "" || net.ParseIP(ip) == nil {
		return ""
	}
	return ip
}

// loggingResponseWriter wraps http.ResponseWriter to capture response
// metadata for access logging.
type loggingResponseWriter struct {
	rw          http.ResponseWriter
	status      int
	wroteHeader bool
	written     int64
	contentType string
	ttfb        time.Duration
	started     time.Time
}

func (l *loggingResponseWriter) Header() http.Header {
	return l.rw.Header()
}

func (l *loggingResponseWriter) Write(data []byte) (int, error) {
	if !l.wroteHeader {
		l.WriteHeader(http.StatusOK)
	}
	n, err := l.rw.Write(data)
	l.written += int64(n)
	return n, err
}

func (l *loggingResponseWriter) WriteHeader(status int) {
	if l.wroteHeader {
		return
	}
	l.wroteHeader = true
	l.status = status
	l.contentType = l.rw.Header().Get("Content-Type")
	l.ttfb = time.Since(l.started)
	l.rw.WriteHeader(status)
}

// Unwrap returns the underlying http.ResponseWriter, enabling
// http.ResponseController to access optional interfaces (e.g. Flusher).
func (l *loggingResponseWriter) Unwrap() http.ResponseWriter {
	return l.rw
}

// hijackingResponseWriter extends loggingResponseWriter with http.Hijacker
// support for WebSocket upgrades.
type hijackingResponseWriter struct {
	*loggingResponseWriter
}

func (h *hijackingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
	hj, ok := h.loggingResponseWriter.rw.(http.Hijacker)
	if !ok {
		return nil, nil, fmt.Errorf("httpserver.AccessLogger: underlying ResponseWriter does not implement http.Hijacker")
	}
	return hj.Hijack()
}
+119 −0
Original line number Diff line number Diff line
package httpserver

import (
	"io"
	"log/slog"
	"net/http"
	"net/http/httptest"
	"testing"
)

// discardHandler writes a small JSON response, simulating a real endpoint.
var discardHandler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
	w.Header().Set("Content-Type", "application/json")
	w.WriteHeader(http.StatusOK)
	_, _ = w.Write([]byte(`{"status":"ok"}`))
})

// discardLogger sends all log output to io.Discard so logging overhead is
// measured but I/O is not.
func discardLogger() *slog.Logger {
	return slog.New(slog.NewJSONHandler(io.Discard, nil))
}

// BenchmarkAccessLogger measures the full middleware overhead: response writer
// wrapping, handler execution, field assembly, and log emission.
func BenchmarkAccessLogger(b *testing.B) {
	logger := discardLogger()
	handler := AccessLogger(discardHandler, logger)
	req := httptest.NewRequest(http.MethodGet, "/api/v4/projects/123/merge_requests?page=1&per_page=20", nil)
	req.Header.Set("User-Agent", "Go-http-client/1.1")
	req.Header.Set("Referer", "https://gitlab.com/dashboard")
	req.RemoteAddr = "192.168.1.100:12345"

	b.ReportAllocs()
	b.ResetTimer()
	for range b.N {
		rec := httptest.NewRecorder()
		handler.ServeHTTP(rec, req)
	}
}

// BenchmarkAccessLogger_SensitiveParams measures overhead when the URI
// contains sensitive query parameters that trigger masking.
func BenchmarkAccessLogger_SensitiveParams(b *testing.B) {
	logger := discardLogger()
	handler := AccessLogger(discardHandler, logger)
	req := httptest.NewRequest(http.MethodGet, "/api/v4/session?password=secret&private_token=abc123&page=1", nil)
	req.Header.Set("Referer", "https://gitlab.com/login?password=leaked")
	req.RemoteAddr = "10.0.0.1:54321"

	b.ReportAllocs()
	b.ResetTimer()
	for range b.N {
		rec := httptest.NewRecorder()
		handler.ServeHTTP(rec, req)
	}
}

// BenchmarkAccessLogger_LargeResponse measures overhead when the handler
// writes a larger response body to exercise the byte counting path.
func BenchmarkAccessLogger_LargeResponse(b *testing.B) {
	largeBody := make([]byte, 64*1024) // 64 KiB
	for i := range largeBody {
		largeBody[i] = 'x'
	}
	largeHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
		w.Header().Set("Content-Type", "application/octet-stream")
		w.WriteHeader(http.StatusOK)
		_, _ = w.Write(largeBody)
	})

	logger := discardLogger()
	handler := AccessLogger(largeHandler, logger)
	req := httptest.NewRequest(http.MethodGet, "/download", nil)
	req.RemoteAddr = "10.0.0.1:54321"

	b.ReportAllocs()
	b.ResetTimer()
	for range b.N {
		rec := httptest.NewRecorder()
		handler.ServeHTTP(rec, req)
	}
}

// BenchmarkAccessLogger_NilLogger measures the baseline handler overhead when
// no logger is configured (nil logger falls back to slog.Default, which
// writes to stderr).
func BenchmarkAccessLogger_NilLogger(b *testing.B) {
	// Override slog default to discard so benchmark isn't I/O bound
	orig := slog.Default()
	slog.SetDefault(discardLogger())
	defer slog.SetDefault(orig)

	handler := AccessLogger(discardHandler, nil)
	req := httptest.NewRequest(http.MethodGet, "/health", nil)
	req.RemoteAddr = "127.0.0.1:1234"

	b.ReportAllocs()
	b.ResetTimer()
	for range b.N {
		rec := httptest.NewRecorder()
		handler.ServeHTTP(rec, req)
	}
}

// BenchmarkResponseWriterWrite measures the isolated write path of the
// loggingResponseWriter without the full middleware stack.
func BenchmarkResponseWriterWrite(b *testing.B) {
	data := []byte(`{"status":"ok"}`)

	b.ReportAllocs()
	b.ResetTimer()
	for range b.N {
		rec := httptest.NewRecorder()
		lrw := &loggingResponseWriter{rw: rec}
		lrw.WriteHeader(http.StatusOK)
		_, _ = lrw.Write(data)
	}
}
Loading