Verified Commit 1ceba831 authored by Michael Angelo Rivera's avatar Michael Angelo Rivera Committed by GitLab
Browse files

fix(query-engine): track edge direction in bidirectional neighbor queries

parent 6136514a
Loading
Loading
Loading
Loading
+90 −9
Original line number Diff line number Diff line
@@ -911,20 +911,22 @@ async fn neighbors_both_exact(ctx: &TestContext) {
        "should have CONTAINS edges"
    );

    // MEMBER_OF edges: User→Group (incoming to center), so from=User, to=Group
    assert!(edges.iter().any(|e| {
        e["from"] == "Group"
            && e["from_id"] == 100
            && e["to"] == "User"
            && e["to_id"] == 1
        e["from"] == "User"
            && e["from_id"] == 1
            && e["to"] == "Group"
            && e["to_id"] == 100
            && e["type"] == "MEMBER_OF"
    }));
    assert!(edges.iter().any(|e| {
        e["from"] == "Group"
            && e["from_id"] == 100
            && e["to"] == "User"
            && e["to_id"] == 5
        e["from"] == "User"
            && e["from_id"] == 5
            && e["to"] == "Group"
            && e["to_id"] == 100
            && e["type"] == "MEMBER_OF"
    }));
    // CONTAINS edge: Group→Project (outgoing from center), so from=Group, to=Project
    assert!(edges.iter().any(|e| {
        e["from"] == "Group"
            && e["from_id"] == 100
@@ -934,12 +936,89 @@ async fn neighbors_both_exact(ctx: &TestContext) {
    }));

    for edge in edges {
        assert_eq!(edge["from_id"], 100, "center node should be from_id");
        assert!(edge["from_id"].is_i64());
        assert!(edge["to_id"].is_i64());
        assert!(edge.get("path_id").is_none());
    }
}

async fn neighbors_both_direction_edges_correct(ctx: &TestContext) {
    seed(ctx).await;

    // User 1 is MEMBER_OF Group 100 (User→Group edge in gl_edge)
    // Query neighbors of User 1 in both directions
    let value = run_pipeline(
        ctx,
        r#"{
            "query_type": "neighbors",
            "node": {"id": "u", "entity": "User", "node_ids": [1]},
            "neighbors": {"node": "u", "direction": "both"}
        }"#,
        &allow_all(),
    )
    .await;

    let edges = value["edges"].as_array().unwrap();
    assert!(!edges.is_empty(), "should have neighbor edges");

    // User 1 is the source of MEMBER_OF edges (outgoing from center)
    // so from=User, to=Group
    for edge in edges {
        if edge["type"] == "MEMBER_OF" {
            assert_eq!(edge["from"], "User", "MEMBER_OF is outgoing from User");
            assert_eq!(edge["from_id"], 1);
        }
    }
}

async fn neighbors_both_direction_mixed_entity(ctx: &TestContext) {
    seed(ctx).await;

    // MR 2000 has incoming AUTHORED from User 1 and outgoing HAS_NOTE to Notes 3000, 3002, 3003
    let value = run_pipeline(
        ctx,
        r#"{
            "query_type": "neighbors",
            "node": {"id": "mr", "entity": "MergeRequest", "node_ids": [2000]},
            "neighbors": {"node": "mr", "direction": "both"}
        }"#,
        &allow_all(),
    )
    .await;

    let edges = value["edges"].as_array().unwrap();
    assert!(!edges.is_empty(), "should have neighbor edges");

    // AUTHORED: User→MergeRequest (incoming to center), so from=User, to=MergeRequest
    assert!(
        edges.iter().any(|e| {
            e["from"] == "User"
                && e["from_id"] == 1
                && e["to"] == "MergeRequest"
                && e["to_id"] == 2000
                && e["type"] == "AUTHORED"
        }),
        "AUTHORED edge should show User as source"
    );

    // HAS_NOTE: MergeRequest→Note (outgoing from center), so from=MergeRequest, to=Note
    let has_note_edges: Vec<_> = edges.iter().filter(|e| e["type"] == "HAS_NOTE").collect();
    assert!(!has_note_edges.is_empty(), "should have HAS_NOTE edges");
    for edge in &has_note_edges {
        assert_eq!(edge["from"], "MergeRequest", "HAS_NOTE is outgoing from MR");
        assert_eq!(edge["from_id"], 2000);
        assert_eq!(edge["to"], "Note");
    }

    let note_ids: HashSet<i64> = has_note_edges
        .iter()
        .filter_map(|e| e["to_id"].as_i64())
        .collect();
    assert!(note_ids.contains(&3000));
    assert!(note_ids.contains(&3002));
    assert!(note_ids.contains(&3003));
}

