Skip to content
Snippets Groups Projects
Commit 5efc2fbb authored by Britney Tong's avatar Britney Tong
Browse files

chore(q): remove payload transformation logic

parent 3403e445
No related branches found
No related tags found
1 merge request!50Q integration bt
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
class JsonProcessingError(Exception):
"""Custom exception for JSON processing errors"""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
@dataclass
class JsonProcessor:
exclude_fields: List[str] = None
ignore_null: bool = False
def __post_init__(self):
self.exclude_fields = self.exclude_fields or []
def process(self, data: Any) -> Any:
"""
Process JSON data recursively based on configuration
Args:
data: Input data to process (can be dict, list, or primitive types)
Returns:
Processed data with applied configurations
"""
if isinstance(data, dict):
return self._process_dict(data)
if isinstance(data, list):
return self._process_list(data)
return data
def add_exclude_field(self, field: str) -> None:
"""Add a field to exclude list"""
if not isinstance(field, str):
raise TypeError("Field must be a string")
if field not in self.exclude_fields:
self.exclude_fields.append(field)
def remove_exclude_field(self, field: str) -> None:
"""Remove a field from exclude list"""
if not isinstance(field, str):
raise TypeError("Field must be a string")
if field in self.exclude_fields:
self.exclude_fields.remove(field)
def set_ignore_null(self, ignore: bool) -> None:
"""Update ignore_null setting"""
self.ignore_null = ignore
def _process_dict(self, data: Dict) -> Dict:
"""Process dictionary objects"""
result = {}
if not isinstance(data, dict):
raise AttributeError("Input must be a dictionary")
for key, value in data.items():
# Skip if key is in exclude list
if key in self.exclude_fields:
continue
# Skip null values if ignore_null is True
if self.ignore_null and value is None:
continue
# Recursively process nested structures
processed_value = self.process(value)
# Test if the value is JSON serializable
json.dumps(processed_value)
result[key] = processed_value
return result
def _process_list(self, data: List) -> List:
"""Process list objects"""
return [self.process(item) for item in data]
def process_json(
data: Dict, exclude_fields: Optional[List[str]] = None, ignore_null: bool = False
) -> Dict:
"""
Convenience function to process JSON data
Args:
data: Input JSON data as dictionary
exclude_fields: List of field names to exclude
ignore_null: Whether to ignore null values
Returns:
Processed JSON data
"""
processor = JsonProcessor(exclude_fields=exclude_fields, ignore_null=ignore_null)
return processor.process(data)
def safe_process_json(
data: Dict, exclude_fields: Optional[List[str]] = None, ignore_null: bool = False
) -> Dict:
"""
Safely process JSON data with error handling
"""
try:
return process_json(data, exclude_fields, ignore_null)
except (TypeError, ValueError) as e:
raise JsonProcessingError(f"Error processing JSON: {str(e)}")
except Exception as e:
raise JsonProcessingError(f"Unexpected error: {str(e)}")
......@@ -56,13 +56,14 @@ class EventIssuePayload(EventRequestPayload):
class EventHookPayload(BaseModel):
source: Literal["system_hook"]
source: Literal["web_hook"]
data: dict[str, Any]
class EventRequest(BaseModel):
role_arn: Annotated[str, StringConstraints(max_length=2048)]
code: Annotated[str, StringConstraints(max_length=255)]
event_id: Annotated[str, StringConstraints(max_length=255)]
payload: Union[EventMergeRequestPayload, EventIssuePayload, EventHookPayload] = (
Field(discriminator="source")
)
......
import json
import boto3
from botocore.exceptions import ClientError
from fastapi import HTTPException, status
from q_developer_boto3 import boto3 as q_boto3
from ai_gateway.api.auth_utils import StarletteUser
from ai_gateway.api.json_utils import safe_process_json
from ai_gateway.auth.glgo import GlgoAuthority
from ai_gateway.integrations.amazon_q.errors import (
AccessDeniedExceptionReason,
......@@ -162,8 +159,8 @@ class AmazonQClient:
@raise_aws_errors
def send_event(self, event_request):
event_id = self._resolve_event_id(event_request)
payload = self._get_payload(event_request)
event_id = event_request.event_id
payload = event_request.payload.model_dump_json(exclude_none=True)
if not event_id:
raise HTTPException(
......@@ -178,6 +175,7 @@ class AmazonQClient:
)
print("DEBUG [AmazonQClient]: send_event payload", payload)
print("DEBUG [AmazonQClient]: event_id", event_id)
try:
self._send_event(event_id, payload)
except ClientError as ex:
......@@ -217,11 +215,6 @@ class AmazonQClient:
event=payload,
)
def _send_message(self, payload):
return self.client.send_message(
message=payload["message"], conversationId=payload["conversation_id"]
)
def _retry_send_event(self, error, code, payload, event_id):
self._is_retry(error, code)
......@@ -238,37 +231,3 @@ class AmazonQClient:
status_code=status.HTTP_403_FORBIDDEN,
detail=str(error),
)
def _resolve_event_id(self, event_request):
payload = event_request.payload
if payload.__class__.__name__ == "EventHookPayload":
# Use class name comparison to avoid circular import for dependency injection
return SYSTEM_HOOK_EVENT_MAP.get(payload.data.get("object_kind"), None)
elif payload.__class__.__name__ in [
"EventMergeRequestPayload",
"EventIssuePayload",
]:
return QUICK_ACTION_EVENT_ID
request_log.warn("Unknown event payload, ignore")
return None
def _get_payload(self, event_request):
payload = event_request.payload
if payload.__class__.__name__ == "EventHookPayload":
updated_payload = safe_process_json(
payload.model_dump(exclude_none=True),
EXCLUDE_EVENT_ATTRIBUTES,
ignore_null=True,
)
return json.dumps(updated_payload)
elif payload.__class__.__name__ in [
"EventMergeRequestPayload",
"EventIssuePayload",
]:
return payload.model_dump_json(exclude_none=True)
request_log.warn("Unknown event payload, ignore")
return None
import pytest
from ai_gateway.api.json_utils import JsonProcessor
class TestJsonUtils:
@pytest.mark.parametrize(
"input_dict, expected_result",
[
(
{"key1": "value1", "key2": None, "key3": "value3"},
{"key1": "value1", "key3": "value3"},
),
({"a": None, "b": None, "c": None}, {}),
({}, {}),
],
)
def test_process_dict_with_null_values(self, input_dict, expected_result):
"""Test _process_dict method with various null value scenarios when ignore_null is True"""
processor = JsonProcessor(ignore_null=True)
result = processor._process_dict(input_dict)
assert result == expected_result
@pytest.mark.parametrize(
"exclude_fields, input_data, expected_output",
[
(
["exclude_me"],
{
"exclude_me": "should be excluded",
"include_me": "should be included",
"nested": {"key": "value"},
},
{"include_me": "should be included", "nested": {"key": "value"}},
),
(["a", "b", "c"], {"a": 1, "b": 2, "c": 3}, {}),
(
["exclude_me"],
{
"exclude_me": "should not appear",
"null_value": None,
"keep_me": "should appear",
},
{"keep_me": "should appear"},
),
],
)
def test_process_dict_with_exclude_fields(
self, exclude_fields, input_data, expected_output
):
"""Test _process_dict with various exclude field scenarios"""
processor = JsonProcessor(exclude_fields=exclude_fields, ignore_null=True)
result = processor._process_dict(input_data)
assert result == expected_output
@pytest.mark.parametrize(
"invalid_input, expected_exception",
[
("not a dictionary", AttributeError),
(None, AttributeError),
({"key": set()}, TypeError), # sets are not JSON serializable
],
)
def test_process_dict_invalid_inputs(self, invalid_input, expected_exception):
"""Test _process_dict with various invalid inputs"""
processor = JsonProcessor()
with pytest.raises(expected_exception):
processor._process_dict(invalid_input)
@pytest.mark.parametrize(
"field, expected_type_error",
[(123, True), (None, True), ("valid_field", False), ("", False)],
)
def test_add_exclude_field_validation(self, field, expected_type_error):
"""Test add_exclude_field with various inputs"""
processor = JsonProcessor()
if expected_type_error:
with pytest.raises(TypeError):
processor.add_exclude_field(field)
else:
processor.add_exclude_field(field)
assert field in processor.exclude_fields
@pytest.mark.parametrize(
"primitive_input, expected",
[(42, 42), ("hello", "hello"), (True, True), (None, None)],
)
def test_process_primitive_data(self, primitive_input, expected):
"""Test process method with primitive data types"""
processor = JsonProcessor()
assert processor.process(primitive_input) == expected
......@@ -96,6 +96,7 @@ def perform_request(mock_client, payload):
"role_arn": "arn:aws:iam::123456789012:role/q-dev-role",
"code": "code-123",
"payload": payload,
"event_id": "Quick Action",
},
)
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment