Verified Commit e6a11ef3 authored by Jean-Gabriel Doyon's avatar Jean-Gabriel Doyon Committed by GitLab
Browse files

feat(analytics): extract coding agent from client User-Agent header

parent ad8ecf36
Loading
Loading
Loading
Loading
+28 −3
Original line number Diff line number Diff line
@@ -32,8 +32,15 @@ pub(crate) fn build_common(
        .ok()
}

pub(crate) fn build_query(claims: &Claims, tool_name: &str) -> Option<OrbitQueryContext> {
pub(crate) fn build_query(
    claims: &Claims,
    tool_name: &str,
    coding_agent: Option<&str>,
) -> Option<OrbitQueryContext> {
    let mut b = OrbitQueryContext::builder(map_source(&claims.source_type)).tool_name(tool_name);
    if let Some(agent) = coding_agent {
        b = b.coding_agent(agent);
    }
    if let Some(ref id) = claims.global_user_id {
        b = b.global_user_id(id);
    }
@@ -127,7 +134,7 @@ mod tests {
    fn query_data(claims: &Claims, tool: &str) -> serde_json::Value {
        use labkit_events::gkg::GkgEvent;
        let common = build_common(&AnalyticsConfig::default(), claims, "33").unwrap();
        let query = build_query(claims, tool).unwrap();
        let query = build_query(claims, tool, None).unwrap();
        let event = GkgEvent::query_executed(common, query);
        event.contexts()[1].data.clone()
    }
@@ -135,7 +142,7 @@ mod tests {
    fn common_data(claims: &Claims, schema_version: &str) -> serde_json::Value {
        use labkit_events::gkg::GkgEvent;
        let common = build_common(&AnalyticsConfig::default(), claims, schema_version).unwrap();
        let query = build_query(claims, "query_graph").unwrap();
        let query = build_query(claims, "query_graph", None).unwrap();
        let event = GkgEvent::query_executed(common, query);
        event.contexts()[0].data.clone()
    }
@@ -163,6 +170,24 @@ mod tests {
        assert_eq!(data["tool_name"], "get_graph_schema");
    }

    #[test]
    fn build_query_passes_through_coding_agent() {
        use labkit_events::gkg::GkgEvent;
        let claims = claims_with_paths(vec![]);
        let common = build_common(&AnalyticsConfig::default(), &claims, "33").unwrap();
        let query = build_query(&claims, "query_graph", Some("claude-code")).unwrap();
        let event = GkgEvent::query_executed(common, query);
        let data = event.contexts()[1].data.clone();
        assert_eq!(data["coding_agent"], "claude-code");
    }

    #[test]
    fn build_query_omits_coding_agent_when_none() {
        let claims = claims_with_paths(vec![]);
        let data = query_data(&claims, "query_graph");
        assert!(data.get("coding_agent").is_none());
    }

    #[test]
    fn map_source_recognises_all_jwt_values() {
        let cases = [
+8 −1
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@ pub(crate) struct AnalyticsObserver {
    config: Arc<AnalyticsConfig>,
    claims: Claims,
    tool_name: String,
    coding_agent: Option<String>,
    schema_version: String,
    errored: Cell<bool>,
}
@@ -27,6 +28,7 @@ impl AnalyticsObserver {
        config: Arc<AnalyticsConfig>,
        claims: Claims,
        tool_name: impl Into<String>,
        coding_agent: Option<String>,
        schema_version: String,
    ) -> Self {
        Self {
@@ -34,6 +36,7 @@ impl AnalyticsObserver {
            config,
            claims,
            tool_name: tool_name.into(),
            coding_agent,
            schema_version,
            errored: Cell::new(false),
        }
@@ -62,7 +65,8 @@ impl PipelineObserver for AnalyticsObserver {
        let Some(common) = build_common(&self.config, &self.claims, &self.schema_version) else {
            return;
        };
        let Some(query) = build_query(&self.claims, &self.tool_name) else {
        let Some(query) = build_query(&self.claims, &self.tool_name, self.coding_agent.as_deref())
        else {
            return;
        };
        tracker.track(GkgEvent::query_executed(common, query));
@@ -114,6 +118,7 @@ mod tests {
            Arc::new(AnalyticsConfig::default()),
            test_claims(),
            "query_graph",
            None,
            "33".to_string(),
        );
        obs.finish(10, 0);
@@ -128,6 +133,7 @@ mod tests {
            Arc::new(AnalyticsConfig::default()),
            test_claims(),
            "query_graph",
            None,
            "33".to_string(),
        );
        obs.record_error(&PipelineError::Execution("x".into()));
@@ -142,6 +148,7 @@ mod tests {
            Arc::new(AnalyticsConfig::default()),
            test_claims(),
            "query_graph",
            None,
            "33".to_string(),
        );
        obs.finish(1, 0);
+115 −6
Original line number Diff line number Diff line
@@ -2,7 +2,37 @@ use tonic::{Request, Status};

use crate::auth::{Claims, JwtValidator};

pub fn extract_claims<T>(request: &Request<T>, validator: &JwtValidator) -> Result<Claims, Status> {
#[derive(Debug)]
pub struct RequestContext {
    pub claims: Claims,
    pub user_agent: Option<String>,
}

impl RequestContext {
    pub fn coding_agent(&self) -> Option<&str> {
        self.user_agent.as_deref().and_then(|ua| {
            ua.split_whitespace()
                .find_map(|token| token.strip_prefix("Coding-Agent/"))
        })
    }

    pub fn record_in_current_span(&self) {
        let span = tracing::Span::current();
        span.record("user_id", self.claims.user_id);
        span.record("source_type", &self.claims.source_type);
        if let Some(sid) = &self.claims.ai_session_id {
            span.record("ai_session_id", sid.as_str());
        }
        if let Some(agent) = self.coding_agent() {
            span.record("coding_agent", agent);
        }
    }
}

pub fn extract_request_context<T>(
    request: &Request<T>,
    validator: &JwtValidator,
) -> Result<RequestContext, Status> {
    let token = request
        .metadata()
        .get("authorization")
@@ -10,10 +40,18 @@ pub fn extract_claims<T>(request: &Request<T>, validator: &JwtValidator) -> Resu
        .and_then(|s| s.strip_prefix("Bearer "))
        .ok_or_else(|| Status::unauthenticated("Missing or invalid authorization header"))?;

    validator.validate(token).map_err(|e| {
    let claims = validator.validate(token).map_err(|e| {
        tracing::warn!(error = %e, "JWT validation failed");
        Status::unauthenticated(format!("JWT validation failed: {e}"))
    })
    })?;

    let user_agent = request
        .metadata()
        .get("x-client-user-agent")
        .and_then(|v| v.to_str().ok())
        .map(String::from);

    Ok(RequestContext { claims, user_agent })
}

#[cfg(test)]
@@ -30,7 +68,7 @@ mod tests {
        let request: Request<()> = Request::new(());
        let validator = mock_validator();

        let result = extract_claims(&request, &validator);
        let result = extract_request_context(&request, &validator);
        assert!(result.is_err());

        let status = result.unwrap_err();
@@ -47,7 +85,7 @@ mod tests {
        );
        let validator = mock_validator();

        let result = extract_claims(&request, &validator);
        let result = extract_request_context(&request, &validator);
        assert!(result.is_err());

        let status = result.unwrap_err();
@@ -63,11 +101,82 @@ mod tests {
        );
        let validator = mock_validator();

        let result = extract_claims(&request, &validator);
        let result = extract_request_context(&request, &validator);
        assert!(result.is_err());

        let status = result.unwrap_err();
        assert_eq!(status.code(), tonic::Code::Unauthenticated);
        assert!(status.message().contains("JWT validation failed"));
    }

    fn request_context(user_agent: Option<&str>) -> RequestContext {
        RequestContext {
            claims: crate::auth::Claims {
                sub: String::new(),
                iss: String::new(),
                aud: String::new(),
                iat: 0,
                exp: 0,
                user_id: 0,
                username: String::new(),
                admin: false,
                organization_id: None,
                min_access_level: None,
                group_traversal_ids: vec![],
                source_type: String::new(),
                ai_session_id: None,
                instance_id: None,
                unique_instance_id: None,
                instance_version: None,
                global_user_id: None,
                host_name: None,
                root_namespace_id: None,
                deployment_type: None,
                realm: None,
            },
            user_agent: user_agent.map(Into::into),
        }
    }

    #[test]
    fn coding_agent_extracts_known_agents() {
        let cases = [
            (
                "glab/1.50.0 (linux, amd64) Coding-Agent/claude-code",
                Some("claude-code"),
            ),
            (
                "glab/1.50.0 (darwin, arm64) Coding-Agent/codex",
                Some("codex"),
            ),
            (
                "glab/1.50.0 (windows, amd64) Coding-Agent/cursor",
                Some("cursor"),
            ),
            (
                "glab/1.50.0 (linux, amd64) Coding-Agent/opencode",
                Some("opencode"),
            ),
            (
                "glab/DEV (linux, amd64) Coding-Agent/custom-agent_2.1",
                Some("custom-agent_2.1"),
            ),
        ];
        for (user_agent, expected) in cases {
            let ctx = request_context(Some(user_agent));
            assert_eq!(ctx.coding_agent(), expected, "for user_agent: {user_agent}");
        }
    }

    #[test]
    fn coding_agent_none_when_absent() {
        let ctx = request_context(Some("glab/1.50.0 (linux, amd64)"));
        assert_eq!(ctx.coding_agent(), None);
    }

    #[test]
    fn coding_agent_none_when_no_user_agent() {
        let ctx = request_context(None);
        assert_eq!(ctx.coding_agent(), None);
    }
}
+66 −53
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ use tokio_stream::wrappers::ReceiverStream;
use tonic::{Request, Response, Status, Streaming};
use tracing::{Instrument, info, instrument};

use super::auth::extract_claims;
use super::auth::extract_request_context;
use crate::analytics::AnalyticsTracker;
use crate::auth::{Claims, JwtValidator, build_security_context};
use crate::cluster_health::ClusterHealthChecker;
@@ -41,12 +41,6 @@ fn proto_format_name(name: FormatName) -> ProtoFormatName {
    }
}

fn record_ai_session_id(ai_session_id: &Option<String>) {
    if let Some(sid) = ai_session_id {
        tracing::Span::current().record("ai_session_id", sid.as_str());
    }
}

fn proto_tool_definition(t: crate::tools::ToolDefinition) -> ProtoToolDefinition {
    ProtoToolDefinition {
        name: t.name,
@@ -138,15 +132,16 @@ type ExecuteQueryStream =
impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService
    for KnowledgeGraphServiceImpl
{
    #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))]
    #[instrument(
        skip(self, request),
        fields(user_id, source_type, ai_session_id, coding_agent)
    )]
    async fn list_tools(
        &self,
        request: Request<ListToolsRequest>,
    ) -> Result<Response<ListToolsResponse>, Status> {
        let claims = extract_claims(&request, &self.validator)?;
        tracing::Span::current().record("user_id", claims.user_id);
        tracing::Span::current().record("source_type", &claims.source_type);
        record_ai_session_id(&claims.ai_session_id);
        let ctx = extract_request_context(&request, &self.validator)?;
        ctx.record_in_current_span();

        info!("Listing tools for user");

@@ -158,15 +153,16 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService
        Ok(Response::new(ListToolsResponse { tools }))
    }

    #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))]
    #[instrument(
        skip(self, request),
        fields(user_id, source_type, ai_session_id, coding_agent)
    )]
    async fn list_agent_commands(
        &self,
        request: Request<ListAgentCommandsRequest>,
    ) -> Result<Response<ListAgentCommandsResponse>, Status> {
        let claims = extract_claims(&request, &self.validator)?;
        tracing::Span::current().record("user_id", claims.user_id);
        tracing::Span::current().record("source_type", &claims.source_type);
        record_ai_session_id(&claims.ai_session_id);
        let ctx = extract_request_context(&request, &self.validator)?;
        ctx.record_in_current_span();

        let req = request.get_ref();
        let requested = &req.command_names;
@@ -217,15 +213,16 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService
        }))
    }

    #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))]
    #[instrument(
        skip(self, request),
        fields(user_id, source_type, ai_session_id, coding_agent)
    )]
    async fn invoke_agent_command(
        &self,
        request: Request<InvokeAgentCommandRequest>,
    ) -> Result<Response<InvokeAgentCommandResponse>, Status> {
        let claims = extract_claims(&request, &self.validator)?;
        tracing::Span::current().record("user_id", claims.user_id);
        tracing::Span::current().record("source_type", &claims.source_type);
        record_ai_session_id(&claims.ai_session_id);
        let ctx = extract_request_context(&request, &self.validator)?;
        ctx.record_in_current_span();

        let req = request.get_ref();
        if req.command_name.trim().is_empty() {
@@ -268,15 +265,18 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService

    type ExecuteQueryStream = ExecuteQueryStream;

    #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))]
    #[instrument(
        skip(self, request),
        fields(user_id, source_type, ai_session_id, coding_agent)
    )]
    async fn execute_query(
        &self,
        request: Request<Streaming<ExecuteQueryMessage>>,
    ) -> Result<Response<Self::ExecuteQueryStream>, Status> {
        let claims = extract_claims(&request, &self.validator)?;
        tracing::Span::current().record("user_id", claims.user_id);
        tracing::Span::current().record("source_type", &claims.source_type);
        record_ai_session_id(&claims.ai_session_id);
        let ctx = extract_request_context(&request, &self.validator)?;
        ctx.record_in_current_span();
        let coding_agent = ctx.coding_agent().map(String::from);
        let claims = ctx.claims;

        let mut stream = request.into_inner();
        let (tx, rx) = mpsc::channel(4);
@@ -298,7 +298,14 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService

                let timeout = std::time::Duration::from_secs(stream_timeout);
                let result = pipeline
                    .run_query(claims, &req.query, tx.clone(), stream, timeout)
                    .run_query(
                        claims,
                        coding_agent,
                        &req.query,
                        tx.clone(),
                        stream,
                        timeout,
                    )
                    .await;

                match result {
@@ -367,15 +374,16 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService
        Ok(Response::new(Box::pin(ReceiverStream::new(rx))))
    }

    #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))]
    #[instrument(
        skip(self, request),
        fields(user_id, source_type, ai_session_id, coding_agent)
    )]
    async fn get_graph_schema(
        &self,
        request: Request<GetGraphSchemaRequest>,
    ) -> Result<Response<GetGraphSchemaResponse>, Status> {
        let claims = extract_claims(&request, &self.validator)?;
        tracing::Span::current().record("user_id", claims.user_id);
        tracing::Span::current().record("source_type", &claims.source_type);
        record_ai_session_id(&claims.ai_session_id);
        let ctx = extract_request_context(&request, &self.validator)?;
        ctx.record_in_current_span();

        let req = request.get_ref();
        info!(format = ?req.format, "Fetching graph schema for user");
@@ -398,15 +406,16 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService
        Ok(Response::new(response))
    }

    #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))]
    #[instrument(
        skip(self, request),
        fields(user_id, source_type, ai_session_id, coding_agent)
    )]
    async fn get_response_format(
        &self,
        request: Request<GetResponseFormatRequest>,
    ) -> Result<Response<GetResponseFormatResponse>, Status> {
        let claims = extract_claims(&request, &self.validator)?;
        tracing::Span::current().record("user_id", claims.user_id);
        tracing::Span::current().record("source_type", &claims.source_type);
        record_ai_session_id(&claims.ai_session_id);
        let ctx = extract_request_context(&request, &self.validator)?;
        ctx.record_in_current_span();

        let req = request.get_ref();
        info!(format = ?req.format, "Fetching query response format for user");
@@ -434,15 +443,16 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService
        Ok(Response::new(response))
    }

    #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))]
    #[instrument(
        skip(self, request),
        fields(user_id, source_type, ai_session_id, coding_agent)
    )]
    async fn get_query_dsl(
        &self,
        request: Request<GetQueryDslRequest>,
    ) -> Result<Response<GetQueryDslResponse>, Status> {
        let claims = extract_claims(&request, &self.validator)?;
        tracing::Span::current().record("user_id", claims.user_id);
        tracing::Span::current().record("source_type", &claims.source_type);
        record_ai_session_id(&claims.ai_session_id);
        let ctx = extract_request_context(&request, &self.validator)?;
        ctx.record_in_current_span();

        let req = request.get_ref();
        info!(format = ?req.format, "Fetching query DSL grammar for user");
@@ -466,15 +476,16 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService
        Ok(Response::new(response))
    }

    #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))]
    #[instrument(
        skip(self, request),
        fields(user_id, source_type, ai_session_id, coding_agent)
    )]
    async fn get_cluster_health(
        &self,
        request: Request<GetClusterHealthRequest>,
    ) -> Result<Response<GetClusterHealthResponse>, Status> {
        let claims = extract_claims(&request, &self.validator)?;
        tracing::Span::current().record("user_id", claims.user_id);
        tracing::Span::current().record("source_type", &claims.source_type);
        record_ai_session_id(&claims.ai_session_id);
        let ctx = extract_request_context(&request, &self.validator)?;
        ctx.record_in_current_span();

        let req = request.get_ref();
        info!(format = ?req.format, "Fetching cluster health for user");
@@ -483,15 +494,17 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService
        Ok(Response::new(response))
    }

    #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))]
    #[instrument(
        skip(self, request),
        fields(user_id, source_type, ai_session_id, coding_agent)
    )]
    async fn get_graph_status(
        &self,
        request: Request<GetGraphStatusRequest>,
    ) -> Result<Response<GetGraphStatusResponse>, Status> {
        let claims = extract_claims(&request, &self.validator)?;
        tracing::Span::current().record("user_id", claims.user_id);
        tracing::Span::current().record("source_type", &claims.source_type);
        record_ai_session_id(&claims.ai_session_id);
        let ctx = extract_request_context(&request, &self.validator)?;
        ctx.record_in_current_span();
        let claims = ctx.claims;

        let req = request.get_ref();
        authorize_traversal_path(&claims, &req.traversal_path)?;
+2 −0
Original line number Diff line number Diff line
@@ -77,6 +77,7 @@ impl QueryPipelineService {
    pub async fn run_query(
        &self,
        claims: Claims,
        coding_agent: Option<String>,
        query_json: &str,
        tx: mpsc::Sender<Result<ExecuteQueryMessage, Status>>,
        stream: Streaming<ExecuteQueryMessage>,
@@ -93,6 +94,7 @@ impl QueryPipelineService {
                Arc::clone(&self.analytics_config),
                claims.clone(),
                "query_graph",
                coding_agent,
                SCHEMA_VERSION.to_string(),
            )),
        ]);