Verified Commit 698c80ae authored by Michael Usachenko's avatar Michael Usachenko Committed by GitLab
Browse files

chore(testing): harden assertion enforcement

parent a8e3f5bf
Loading
Loading
Loading
Loading
+31 −16
Original line number Diff line number Diff line
@@ -99,7 +99,7 @@ impl QueryRequirements for Input {
            QueryType::Neighbors => {
                reqs.insert(Requirement::Neighbors);
            }
            _ => {}
            QueryType::Traversal | QueryType::Search => {}
        }

        // Traversal queries with joins produce edges the test must verify per type.
@@ -445,7 +445,7 @@ mod tests {
        );
        let view = ResponseView::for_query(&input, sample_search_response());
        view.assert_filter("User", "username", |n| n.prop_str("username").is_some());
        view.assert_filter("User", "state", |_| true);
        view.assert_filter("User", "state", |n| n.prop_str("username").is_some());
    }

    #[test]
@@ -458,7 +458,7 @@ mod tests {
                "limit": 10}"#,
        );
        let view = ResponseView::for_query(&input, sample_search_response());
        view.assert_filter("User", "username", |_| true);
        view.assert_filter("User", "username", |n| n.prop_str("username").is_some());
        drop(view);
    }

@@ -470,7 +470,7 @@ mod tests {
                "limit": 10}"#,
        );
        let view = ResponseView::for_query(&input, sample_search_response());
        let _ = view.node_ids("User");
        let _ = view.node_ids("User").into_inner();
    }

    #[test]
@@ -482,7 +482,7 @@ mod tests {
                "limit": 10}"#,
        );
        let view = ResponseView::for_query(&input, sample_aggregation_response());
        view.assert_node("User", 1, |_| true);
        view.assert_node("User", 1, |n| n.prop_str("username") == Some("alice"));
    }

    #[test]
@@ -499,7 +499,7 @@ mod tests {
            edges: vec![make_path_edge("User", 1, "Project", 1000, "CONTAINS", 0, 0)],
        };
        let view = ResponseView::for_query(&input, resp);
        let _ = view.path_ids();
        let _ = view.path_ids().into_inner();
    }

    #[test]
@@ -514,7 +514,7 @@ mod tests {
                "limit": 10}"#,
        );
        let view = ResponseView::for_query(&input, sample_response());
        let _ = view.edges_of_type("MEMBER_OF");
        let _ = view.edges_of_type("MEMBER_OF").into_inner();
    }

    #[test]
@@ -533,7 +533,8 @@ mod tests {
    }

    #[test]
    fn for_query_relationship_satisfied_by_assert_edge_absent() {
    #[should_panic(expected = "unsatisfied assertion requirements")]
    fn for_query_relationship_not_satisfied_by_assert_edge_absent() {
        let input = parse_test_input(
            r#"{"query_type": "traversal",
                "nodes": [
@@ -563,8 +564,8 @@ mod tests {
                "limit": 10}"#,
        );
        let view = ResponseView::for_query(&input, sample_response());
        let _ = view.edges_of_type("MEMBER_OF");
        let _ = view.edges_of_type("CONTAINS");
        let _ = view.edges_of_type("MEMBER_OF").into_inner();
        let _ = view.edges_of_type("CONTAINS").into_inner();
    }

    #[test]
@@ -584,7 +585,7 @@ mod tests {
                "limit": 10}"#,
        );
        let view = ResponseView::for_query(&input, sample_response());
        let _ = view.edges_of_type("MEMBER_OF");
        let _ = view.edges_of_type("MEMBER_OF").into_inner();
        drop(view);
    }

@@ -597,7 +598,7 @@ mod tests {
        );
        let view = ResponseView::for_query(&input, sample_neighbors_response());
        view.assert_edge_exists("User", 1, "Group", 100, "MEMBER_OF");
        let _ = view.node_ids("User");
        let _ = view.node_ids("User").into_inner();
    }

    #[test]
@@ -608,8 +609,8 @@ mod tests {
                "neighbors": {"node": "u", "direction": "outgoing"}}"#,
        );
        let view = ResponseView::for_query(&input, sample_neighbors_response());
        let _ = view.edges_of_type("MEMBER_OF");
        let _ = view.node_ids("User");
        let _ = view.edges_of_type("MEMBER_OF").into_inner();
        let _ = view.node_ids("User").into_inner();
    }

    #[test]
@@ -623,7 +624,7 @@ mod tests {
        );
        let view = ResponseView::for_query(&input, sample_aggregation_response());
        view.assert_node_order("User", &[1, 2]);
        view.assert_node("User", 1, |_| true);
        view.assert_node("User", 1, |n| n.prop_str("username") == Some("alice"));
    }

    #[test]
@@ -701,7 +702,7 @@ mod tests {
                "neighbors": {"node": "u", "direction": "outgoing"}}"#,
        );
        let view = ResponseView::for_query(&input, sample_neighbors_response());
        let _ = view.node_ids("User");
        let _ = view.node_ids("User").into_inner();
    }

    #[test]
@@ -751,4 +752,18 @@ mod tests {
        let view = ResponseView::new(sample_response());
        drop(view);
    }

    #[test]
    #[should_panic(expected = "trivial predicate")]
    fn assert_node_rejects_trivial_predicate() {
        let view = ResponseView::new(sample_response());
        view.assert_node("User", 1, |_| true);
    }

    #[test]
    #[should_panic(expected = "trivial predicate")]
    fn assert_filter_rejects_trivial_predicate() {
        let view = ResponseView::new(sample_search_response());
        view.assert_filter("User", "username", |_| true);
    }
}
+142 −22
Original line number Diff line number Diff line
@@ -29,6 +29,74 @@ use gkg_server::query_pipeline::{GraphEdge, GraphNode, GraphResponse};
use query_engine::input::{Input, QueryType};
use serde_json::Value;

// ─────────────────────────────────────────────────────────────────────────────
// MustInspect — drop-enforced result wrapper
// ─────────────────────────────────────────────────────────────────────────────

/// Wrapper that panics on drop if the inner value was never inspected.
///
/// Returned by [`ResponseView`] methods that satisfy enforcement requirements
/// (`node_ids`, `edges_of_type`, `path_ids`). Transparent in normal use —
/// implements [`Deref`], [`PartialEq`], and [`Debug`] so callers can compare,
/// iterate, or call methods without ceremony. Panics on drop only if the
/// value was never accessed at all (the "satisfy and discard" pattern).
///
/// Use [`into_inner`](Self::into_inner) to take ownership when needed
/// (e.g. in enforcement tests that satisfy the tracker without data checks).
pub struct MustInspect<T> {
    value: Option<T>,
    accessed: std::cell::Cell<bool>,
    context: &'static str,
}

impl<T> MustInspect<T> {
    fn new(value: T, context: &'static str) -> Self {
        Self {
            value: Some(value),
            accessed: std::cell::Cell::new(false),
            context,
        }
    }

    /// Extract the inner value, consuming the wrapper.
    pub fn into_inner(mut self) -> T {
        self.accessed.set(true);
        self.value.take().unwrap()
    }
}

impl<T> std::ops::Deref for MustInspect<T> {
    type Target = T;
    fn deref(&self) -> &T {
        self.accessed.set(true);
        self.value.as_ref().unwrap()
    }
}

impl<T: PartialEq> PartialEq<T> for MustInspect<T> {
    fn eq(&self, other: &T) -> bool {
        self.accessed.set(true);
        self.value.as_ref().unwrap() == other
    }
}

impl<T: std::fmt::Debug> std::fmt::Debug for MustInspect<T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        self.value.as_ref().unwrap().fmt(f)
    }
}

impl<T> Drop for MustInspect<T> {
    fn drop(&mut self) {
        if !self.accessed.get() && !std::thread::panicking() {
            panic!(
                "{}: return value was discarded without inspection",
                self.context
            );
        }
    }
}