async fn neighbors_redaction(ctx: &TestContext) {
    seed(ctx).await;

@@ -1933,6 +2012,8 @@ async fn graph_formatter_e2e() {
        neighbors_outgoing_exact,
        neighbors_incoming_exact,
        neighbors_both_exact,
        neighbors_both_direction_edges_correct,
        neighbors_both_direction_mixed_entity,
        neighbors_with_rel_types_filter,
        neighbors_dynamic_columns_all,
        neighbors_redaction,
+5 −0
Original line number Diff line number Diff line
@@ -377,6 +377,11 @@ mod tests {
        assert!(result.base.sql.contains("_gkg_neighbor_id"));
        assert!(result.base.sql.contains("_gkg_neighbor_type"));
        assert!(result.base.sql.contains("_gkg_relationship_type"));
        assert!(
            result.base.sql.contains("_gkg_neighbor_is_outgoing"),
            "bidirectional neighbor query should include direction column: {}",
            result.base.sql
        );
        assert!(result.base.sql.contains("INNER JOIN"));
    }

+61 −9
Original line number Diff line number Diff line
@@ -4,7 +4,8 @@

use crate::ast::{ChType, Cte, Expr, JoinType, Node, Op, OrderExpr, Query, SelectExpr, TableRef};
use crate::constants::{
    EDGE_ALIAS_SUFFIXES, NEIGHBOR_ID_COLUMN, NEIGHBOR_TYPE_COLUMN, RELATIONSHIP_TYPE_COLUMN,
    EDGE_ALIAS_SUFFIXES, NEIGHBOR_ID_COLUMN, NEIGHBOR_IS_OUTGOING_COLUMN, NEIGHBOR_TYPE_COLUMN,
    RELATIONSHIP_TYPE_COLUMN,
};
use crate::error::{QueryError, Result};
use crate::input::{
@@ -446,16 +447,24 @@ fn lower_neighbors(input: &Input) -> Result<Node> {
        join_cond,
    );

    let center_matches_source = Expr::and(
        Expr::eq(
            Expr::col(&center_node.id, DEFAULT_PRIMARY_KEY),
            Expr::col(edge_alias, "source_id"),
        ),
        Expr::eq(
            Expr::col(edge_alias, "source_kind"),
            Expr::string(center_entity),
        ),
    );

    let neighbor_id_expr = match neighbors_config.direction {
        Direction::Outgoing => Expr::col(edge_alias, "target_id"),
        Direction::Incoming => Expr::col(edge_alias, "source_id"),
        Direction::Both => Expr::func(
            "if",
            vec![
                Expr::eq(
                    Expr::col(&center_node.id, DEFAULT_PRIMARY_KEY),
                    Expr::col(edge_alias, "source_id"),
                ),
                center_matches_source.clone(),
                Expr::col(edge_alias, "target_id"),
                Expr::col(edge_alias, "source_id"),
            ],
@@ -468,10 +477,7 @@ fn lower_neighbors(input: &Input) -> Result<Node> {
        Direction::Both => Expr::func(
            "if",
            vec![
                Expr::eq(
                    Expr::col(&center_node.id, DEFAULT_PRIMARY_KEY),
                    Expr::col(edge_alias, "source_id"),
                ),
                center_matches_source.clone(),
                Expr::col(edge_alias, "target_kind"),
                Expr::col(edge_alias, "source_kind"),
            ],
@@ -485,6 +491,17 @@ fn lower_neighbors(input: &Input) -> Result<Node> {
            Expr::col(edge_alias, "relationship_kind"),
            RELATIONSHIP_TYPE_COLUMN,
        ),
        SelectExpr::new(
            match neighbors_config.direction {
                Direction::Outgoing => Expr::int(1),
                Direction::Incoming => Expr::int(0),
                Direction::Both => Expr::func(
                    "if",
                    vec![center_matches_source, Expr::int(1), Expr::int(0)],
                ),
            },
            NEIGHBOR_IS_OUTGOING_COLUMN,
        ),
    ];

    let where_clause = id_filter(&center_node.id, DEFAULT_PRIMARY_KEY, &center_node.node_ids);
@@ -1556,6 +1573,7 @@ mod tests {
        assert!(aliases.contains(&&"_gkg_neighbor_id".to_string()));
        assert!(aliases.contains(&&"_gkg_neighbor_type".to_string()));
        assert!(aliases.contains(&&"_gkg_relationship_type".to_string()));
        assert!(aliases.contains(&&"_gkg_neighbor_is_outgoing".to_string()));

        // Should NOT have raw edge columns (indirect auth uses static/dynamic nodes instead)
        assert!(!aliases.contains(&&"e_path".to_string()));
@@ -1563,6 +1581,40 @@ mod tests {
        assert!(!aliases.contains(&&"e_dst".to_string()));
    }

    #[test]
    fn test_lower_neighbors_both_direction() {
        use crate::input::{Direction, InputNeighbors};

        let input = Input {
            query_type: QueryType::Neighbors,
            nodes: vec![InputNode {
                id: "g".to_string(),
                entity: Some("Group".to_string()),
                table: Some("gl_group".to_string()),
                node_ids: vec![100],
                ..Default::default()
            }],
            neighbors: Some(InputNeighbors {
                node: "g".to_string(),
                direction: Direction::Both,
                rel_types: vec![],
            }),
            limit: 10,
            ..Input::default()
        };

        let Node::Query(q) = lower(&input).unwrap() else {
            panic!("expected Query");
        };

        let aliases: Vec<_> = q.select.iter().filter_map(|s| s.alias.as_ref()).collect();

        assert!(aliases.contains(&&"_gkg_neighbor_is_outgoing".to_string()));
        assert!(aliases.contains(&&"_gkg_neighbor_id".to_string()));
        assert!(aliases.contains(&&"_gkg_neighbor_type".to_string()));
        assert!(aliases.contains(&&"_gkg_relationship_type".to_string()));
    }

    #[test]
    fn test_multi_relationship_has_multiple_edge_columns() {
        let input = validated_input(