Loading crates/gkg-server/src/analytics/context.rs +28 −3 Original line number Diff line number Diff line Loading @@ -32,8 +32,15 @@ pub(crate) fn build_common( .ok() } pub(crate) fn build_query(claims: &Claims, tool_name: &str) -> Option<OrbitQueryContext> { pub(crate) fn build_query( claims: &Claims, tool_name: &str, coding_agent: Option<&str>, ) -> Option<OrbitQueryContext> { let mut b = OrbitQueryContext::builder(map_source(&claims.source_type)).tool_name(tool_name); if let Some(agent) = coding_agent { b = b.coding_agent(agent); } if let Some(ref id) = claims.global_user_id { b = b.global_user_id(id); } Loading Loading @@ -127,7 +134,7 @@ mod tests { fn query_data(claims: &Claims, tool: &str) -> serde_json::Value { use labkit_events::gkg::GkgEvent; let common = build_common(&AnalyticsConfig::default(), claims, "33").unwrap(); let query = build_query(claims, tool).unwrap(); let query = build_query(claims, tool, None).unwrap(); let event = GkgEvent::query_executed(common, query); event.contexts()[1].data.clone() } Loading @@ -135,7 +142,7 @@ mod tests { fn common_data(claims: &Claims, schema_version: &str) -> serde_json::Value { use labkit_events::gkg::GkgEvent; let common = build_common(&AnalyticsConfig::default(), claims, schema_version).unwrap(); let query = build_query(claims, "query_graph").unwrap(); let query = build_query(claims, "query_graph", None).unwrap(); let event = GkgEvent::query_executed(common, query); event.contexts()[0].data.clone() } Loading Loading @@ -163,6 +170,24 @@ mod tests { assert_eq!(data["tool_name"], "get_graph_schema"); } #[test] fn build_query_passes_through_coding_agent() { use labkit_events::gkg::GkgEvent; let claims = claims_with_paths(vec![]); let common = build_common(&AnalyticsConfig::default(), &claims, "33").unwrap(); let query = build_query(&claims, "query_graph", Some("claude-code")).unwrap(); let event = GkgEvent::query_executed(common, query); let data = event.contexts()[1].data.clone(); assert_eq!(data["coding_agent"], "claude-code"); } #[test] fn build_query_omits_coding_agent_when_none() { let claims = claims_with_paths(vec![]); let data = query_data(&claims, "query_graph"); assert!(data.get("coding_agent").is_none()); } #[test] fn map_source_recognises_all_jwt_values() { let cases = [ Loading crates/gkg-server/src/analytics/observer.rs +8 −1 Original line number Diff line number Diff line Loading @@ -17,6 +17,7 @@ pub(crate) struct AnalyticsObserver { config: Arc<AnalyticsConfig>, claims: Claims, tool_name: String, coding_agent: Option<String>, schema_version: String, errored: Cell<bool>, } Loading @@ -27,6 +28,7 @@ impl AnalyticsObserver { config: Arc<AnalyticsConfig>, claims: Claims, tool_name: impl Into<String>, coding_agent: Option<String>, schema_version: String, ) -> Self { Self { Loading @@ -34,6 +36,7 @@ impl AnalyticsObserver { config, claims, tool_name: tool_name.into(), coding_agent, schema_version, errored: Cell::new(false), } Loading Loading @@ -62,7 +65,8 @@ impl PipelineObserver for AnalyticsObserver { let Some(common) = build_common(&self.config, &self.claims, &self.schema_version) else { return; }; let Some(query) = build_query(&self.claims, &self.tool_name) else { let Some(query) = build_query(&self.claims, &self.tool_name, self.coding_agent.as_deref()) else { return; }; tracker.track(GkgEvent::query_executed(common, query)); Loading Loading @@ -114,6 +118,7 @@ mod tests { Arc::new(AnalyticsConfig::default()), test_claims(), "query_graph", None, "33".to_string(), ); obs.finish(10, 0); Loading @@ -128,6 +133,7 @@ mod tests { Arc::new(AnalyticsConfig::default()), test_claims(), "query_graph", None, "33".to_string(), ); obs.record_error(&PipelineError::Execution("x".into())); Loading @@ -142,6 +148,7 @@ mod tests { Arc::new(AnalyticsConfig::default()), test_claims(), "query_graph", None, "33".to_string(), ); obs.finish(1, 0); Loading crates/gkg-server/src/grpc/auth.rs +115 −6 Original line number Diff line number Diff line Loading @@ -2,7 +2,37 @@ use tonic::{Request, Status}; use crate::auth::{Claims, JwtValidator}; pub fn extract_claims<T>(request: &Request<T>, validator: &JwtValidator) -> Result<Claims, Status> { #[derive(Debug)] pub struct RequestContext { pub claims: Claims, pub user_agent: Option<String>, } impl RequestContext { pub fn coding_agent(&self) -> Option<&str> { self.user_agent.as_deref().and_then(|ua| { ua.split_whitespace() .find_map(|token| token.strip_prefix("Coding-Agent/")) }) } pub fn record_in_current_span(&self) { let span = tracing::Span::current(); span.record("user_id", self.claims.user_id); span.record("source_type", &self.claims.source_type); if let Some(sid) = &self.claims.ai_session_id { span.record("ai_session_id", sid.as_str()); } if let Some(agent) = self.coding_agent() { span.record("coding_agent", agent); } } } pub fn extract_request_context<T>( request: &Request<T>, validator: &JwtValidator, ) -> Result<RequestContext, Status> { let token = request .metadata() .get("authorization") Loading @@ -10,10 +40,18 @@ pub fn extract_claims<T>(request: &Request<T>, validator: &JwtValidator) -> Resu .and_then(|s| s.strip_prefix("Bearer ")) .ok_or_else(|| Status::unauthenticated("Missing or invalid authorization header"))?; validator.validate(token).map_err(|e| { let claims = validator.validate(token).map_err(|e| { tracing::warn!(error = %e, "JWT validation failed"); Status::unauthenticated(format!("JWT validation failed: {e}")) }) })?; let user_agent = request .metadata() .get("x-client-user-agent") .and_then(|v| v.to_str().ok()) .map(String::from); Ok(RequestContext { claims, user_agent }) } #[cfg(test)] Loading @@ -30,7 +68,7 @@ mod tests { let request: Request<()> = Request::new(()); let validator = mock_validator(); let result = extract_claims(&request, &validator); let result = extract_request_context(&request, &validator); assert!(result.is_err()); let status = result.unwrap_err(); Loading @@ -47,7 +85,7 @@ mod tests { ); let validator = mock_validator(); let result = extract_claims(&request, &validator); let result = extract_request_context(&request, &validator); assert!(result.is_err()); let status = result.unwrap_err(); Loading @@ -63,11 +101,82 @@ mod tests { ); let validator = mock_validator(); let result = extract_claims(&request, &validator); let result = extract_request_context(&request, &validator); assert!(result.is_err()); let status = result.unwrap_err(); assert_eq!(status.code(), tonic::Code::Unauthenticated); assert!(status.message().contains("JWT validation failed")); } fn request_context(user_agent: Option<&str>) -> RequestContext { RequestContext { claims: crate::auth::Claims { sub: String::new(), iss: String::new(), aud: String::new(), iat: 0, exp: 0, user_id: 0, username: String::new(), admin: false, organization_id: None, min_access_level: None, group_traversal_ids: vec![], source_type: String::new(), ai_session_id: None, instance_id: None, unique_instance_id: None, instance_version: None, global_user_id: None, host_name: None, root_namespace_id: None, deployment_type: None, realm: None, }, user_agent: user_agent.map(Into::into), } } #[test] fn coding_agent_extracts_known_agents() { let cases = [ ( "glab/1.50.0 (linux, amd64) Coding-Agent/claude-code", Some("claude-code"), ), ( "glab/1.50.0 (darwin, arm64) Coding-Agent/codex", Some("codex"), ), ( "glab/1.50.0 (windows, amd64) Coding-Agent/cursor", Some("cursor"), ), ( "glab/1.50.0 (linux, amd64) Coding-Agent/opencode", Some("opencode"), ), ( "glab/DEV (linux, amd64) Coding-Agent/custom-agent_2.1", Some("custom-agent_2.1"), ), ]; for (user_agent, expected) in cases { let ctx = request_context(Some(user_agent)); assert_eq!(ctx.coding_agent(), expected, "for user_agent: {user_agent}"); } } #[test] fn coding_agent_none_when_absent() { let ctx = request_context(Some("glab/1.50.0 (linux, amd64)")); assert_eq!(ctx.coding_agent(), None); } #[test] fn coding_agent_none_when_no_user_agent() { let ctx = request_context(None); assert_eq!(ctx.coding_agent(), None); } } crates/gkg-server/src/grpc/service.rs +66 −53 Original line number Diff line number Diff line Loading @@ -12,7 +12,7 @@ use tokio_stream::wrappers::ReceiverStream; use tonic::{Request, Response, Status, Streaming}; use tracing::{Instrument, info, instrument}; use super::auth::extract_claims; use super::auth::extract_request_context; use crate::analytics::AnalyticsTracker; use crate::auth::{Claims, JwtValidator, build_security_context}; use crate::cluster_health::ClusterHealthChecker; Loading Loading @@ -41,12 +41,6 @@ fn proto_format_name(name: FormatName) -> ProtoFormatName { } } fn record_ai_session_id(ai_session_id: &Option<String>) { if let Some(sid) = ai_session_id { tracing::Span::current().record("ai_session_id", sid.as_str()); } } fn proto_tool_definition(t: crate::tools::ToolDefinition) -> ProtoToolDefinition { ProtoToolDefinition { name: t.name, Loading Loading @@ -138,15 +132,16 @@ type ExecuteQueryStream = impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService for KnowledgeGraphServiceImpl { #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))] #[instrument( skip(self, request), fields(user_id, source_type, ai_session_id, coding_agent) )] async fn list_tools( &self, request: Request<ListToolsRequest>, ) -> Result<Response<ListToolsResponse>, Status> { let claims = extract_claims(&request, &self.validator)?; tracing::Span::current().record("user_id", claims.user_id); tracing::Span::current().record("source_type", &claims.source_type); record_ai_session_id(&claims.ai_session_id); let ctx = extract_request_context(&request, &self.validator)?; ctx.record_in_current_span(); info!("Listing tools for user"); Loading @@ -158,15 +153,16 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService Ok(Response::new(ListToolsResponse { tools })) } #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))] #[instrument( skip(self, request), fields(user_id, source_type, ai_session_id, coding_agent) )] async fn list_agent_commands( &self, request: Request<ListAgentCommandsRequest>, ) -> Result<Response<ListAgentCommandsResponse>, Status> { let claims = extract_claims(&request, &self.validator)?; tracing::Span::current().record("user_id", claims.user_id); tracing::Span::current().record("source_type", &claims.source_type); record_ai_session_id(&claims.ai_session_id); let ctx = extract_request_context(&request, &self.validator)?; ctx.record_in_current_span(); let req = request.get_ref(); let requested = &req.command_names; Loading Loading @@ -217,15 +213,16 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService })) } #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))] #[instrument( skip(self, request), fields(user_id, source_type, ai_session_id, coding_agent) )] async fn invoke_agent_command( &self, request: Request<InvokeAgentCommandRequest>, ) -> Result<Response<InvokeAgentCommandResponse>, Status> { let claims = extract_claims(&request, &self.validator)?; tracing::Span::current().record("user_id", claims.user_id); tracing::Span::current().record("source_type", &claims.source_type); record_ai_session_id(&claims.ai_session_id); let ctx = extract_request_context(&request, &self.validator)?; ctx.record_in_current_span(); let req = request.get_ref(); if req.command_name.trim().is_empty() { Loading Loading @@ -268,15 +265,18 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService type ExecuteQueryStream = ExecuteQueryStream; #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))] #[instrument( skip(self, request), fields(user_id, source_type, ai_session_id, coding_agent) )] async fn execute_query( &self, request: Request<Streaming<ExecuteQueryMessage>>, ) -> Result<Response<Self::ExecuteQueryStream>, Status> { let claims = extract_claims(&request, &self.validator)?; tracing::Span::current().record("user_id", claims.user_id); tracing::Span::current().record("source_type", &claims.source_type); record_ai_session_id(&claims.ai_session_id); let ctx = extract_request_context(&request, &self.validator)?; ctx.record_in_current_span(); let coding_agent = ctx.coding_agent().map(String::from); let claims = ctx.claims; let mut stream = request.into_inner(); let (tx, rx) = mpsc::channel(4); Loading @@ -298,7 +298,14 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService let timeout = std::time::Duration::from_secs(stream_timeout); let result = pipeline .run_query(claims, &req.query, tx.clone(), stream, timeout) .run_query( claims, coding_agent, &req.query, tx.clone(), stream, timeout, ) .await; match result { Loading Loading @@ -367,15 +374,16 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService Ok(Response::new(Box::pin(ReceiverStream::new(rx)))) } #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))] #[instrument( skip(self, request), fields(user_id, source_type, ai_session_id, coding_agent) )] async fn get_graph_schema( &self, request: Request<GetGraphSchemaRequest>, ) -> Result<Response<GetGraphSchemaResponse>, Status> { let claims = extract_claims(&request, &self.validator)?; tracing::Span::current().record("user_id", claims.user_id); tracing::Span::current().record("source_type", &claims.source_type); record_ai_session_id(&claims.ai_session_id); let ctx = extract_request_context(&request, &self.validator)?; ctx.record_in_current_span(); let req = request.get_ref(); info!(format = ?req.format, "Fetching graph schema for user"); Loading @@ -398,15 +406,16 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService Ok(Response::new(response)) } #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))] #[instrument( skip(self, request), fields(user_id, source_type, ai_session_id, coding_agent) )] async fn get_response_format( &self, request: Request<GetResponseFormatRequest>, ) -> Result<Response<GetResponseFormatResponse>, Status> { let claims = extract_claims(&request, &self.validator)?; tracing::Span::current().record("user_id", claims.user_id); tracing::Span::current().record("source_type", &claims.source_type); record_ai_session_id(&claims.ai_session_id); let ctx = extract_request_context(&request, &self.validator)?; ctx.record_in_current_span(); let req = request.get_ref(); info!(format = ?req.format, "Fetching query response format for user"); Loading Loading @@ -434,15 +443,16 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService Ok(Response::new(response)) } #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))] #[instrument( skip(self, request), fields(user_id, source_type, ai_session_id, coding_agent) )] async fn get_query_dsl( &self, request: Request<GetQueryDslRequest>, ) -> Result<Response<GetQueryDslResponse>, Status> { let claims = extract_claims(&request, &self.validator)?; tracing::Span::current().record("user_id", claims.user_id); tracing::Span::current().record("source_type", &claims.source_type); record_ai_session_id(&claims.ai_session_id); let ctx = extract_request_context(&request, &self.validator)?; ctx.record_in_current_span(); let req = request.get_ref(); info!(format = ?req.format, "Fetching query DSL grammar for user"); Loading @@ -466,15 +476,16 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService Ok(Response::new(response)) } #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))] #[instrument( skip(self, request), fields(user_id, source_type, ai_session_id, coding_agent) )] async fn get_cluster_health( &self, request: Request<GetClusterHealthRequest>, ) -> Result<Response<GetClusterHealthResponse>, Status> { let claims = extract_claims(&request, &self.validator)?; tracing::Span::current().record("user_id", claims.user_id); tracing::Span::current().record("source_type", &claims.source_type); record_ai_session_id(&claims.ai_session_id); let ctx = extract_request_context(&request, &self.validator)?; ctx.record_in_current_span(); let req = request.get_ref(); info!(format = ?req.format, "Fetching cluster health for user"); Loading @@ -483,15 +494,17 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService Ok(Response::new(response)) } #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))] #[instrument( skip(self, request), fields(user_id, source_type, ai_session_id, coding_agent) )] async fn get_graph_status( &self, request: Request<GetGraphStatusRequest>, ) -> Result<Response<GetGraphStatusResponse>, Status> { let claims = extract_claims(&request, &self.validator)?; tracing::Span::current().record("user_id", claims.user_id); tracing::Span::current().record("source_type", &claims.source_type); record_ai_session_id(&claims.ai_session_id); let ctx = extract_request_context(&request, &self.validator)?; ctx.record_in_current_span(); let claims = ctx.claims; let req = request.get_ref(); authorize_traversal_path(&claims, &req.traversal_path)?; Loading crates/gkg-server/src/pipeline/service.rs +2 −0 Original line number Diff line number Diff line Loading @@ -77,6 +77,7 @@ impl QueryPipelineService { pub async fn run_query( &self, claims: Claims, coding_agent: Option<String>, query_json: &str, tx: mpsc::Sender<Result<ExecuteQueryMessage, Status>>, stream: Streaming<ExecuteQueryMessage>, Loading @@ -93,6 +94,7 @@ impl QueryPipelineService { Arc::clone(&self.analytics_config), claims.clone(), "query_graph", coding_agent, SCHEMA_VERSION.to_string(), )), ]); Loading Loading
crates/gkg-server/src/analytics/context.rs +28 −3 Original line number Diff line number Diff line Loading @@ -32,8 +32,15 @@ pub(crate) fn build_common( .ok() } pub(crate) fn build_query(claims: &Claims, tool_name: &str) -> Option<OrbitQueryContext> { pub(crate) fn build_query( claims: &Claims, tool_name: &str, coding_agent: Option<&str>, ) -> Option<OrbitQueryContext> { let mut b = OrbitQueryContext::builder(map_source(&claims.source_type)).tool_name(tool_name); if let Some(agent) = coding_agent { b = b.coding_agent(agent); } if let Some(ref id) = claims.global_user_id { b = b.global_user_id(id); } Loading Loading @@ -127,7 +134,7 @@ mod tests { fn query_data(claims: &Claims, tool: &str) -> serde_json::Value { use labkit_events::gkg::GkgEvent; let common = build_common(&AnalyticsConfig::default(), claims, "33").unwrap(); let query = build_query(claims, tool).unwrap(); let query = build_query(claims, tool, None).unwrap(); let event = GkgEvent::query_executed(common, query); event.contexts()[1].data.clone() } Loading @@ -135,7 +142,7 @@ mod tests { fn common_data(claims: &Claims, schema_version: &str) -> serde_json::Value { use labkit_events::gkg::GkgEvent; let common = build_common(&AnalyticsConfig::default(), claims, schema_version).unwrap(); let query = build_query(claims, "query_graph").unwrap(); let query = build_query(claims, "query_graph", None).unwrap(); let event = GkgEvent::query_executed(common, query); event.contexts()[0].data.clone() } Loading Loading @@ -163,6 +170,24 @@ mod tests { assert_eq!(data["tool_name"], "get_graph_schema"); } #[test] fn build_query_passes_through_coding_agent() { use labkit_events::gkg::GkgEvent; let claims = claims_with_paths(vec![]); let common = build_common(&AnalyticsConfig::default(), &claims, "33").unwrap(); let query = build_query(&claims, "query_graph", Some("claude-code")).unwrap(); let event = GkgEvent::query_executed(common, query); let data = event.contexts()[1].data.clone(); assert_eq!(data["coding_agent"], "claude-code"); } #[test] fn build_query_omits_coding_agent_when_none() { let claims = claims_with_paths(vec![]); let data = query_data(&claims, "query_graph"); assert!(data.get("coding_agent").is_none()); } #[test] fn map_source_recognises_all_jwt_values() { let cases = [ Loading
crates/gkg-server/src/analytics/observer.rs +8 −1 Original line number Diff line number Diff line Loading @@ -17,6 +17,7 @@ pub(crate) struct AnalyticsObserver { config: Arc<AnalyticsConfig>, claims: Claims, tool_name: String, coding_agent: Option<String>, schema_version: String, errored: Cell<bool>, } Loading @@ -27,6 +28,7 @@ impl AnalyticsObserver { config: Arc<AnalyticsConfig>, claims: Claims, tool_name: impl Into<String>, coding_agent: Option<String>, schema_version: String, ) -> Self { Self { Loading @@ -34,6 +36,7 @@ impl AnalyticsObserver { config, claims, tool_name: tool_name.into(), coding_agent, schema_version, errored: Cell::new(false), } Loading Loading @@ -62,7 +65,8 @@ impl PipelineObserver for AnalyticsObserver { let Some(common) = build_common(&self.config, &self.claims, &self.schema_version) else { return; }; let Some(query) = build_query(&self.claims, &self.tool_name) else { let Some(query) = build_query(&self.claims, &self.tool_name, self.coding_agent.as_deref()) else { return; }; tracker.track(GkgEvent::query_executed(common, query)); Loading Loading @@ -114,6 +118,7 @@ mod tests { Arc::new(AnalyticsConfig::default()), test_claims(), "query_graph", None, "33".to_string(), ); obs.finish(10, 0); Loading @@ -128,6 +133,7 @@ mod tests { Arc::new(AnalyticsConfig::default()), test_claims(), "query_graph", None, "33".to_string(), ); obs.record_error(&PipelineError::Execution("x".into())); Loading @@ -142,6 +148,7 @@ mod tests { Arc::new(AnalyticsConfig::default()), test_claims(), "query_graph", None, "33".to_string(), ); obs.finish(1, 0); Loading
crates/gkg-server/src/grpc/auth.rs +115 −6 Original line number Diff line number Diff line Loading @@ -2,7 +2,37 @@ use tonic::{Request, Status}; use crate::auth::{Claims, JwtValidator}; pub fn extract_claims<T>(request: &Request<T>, validator: &JwtValidator) -> Result<Claims, Status> { #[derive(Debug)] pub struct RequestContext { pub claims: Claims, pub user_agent: Option<String>, } impl RequestContext { pub fn coding_agent(&self) -> Option<&str> { self.user_agent.as_deref().and_then(|ua| { ua.split_whitespace() .find_map(|token| token.strip_prefix("Coding-Agent/")) }) } pub fn record_in_current_span(&self) { let span = tracing::Span::current(); span.record("user_id", self.claims.user_id); span.record("source_type", &self.claims.source_type); if let Some(sid) = &self.claims.ai_session_id { span.record("ai_session_id", sid.as_str()); } if let Some(agent) = self.coding_agent() { span.record("coding_agent", agent); } } } pub fn extract_request_context<T>( request: &Request<T>, validator: &JwtValidator, ) -> Result<RequestContext, Status> { let token = request .metadata() .get("authorization") Loading @@ -10,10 +40,18 @@ pub fn extract_claims<T>(request: &Request<T>, validator: &JwtValidator) -> Resu .and_then(|s| s.strip_prefix("Bearer ")) .ok_or_else(|| Status::unauthenticated("Missing or invalid authorization header"))?; validator.validate(token).map_err(|e| { let claims = validator.validate(token).map_err(|e| { tracing::warn!(error = %e, "JWT validation failed"); Status::unauthenticated(format!("JWT validation failed: {e}")) }) })?; let user_agent = request .metadata() .get("x-client-user-agent") .and_then(|v| v.to_str().ok()) .map(String::from); Ok(RequestContext { claims, user_agent }) } #[cfg(test)] Loading @@ -30,7 +68,7 @@ mod tests { let request: Request<()> = Request::new(()); let validator = mock_validator(); let result = extract_claims(&request, &validator); let result = extract_request_context(&request, &validator); assert!(result.is_err()); let status = result.unwrap_err(); Loading @@ -47,7 +85,7 @@ mod tests { ); let validator = mock_validator(); let result = extract_claims(&request, &validator); let result = extract_request_context(&request, &validator); assert!(result.is_err()); let status = result.unwrap_err(); Loading @@ -63,11 +101,82 @@ mod tests { ); let validator = mock_validator(); let result = extract_claims(&request, &validator); let result = extract_request_context(&request, &validator); assert!(result.is_err()); let status = result.unwrap_err(); assert_eq!(status.code(), tonic::Code::Unauthenticated); assert!(status.message().contains("JWT validation failed")); } fn request_context(user_agent: Option<&str>) -> RequestContext { RequestContext { claims: crate::auth::Claims { sub: String::new(), iss: String::new(), aud: String::new(), iat: 0, exp: 0, user_id: 0, username: String::new(), admin: false, organization_id: None, min_access_level: None, group_traversal_ids: vec![], source_type: String::new(), ai_session_id: None, instance_id: None, unique_instance_id: None, instance_version: None, global_user_id: None, host_name: None, root_namespace_id: None, deployment_type: None, realm: None, }, user_agent: user_agent.map(Into::into), } } #[test] fn coding_agent_extracts_known_agents() { let cases = [ ( "glab/1.50.0 (linux, amd64) Coding-Agent/claude-code", Some("claude-code"), ), ( "glab/1.50.0 (darwin, arm64) Coding-Agent/codex", Some("codex"), ), ( "glab/1.50.0 (windows, amd64) Coding-Agent/cursor", Some("cursor"), ), ( "glab/1.50.0 (linux, amd64) Coding-Agent/opencode", Some("opencode"), ), ( "glab/DEV (linux, amd64) Coding-Agent/custom-agent_2.1", Some("custom-agent_2.1"), ), ]; for (user_agent, expected) in cases { let ctx = request_context(Some(user_agent)); assert_eq!(ctx.coding_agent(), expected, "for user_agent: {user_agent}"); } } #[test] fn coding_agent_none_when_absent() { let ctx = request_context(Some("glab/1.50.0 (linux, amd64)")); assert_eq!(ctx.coding_agent(), None); } #[test] fn coding_agent_none_when_no_user_agent() { let ctx = request_context(None); assert_eq!(ctx.coding_agent(), None); } }
crates/gkg-server/src/grpc/service.rs +66 −53 Original line number Diff line number Diff line Loading @@ -12,7 +12,7 @@ use tokio_stream::wrappers::ReceiverStream; use tonic::{Request, Response, Status, Streaming}; use tracing::{Instrument, info, instrument}; use super::auth::extract_claims; use super::auth::extract_request_context; use crate::analytics::AnalyticsTracker; use crate::auth::{Claims, JwtValidator, build_security_context}; use crate::cluster_health::ClusterHealthChecker; Loading Loading @@ -41,12 +41,6 @@ fn proto_format_name(name: FormatName) -> ProtoFormatName { } } fn record_ai_session_id(ai_session_id: &Option<String>) { if let Some(sid) = ai_session_id { tracing::Span::current().record("ai_session_id", sid.as_str()); } } fn proto_tool_definition(t: crate::tools::ToolDefinition) -> ProtoToolDefinition { ProtoToolDefinition { name: t.name, Loading Loading @@ -138,15 +132,16 @@ type ExecuteQueryStream = impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService for KnowledgeGraphServiceImpl { #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))] #[instrument( skip(self, request), fields(user_id, source_type, ai_session_id, coding_agent) )] async fn list_tools( &self, request: Request<ListToolsRequest>, ) -> Result<Response<ListToolsResponse>, Status> { let claims = extract_claims(&request, &self.validator)?; tracing::Span::current().record("user_id", claims.user_id); tracing::Span::current().record("source_type", &claims.source_type); record_ai_session_id(&claims.ai_session_id); let ctx = extract_request_context(&request, &self.validator)?; ctx.record_in_current_span(); info!("Listing tools for user"); Loading @@ -158,15 +153,16 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService Ok(Response::new(ListToolsResponse { tools })) } #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))] #[instrument( skip(self, request), fields(user_id, source_type, ai_session_id, coding_agent) )] async fn list_agent_commands( &self, request: Request<ListAgentCommandsRequest>, ) -> Result<Response<ListAgentCommandsResponse>, Status> { let claims = extract_claims(&request, &self.validator)?; tracing::Span::current().record("user_id", claims.user_id); tracing::Span::current().record("source_type", &claims.source_type); record_ai_session_id(&claims.ai_session_id); let ctx = extract_request_context(&request, &self.validator)?; ctx.record_in_current_span(); let req = request.get_ref(); let requested = &req.command_names; Loading Loading @@ -217,15 +213,16 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService })) } #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))] #[instrument( skip(self, request), fields(user_id, source_type, ai_session_id, coding_agent) )] async fn invoke_agent_command( &self, request: Request<InvokeAgentCommandRequest>, ) -> Result<Response<InvokeAgentCommandResponse>, Status> { let claims = extract_claims(&request, &self.validator)?; tracing::Span::current().record("user_id", claims.user_id); tracing::Span::current().record("source_type", &claims.source_type); record_ai_session_id(&claims.ai_session_id); let ctx = extract_request_context(&request, &self.validator)?; ctx.record_in_current_span(); let req = request.get_ref(); if req.command_name.trim().is_empty() { Loading Loading @@ -268,15 +265,18 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService type ExecuteQueryStream = ExecuteQueryStream; #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))] #[instrument( skip(self, request), fields(user_id, source_type, ai_session_id, coding_agent) )] async fn execute_query( &self, request: Request<Streaming<ExecuteQueryMessage>>, ) -> Result<Response<Self::ExecuteQueryStream>, Status> { let claims = extract_claims(&request, &self.validator)?; tracing::Span::current().record("user_id", claims.user_id); tracing::Span::current().record("source_type", &claims.source_type); record_ai_session_id(&claims.ai_session_id); let ctx = extract_request_context(&request, &self.validator)?; ctx.record_in_current_span(); let coding_agent = ctx.coding_agent().map(String::from); let claims = ctx.claims; let mut stream = request.into_inner(); let (tx, rx) = mpsc::channel(4); Loading @@ -298,7 +298,14 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService let timeout = std::time::Duration::from_secs(stream_timeout); let result = pipeline .run_query(claims, &req.query, tx.clone(), stream, timeout) .run_query( claims, coding_agent, &req.query, tx.clone(), stream, timeout, ) .await; match result { Loading Loading @@ -367,15 +374,16 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService Ok(Response::new(Box::pin(ReceiverStream::new(rx)))) } #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))] #[instrument( skip(self, request), fields(user_id, source_type, ai_session_id, coding_agent) )] async fn get_graph_schema( &self, request: Request<GetGraphSchemaRequest>, ) -> Result<Response<GetGraphSchemaResponse>, Status> { let claims = extract_claims(&request, &self.validator)?; tracing::Span::current().record("user_id", claims.user_id); tracing::Span::current().record("source_type", &claims.source_type); record_ai_session_id(&claims.ai_session_id); let ctx = extract_request_context(&request, &self.validator)?; ctx.record_in_current_span(); let req = request.get_ref(); info!(format = ?req.format, "Fetching graph schema for user"); Loading @@ -398,15 +406,16 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService Ok(Response::new(response)) } #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))] #[instrument( skip(self, request), fields(user_id, source_type, ai_session_id, coding_agent) )] async fn get_response_format( &self, request: Request<GetResponseFormatRequest>, ) -> Result<Response<GetResponseFormatResponse>, Status> { let claims = extract_claims(&request, &self.validator)?; tracing::Span::current().record("user_id", claims.user_id); tracing::Span::current().record("source_type", &claims.source_type); record_ai_session_id(&claims.ai_session_id); let ctx = extract_request_context(&request, &self.validator)?; ctx.record_in_current_span(); let req = request.get_ref(); info!(format = ?req.format, "Fetching query response format for user"); Loading Loading @@ -434,15 +443,16 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService Ok(Response::new(response)) } #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))] #[instrument( skip(self, request), fields(user_id, source_type, ai_session_id, coding_agent) )] async fn get_query_dsl( &self, request: Request<GetQueryDslRequest>, ) -> Result<Response<GetQueryDslResponse>, Status> { let claims = extract_claims(&request, &self.validator)?; tracing::Span::current().record("user_id", claims.user_id); tracing::Span::current().record("source_type", &claims.source_type); record_ai_session_id(&claims.ai_session_id); let ctx = extract_request_context(&request, &self.validator)?; ctx.record_in_current_span(); let req = request.get_ref(); info!(format = ?req.format, "Fetching query DSL grammar for user"); Loading @@ -466,15 +476,16 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService Ok(Response::new(response)) } #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))] #[instrument( skip(self, request), fields(user_id, source_type, ai_session_id, coding_agent) )] async fn get_cluster_health( &self, request: Request<GetClusterHealthRequest>, ) -> Result<Response<GetClusterHealthResponse>, Status> { let claims = extract_claims(&request, &self.validator)?; tracing::Span::current().record("user_id", claims.user_id); tracing::Span::current().record("source_type", &claims.source_type); record_ai_session_id(&claims.ai_session_id); let ctx = extract_request_context(&request, &self.validator)?; ctx.record_in_current_span(); let req = request.get_ref(); info!(format = ?req.format, "Fetching cluster health for user"); Loading @@ -483,15 +494,17 @@ impl crate::proto::knowledge_graph_service_server::KnowledgeGraphService Ok(Response::new(response)) } #[instrument(skip(self, request), fields(user_id, source_type, ai_session_id))] #[instrument( skip(self, request), fields(user_id, source_type, ai_session_id, coding_agent) )] async fn get_graph_status( &self, request: Request<GetGraphStatusRequest>, ) -> Result<Response<GetGraphStatusResponse>, Status> { let claims = extract_claims(&request, &self.validator)?; tracing::Span::current().record("user_id", claims.user_id); tracing::Span::current().record("source_type", &claims.source_type); record_ai_session_id(&claims.ai_session_id); let ctx = extract_request_context(&request, &self.validator)?; ctx.record_in_current_span(); let claims = ctx.claims; let req = request.get_ref(); authorize_traversal_path(&claims, &req.traversal_path)?; Loading
crates/gkg-server/src/pipeline/service.rs +2 −0 Original line number Diff line number Diff line Loading @@ -77,6 +77,7 @@ impl QueryPipelineService { pub async fn run_query( &self, claims: Claims, coding_agent: Option<String>, query_json: &str, tx: mpsc::Sender<Result<ExecuteQueryMessage, Status>>, stream: Streaming<ExecuteQueryMessage>, Loading @@ -93,6 +94,7 @@ impl QueryPipelineService { Arc::clone(&self.analytics_config), claims.clone(), "query_graph", coding_agent, SCHEMA_VERSION.to_string(), )), ]); Loading