Verified Commit e838f4b4 authored by Michael Angelo Rivera's avatar Michael Angelo Rivera Committed by GitLab
Browse files

feat(mcp): add format argument to tool parameter schemas

parent aa0aa595
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -3,4 +3,4 @@ mod schema;
mod service;

pub use registry::{ToolDefinition, ToolRegistry};
pub use service::{ExecutorError, ToolPlan, ToolService};
pub use service::{ExecutorError, OutputFormat, ToolPlan, ToolService};
+100 −97
Original line number Diff line number Diff line
@@ -13,6 +13,33 @@ pub struct ToolDefinition {
    pub parameters: serde_json::Value,
}

mod params {
    use serde_json::{Value, json};

    pub fn format() -> Value {
        json!({
            "type": "string",
            "enum": ["llm", "raw"],
            "description": "Output format. 'llm' (default) returns compact text optimized for AI. 'raw' returns structured JSON."
        })
    }

    pub fn query() -> Value {
        json!({
            "type": "object",
            "description": "Graph query following the DSL schema"
        })
    }

    pub fn expand_nodes() -> Value {
        json!({
            "type": "array",
            "items": { "type": "string" },
            "description": "Node types to expand with properties and relationships."
        })
    }
}

pub struct ToolRegistry;

impl ToolRegistry {
@@ -34,16 +61,14 @@ impl ToolRegistry {
        };

