Commit 4993800f authored by Doug Barrett's avatar Doug Barrett
Browse files

fix(v2/httpserver): Decouple tracing middleware from Chi and add Flusher support

parent 9ac3ddad
Loading
Loading
Loading
Loading
+16 −6
Original line number Diff line number Diff line
@@ -5,7 +5,6 @@ import (
	"log/slog"
	"net/http"

	"github.com/go-chi/chi/v5"
	"go.opentelemetry.io/otel/propagation"

	"gitlab.com/gitlab-org/labkit/v2/trace"
@@ -37,11 +36,21 @@ func (rw *responseWriter) Unwrap() http.ResponseWriter {
	return rw.ResponseWriter
}

// Flush implements http.Flusher by delegating to the underlying ResponseWriter
// when it supports flushing. This enables SSE and streaming responses through
// middleware that wraps the writer. Code using http.ResponseController (Go
// 1.20+) can also reach the underlying Flusher via Unwrap.
func (rw *responseWriter) Flush() {
	if f, ok := rw.ResponseWriter.(http.Flusher); ok {
		f.Flush()
	}
}

// tracingMiddleware creates a server span for each request. It extracts any
// incoming W3C trace context so that the new span is correctly parented to the
// upstream caller's trace. The span name uses the route pattern (e.g.
// "GET /users/{id}") when available, falling back to "METHOD /path".
func tracingMiddleware(tracer *trace.Tracer) func(http.Handler) http.Handler {
func tracingMiddleware(tracer *trace.Tracer, router Router) func(http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			// Extract incoming trace context before starting the server span so
@@ -62,10 +71,11 @@ func tracingMiddleware(tracer *trace.Tracer) func(http.Handler) http.Handler {
			rw := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
			next.ServeHTTP(rw, r.WithContext(ctx))

			// Use the route pattern for a low-cardinality span name when the
			// chi route context is available.
			if rctx := chi.RouteContext(ctx); rctx != nil && rctx.RoutePattern() != "" {
				span.SetName(r.Method + " " + rctx.RoutePattern())
			// Refine to a low-cardinality span name using the route pattern
			// from the Router. This works with any Router implementation,
			// not just Chi.
			if pattern := router.RoutePattern(r); pattern != "" {
				span.SetName(r.Method + " " + pattern)
			}

			span.SetAttribute("http.status_code", rw.statusCode)
+13 −0
Original line number Diff line number Diff line
@@ -68,6 +68,11 @@ type Router interface {
	// exists but the HTTP method does not match.
	MethodNotAllowed(handler http.HandlerFunc)

	// RoutePattern returns the matched route pattern for the given request
	// (e.g. "/users/{id}"), or an empty string if no pattern is available.
	// The tracing middleware uses this to produce low-cardinality span names.
	RoutePattern(r *http.Request) string

	// ServeHTTP dispatches the request to the matching route.
	http.Handler
}
@@ -149,6 +154,14 @@ func (cr *chiRouter) MethodNotAllowed(handler http.HandlerFunc) {
	cr.r.MethodNotAllowed(handler)
}

func (cr *chiRouter) RoutePattern(r *http.Request) string {
	rctx := chi.RouteContext(r.Context())
	if rctx == nil {
		return ""
	}
	return rctx.RoutePattern()
}

func (cr *chiRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
	cr.r.ServeHTTP(w, req)
}
+1 −1
Original line number Diff line number Diff line
@@ -124,7 +124,7 @@ func NewWithConfig(cfg *Config) *Server {
	// Tracing is outermost: the span covers the full request duration and
	// enriches the context for the logger and application handlers.
	if cfg.Tracer != nil {
		router.Use(tracingMiddleware(cfg.Tracer))
		router.Use(tracingMiddleware(cfg.Tracer, router))
	}
	if cfg.Logger != nil {
		router.Use(loggingMiddleware(cfg.Logger))
+1 −0
Original line number Diff line number Diff line
@@ -301,6 +301,7 @@ func (r *testRouter) Route(_ string, _ func(httpserver.Router))
func (r *testRouter) Mount(_ string, _ http.Handler)                              {}
func (r *testRouter) NotFound(_ http.HandlerFunc)                                 {}
func (r *testRouter) MethodNotAllowed(_ http.HandlerFunc)                         {}
func (r *testRouter) RoutePattern(_ *http.Request) string                         { return "" }
func (r *testRouter) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
	r.served = true
	w.WriteHeader(http.StatusOK)