Loading crates/gkg-server/src/query_pipeline/stages/extraction.rs +40 −0 Original line number Diff line number Diff line Loading @@ -12,3 +12,43 @@ impl ExtractionStage { } } } #[cfg(test)] mod tests { use super::*; use arrow::array::{Int64Array, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use query_engine::ResultContext; use std::sync::Arc; #[test] fn wires_batches_and_context_into_query_result() { let schema = Arc::new(Schema::new(vec![ Field::new("_gkg_p_id", DataType::Int64, false), Field::new("_gkg_p_type", DataType::Utf8, false), ])); let batch = RecordBatch::try_new( schema, vec![ Arc::new(Int64Array::from(vec![1, 2])), Arc::new(StringArray::from(vec!["Project", "Project"])), ], ) .unwrap(); let mut ctx = ResultContext::new(); ctx.add_node("p", "Project"); let output = ExtractionStage::execute( ExecutionOutput { batches: vec![batch], result_context: ctx, }, &PipelineObserver::start(), ); assert_eq!(output.query_result.len(), 2); assert!(output.query_result.ctx().get("p").is_some()); } } crates/gkg-server/src/query_pipeline/stages/formatting.rs +64 −0 Original line number Diff line number Diff line Loading @@ -39,3 +39,67 @@ impl<F: ResultFormatter> FormattingStage<F> { } } } #[cfg(test)] mod tests { use super::*; use arrow::array::{Int64Array, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use query_engine::{ParameterizedQuery, ResultContext}; use serde_json::{Value, json}; use std::collections::HashMap; use crate::redaction::QueryResult; struct ConstFormatter(Value); impl ResultFormatter for ConstFormatter { fn format(&self, _: &QueryResult, _: &ResultContext, _: &Ontology) -> Value { self.0.clone() } } #[test] fn assembles_output_with_correct_counts() { let schema = Arc::new(Schema::new(vec![ Field::new("_gkg_p_id", DataType::Int64, false), Field::new("_gkg_p_type", DataType::Utf8, false), ])); let batch = RecordBatch::try_new( schema, vec![ Arc::new(Int64Array::from(vec![1, 2, 3])), Arc::new(StringArray::from(vec!["Project", "Project", "Project"])), ], ) .unwrap(); let mut ctx = ResultContext::new(); ctx.add_node("p", "Project"); let mut qr = QueryResult::from_batches(&[batch], &ctx); qr.rows_mut()[0].set_unauthorized(); let input = HydrationOutput { result_context: qr.ctx().clone(), query_result: qr, redacted_count: 1, }; let compiled = CompilationOutput { compiled_query: ParameterizedQuery { sql: "SELECT 1".to_string(), params: HashMap::new(), result_context: ResultContext::new(), }, }; let stage = FormattingStage::new(ConstFormatter(json!(["ok"])), Arc::new(Ontology::new())); let output = stage.execute(input, &compiled, &PipelineObserver::start()); assert_eq!(output.formatted_result, json!(["ok"])); assert_eq!(output.generated_sql.as_deref(), Some("SELECT 1")); assert_eq!(output.row_count, 2); // 3 total - 1 redacted assert_eq!(output.redacted_count, 1); } } crates/gkg-server/src/query_pipeline/stages/redaction.rs +65 −0 Original line number Diff line number Diff line Loading @@ -15,3 +15,68 @@ impl RedactionStage { } } } #[cfg(test)] mod tests { use super::*; use arrow::array::{Int64Array, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use query_engine::{EntityAuthConfig, ResultContext}; use std::sync::Arc; use crate::redaction::{QueryResult, ResourceAuthorization}; fn make_input(authorizations: Vec<ResourceAuthorization>) -> AuthorizationOutput { let schema = Arc::new(Schema::new(vec![ Field::new("_gkg_p_id", DataType::Int64, false), Field::new("_gkg_p_type", DataType::Utf8, false), ])); let batch = RecordBatch::try_new( schema, vec![ Arc::new(Int64Array::from(vec![10, 20, 30])), Arc::new(StringArray::from(vec!["Project", "Project", "Project"])), ], ) .unwrap(); let mut ctx = ResultContext::new(); ctx.add_node("p", "Project"); ctx.add_entity_auth( "Project", EntityAuthConfig { resource_type: "project".to_string(), ability: "read".to_string(), auth_id_column: "id".to_string(), owner_entity: None, }, ); AuthorizationOutput { query_result: QueryResult::from_batches(&[batch], &ctx), authorizations, } } #[test] fn denied_rows_are_redacted() { let auth = vec![ResourceAuthorization { resource_type: "project".to_string(), authorized: [(10, true), (20, false), (30, true)].into_iter().collect(), }]; let output = RedactionStage::execute(make_input(auth), &PipelineObserver::start()); assert_eq!(output.redacted_count, 1); assert_eq!(output.query_result.authorized_count(), 2); } #[test] fn no_authorizations_redacts_all() { let output = RedactionStage::execute(make_input(vec![]), &PipelineObserver::start()); assert_eq!(output.redacted_count, 3); assert_eq!(output.query_result.authorized_count(), 0); } } crates/query-engine/src/ast.rs +1 −1 Original line number Diff line number Diff line Loading @@ -331,7 +331,7 @@ impl TableRef { } } pub fn union(queries: Vec<Query>, alias: impl Into<String>) -> Self { pub fn union_all(queries: Vec<Query>, alias: impl Into<String>) -> Self { TableRef::Union { queries, alias: alias.into(), Loading crates/query-engine/src/enforce.rs +138 −1 Original line number Diff line number Diff line Loading @@ -131,10 +131,18 @@ fn enforce_return_columns( let has_type = q.select.iter().any(|s| s.alias.as_ref() == Some(&type_col)); if !has_id { let id_expr = Expr::col(&node.id, &node.redaction_id_column); q.select.push(SelectExpr { expr: Expr::col(&node.id, &node.redaction_id_column), expr: id_expr.clone(), alias: Some(id_col.clone()), }); // Push down id column to aggregation group by if not already present. if input.query_type == QueryType::Aggregation && !q.group_by.is_empty() && !q.group_by.contains(&id_expr) { q.group_by.push(id_expr); } } if !has_type { Loading Loading @@ -455,12 +463,141 @@ mod tests { .any(|s| s.alias.as_ref() == Some(&"_gkg_n_type".to_string())) ); // Enforced id column should be added to GROUP BY (no duplicate since u.id already present) assert_eq!(q.group_by.len(), 1); assert_eq!(q.group_by[0], Expr::col("u", "id")); // Context should only have the group_by node assert_eq!(ctx.len(), 1); assert!(ctx.get("u").is_some()); assert!(ctx.get("n").is_none()); } #[test] fn aggregation_adds_redaction_id_to_group_by() { use crate::input::{AggFunction, InputAggregation}; let input = Input { query_type: QueryType::Aggregation, nodes: vec![ InputNode { id: "u".to_string(), entity: Some("User".to_string()), table: Some("gl_user".to_string()), ..Default::default() }, InputNode { id: "mr".to_string(), entity: Some("MergeRequest".to_string()), table: Some("gl_merge_request".to_string()), ..Default::default() }, ], relationships: vec![], aggregations: vec![InputAggregation { function: AggFunction::Count, target: Some("mr".to_string()), group_by: Some("u".to_string()), property: None, alias: Some("mr_count".to_string()), }], path: None, neighbors: None, limit: 10, range: None, order_by: None, aggregation_sort: None, entity_auth: Default::default(), }; let query = Query { select: vec![SelectExpr { expr: Expr::col("u", "username"), alias: Some("u_username".into()), }], from: TableRef::scan("gl_user", "u"), group_by: vec![Expr::col("u", "username")], limit: Some(10), ..Default::default() }; let mut node = Node::Query(Box::new(query)); enforce_return(&mut node, &input).unwrap(); let Node::Query(q) = node else { panic!("expected Query") }; assert!( q.group_by.contains(&Expr::col("u", "id")), "redaction id column must be in GROUP BY: {:?}", q.group_by ); assert_eq!(q.group_by.len(), 2); // username + id } #[test] fn uses_correct_redaction_id_column_per_node() { let mut node = Node::Query(Box::new(Query { select: vec![], from: TableRef::scan("gl_definition", "d"), limit: Some(10), ..Default::default() })); let input = Input { query_type: QueryType::Traversal, nodes: vec![ InputNode { id: "d".to_string(), entity: Some("Definition".to_string()), table: Some("gl_definition".to_string()), redaction_id_column: "project_id".to_string(), ..Default::default() }, InputNode { id: "p".to_string(), entity: Some("Project".to_string()), table: Some("gl_project".to_string()), ..Default::default() }, ], relationships: vec![], aggregations: vec![], path: None, neighbors: None, limit: 10, range: None, order_by: None, aggregation_sort: None, entity_auth: Default::default(), }; let ctx = enforce_return(&mut node, &input).unwrap(); let Node::Query(q) = node else { panic!("expected Query") }; assert_eq!(q.select.len(), 4); // Definition: custom redaction column + type literal assert_eq!(q.select[0].alias, Some("_gkg_d_id".into())); assert!(matches!(&q.select[0].expr, Expr::Column { column, .. } if column == "project_id")); assert_eq!(q.select[1].alias, Some("_gkg_d_type".into())); assert!(matches!(&q.select[1].expr, Expr::Literal(v) if v == "Definition")); // Project: default id column + type literal assert_eq!(q.select[2].alias, Some("_gkg_p_id".into())); assert!(matches!(&q.select[2].expr, Expr::Column { column, .. } if column == "id")); assert_eq!(q.select[3].alias, Some("_gkg_p_type".into())); assert!(matches!(&q.select[3].expr, Expr::Literal(v) if v == "Project")); assert_eq!(ctx.len(), 2); assert_eq!(ctx.get("d").unwrap().entity_type, "Definition"); assert_eq!(ctx.get("p").unwrap().entity_type, "Project"); } #[test] fn path_finding_uses_gkg_path_column() { use crate::ast::Cte; Loading Loading
crates/gkg-server/src/query_pipeline/stages/extraction.rs +40 −0 Original line number Diff line number Diff line Loading @@ -12,3 +12,43 @@ impl ExtractionStage { } } } #[cfg(test)] mod tests { use super::*; use arrow::array::{Int64Array, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use query_engine::ResultContext; use std::sync::Arc; #[test] fn wires_batches_and_context_into_query_result() { let schema = Arc::new(Schema::new(vec![ Field::new("_gkg_p_id", DataType::Int64, false), Field::new("_gkg_p_type", DataType::Utf8, false), ])); let batch = RecordBatch::try_new( schema, vec![ Arc::new(Int64Array::from(vec![1, 2])), Arc::new(StringArray::from(vec!["Project", "Project"])), ], ) .unwrap(); let mut ctx = ResultContext::new(); ctx.add_node("p", "Project"); let output = ExtractionStage::execute( ExecutionOutput { batches: vec![batch], result_context: ctx, }, &PipelineObserver::start(), ); assert_eq!(output.query_result.len(), 2); assert!(output.query_result.ctx().get("p").is_some()); } }
crates/gkg-server/src/query_pipeline/stages/formatting.rs +64 −0 Original line number Diff line number Diff line Loading @@ -39,3 +39,67 @@ impl<F: ResultFormatter> FormattingStage<F> { } } } #[cfg(test)] mod tests { use super::*; use arrow::array::{Int64Array, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use query_engine::{ParameterizedQuery, ResultContext}; use serde_json::{Value, json}; use std::collections::HashMap; use crate::redaction::QueryResult; struct ConstFormatter(Value); impl ResultFormatter for ConstFormatter { fn format(&self, _: &QueryResult, _: &ResultContext, _: &Ontology) -> Value { self.0.clone() } } #[test] fn assembles_output_with_correct_counts() { let schema = Arc::new(Schema::new(vec![ Field::new("_gkg_p_id", DataType::Int64, false), Field::new("_gkg_p_type", DataType::Utf8, false), ])); let batch = RecordBatch::try_new( schema, vec![ Arc::new(Int64Array::from(vec![1, 2, 3])), Arc::new(StringArray::from(vec!["Project", "Project", "Project"])), ], ) .unwrap(); let mut ctx = ResultContext::new(); ctx.add_node("p", "Project"); let mut qr = QueryResult::from_batches(&[batch], &ctx); qr.rows_mut()[0].set_unauthorized(); let input = HydrationOutput { result_context: qr.ctx().clone(), query_result: qr, redacted_count: 1, }; let compiled = CompilationOutput { compiled_query: ParameterizedQuery { sql: "SELECT 1".to_string(), params: HashMap::new(), result_context: ResultContext::new(), }, }; let stage = FormattingStage::new(ConstFormatter(json!(["ok"])), Arc::new(Ontology::new())); let output = stage.execute(input, &compiled, &PipelineObserver::start()); assert_eq!(output.formatted_result, json!(["ok"])); assert_eq!(output.generated_sql.as_deref(), Some("SELECT 1")); assert_eq!(output.row_count, 2); // 3 total - 1 redacted assert_eq!(output.redacted_count, 1); } }
crates/gkg-server/src/query_pipeline/stages/redaction.rs +65 −0 Original line number Diff line number Diff line Loading @@ -15,3 +15,68 @@ impl RedactionStage { } } } #[cfg(test)] mod tests { use super::*; use arrow::array::{Int64Array, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use query_engine::{EntityAuthConfig, ResultContext}; use std::sync::Arc; use crate::redaction::{QueryResult, ResourceAuthorization}; fn make_input(authorizations: Vec<ResourceAuthorization>) -> AuthorizationOutput { let schema = Arc::new(Schema::new(vec![ Field::new("_gkg_p_id", DataType::Int64, false), Field::new("_gkg_p_type", DataType::Utf8, false), ])); let batch = RecordBatch::try_new( schema, vec![ Arc::new(Int64Array::from(vec![10, 20, 30])), Arc::new(StringArray::from(vec!["Project", "Project", "Project"])), ], ) .unwrap(); let mut ctx = ResultContext::new(); ctx.add_node("p", "Project"); ctx.add_entity_auth( "Project", EntityAuthConfig { resource_type: "project".to_string(), ability: "read".to_string(), auth_id_column: "id".to_string(), owner_entity: None, }, ); AuthorizationOutput { query_result: QueryResult::from_batches(&[batch], &ctx), authorizations, } } #[test] fn denied_rows_are_redacted() { let auth = vec![ResourceAuthorization { resource_type: "project".to_string(), authorized: [(10, true), (20, false), (30, true)].into_iter().collect(), }]; let output = RedactionStage::execute(make_input(auth), &PipelineObserver::start()); assert_eq!(output.redacted_count, 1); assert_eq!(output.query_result.authorized_count(), 2); } #[test] fn no_authorizations_redacts_all() { let output = RedactionStage::execute(make_input(vec![]), &PipelineObserver::start()); assert_eq!(output.redacted_count, 3); assert_eq!(output.query_result.authorized_count(), 0); } }
crates/query-engine/src/ast.rs +1 −1 Original line number Diff line number Diff line Loading @@ -331,7 +331,7 @@ impl TableRef { } } pub fn union(queries: Vec<Query>, alias: impl Into<String>) -> Self { pub fn union_all(queries: Vec<Query>, alias: impl Into<String>) -> Self { TableRef::Union { queries, alias: alias.into(), Loading
crates/query-engine/src/enforce.rs +138 −1 Original line number Diff line number Diff line Loading @@ -131,10 +131,18 @@ fn enforce_return_columns( let has_type = q.select.iter().any(|s| s.alias.as_ref() == Some(&type_col)); if !has_id { let id_expr = Expr::col(&node.id, &node.redaction_id_column); q.select.push(SelectExpr { expr: Expr::col(&node.id, &node.redaction_id_column), expr: id_expr.clone(), alias: Some(id_col.clone()), }); // Push down id column to aggregation group by if not already present. if input.query_type == QueryType::Aggregation && !q.group_by.is_empty() && !q.group_by.contains(&id_expr) { q.group_by.push(id_expr); } } if !has_type { Loading Loading @@ -455,12 +463,141 @@ mod tests { .any(|s| s.alias.as_ref() == Some(&"_gkg_n_type".to_string())) ); // Enforced id column should be added to GROUP BY (no duplicate since u.id already present) assert_eq!(q.group_by.len(), 1); assert_eq!(q.group_by[0], Expr::col("u", "id")); // Context should only have the group_by node assert_eq!(ctx.len(), 1); assert!(ctx.get("u").is_some()); assert!(ctx.get("n").is_none()); } #[test] fn aggregation_adds_redaction_id_to_group_by() { use crate::input::{AggFunction, InputAggregation}; let input = Input { query_type: QueryType::Aggregation, nodes: vec![ InputNode { id: "u".to_string(), entity: Some("User".to_string()), table: Some("gl_user".to_string()), ..Default::default() }, InputNode { id: "mr".to_string(), entity: Some("MergeRequest".to_string()), table: Some("gl_merge_request".to_string()), ..Default::default() }, ], relationships: vec![], aggregations: vec![InputAggregation { function: AggFunction::Count, target: Some("mr".to_string()), group_by: Some("u".to_string()), property: None, alias: Some("mr_count".to_string()), }], path: None, neighbors: None, limit: 10, range: None, order_by: None, aggregation_sort: None, entity_auth: Default::default(), }; let query = Query { select: vec![SelectExpr { expr: Expr::col("u", "username"), alias: Some("u_username".into()), }], from: TableRef::scan("gl_user", "u"), group_by: vec![Expr::col("u", "username")], limit: Some(10), ..Default::default() }; let mut node = Node::Query(Box::new(query)); enforce_return(&mut node, &input).unwrap(); let Node::Query(q) = node else { panic!("expected Query") }; assert!( q.group_by.contains(&Expr::col("u", "id")), "redaction id column must be in GROUP BY: {:?}", q.group_by ); assert_eq!(q.group_by.len(), 2); // username + id } #[test] fn uses_correct_redaction_id_column_per_node() { let mut node = Node::Query(Box::new(Query { select: vec![], from: TableRef::scan("gl_definition", "d"), limit: Some(10), ..Default::default() })); let input = Input { query_type: QueryType::Traversal, nodes: vec![ InputNode { id: "d".to_string(), entity: Some("Definition".to_string()), table: Some("gl_definition".to_string()), redaction_id_column: "project_id".to_string(), ..Default::default() }, InputNode { id: "p".to_string(), entity: Some("Project".to_string()), table: Some("gl_project".to_string()), ..Default::default() }, ], relationships: vec![], aggregations: vec![], path: None, neighbors: None, limit: 10, range: None, order_by: None, aggregation_sort: None, entity_auth: Default::default(), }; let ctx = enforce_return(&mut node, &input).unwrap(); let Node::Query(q) = node else { panic!("expected Query") }; assert_eq!(q.select.len(), 4); // Definition: custom redaction column + type literal assert_eq!(q.select[0].alias, Some("_gkg_d_id".into())); assert!(matches!(&q.select[0].expr, Expr::Column { column, .. } if column == "project_id")); assert_eq!(q.select[1].alias, Some("_gkg_d_type".into())); assert!(matches!(&q.select[1].expr, Expr::Literal(v) if v == "Definition")); // Project: default id column + type literal assert_eq!(q.select[2].alias, Some("_gkg_p_id".into())); assert!(matches!(&q.select[2].expr, Expr::Column { column, .. } if column == "id")); assert_eq!(q.select[3].alias, Some("_gkg_p_type".into())); assert!(matches!(&q.select[3].expr, Expr::Literal(v) if v == "Project")); assert_eq!(ctx.len(), 2); assert_eq!(ctx.get("d").unwrap().entity_type, "Definition"); assert_eq!(ctx.get("p").unwrap().entity_type, "Project"); } #[test] fn path_finding_uses_gkg_path_column() { use crate::ast::Cte; Loading