        ToolDefinition {
            name: "query_graph".to_string(),
            name: "query_graph".into(),
            description,
            parameters: json!({
                "type": "object",
                "required": ["query"],
                "properties": {
                    "query": {
                        "type": "object",
                        "description": "Graph query following the DSL schema"
                    }
                    "query": params::query(),
                    "format": params::format()
                },
                "additionalProperties": false
            }),
@@ -52,19 +77,16 @@ impl ToolRegistry {

    fn get_graph_schema() -> ToolDefinition {
        ToolDefinition {
            name: "get_graph_schema".to_string(),
            name: "get_graph_schema".into(),
            description: "List the GitLab Knowledge Graph schema. Returns the available nodes \
                          and edges with their source/target types. Use expand_nodes to get \
                          property details for specific types."
                .to_string(),
                .into(),
            parameters: json!({
                "type": "object",
                "properties": {
                    "expand_nodes": {
                        "type": "array",
                        "items": { "type": "string" },
                        "description": "Node types to expand with properties and relationships."
                    }
                    "expand_nodes": params::expand_nodes(),
                    "format": params::format()
                },
                "additionalProperties": false
            }),
@@ -76,15 +98,22 @@ impl ToolRegistry {
mod tests {
    use super::*;

    fn test_ontology() -> Arc<Ontology> {
        Arc::new(Ontology::load_embedded().expect("Failed to load ontology"))
    fn all_tools() -> Vec<ToolDefinition> {
        let ontology = Arc::new(Ontology::load_embedded().expect("Failed to load ontology"));
        ToolRegistry::get_all_tools(&ontology)
    }

    fn find_tool(name: &str) -> ToolDefinition {
        all_tools()
            .into_iter()
            .find(|t| t.name == name)
            .unwrap_or_else(|| panic!("tool '{name}' not found"))
    }

    #[test]
    fn test_all_tools_have_valid_schemas() {
        let ontology = test_ontology();
        let tools = ToolRegistry::get_all_tools(&ontology);
        assert_eq!(tools.len(), 2, "Should have exactly 2 tools");
    fn all_tools_have_valid_schemas() {
        let tools = all_tools();
        assert_eq!(tools.len(), 2);

        for tool in &tools {
            assert!(!tool.name.is_empty());
@@ -94,105 +123,79 @@ mod tests {
    }

    #[test]
    fn test_tool_names_are_unique() {
        let ontology = test_ontology();
        let tools = ToolRegistry::get_all_tools(&ontology);
    fn tool_names_are_unique() {
        let tools = all_tools();
        let mut names = std::collections::HashSet::new();

        for tool in &tools {
            assert!(
                names.insert(&tool.name),
                "Duplicate tool name found: {}",
                tool.name
            );
            assert!(names.insert(&tool.name), "Duplicate tool: {}", tool.name);
        }
    }

    #[test]
    fn test_tool_names() {
        let ontology = test_ontology();
        let tools = ToolRegistry::get_all_tools(&ontology);
        let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();

        assert!(names.contains(&"query_graph"));
        assert!(names.contains(&"get_graph_schema"));
    fn expected_tools_are_registered() {
        let names: Vec<String> = all_tools().into_iter().map(|t| t.name).collect();
        assert!(names.contains(&"query_graph".into()));
        assert!(names.contains(&"get_graph_schema".into()));
    }

    #[test]
    fn test_get_graph_schema_has_expand_nodes_param() {
        let ontology = test_ontology();
        let tools = ToolRegistry::get_all_tools(&ontology);
        let get_schema = tools
    fn all_tools_have_format_parameter() {
        for tool in &all_tools() {
            let format = &tool.parameters["properties"]["format"];
            assert!(format.is_object(), "{} missing format parameter", tool.name);
            assert_eq!(format["type"], "string");

            let values: Vec<&str> = format["enum"]
                .as_array()
                .expect("format should have enum")
                .iter()
            .find(|t| t.name == "get_graph_schema")
            .expect("get_graph_schema tool should exist");

        let params = &get_schema.parameters;
        assert_eq!(params["type"], "object");
        assert!(params["properties"]["expand_nodes"].is_object());
                .map(|v| v.as_str().unwrap())
                .collect();
            assert_eq!(values, vec!["llm", "raw"]);
        }
    }

    #[test]
    fn test_query_graph_has_query_parameter() {
        let ontology = test_ontology();
        let tools = ToolRegistry::get_all_tools(&ontology);
        let query_graph = tools
            .iter()
            .find(|t| t.name == "query_graph")
            .expect("query_graph tool should exist");

        let params = &query_graph.parameters;
        assert_eq!(params["type"], "object");
        assert!(params["properties"]["query"].is_object());

        let required = params["required"].as_array().expect("Should have required");
    fn format_is_never_required() {
        for tool in &all_tools() {
            if let Some(required) = tool.parameters.get("required").and_then(|r| r.as_array()) {
                assert!(
            required.iter().any(|v| v == "query"),
            "query should be required"
                    !required.iter().any(|v| v == "format"),
                    "{} should not require format",
                    tool.name
                );
            }
        }
    }

    #[test]
    fn test_query_graph_description_contains_schema() {
        let ontology = test_ontology();
        let tools = ToolRegistry::get_all_tools(&ontology);
        let query_graph = tools
            .iter()
            .find(|t| t.name == "query_graph")
            .expect("query_graph tool should exist");
    fn query_graph_requires_query_parameter() {
        let tool = find_tool("query_graph");
        let params = &tool.parameters;

        let desc = &query_graph.description;
        assert!(
            desc.contains("query_type"),
            "Description should contain query_type"
        );
        assert!(
            desc.contains("traversal"),
            "Description should contain traversal"
        );
        assert!(
            desc.contains("get_graph_schema"),
            "Description should reference get_graph_schema for entity discovery"
        );
        assert!(params["properties"]["query"].is_object());
        let required = params["required"].as_array().expect("should have required");
        assert!(required.iter().any(|v| v == "query"));
    }

    #[test]
    fn test_query_graph_excludes_ontology_data() {
        let ontology = test_ontology();
        let tools = ToolRegistry::get_all_tools(&ontology);
        let query_graph = tools
            .iter()
            .find(|t| t.name == "query_graph")
            .expect("query_graph tool should exist");
    fn query_graph_description_contains_dsl_schema() {
        let tool = find_tool("query_graph");
        assert!(tool.description.contains("query_type"));
        assert!(tool.description.contains("traversal"));
        assert!(tool.description.contains("get_graph_schema"));
    }

        let desc = &query_graph.description;
        assert!(
            !desc.contains("username"),
            "Description should not contain entity-specific fields"
        );
        assert!(
            !desc.contains("AUTHORED"),
            "Description should not contain relationship types (use get_graph_schema)"
        );
    #[test]
    fn query_graph_excludes_ontology_data() {
        let tool = find_tool("query_graph");
        assert!(!tool.description.contains("username"));
        assert!(!tool.description.contains("AUTHORED"));
    }

    #[test]
    fn get_graph_schema_has_expand_nodes_param() {
        let tool = find_tool("get_graph_schema");
        assert!(tool.parameters["properties"]["expand_nodes"].is_object());
    }
}
+115 −5
Original line number Diff line number Diff line
@@ -25,10 +25,31 @@ impl ExecutorError {
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum OutputFormat {
    #[default]
    Llm,
    Raw,
}

impl OutputFormat {
    pub fn from_str_lossy(s: &str) -> Self {
        match s {
            "raw" => Self::Raw,
            _ => Self::Llm,
        }
    }
}

#[derive(Debug)]
pub enum ToolPlan {
    RunGraphQuery { query_json: String },
    Immediate { result: Value },
    RunGraphQuery {
        query_json: String,
        format: OutputFormat,
    },
    Immediate {
        result: Value,
    },
}

#[derive(Debug, Clone)]
@@ -71,16 +92,24 @@ impl ToolService {
        let query_json = serde_json::to_string(query)
            .map_err(|e| ExecutorError::InvalidArguments(e.to_string()))?;

        Ok(ToolPlan::RunGraphQuery { query_json })
        let format = parse_format(arguments);

        Ok(ToolPlan::RunGraphQuery { query_json, format })
    }

    fn execute_get_graph_schema(&self, arguments: &Value) -> Result<ToolPlan, ExecutorError> {
        let args: GetGraphSchemaArgs = serde_json::from_value(arguments.clone())
            .map_err(|e| ExecutorError::InvalidArguments(e.to_string()))?;

        let format = parse_format(arguments);
        let expand_nodes = args.expand_nodes.as_deref().unwrap_or(&[]);
        let response = self.build_graph_schema_response(expand_nodes)?;
        let result = self.format_as_toon(&response)?;

        let result = match format {
            OutputFormat::Llm => self.format_as_toon(&response)?,
            OutputFormat::Raw => serde_json::to_value(&response)
                .map_err(|e| ExecutorError::InvalidArguments(e.to_string()))?,
        };

        Ok(ToolPlan::Immediate { result })
    }
@@ -211,6 +240,14 @@ impl ToolService {
    }
}

fn parse_format(arguments: &Value) -> OutputFormat {
    arguments
        .get("format")
        .and_then(|v| v.as_str())
        .map(OutputFormat::from_str_lossy)
        .unwrap_or_default()
}

#[derive(Debug, Deserialize)]
struct GetGraphSchemaArgs {
    #[serde(default)]
@@ -383,8 +420,9 @@ mod tests {
            .expect("Should resolve");

        match plan {
            ToolPlan::RunGraphQuery { query_json } => {
            ToolPlan::RunGraphQuery { query_json, format } => {
                assert!(query_json.contains("match"));
                assert_eq!(format, OutputFormat::Llm);
            }
            _ => panic!("Expected RunGraphQuery plan"),
        }
@@ -478,6 +516,78 @@ mod tests {
        assert!(toon.contains("domains"), "Should still return valid schema");
    }

    #[test]
    fn get_graph_schema_raw_format_returns_json() {
        let ontology = Arc::new(Ontology::load_embedded().expect("Failed to load ontology"));
        let service = ToolService::new(ontology);

        let plan = service
            .resolve("get_graph_schema", r#"{"format": "raw"}"#)
            .expect("Should resolve");

        match plan {
            ToolPlan::Immediate { result } => {
                assert!(result.is_object(), "Raw format should return a JSON object");
                assert!(result.get("domains").is_some(), "Should have domains key");
                assert!(result.get("edges").is_some(), "Should have edges key");
            }
            _ => panic!("Expected Immediate plan"),
        }
    }

    #[test]
    fn get_graph_schema_llm_format_returns_toon() {
        let ontology = Arc::new(Ontology::load_embedded().expect("Failed to load ontology"));
        let service = ToolService::new(ontology);

        let plan = service
            .resolve("get_graph_schema", r#"{"format": "llm"}"#)
            .expect("Should resolve");

        match plan {
            ToolPlan::Immediate { result } => {
                assert!(result.is_string(), "LLM format should return a TOON string");
                let text = result.as_str().unwrap();
                assert!(text.contains("domains"));
            }
            _ => panic!("Expected Immediate plan"),
        }
    }

    #[test]
    fn get_graph_schema_default_format_is_llm() {
        let ontology = Arc::new(Ontology::load_embedded().expect("Failed to load ontology"));
        let service = ToolService::new(ontology);

        let plan = service
            .resolve("get_graph_schema", r#"{}"#)
            .expect("Should resolve");

        match plan {
            ToolPlan::Immediate { result } => {
                assert!(result.is_string(), "Default format should be TOON string");
            }
            _ => panic!("Expected Immediate plan"),
        }
    }

    #[test]
    fn query_graph_raw_format_is_carried_in_plan() {
        let ontology = Arc::new(Ontology::load_embedded().expect("Failed to load ontology"));
        let service = ToolService::new(ontology);

        let plan = service
            .resolve("query_graph", r#"{"query":{"match":{}}, "format": "raw"}"#)
            .expect("Should resolve");

        match plan {
            ToolPlan::RunGraphQuery { format, .. } => {
                assert_eq!(format, OutputFormat::Raw);
            }
            _ => panic!("Expected RunGraphQuery plan"),
        }
    }

    #[test]
    fn test_expand_all_wildcard() {
        let output = get_toon_output(r#"{"expand_nodes": ["*"]}"#);