Verified Commit f79f0a65 authored by Michael Usachenko's avatar Michael Usachenko Committed by GitLab
Browse files

fix(querying): graph formatter removes node sort order with hashmap - use indexmap instead

parent f9c56de1
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -2783,6 +2783,7 @@ dependencies = [
 "gkg-utils",
 "health-check",
 "indexer",
 "indexmap 2.13.0",
 "jsonschema",
 "jsonwebtoken",
 "labkit-rs",
+1 −0
Original line number Diff line number Diff line
@@ -89,6 +89,7 @@ testcontainers = "0.27.1"
testcontainers-modules = { version = "0.15.0", features = ["nats", "clickhouse"] }

# Query engine dependencies
indexmap = "2"
jsonschema = "0.44.0"
const_format = "0.2.35"

+1 −0
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@ config = { workspace = true }
clickhouse-client = { path = "../clickhouse-client" }
gkg-utils = { path = "../utils" }
health-check = { path = "../health-check" }
indexmap = { workspace = true }
indexer = { path = "../indexer" }
ontology = { path = "../ontology" }
opentelemetry = { workspace = true }
+68 −6
Original line number Diff line number Diff line
use std::collections::{HashMap, HashSet};
use std::collections::HashSet;

use indexmap::IndexMap;
use query_engine::{
    EdgeMeta, NEIGHBOR_IS_OUTGOING_COLUMN, QueryType, RELATIONSHIP_TYPE_COLUMN, ResultContext,
};
@@ -72,7 +73,7 @@ impl GraphFormatter {
            .map(|qt| qt.to_string())
            .unwrap_or_default();

        let mut node_map: HashMap<(String, i64), GraphNode> = HashMap::new();
        let mut node_map: IndexMap<(String, i64), GraphNode> = IndexMap::new();
        let mut edges: Vec<GraphEdge> = Vec::new();
        let mut edge_set: HashSet<EdgeKey> = HashSet::new();

@@ -134,7 +135,7 @@ impl GraphFormatter {
        result: &QueryResult,
        result_context: &ResultContext,
        edge_prefixes: &[&str],
        node_map: &mut HashMap<(String, i64), GraphNode>,
        node_map: &mut IndexMap<(String, i64), GraphNode>,
    ) {
        for row in result.authorized_rows() {
            for node in result_context.nodes() {
@@ -226,7 +227,7 @@ impl GraphFormatter {
        result_context: &ResultContext,
        edge_prefixes: &[&str],
        aggregations: Option<&Vec<query_engine::input::InputAggregation>>,
        node_map: &mut HashMap<(String, i64), GraphNode>,
        node_map: &mut IndexMap<(String, i64), GraphNode>,
    ) {
        let Some(aggs) = aggregations else { return };

@@ -269,7 +270,7 @@ impl GraphFormatter {
    fn extract_path_finding(
        &self,
        result: &QueryResult,
        node_map: &mut HashMap<(String, i64), GraphNode>,
        node_map: &mut IndexMap<(String, i64), GraphNode>,
        edges: &mut Vec<GraphEdge>,
    ) {
        for (row_idx, row) in result.authorized_rows().enumerate() {
@@ -316,7 +317,7 @@ impl GraphFormatter {
        result_context: &ResultContext,
        edge_prefixes: &[&str],
        ctx: &QueryPipelineContext,
        node_map: &mut HashMap<(String, i64), GraphNode>,
        node_map: &mut IndexMap<(String, i64), GraphNode>,
        edges: &mut Vec<GraphEdge>,
    ) {
        let direction = ctx
@@ -408,6 +409,8 @@ impl GraphFormatter {
#[cfg(test)]
mod tests {
    use super::*;
    use std::collections::HashMap;

    use arrow::array::{Int64Array, StringArray};
    use arrow::datatypes::{DataType, Field, Schema};
    use arrow::record_batch::RecordBatch;
@@ -547,4 +550,63 @@ mod tests {

        assert_eq!(response.nodes.len(), 2);
    }

    #[test]
    fn node_ordering_matches_row_order() {
        let schema = Arc::new(Schema::new(vec![
            Field::new("_gkg_p_id", DataType::Int64, false),
            Field::new("_gkg_p_type", DataType::Utf8, false),
            Field::new("p_name", DataType::Utf8, false),
        ]));
        let batch = RecordBatch::try_new(
            schema,
            vec![
                Arc::new(Int64Array::from(vec![3, 1, 4, 1, 5, 2])),
                Arc::new(StringArray::from(vec![
                    "Project", "Project", "Project", "Project", "Project", "Project",
                ])),
                Arc::new(StringArray::from(vec![
                    "Charlie", "Alpha", "Delta", "Alpha", "Echo", "Beta",
                ])),
            ],
        )
        .unwrap();

        let mut result_ctx = ResultContext::new();
        result_ctx.add_node("p", "Project");
        result_ctx.query_type = Some(QueryType::Search);

        let qr = QueryResult::from_batches(&[batch], &result_ctx);

        let ctx = QueryPipelineContext {
            compiled: Some(Arc::new(CompiledQueryContext {
                query_type: QueryType::Search,
                base: ParameterizedQuery {
                    sql: "SELECT 1".to_string(),
                    params: HashMap::new(),
                    result_context: ResultContext::new(),
                },
                hydration: HydrationPlan::None,
                input: serde_json::from_value(serde_json::json!({
                    "query_type": "search",
                    "node": {"id": "p", "entity": "Project"},
                    "limit": 10
                }))
                .unwrap(),
            })),
            ontology: Arc::new(Ontology::new()),
            client: crate::query_pipeline::types::dummy_clickhouse_client(),
            security_context: None,
        };

        let formatter = GraphFormatter;
        let response = formatter.build_response(&qr, &result_ctx, &ctx);

        let ids: Vec<i64> = response.nodes.iter().map(|n| n.id).collect();
        assert_eq!(
            ids,
            vec![3, 1, 4, 5, 2],
            "node order must match row order (dedup keeps first occurrence)"
        );
    }
}