Verified Commit 693ae7ad authored by Jean-Gabriel Doyon's avatar Jean-Gabriel Doyon Committed by GitLab
Browse files

feat(server): add GetGraphStats gRPC endpoint

parent 9385c84c
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -3482,6 +3482,7 @@ dependencies = [
 "testcontainers-modules",
 "tokio",
 "tokio-stream",
 "tonic",
 "tower",
 "zstd",
]
+30 −0
Original line number Diff line number Diff line
@@ -37,6 +37,10 @@ service KnowledgeGraphService {
  // Returns cluster health and component status.
  // Used by GET /api/v4/orbit/status.
  rpc GetClusterHealth(GetClusterHealthRequest) returns (GetClusterHealthResponse);

  // Returns entity counts per domain, scoped by traversal_path prefix.
  // Used by admin dashboards to inspect graph coverage.
  rpc GetGraphStats(GetGraphStatsRequest) returns (GetGraphStatsResponse);
}

// ---------------------------------------------------------------------------
@@ -262,3 +266,29 @@ message ReplicaStatus {
  int32 ready = 1;
  int32 desired = 2;
}

// ---------------------------------------------------------------------------
// GetGraphStats — unary, returns entity counts grouped by domain
// ---------------------------------------------------------------------------

// Request for graph entity counts scoped by traversal_path prefix.
message GetGraphStatsRequest {
  string traversal_path = 1;  // traversal_path prefix to scope counts (e.g. "1/2/")
}

// Response containing entity counts grouped by domain.
message GetGraphStatsResponse {
  repeated GraphStatsDomain domains = 1;
}

// Entity counts for a single domain (e.g. "ci", "core", "plan").
message GraphStatsDomain {
  string name = 1;
  repeated GraphStatsItem items = 2;
}

// Count for a single entity type (e.g. "Project": 42).
message GraphStatsItem {
  string name = 1;
  int64 count = 2;
}
+29 −0
Original line number Diff line number Diff line
use ontology::Ontology;

pub struct GraphStatsInput {
    pub traversal_path: String,
    pub nodes: Vec<NodeStatsTarget>,
}

pub struct NodeStatsTarget {
    pub name: String,
    pub table: String,
}

impl GraphStatsInput {
    pub fn from_ontology(ontology: &Ontology, traversal_path: String) -> Self {
        let nodes = ontology
            .nodes()
            .filter(|node| node.has_traversal_path)
            .map(|node| NodeStatsTarget {
                name: node.name.clone(),
                table: node.destination_table.clone(),
            })
            .collect();

        Self {
            traversal_path,
            nodes,
        }
    }
}
+115 −0
Original line number Diff line number Diff line
use query_engine::compiler::{Expr, Node, Query, SelectExpr, TableRef};

use super::input::{GraphStatsInput, NodeStatsTarget};

pub fn lower(input: &GraphStatsInput) -> Node {
    let mut queries = input
        .nodes
        .iter()
        .map(|node| build_node_query(node, &input.traversal_path));

    let mut first = queries.next().expect("lower() requires at least one node");
    first.union_all = queries.collect();

    Node::Query(Box::new(first))
}

fn build_node_query(node: &NodeStatsTarget, traversal_path: &str) -> Query {
    let alias = "t";

    let select = vec![
        SelectExpr::new(Expr::string(&node.name), "entity"),
        SelectExpr::new(Expr::func("count", vec![]), "cnt"),
    ];

    let from = TableRef::scan(&node.table, alias);

    let deleted_filter = Expr::eq(Expr::col(alias, "_deleted"), Expr::int(0));
    let traversal_filter = Expr::func(
        "startsWith",
        vec![
            Expr::col(alias, "traversal_path"),
            Expr::string(traversal_path),
        ],
    );

    Query {
        select,
        from,
        where_clause: Some(Expr::and(deleted_filter, traversal_filter)),
        ..Default::default()
    }
}

#[cfg(test)]
mod tests {
    use query_engine::compiler::{ResultContext, codegen};

    use super::*;

    fn test_input() -> GraphStatsInput {
        GraphStatsInput {
            traversal_path: "1/2/".to_string(),
            nodes: vec![
                NodeStatsTarget {
                    name: "Project".to_string(),
                    table: "gl_project".to_string(),
                },
                NodeStatsTarget {
                    name: "Group".to_string(),
                    table: "gl_group".to_string(),
                },
                NodeStatsTarget {
                    name: "MergeRequest".to_string(),
                    table: "gl_merge_request".to_string(),
                },
            ],
        }
    }

    #[test]
    fn lower_produces_union_all() {
        let input = test_input();
        let ast = lower(&input);
        let result = codegen(&ast, ResultContext::new()).unwrap();

        assert!(result.sql.contains("UNION ALL"), "SQL: {}", result.sql);
        assert!(result.sql.contains("gl_project"), "SQL: {}", result.sql);
        assert!(result.sql.contains("gl_group"), "SQL: {}", result.sql);
        assert!(
            result.sql.contains("gl_merge_request"),
            "SQL: {}",
            result.sql
        );
    }

    #[test]
    fn every_subquery_has_starts_with_filter() {
        let input = test_input();
        let ast = lower(&input);
        let result = codegen(&ast, ResultContext::new()).unwrap();

        let starts_with_count = result.sql.matches("startsWith").count();
        assert_eq!(
            starts_with_count,
            input.nodes.len(),
            "Each subquery should have startsWith filter. SQL: {}",
            result.sql
        );
    }

    #[test]
    fn every_subquery_has_deleted_filter() {
        let input = test_input();
        let ast = lower(&input);
        let result = codegen(&ast, ResultContext::new()).unwrap();

        let deleted_count = result.sql.matches("_deleted").count();
        assert_eq!(
            deleted_count,
            input.nodes.len(),
            "Each subquery should have _deleted filter. SQL: {}",
            result.sql
        );
    }
}
+183 −0
Original line number Diff line number Diff line
mod input;
mod lower;

use std::collections::HashMap;
use std::sync::Arc;

use arrow::array::{Array, StringArray, UInt64Array};
use clickhouse_client::ArrowClickHouseClient;
use gkg_utils::arrow::ArrowUtils;
use ontology::Ontology;
use query_engine::compiler::{ResultContext, codegen};
use tonic::Status;
use tracing::{debug, info};

use crate::proto::{GetGraphStatsResponse, GraphStatsDomain, GraphStatsItem};

use self::input::GraphStatsInput;

pub struct GraphStatsService {
    client: Arc<ArrowClickHouseClient>,
    ontology: Arc<Ontology>,
}

impl GraphStatsService {
    pub fn new(client: Arc<ArrowClickHouseClient>, ontology: Arc<Ontology>) -> Self {
        Self { client, ontology }
    }

    pub async fn get_stats(&self, traversal_path: &str) -> Result<GetGraphStatsResponse, Status> {
        if traversal_path.is_empty() {
            return Err(Status::invalid_argument("traversal_path is required"));
        }

        let input = GraphStatsInput::from_ontology(&self.ontology, traversal_path.to_string());

        if input.nodes.is_empty() {
            return Ok(GetGraphStatsResponse { domains: vec![] });
        }

        let ast = lower::lower(&input);
        let parameterized = codegen(&ast, ResultContext::new())
            .map_err(|e| Status::internal(format!("codegen error: {e}")))?;

        debug!(sql = %parameterized.sql, "Graph stats query compiled");

        let mut query = self.client.query(&parameterized.sql);
        for (key, param) in &parameterized.params {
            query = ArrowClickHouseClient::bind_param(query, key, &param.value, &param.ch_type);
        }

        let batches = query
            .fetch_arrow()
            .await
            .map_err(|e| Status::internal(format!("ClickHouse error: {e}")))?;

        let mut entity_counts: HashMap<String, i64> = HashMap::new();
        for batch in &batches {
            let Some(entities) = ArrowUtils::get_column_by_name::<StringArray>(batch, "entity")
            else {
                continue;
            };
            let Some(counts) = ArrowUtils::get_column_by_name::<UInt64Array>(batch, "cnt") else {
                continue;
            };
            for row in 0..batch.num_rows() {
                if entities.is_null(row) || counts.is_null(row) {
                    continue;
                }
                let entity = entities.value(row);
                let count = counts.value(row) as i64;
                if let Some(existing) = entity_counts.get_mut(entity) {
                    *existing += count;
                } else {
                    entity_counts.insert(entity.to_string(), count);
                }
            }
        }

        info!(entity_count = entity_counts.len(), "Graph stats fetched");

        let domains = present_domain_response(&self.ontology, &entity_counts);
        Ok(GetGraphStatsResponse { domains })
    }
}

fn present_domain_response(
    ontology: &Ontology,
    entity_counts: &HashMap<String, i64>,
) -> Vec<GraphStatsDomain> {
    ontology
        .domains()
        .map(|domain| {
            let items = domain
                .node_names
                .iter()
                .map(|node_name| GraphStatsItem {
                    name: node_name.clone(),
                    count: entity_counts.get(node_name).copied().unwrap_or(0),
                })
                .collect();

            GraphStatsDomain {
                name: domain.name.clone(),
                items,
            }
        })
        .collect()
}

#[cfg(test)]
mod tests {
    use super::*;

    fn test_ontology() -> Arc<Ontology> {
        Arc::new(Ontology::load_embedded().expect("ontology must load"))
    }

    #[test]
    fn presents_domain_response_groups_by_domain() {
        let ontology = test_ontology();
        let mut entity_counts = HashMap::new();
        entity_counts.insert("Project".to_string(), 42);
        entity_counts.insert("User".to_string(), 10);

        let domains = present_domain_response(&ontology, &entity_counts);

        assert!(!domains.is_empty());

        let core_domain = domains.iter().find(|d| d.name == "core");
        assert!(core_domain.is_some(), "should have core domain");

        let core = core_domain.unwrap();
        let project_item = core.items.iter().find(|i| i.name == "Project");
        assert!(project_item.is_some());
        assert_eq!(project_item.unwrap().count, 42);

        let user_item = core.items.iter().find(|i| i.name == "User");
        assert!(user_item.is_some());
        assert_eq!(user_item.unwrap().count, 10);
    }

    #[test]
    fn presents_domain_response_missing_entity_defaults_to_zero() {
        let ontology = test_ontology();
        let entity_counts = HashMap::new();

        let domains = present_domain_response(&ontology, &entity_counts);

        for domain in &domains {
            for item in &domain.items {
                assert_eq!(
                    item.count, 0,
                    "missing entity {} should default to 0",
                    item.name
                );
            }
        }
    }

    #[test]
    fn presents_domain_response_covers_all_domains() {
        let ontology = test_ontology();
        let entity_counts = HashMap::new();

        let domains = present_domain_response(&ontology, &entity_counts);
        let domain_count = ontology.domains().count();

        assert_eq!(domains.len(), domain_count);
    }

    #[tokio::test]
    async fn empty_traversal_path_rejected() {
        let client = Arc::new(clickhouse_client::ClickHouseConfiguration::default().build_client());
        let service = GraphStatsService::new(client, test_ontology());

        let result = service.get_stats("").await;

        assert!(result.is_err());
        let status = result.unwrap_err();
        assert_eq!(status.code(), tonic::Code::InvalidArgument);
        assert!(status.message().contains("traversal_path"));
    }
}
Loading