Verified Commit 028caa80 authored by Michael Usachenko's avatar Michael Usachenko Committed by GitLab
Browse files

chore(querying): more test coverage + cleanup normalize phase

parent ae7525ab
Loading
Loading
Loading
Loading
+40 −0
Original line number Diff line number Diff line
@@ -12,3 +12,43 @@ impl ExtractionStage {
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use arrow::array::{Int64Array, StringArray};
    use arrow::datatypes::{DataType, Field, Schema};
    use arrow::record_batch::RecordBatch;
    use query_engine::ResultContext;
    use std::sync::Arc;

    #[test]
    fn wires_batches_and_context_into_query_result() {
        let schema = Arc::new(Schema::new(vec![
            Field::new("_gkg_p_id", DataType::Int64, false),
            Field::new("_gkg_p_type", DataType::Utf8, false),
        ]));
        let batch = RecordBatch::try_new(
            schema,
            vec![
                Arc::new(Int64Array::from(vec![1, 2])),
                Arc::new(StringArray::from(vec!["Project", "Project"])),
            ],
        )
        .unwrap();

        let mut ctx = ResultContext::new();
        ctx.add_node("p", "Project");

        let output = ExtractionStage::execute(
            ExecutionOutput {
                batches: vec![batch],
                result_context: ctx,
            },
            &PipelineObserver::start(),
        );

        assert_eq!(output.query_result.len(), 2);
        assert!(output.query_result.ctx().get("p").is_some());
    }
}
+64 −0
Original line number Diff line number Diff line
@@ -39,3 +39,67 @@ impl<F: ResultFormatter> FormattingStage<F> {
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use arrow::array::{Int64Array, StringArray};
    use arrow::datatypes::{DataType, Field, Schema};
    use arrow::record_batch::RecordBatch;
    use query_engine::{ParameterizedQuery, ResultContext};
    use serde_json::{Value, json};
    use std::collections::HashMap;

    use crate::redaction::QueryResult;

    struct ConstFormatter(Value);

    impl ResultFormatter for ConstFormatter {
        fn format(&self, _: &QueryResult, _: &ResultContext, _: &Ontology) -> Value {
            self.0.clone()
        }
    }

    #[test]
    fn assembles_output_with_correct_counts() {
        let schema = Arc::new(Schema::new(vec![
            Field::new("_gkg_p_id", DataType::Int64, false),
            Field::new("_gkg_p_type", DataType::Utf8, false),
        ]));
        let batch = RecordBatch::try_new(
            schema,
            vec![
                Arc::new(Int64Array::from(vec![1, 2, 3])),
                Arc::new(StringArray::from(vec!["Project", "Project", "Project"])),
            ],
        )
        .unwrap();

        let mut ctx = ResultContext::new();
        ctx.add_node("p", "Project");

        let mut qr = QueryResult::from_batches(&[batch], &ctx);
        qr.rows_mut()[0].set_unauthorized();

        let input = HydrationOutput {
            result_context: qr.ctx().clone(),
            query_result: qr,
            redacted_count: 1,
        };
        let compiled = CompilationOutput {
            compiled_query: ParameterizedQuery {
                sql: "SELECT 1".to_string(),
                params: HashMap::new(),
                result_context: ResultContext::new(),
            },
        };

        let stage = FormattingStage::new(ConstFormatter(json!(["ok"])), Arc::new(Ontology::new()));
        let output = stage.execute(input, &compiled, &PipelineObserver::start());

        assert_eq!(output.formatted_result, json!(["ok"]));
        assert_eq!(output.generated_sql.as_deref(), Some("SELECT 1"));
        assert_eq!(output.row_count, 2); // 3 total - 1 redacted
        assert_eq!(output.redacted_count, 1);
    }
}
+65 −0
Original line number Diff line number Diff line
@@ -15,3 +15,68 @@ impl RedactionStage {
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use arrow::array::{Int64Array, StringArray};
    use arrow::datatypes::{DataType, Field, Schema};
    use arrow::record_batch::RecordBatch;
    use query_engine::{EntityAuthConfig, ResultContext};
    use std::sync::Arc;

    use crate::redaction::{QueryResult, ResourceAuthorization};

    fn make_input(authorizations: Vec<ResourceAuthorization>) -> AuthorizationOutput {
        let schema = Arc::new(Schema::new(vec![
            Field::new("_gkg_p_id", DataType::Int64, false),
            Field::new("_gkg_p_type", DataType::Utf8, false),
        ]));
        let batch = RecordBatch::try_new(
            schema,
            vec![
                Arc::new(Int64Array::from(vec![10, 20, 30])),
                Arc::new(StringArray::from(vec!["Project", "Project", "Project"])),
            ],
        )
        .unwrap();

        let mut ctx = ResultContext::new();
        ctx.add_node("p", "Project");
        ctx.add_entity_auth(
            "Project",
            EntityAuthConfig {
                resource_type: "project".to_string(),
                ability: "read".to_string(),
                auth_id_column: "id".to_string(),
                owner_entity: None,
            },
        );

        AuthorizationOutput {
            query_result: QueryResult::from_batches(&[batch], &ctx),
            authorizations,
        }
    }

    #[test]
    fn denied_rows_are_redacted() {
        let auth = vec![ResourceAuthorization {
            resource_type: "project".to_string(),
            authorized: [(10, true), (20, false), (30, true)].into_iter().collect(),
        }];

        let output = RedactionStage::execute(make_input(auth), &PipelineObserver::start());

        assert_eq!(output.redacted_count, 1);
        assert_eq!(output.query_result.authorized_count(), 2);
    }

    #[test]
    fn no_authorizations_redacts_all() {
        let output = RedactionStage::execute(make_input(vec![]), &PipelineObserver::start());

        assert_eq!(output.redacted_count, 3);
        assert_eq!(output.query_result.authorized_count(), 0);
    }
}
+1 −1
Original line number Diff line number Diff line
@@ -331,7 +331,7 @@ impl TableRef {
        }
    }

    pub fn union(queries: Vec<Query>, alias: impl Into<String>) -> Self {
    pub fn union_all(queries: Vec<Query>, alias: impl Into<String>) -> Self {
        TableRef::Union {
            queries,
            alias: alias.into(),
+138 −1
Original line number Diff line number Diff line
@@ -131,10 +131,18 @@ fn enforce_return_columns(
            let has_type = q.select.iter().any(|s| s.alias.as_ref() == Some(&type_col));

            if !has_id {
                let id_expr = Expr::col(&node.id, &node.redaction_id_column);
                q.select.push(SelectExpr {
                    expr: Expr::col(&node.id, &node.redaction_id_column),
                    expr: id_expr.clone(),
                    alias: Some(id_col.clone()),
                });
                // Push down id column to aggregation group by if not already present.
                if input.query_type == QueryType::Aggregation
                    && !q.group_by.is_empty()
                    && !q.group_by.contains(&id_expr)
                {
                    q.group_by.push(id_expr);
                }
            }

            if !has_type {
@@ -455,12 +463,141 @@ mod tests {
                .any(|s| s.alias.as_ref() == Some(&"_gkg_n_type".to_string()))
        );

        // Enforced id column should be added to GROUP BY (no duplicate since u.id already present)
        assert_eq!(q.group_by.len(), 1);
        assert_eq!(q.group_by[0], Expr::col("u", "id"));

        // Context should only have the group_by node
        assert_eq!(ctx.len(), 1);
        assert!(ctx.get("u").is_some());
        assert!(ctx.get("n").is_none());
    }

    #[test]
    fn aggregation_adds_redaction_id_to_group_by() {
        use crate::input::{AggFunction, InputAggregation};

        let input = Input {
            query_type: QueryType::Aggregation,
            nodes: vec![
                InputNode {
                    id: "u".to_string(),
                    entity: Some("User".to_string()),
                    table: Some("gl_user".to_string()),
                    ..Default::default()
                },
                InputNode {
                    id: "mr".to_string(),
                    entity: Some("MergeRequest".to_string()),
                    table: Some("gl_merge_request".to_string()),
                    ..Default::default()
                },
            ],
            relationships: vec![],
            aggregations: vec![InputAggregation {
                function: AggFunction::Count,
                target: Some("mr".to_string()),
                group_by: Some("u".to_string()),
                property: None,
                alias: Some("mr_count".to_string()),
            }],
            path: None,
            neighbors: None,
            limit: 10,
            range: None,
            order_by: None,
            aggregation_sort: None,
            entity_auth: Default::default(),
        };

        let query = Query {
            select: vec![SelectExpr {
                expr: Expr::col("u", "username"),
                alias: Some("u_username".into()),
            }],
            from: TableRef::scan("gl_user", "u"),
            group_by: vec![Expr::col("u", "username")],
            limit: Some(10),
            ..Default::default()
        };

        let mut node = Node::Query(Box::new(query));
        enforce_return(&mut node, &input).unwrap();

        let Node::Query(q) = node else {
            panic!("expected Query")
        };

        assert!(
            q.group_by.contains(&Expr::col("u", "id")),
            "redaction id column must be in GROUP BY: {:?}",
            q.group_by
        );
        assert_eq!(q.group_by.len(), 2); // username + id
    }

    #[test]
    fn uses_correct_redaction_id_column_per_node() {
        let mut node = Node::Query(Box::new(Query {
            select: vec![],
            from: TableRef::scan("gl_definition", "d"),
            limit: Some(10),
            ..Default::default()
        }));

        let input = Input {
            query_type: QueryType::Traversal,
            nodes: vec![
                InputNode {
                    id: "d".to_string(),
                    entity: Some("Definition".to_string()),
                    table: Some("gl_definition".to_string()),
                    redaction_id_column: "project_id".to_string(),
                    ..Default::default()
                },
                InputNode {
                    id: "p".to_string(),
                    entity: Some("Project".to_string()),
                    table: Some("gl_project".to_string()),
                    ..Default::default()
                },
            ],
            relationships: vec![],
            aggregations: vec![],
            path: None,
            neighbors: None,
            limit: 10,
            range: None,
            order_by: None,
            aggregation_sort: None,
            entity_auth: Default::default(),
        };

        let ctx = enforce_return(&mut node, &input).unwrap();

        let Node::Query(q) = node else {
            panic!("expected Query")
        };

        assert_eq!(q.select.len(), 4);

        // Definition: custom redaction column + type literal
        assert_eq!(q.select[0].alias, Some("_gkg_d_id".into()));
        assert!(matches!(&q.select[0].expr, Expr::Column { column, .. } if column == "project_id"));
        assert_eq!(q.select[1].alias, Some("_gkg_d_type".into()));
        assert!(matches!(&q.select[1].expr, Expr::Literal(v) if v == "Definition"));

        // Project: default id column + type literal
        assert_eq!(q.select[2].alias, Some("_gkg_p_id".into()));
        assert!(matches!(&q.select[2].expr, Expr::Column { column, .. } if column == "id"));
        assert_eq!(q.select[3].alias, Some("_gkg_p_type".into()));
        assert!(matches!(&q.select[3].expr, Expr::Literal(v) if v == "Project"));

        assert_eq!(ctx.len(), 2);
        assert_eq!(ctx.get("d").unwrap().entity_type, "Definition");
        assert_eq!(ctx.get("p").unwrap().entity_type, "Project");
    }

    #[test]
    fn path_finding_uses_gkg_path_column() {
        use crate::ast::Cte;
Loading