// ─────────────────────────────────────────────────────────────────────────────
// ResponseView
// ─────────────────────────────────────────────────────────────────────────────
@@ -67,11 +135,14 @@ impl ResponseView {
    /// - Search and aggregation responses have zero edges (the formatter never
    ///   produces edges for these query types)
    pub fn for_query(input: &Input, response: GraphResponse) -> Self {
        let expected_type: &str = input.query_type.into();
        let response_type: QueryType = serde_json::from_value(Value::String(
            response.query_type.clone(),
        ))
        .unwrap_or_else(|_| panic!("unknown response query_type '{}'", response.query_type));
        assert_eq!(
            response.query_type, expected_type,
            "response query_type '{}' does not match input '{expected_type}'",
            response.query_type,
            response_type, input.query_type,
            "response query_type '{}' does not match input '{}'",
            response.query_type, input.query_type,
        );

        if matches!(input.query_type, QueryType::Search | QueryType::Aggregation) {
@@ -110,9 +181,11 @@ impl ResponseView {
        self.response.nodes.len()
    }

    /// Assert exact node count. Satisfies [`Requirement::NodeIds`] and [`Requirement::Range`].
    /// Assert exact node count. Satisfies [`Requirement::Range`].
    ///
    /// Does NOT satisfy [`Requirement::NodeIds`] — use [`node_ids`](Self::node_ids)
    /// or [`assert_node_order`](Self::assert_node_order) to verify which IDs were returned.
    pub fn assert_node_count(&self, expected: usize) {
        self.tracker.satisfy(Requirement::NodeIds);
        self.tracker.satisfy(Requirement::Range);
        assert_eq!(
            self.response.nodes.len(),
@@ -144,14 +217,16 @@ impl ResponseView {
    }

    /// Satisfies [`Requirement::NodeIds`].
    pub fn node_ids(&self, entity_type: &str) -> HashSet<i64> {
    pub fn node_ids(&self, entity_type: &str) -> MustInspect<HashSet<i64>> {
        self.tracker.satisfy(Requirement::NodeIds);
        self.response
        let ids = self
            .response
            .nodes
            .iter()
            .filter(|n| n.entity_type == entity_type)
            .map(|n| n.id)
            .collect()
            .collect();
        MustInspect::new(ids, "node_ids()")
    }

    /// Return IDs of nodes with the given type, preserving response order.
@@ -208,16 +283,18 @@ impl ResponseView {
    }

    /// Satisfies [`Requirement::Relationship`] for the given edge type, and [`Requirement::Neighbors`].
    pub fn edges_of_type(&self, edge_type: &str) -> Vec<&GraphEdge> {
    pub fn edges_of_type(&self, edge_type: &str) -> MustInspect<Vec<&GraphEdge>> {
        self.tracker.satisfy(Requirement::Relationship {
            edge_type: edge_type.to_string(),
        });
        self.tracker.satisfy(Requirement::Neighbors);
        self.response
        let edges = self
            .response
            .edges
            .iter()
            .filter(|e| e.edge_type == edge_type)
            .collect()
            .collect();
        MustInspect::new(edges, "edges_of_type()")
    }

    pub fn edge_tuples(&self) -> HashSet<(String, i64, String, i64, String)> {
@@ -241,13 +318,15 @@ impl ResponseView {
    /// Tests should use this to discover which paths exist, then call
    /// [`path`] for each one explicitly.
    /// Satisfies [`Requirement::PathFinding`].
    pub fn path_ids(&self) -> HashSet<usize> {
    pub fn path_ids(&self) -> MustInspect<HashSet<usize>> {
        self.tracker.satisfy(Requirement::PathFinding);
        self.response
        let ids = self
            .response
            .edges
            .iter()
            .filter_map(|e| e.path_id)
            .collect()
            .collect();
        MustInspect::new(ids, "path_ids()")
    }

    /// Return edges belonging to a specific `path_id`, sorted by `step`.
@@ -298,9 +377,15 @@ impl ResponseView {
    }

    /// Assert a node exists and satisfies a predicate.
    ///
    /// Panics if the predicate also passes for a blank node (same type/id,
    /// empty properties) — this catches trivial predicates like `|_| true`
    /// that don't actually inspect the data.
    ///
    /// Satisfies [`Requirement::Aggregation`] (property value was checked).
    pub fn assert_node(&self, entity_type: &str, id: i64, predicate: impl Fn(&GraphNode) -> bool) {
        self.tracker.satisfy(Requirement::Aggregation);
        assert_predicate_is_nontrivial(entity_type, id, &predicate);
        let node = self
            .find_node(entity_type, id)
            .unwrap_or_else(|| panic!("node {entity_type}:{id} not found"));
@@ -331,7 +416,12 @@ impl ResponseView {
        );
    }

    /// Satisfies [`Requirement::Relationship`] for the given edge type, and [`Requirement::Neighbors`].
    /// Assert that a specific edge does NOT exist.
    ///
    /// Does NOT satisfy [`Requirement::Relationship`] or [`Requirement::Neighbors`] —
    /// a negative assertion proves nothing about what edges exist. Use
    /// [`assert_edge_exists`](Self::assert_edge_exists) or
    /// [`edges_of_type`](Self::edges_of_type) for positive verification.
    pub fn assert_edge_absent(
        &self,
        from: &str,
@@ -340,10 +430,6 @@ impl ResponseView {
        to_id: i64,
        edge_type: &str,
    ) {
        self.tracker.satisfy(Requirement::Relationship {
            edge_type: edge_type.to_string(),
        });
        self.tracker.satisfy(Requirement::Neighbors);
        assert!(
            self.find_edge(from, from_id, to, to_id, edge_type)
                .is_none(),
@@ -368,6 +454,13 @@ impl ResponseView {
    /// Assert that a filter on `field` produced correct results for nodes of
    /// `entity_type`. Checks that every node of the given type satisfies the predicate.
    ///
    /// Panics if:
    /// - Zero nodes match `entity_type` (use [`assert_node_count`](Self::assert_node_count)
    ///   to assert empty results instead — `assert_filter` requires at least one node
    ///   because there is nothing to run the predicate against).
    /// - The predicate passes for a blank node with no properties
    ///   (catches trivial predicates like `|_| true`).
    ///
    /// Satisfies [`Requirement::Filter`] for the specific `field`.
    pub fn assert_filter(
        &self,
@@ -378,12 +471,19 @@ impl ResponseView {
        self.tracker.satisfy(Requirement::Filter {
            field: field.to_string(),
        });
        for node in self
        assert_predicate_is_nontrivial(entity_type, 0, &predicate);
        let matching: Vec<&GraphNode> = self
            .response
            .nodes
            .iter()
            .filter(|n| n.entity_type == entity_type)
        {
            .collect();
        assert!(
            !matching.is_empty(),
            "assert_filter('{entity_type}', '{field}'): zero nodes of type '{entity_type}' \
             in response — use assert_node_count(0) to assert empty results",
        );
        for node in matching {
            assert!(
                predicate(node),
                "{}:{} failed filter assertion on '{field}'",
@@ -422,6 +522,26 @@ impl ResponseView {
    }
}

/// Panic if the predicate returns `true` for a blank node (same type/id, no
/// properties). Catches trivial predicates like `|_| true` or `|n| n.has_prop("x")`
/// that don't actually verify a value.
fn assert_predicate_is_nontrivial(
    entity_type: &str,
    id: i64,
    predicate: &impl Fn(&GraphNode) -> bool,
) {
    let blank = GraphNode {
        entity_type: entity_type.to_string(),
        id,
        properties: serde_json::Map::new(),
    };
    assert!(
        !predicate(&blank),
        "trivial predicate: passes for a blank {entity_type} node with no properties. \
         Check actual property values instead of using |_| true or has_prop().",
    );
}

// ─────────────────────────────────────────────────────────────────────────────
// NodeExt — typed property access for GraphNode
// ─────────────────────────────────────────────────────────────────────────────
+9 −8
Original line number Diff line number Diff line
@@ -526,6 +526,7 @@ async fn traversal_redaction_removes_unauthorized_data(ctx: &TestContext) {
    assert_eq!(resp.node_ids("Group"), HashSet::from([100]));
    resp.assert_node_absent("User", 2);
    resp.assert_node_absent("Group", 102);
    resp.assert_edge_exists("User", 1, "Group", 100, "MEMBER_OF");
    resp.assert_edge_absent("User", 1, "Group", 102, "MEMBER_OF");
}

@@ -622,7 +623,7 @@ async fn path_finding_returns_valid_complete_paths(ctx: &TestContext) {
        "exactly one shortest path from User 1 to Project 1000"
    );

    for &pid in &pids {
    for &pid in pids.iter() {
        let path = resp.path(pid);
        assert_eq!(path.len(), 2, "path {pid}: User→Group→Project = 2 edges");

@@ -707,7 +708,7 @@ async fn path_finding_consecutive_edges_connect(ctx: &TestContext) {
        "exactly 2 paths: to 1000 (via 100) and 1004 (via 102)"
    );

    for &pid in &pids {
    for &pid in pids.iter() {
        let path = resp.path(pid);
        assert_eq!(path.len(), 2, "path {pid}: User→Group→Project = 2 edges");
        for window in path.windows(2) {
@@ -747,8 +748,10 @@ async fn neighbors_outgoing_returns_correct_targets(ctx: &TestContext) {
    resp.assert_edge_exists("User", 1, "Group", 100, "MEMBER_OF");
    resp.assert_edge_exists("User", 1, "Group", 102, "MEMBER_OF");

    resp.assert_node("Group", 100, |n| n.has_prop("name"));
    resp.assert_node("Group", 102, |n| n.has_prop("name"));
    resp.assert_node("Group", 100, |n| n.prop_str("name") == Some("Public Group"));
    resp.assert_node("Group", 102, |n| {
        n.prop_str("name") == Some("Internal Group")
    });
}

async fn neighbors_incoming_returns_correct_sources(ctx: &TestContext) {
@@ -807,10 +810,8 @@ async fn neighbors_both_direction_returns_all_connected(ctx: &TestContext) {
    let user_ids = resp.node_ids("User");
    let project_ids = resp.node_ids("Project");

    assert!(user_ids.contains(&1));
    assert!(user_ids.contains(&2));
    assert!(project_ids.contains(&1000));
    assert!(project_ids.contains(&1002));
    assert_eq!(user_ids, HashSet::from([1, 2]));
    assert_eq!(project_ids, HashSet::from([1000, 1002]));

    resp.assert_referential_integrity();
    resp.assert_edge_exists("User", 1, "Group", 100, "MEMBER_OF");