Verified Commit 8dbf5767 authored by Michael Angelo Rivera's avatar Michael Angelo Rivera Committed by GitLab
Browse files

fix(query-engine): remove traversal path prefix from join conditions

parent 03812719
Loading
Loading
Loading
Loading
+100 −0
Original line number Diff line number Diff line
@@ -2148,6 +2148,103 @@ async fn traversal_variable_length_with_redaction_at_depth(ctx: &TestContext) {
    );
}

// ─────────────────────────────────────────────────────────────────────────────
// Traversal chains (no traversal_path prefix in joins)
// ─────────────────────────────────────────────────────────────────────────────

async fn traversal_chain_user_group_project(ctx: &TestContext) {
    seed(ctx).await;

    let value = run_pipeline(
        ctx,
        r#"{
            "query_type": "traversal",
            "nodes": [
                {"id": "u", "entity": "User", "columns": ["username"]},
                {"id": "g", "entity": "Group", "columns": ["name"]},
                {"id": "p", "entity": "Project", "columns": ["name"]}
            ],
            "relationships": [
                {"type": "MEMBER_OF", "from": "u", "to": "g"},
                {"type": "CONTAINS", "from": "g", "to": "p"}
            ],
            "limit": 50
        }"#,
        &allow_all(),
    )
    .await;

    assert_valid(&value);
    assert_eq!(value["query_type"], "traversal");

    let nodes = value["nodes"].as_array().unwrap();
    let user_ids = node_ids(nodes, "User");
    let group_ids = node_ids(nodes, "Group");
    let project_ids = node_ids(nodes, "Project");

    assert!(user_ids.contains(&1), "alice should be present");
    assert!(group_ids.contains(&100), "group 100 should be present");
    assert!(
        project_ids.contains(&1000),
        "project 1000 should be present"
    );

    let edges = value["edges"].as_array().unwrap();
    assert!(
        edges.iter().any(|e| e["type"] == "MEMBER_OF"),
        "MEMBER_OF edge should exist"
    );
    assert!(
        edges.iter().any(|e| e["type"] == "CONTAINS"),
        "CONTAINS edge should exist"
    );
}

async fn traversal_chain_user_mr_note(ctx: &TestContext) {
    seed(ctx).await;

    let value = run_pipeline(
        ctx,
        r#"{
            "query_type": "traversal",
            "nodes": [
                {"id": "u", "entity": "User", "columns": ["username"]},
                {"id": "mr", "entity": "MergeRequest", "columns": ["title"]},
                {"id": "n", "entity": "Note", "columns": ["note"]}
            ],
            "relationships": [
                {"type": "AUTHORED", "from": "u", "to": "mr"},
                {"type": "HAS_NOTE", "from": "mr", "to": "n"}
            ],
            "limit": 50
        }"#,
        &allow_all(),
    )
    .await;

    assert_valid(&value);
    assert_eq!(value["query_type"], "traversal");

    let nodes = value["nodes"].as_array().unwrap();
    let user_ids = node_ids(nodes, "User");
    let mr_ids = node_ids(nodes, "MergeRequest");
    let note_ids = node_ids(nodes, "Note");

    assert!(user_ids.contains(&1), "alice should be present");
    assert!(mr_ids.contains(&2000), "MR 2000 should be present");
    assert!(note_ids.contains(&3000), "note 3000 should be present");

    let edges = value["edges"].as_array().unwrap();
    assert!(
        edges.iter().any(|e| e["type"] == "AUTHORED"),
        "AUTHORED edge should exist"
    );
    assert!(
        edges.iter().any(|e| e["type"] == "HAS_NOTE"),
        "HAS_NOTE edge should exist"
    );
}

// ─────────────────────────────────────────────────────────────────────────────
// Test runner
// ─────────────────────────────────────────────────────────────────────────────
@@ -2183,6 +2280,9 @@ async fn graph_formatter_e2e() {
        traversal_both_direction,
        // Traversal — fan-in
        traversal_shared_target_node,
        // Traversal — chains (no traversal_path prefix in joins)
        traversal_chain_user_group_project,
        traversal_chain_user_mr_note,
        // Aggregation — all functions
        aggregation_count_exact,
        aggregation_sum,
+0 −6
Original line number Diff line number Diff line
@@ -172,11 +172,6 @@ pub struct InputNode {
    /// Always set before enforce.rs runs; do not add fallbacks in downstream code.
    #[serde(skip)]
    pub redaction_id_column: String,
    /// Whether this node's table has a `traversal_path` column.
    /// Populated during normalization from the ontology. Used by the lowerer to add
    /// `traversal_path` join conditions between node and edge tables.
    #[serde(skip)]
    pub has_traversal_path: bool,
}

impl Default for InputNode {
@@ -191,7 +186,6 @@ impl Default for InputNode {
            id_range: None,
            id_property: DEFAULT_PRIMARY_KEY.to_string(),
            redaction_id_column: DEFAULT_PRIMARY_KEY.to_string(),
            has_traversal_path: false,
        }
    }
}
+123 −68
Original line number Diff line number Diff line
@@ -18,21 +18,6 @@ use ontology::constants::{
use serde_json::Value;
use std::collections::{HashMap, HashSet};

/// Build `startsWith(edge.traversal_path, node.traversal_path)`.
///
/// The edge's path is always equal to or deeper than the node's path in the
/// namespace hierarchy, so a prefix match is correct for both source and target
/// sides. ClickHouse can still use the ORDER BY key prefix for this predicate.
fn edge_path_starts_with(edge_alias: &str, node_alias: &str) -> Expr {
    Expr::func(
        "startsWith",
        vec![
            Expr::col(edge_alias, TRAVERSAL_PATH_COLUMN),
            Expr::col(node_alias, TRAVERSAL_PATH_COLUMN),
        ],
    )
}

/// Generate SELECT expressions for all edge columns with the given table alias.
fn edge_select_exprs(alias: &str) -> Vec<SelectExpr> {
    EDGE_RESERVED_COLUMNS
@@ -449,7 +434,6 @@ fn lower_neighbors(input: &Input) -> Result<Node> {
        edge_alias,
        center_entity,
        neighbors_config.direction,
        center_node.has_traversal_path,
    );
    if let Some(tc) = edge_type_cond {
        join_cond = Expr::and(join_cond, tc);
@@ -692,25 +676,18 @@ fn build_joins(
            let alias = format!("hop_e{i}");
            edge_aliases.insert(i, alias.clone());

            let from_node = find_node(nodes, &rel.from)?;
            let union = build_hop_union_all(rel, &alias);
            let (from_col, to_col) = rel.direction.union_columns();

            let mut source_cond = Expr::eq(
            let source_cond = Expr::eq(
                Expr::col(&rel.from, DEFAULT_PRIMARY_KEY),
                Expr::col(&alias, from_col),
            );
            if from_node.has_traversal_path {
                source_cond = Expr::and(edge_path_starts_with(&alias, &rel.from), source_cond);
            }

            let mut target_cond = Expr::eq(
            let target_cond = Expr::eq(
                Expr::col(&alias, to_col),
                Expr::col(&rel.to, DEFAULT_PRIMARY_KEY),
            );
            if target.has_traversal_path {
                target_cond = Expr::and(edge_path_starts_with(&alias, &rel.to), target_cond);
            }

            let union_join_cond = match (source_joined, target_joined) {
                (true, true) => Expr::and(source_cond.clone(), target_cond.clone()),
@@ -727,6 +704,7 @@ fn build_joins(
            result = TableRef::join(JoinType::Inner, result, union, union_join_cond);

            if !source_joined {
                let from_node = find_node(nodes, &rel.from)?;
                let source_table = resolve_table(from_node)?;
                result = TableRef::join(
                    JoinType::Inner,
@@ -749,16 +727,9 @@ fn build_joins(
            let alias = format!("e{i}");
            edge_aliases.insert(i, alias.clone());

            let from_node = find_node(nodes, &rel.from)?;
            let (edge, edge_type_cond) = edge_scan(&alias, &type_filter(&rel.types));
            let source_cond = source_join_cond(
                &rel.from,
                &alias,
                rel.direction,
                from_node.has_traversal_path,
            );
            let target_cond =
                target_join_cond(&alias, &rel.to, rel.direction, target.has_traversal_path);
            let source_cond = source_join_cond(&rel.from, &alias, rel.direction);
            let target_cond = target_join_cond(&alias, &rel.to, rel.direction);

            let mut edge_join_cond = match (source_joined, target_joined) {
                (true, true) => Expr::and(source_cond.clone(), target_cond.clone()),
@@ -778,6 +749,7 @@ fn build_joins(
            result = TableRef::join(JoinType::Inner, result, edge, edge_join_cond);

            if !source_joined {
                let from_node = find_node(nodes, &rel.from)?;
                let source_table = resolve_table(from_node)?;
                result = TableRef::join(
                    JoinType::Inner,
@@ -803,12 +775,8 @@ fn build_joins(
}

/// Join from source node to edge table.
/// When `with_path` is true, adds `startsWith(edge.traversal_path, node.traversal_path)`
/// to leverage ClickHouse's ORDER BY key on the edge table. The edge's path
/// is always equal to or deeper than either endpoint's path in the namespace
/// hierarchy, so a prefix match is safe for all directions.
fn source_join_cond(node: &str, edge: &str, dir: Direction, with_path: bool) -> Expr {
    let id_cond = match dir {
fn source_join_cond(node: &str, edge: &str, dir: Direction) -> Expr {
    match dir {
        Direction::Outgoing => Expr::eq(
            Expr::col(node, DEFAULT_PRIMARY_KEY),
            Expr::col(edge, "source_id"),
@@ -827,25 +795,13 @@ fn source_join_cond(node: &str, edge: &str, dir: Direction, with_path: bool) ->
                Expr::col(edge, "target_id"),
            ),
        ),
    };
    if with_path {
        Expr::and(edge_path_starts_with(edge, node), id_cond)
    } else {
        id_cond
    }
}

/// Join from source node to edge table, with entity type filter.
/// Unlike `source_join_cond`, this also filters on source_kind/target_kind
/// to prevent ID collisions across entity types.
/// When `with_path` is true, adds `startsWith(edge.traversal_path, node.traversal_path)`.
fn source_join_cond_with_kind(
    node: &str,
    edge: &str,
    entity: &str,
    dir: Direction,
    with_path: bool,
) -> Expr {
fn source_join_cond_with_kind(node: &str, edge: &str, entity: &str, dir: Direction) -> Expr {
    let id_and_kind = |id_col, kind_col| {
        Expr::and(
            Expr::eq(
@@ -856,26 +812,19 @@ fn source_join_cond_with_kind(
        )
    };

    let id_cond = match dir {
    match dir {
        Direction::Outgoing => id_and_kind("source_id", "source_kind"),
        Direction::Incoming => id_and_kind("target_id", "target_kind"),
        Direction::Both => Expr::or(
            id_and_kind("source_id", "source_kind"),
            id_and_kind("target_id", "target_kind"),
        ),
    };
    if with_path {
        Expr::and(edge_path_starts_with(edge, node), id_cond)
    } else {
        id_cond
    }
}

/// Join from edge table to target node.
///
/// When `with_path` is true, adds `startsWith(edge.traversal_path, node.traversal_path)`.
fn target_join_cond(edge: &str, node: &str, dir: Direction, with_path: bool) -> Expr {
    let id_cond = match dir {
fn target_join_cond(edge: &str, node: &str, dir: Direction) -> Expr {
    match dir {
        Direction::Outgoing => Expr::eq(
            Expr::col(edge, "target_id"),
            Expr::col(node, DEFAULT_PRIMARY_KEY),
@@ -894,11 +843,6 @@ fn target_join_cond(edge: &str, node: &str, dir: Direction, with_path: bool) ->
                Expr::col(node, DEFAULT_PRIMARY_KEY),
            ),
        ),
    };
    if with_path {
        Expr::and(edge_path_starts_with(edge, node), id_cond)
    } else {
        id_cond
    }
}

@@ -1878,4 +1822,115 @@ mod tests {
        let on = extract_join_on(&q.from).expect("expected join");
        assert!(!has_type_filter(on), "wildcard should not have type filter");
    }

    fn contains_starts_with(expr: &Expr) -> bool {
        match expr {
            Expr::FuncCall { name, .. } if name == "startsWith" => true,
            Expr::BinaryOp { left, right, .. } => {
                contains_starts_with(left) || contains_starts_with(right)
            }
            Expr::UnaryOp { expr, .. } => contains_starts_with(expr),
            _ => false,
        }
    }

    fn table_ref_has_starts_with(table_ref: &TableRef) -> bool {
        match table_ref {
            TableRef::Join {
                on, left, right, ..
            } => {
                contains_starts_with(on)
                    || table_ref_has_starts_with(left)
                    || table_ref_has_starts_with(right)
            }
            TableRef::Union { queries, .. } => {
                queries.iter().any(|q| table_ref_has_starts_with(&q.from))
            }
            TableRef::Subquery { query, .. } => table_ref_has_starts_with(&query.from),
            TableRef::Scan { .. } => false,
        }
    }

    #[test]
    fn no_starts_with_in_single_hop_join() {
        let input = validated_input(
            r#"{
            "query_type": "traversal",
            "nodes": [
                {"id": "u", "entity": "User"},
                {"id": "n", "entity": "Note"}
            ],
            "relationships": [{"type": "AUTHORED", "from": "u", "to": "n"}],
            "limit": 10
        }"#,
        );

        let Node::Query(q) = lower(&input).unwrap() else {
            panic!("expected Query");
        };
        assert!(
            !table_ref_has_starts_with(&q.from),
            "single-hop join should not contain startsWith"
        );
    }

    #[test]
    fn no_starts_with_in_multi_hop_join() {
        let input = validated_input(
            r#"{
            "query_type": "traversal",
            "nodes": [
                {"id": "u", "entity": "User"},
                {"id": "p", "entity": "Project"}
            ],
            "relationships": [{
                "type": "MEMBER_OF",
                "from": "u",
                "to": "p",
                "min_hops": 1,
                "max_hops": 3
            }],
            "limit": 10
        }"#,
        );

        let Node::Query(q) = lower(&input).unwrap() else {
            panic!("expected Query");
        };
        assert!(
            !table_ref_has_starts_with(&q.from),
            "multi-hop join should not contain startsWith"
        );
    }

    #[test]
    fn no_starts_with_in_neighbors_join() {
        use crate::input::{Direction, InputNeighbors};

        let input = Input {
            query_type: QueryType::Neighbors,
            nodes: vec![InputNode {
                id: "g".to_string(),
                entity: Some("Group".to_string()),
                table: Some("gl_group".to_string()),
                node_ids: vec![100],
                ..Default::default()
            }],
            neighbors: Some(InputNeighbors {
                node: "g".to_string(),
                direction: Direction::Both,
                rel_types: vec![],
            }),
            limit: 10,
            ..Input::default()
        };

        let Node::Query(q) = lower(&input).unwrap() else {
            panic!("expected Query");
        };
        assert!(
            !table_ref_has_starts_with(&q.from),
            "neighbors join should not contain startsWith"
        );
    }
}
+0 −2
Original line number Diff line number Diff line
@@ -91,8 +91,6 @@ pub fn normalize(mut input: Input, ontology: &Ontology) -> Result<Input> {
            .map(|r| r.id_column.clone())
            .unwrap_or_else(|| DEFAULT_PRIMARY_KEY.to_string());

        node.has_traversal_path = node_entity.has_traversal_path;

        // Expand wildcard/empty column selections to explicit lists for lowering.
        // Redaction columns (_gkg_*) are added separately by enforce.rs.
        match &mut node.columns {