Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • ggillies/ai-assist
  • gitlab-org/modelops/applied-ml/code-suggestions/ai-assist
  • golmohamadirasuol139401/ai-assist
  • alejandro/ai-assist
  • JackSkylark/ai-assist
  • knottos/ai-assist
  • nanmu42/model-gateway
  • shinya.maeda/ai-assist
  • abitrolly/ai-assist
  • gitlab-community/modelops/applied-ml/code-suggestions/ai-assist
  • tanducque/ai-assist
  • gitlab-renovate-forks/ai-assist
  • armbiant/gitlab-ai-gateway
  • mike22664/ai-assist
14 results
Show changes
Showing
with 213 additions and 111 deletions
--- ---
name: Claude 3 Write Tests agent name: Claude 3 Write Tests agent
model: model:
name: claude-3-5-sonnet-20241022 config_file: conversation_performant
params: params:
model_class_provider: litellm model_class_provider: litellm
temperature: 0.1 temperature: 0.1
max_tokens: 2_048 max_tokens: 2_048
max_retries: 1
unit_primitives: unit_primitives:
- write_tests - write_tests
prompt_template: prompt_template:
......
{% include "chat/_library_section.jinja" %}
You are a software developer. You are a software developer.
You can write new tests. You can write new tests.
{{language_info}} {{language_info}}
{% include "chat/_library_section.jinja" %} \ No newline at end of file
\ No newline at end of file
--- ---
name: Claude 3 Generate Commit Message name: Claude 3 Generate Commit Message
model: model:
name: claude-3-5-sonnet-20241022 config_file: conversation_performant
params:
model_class_provider: anthropic
temperature: 0.0
max_tokens: 4096
max_retries: 1
unit_primitives: unit_primitives:
- generate_commit_message - generate_commit_message
prompt_template: prompt_template:
......
--- ---
name: Claude 3 Generate Description name: Claude 3 Generate Description
model: model:
name: claude-3-5-sonnet-20241022 config_file: conversation_performant
params:
model_class_provider: anthropic
temperature: 0.0
max_tokens: 4096
max_retries: 1
unit_primitives: unit_primitives:
- generate_issue_description - generate_issue_description
prompt_template: prompt_template:
......
--- ---
name: Claude 3 Generate Description name: Claude 3 Generate Description
model: model:
name: claude-3-5-sonnet-20241022 config_file: conversation_performant
params: params:
model_class_provider: litellm model_class_provider: litellm
temperature: 0.0
max_tokens: 4096
max_retries: 1
unit_primitives: unit_primitives:
- generate_issue_description - generate_issue_description
prompt_template: prompt_template:
......
name: claude-3-5-sonnet-20241022
params:
temperature: 0.0
max_tokens: 4096
max_retries: 1
model_class_provider: anthropic
...@@ -7,7 +7,7 @@ from poetry.core.constraints.version import Version, parse_constraint ...@@ -7,7 +7,7 @@ from poetry.core.constraints.version import Version, parse_constraint
from ai_gateway.internal_events.client import InternalEventsClient from ai_gateway.internal_events.client import InternalEventsClient
from ai_gateway.prompts.base import BasePromptRegistry, Prompt from ai_gateway.prompts.base import BasePromptRegistry, Prompt
from ai_gateway.prompts.config import ModelClassProvider, PromptConfig from ai_gateway.prompts.config import BaseModelConfig, ModelClassProvider, PromptConfig
from ai_gateway.prompts.typing import ModelMetadata, TypeModelFactory from ai_gateway.prompts.typing import ModelMetadata, TypeModelFactory
__all__ = ["LocalPromptRegistry", "PromptRegistered"] __all__ = ["LocalPromptRegistry", "PromptRegistered"]
...@@ -127,17 +127,26 @@ class LocalPromptRegistry(BasePromptRegistry): ...@@ -127,17 +127,26 @@ class LocalPromptRegistry(BasePromptRegistry):
used if no matching override is provided in `class_overrides`. used if no matching override is provided in `class_overrides`.
""" """
prompts_definitions_dir = Path(__file__).parent / "definitions" base_path = Path(__file__).parent
prompts_definitions_dir = base_path / "definitions"
model_configs_dir = (
base_path / "model_configs"
) # New directory for model configs
prompts_registered = {} prompts_registered = {}
# Parse model config YAML files
model_configs = {
file.stem: cls._parse_base_model(file)
for file in model_configs_dir.glob("*.yml")
}
# Iterate over each folder # Iterate over each folder
for path in prompts_definitions_dir.glob("**"): for path in prompts_definitions_dir.glob("**"):
versions = {}
# Iterate over each version file # Iterate over each version file
for version in path.glob("*.yml"): versions = {
with open(version, "r") as fp: version.stem: cls._process_version_file(version, model_configs)
versions[version.stem] = PromptConfig(**yaml.safe_load(fp)) for version in path.glob("*.yml")
}
# If there were no yml files in this folder, skip it # If there were no yml files in this folder, skip it
if not versions: if not versions:
...@@ -165,3 +174,66 @@ class LocalPromptRegistry(BasePromptRegistry): ...@@ -165,3 +174,66 @@ class LocalPromptRegistry(BasePromptRegistry):
custom_models_enabled, custom_models_enabled,
disable_streaming, disable_streaming,
) )
@classmethod
def _parse_base_model(cls, file_name: Path) -> BaseModelConfig:
"""Parses a YAML file and converts its content to a BaseModelConfig object.
This method reads the specified YAML file, extracts the configuration
parameters, and constructs a BaseModelConfig object. It handles the
conversion of YAML data types to appropriate Python types.
Args:
file (Path): A Path object pointing to the YAML file to be parsed.
Returns:
BaseModelConfig: An instance of BaseModelConfig containing the
parsed configuration data.
"""
with open(file_name, "r") as fp:
return BaseModelConfig(**yaml.safe_load(fp))
@classmethod
def _process_version_file(
cls, version_file: Path, model_configs: dict[str, BaseModelConfig]
) -> PromptConfig:
"""Processes a single version YAML file and returns a PromptConfig.
Args:
version_file: Path to the version YAML file
model_configs: Dictionary of model configurations
Returns:
PromptConfig: Processed prompt configuration
"""
with open(version_file, "r") as fp:
prompt_config_params = yaml.safe_load(fp)
if "config_file" in prompt_config_params["model"]:
model_config = prompt_config_params["model"]["config_file"]
config_for_general_model = model_configs.get(model_config)
if config_for_general_model:
prompt_config_params = cls._patch_model_configuration(
config_for_general_model, prompt_config_params
)
return PromptConfig(**prompt_config_params)
@classmethod
def _patch_model_configuration(
cls, config_for_general_model: BaseModelConfig, prompt_config_params: dict
) -> dict:
params = {
**config_for_general_model.params.model_dump(),
**prompt_config_params["model"].get("params", {}),
}
return {
**prompt_config_params,
"model": {
"name": config_for_general_model.name,
"params": params,
},
}
...@@ -11,5 +11,5 @@ class ContainerXRay(containers.DeclarativeContainer): ...@@ -11,5 +11,5 @@ class ContainerXRay(containers.DeclarativeContainer):
models = providers.DependenciesContainer() models = providers.DependenciesContainer()
anthropic_claude = providers.Factory( anthropic_claude = providers.Factory(
models.anthropic_claude, name=KindAnthropicModel.CLAUDE_2_0 models.anthropic_claude, name=KindAnthropicModel.CLAUDE_2_1
) )
...@@ -375,11 +375,9 @@ def mock_litellm_acompletion(): ...@@ -375,11 +375,9 @@ def mock_litellm_acompletion():
AsyncMock( AsyncMock(
message=AsyncMock(content="Test response"), message=AsyncMock(content="Test response"),
text="Test text completion response", text="Test text completion response",
logprobs= AsyncMock(token_logprobs=[999]),
), ),
], ],
_hidden_params={
"original_response": {"choices": [{"logprobs": AsyncMock(token_logprobs=[999])}]}
},
usage=AsyncMock(completion_tokens=999), usage=AsyncMock(completion_tokens=999),
) )
......
...@@ -21,10 +21,19 @@ The AI Gateway Prompts serve as structured templates that allow AI feature to se ...@@ -21,10 +21,19 @@ The AI Gateway Prompts serve as structured templates that allow AI feature to se
AI Gateway Prompts are defined as `.yml` files located in `prompts/definitions/` and specify parameters for different LLMs. Each configuration file includes: AI Gateway Prompts are defined as `.yml` files located in `prompts/definitions/` and specify parameters for different LLMs. Each configuration file includes:
- **Model Parameters:** Specify LLM provider and configurations, including model name and parameters such as `temperature`, `top_p`, `top_k`, `max_tokens`, and `stop` - **Model Parameters:** Specify LLM provider and configurations, including model name (or config_file name) and parameters such as `temperature`, `top_p`, `top_k`, `max_tokens`, and `stop`
- **Prompt Templates:** Define prompt templates that support [Jinja expression](https://jinja.palletsprojects.com/en/stable/) for multiple roles such as `user` and `system` - **Prompt Templates:** Define prompt templates that support [Jinja expression](https://jinja.palletsprojects.com/en/stable/) for multiple roles such as `user` and `system`
- **Control Parameters:** Parameters such as `max_retries` and `timeout` to manage retries and session handling - **Control Parameters:** Parameters such as `max_retries` and `timeout` to manage retries and session handling
#### AI Gateway Model Configuration
AI Gateway allows for storing model name and parameters together and retrieving them without re-defining them in every definition file.
AI Gateway Model Configs are defined as `.yml` files located in `prompts/model_configs/` and specify parameters for different models. Parameters from definition files take precedence over parameters from the model config file.
Each configuration file includes:
- **Model Parameters:** Specify LLM provider and configurations, including model name and parameters such as `temperature`, `top_p`, `top_k`, `max_tokens`, and `stop`
### Integration with LangChain and LiteLLM ### Integration with LangChain and LiteLLM
LangChain and LiteLLM enable structured prompt construction and manage LLM interactions within the AI Gateway Prompt system. LangChain and LiteLLM enable structured prompt construction and manage LLM interactions within the AI Gateway Prompt system.
...@@ -91,7 +100,8 @@ Each prompt configuration file in `prompts/definitions/` requires the following ...@@ -91,7 +100,8 @@ Each prompt configuration file in `prompts/definitions/` requires the following
```yaml ```yaml
name: <string> # Required. Unique identifier for the prompt name: <string> # Required. Unique identifier for the prompt
model: model:
name: <string> # Required. Model identifier (e.g. "claude-3-sonnet-20240229") name: <string> # Optional. Model identifier (e.g. "claude-3-sonnet-20240229"). Either config_file or model identifier needs to be present
config_file: <string> # Optional. Config identifier. Either config_file or model identifier needs to be present
params: params:
model_class_provider: litellm # Required. Provider interface model_class_provider: litellm # Required. Provider interface
temperature: <float> # Optional. 0.0-1.0. Controls randomness (default: 0.7) temperature: <float> # Optional. 0.0-1.0. Controls randomness (default: 0.7)
......
...@@ -48,7 +48,7 @@ AIGW_INSTRUMENTATOR__THREAD_MONITORING_INTERVAL=60 ...@@ -48,7 +48,7 @@ AIGW_INSTRUMENTATOR__THREAD_MONITORING_INTERVAL=60
# Feature flags # Feature flags
AIGW_FEATURE_FLAGS__DISALLOWED_FLAGS='{}' AIGW_FEATURE_FLAGS__DISALLOWED_FLAGS='{}'
AIGW_FEATURE_FLAGS__EXCL_POST_PROC='[]' AIGW_FEATURE_FLAGS__EXCL_POST_PROCESS='[]'
AIGW_FEATURE_FLAGS__FIREWORKS_QWEN_SCORE_THRESHOLD=-9999.0 AIGW_FEATURE_FLAGS__FIREWORKS_QWEN_SCORE_THRESHOLD=-9999.0
......
...@@ -268,18 +268,18 @@ uvloop = ["uvloop (>=0.15.2)"] ...@@ -268,18 +268,18 @@ uvloop = ["uvloop (>=0.15.2)"]
[[package]] [[package]]
name = "boto3" name = "boto3"
version = "1.36.16" version = "1.36.21"
description = "The AWS SDK for Python" description = "The AWS SDK for Python"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
groups = ["main"] groups = ["main"]
files = [ files = [
{file = "boto3-1.36.16-py3-none-any.whl", hash = "sha256:b10583bf8bd35be1b4027ee7e26b7cdf2078c79eab18357fd602cecb6d39400b"}, {file = "boto3-1.36.21-py3-none-any.whl", hash = "sha256:f94faa7cf932d781f474d87f8b4c14a033af95ac1460136b40d75e7a30086ef0"},
{file = "boto3-1.36.16.tar.gz", hash = "sha256:0cf92ca0538ab115447e1c58050d43e1273e88c58ddfea2b6f133fdc508b400a"}, {file = "boto3-1.36.21.tar.gz", hash = "sha256:41eb2b73eb612d300e629e3328b83f1ffea0fc6633e75c241a72a76746c1db26"},
] ]
[package.dependencies] [package.dependencies]
botocore = ">=1.36.16,<1.37.0" botocore = ">=1.36.21,<1.37.0"
jmespath = ">=0.7.1,<2.0.0" jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.11.0,<0.12.0" s3transfer = ">=0.11.0,<0.12.0"
...@@ -288,14 +288,14 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] ...@@ -288,14 +288,14 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]] [[package]]
name = "botocore" name = "botocore"
version = "1.36.16" version = "1.36.21"
description = "Low-level, data-driven core of boto 3." description = "Low-level, data-driven core of boto 3."
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
groups = ["main"] groups = ["main"]
files = [ files = [
{file = "botocore-1.36.16-py3-none-any.whl", hash = "sha256:aca0348ccd730332082489b6817fdf89e1526049adcf6e9c8c11c96dd9f42c03"}, {file = "botocore-1.36.21-py3-none-any.whl", hash = "sha256:24a7052e792639dc2726001bd474cd0aaa959c1e18ddd92c17f3adc6efa1b132"},
{file = "botocore-1.36.16.tar.gz", hash = "sha256:10c6aa386ba1a9a0faef6bb5dbfc58fc2563a3c6b95352e86a583cd5f14b11f3"}, {file = "botocore-1.36.21.tar.gz", hash = "sha256:da746240e2ad64fd4997f7f3664a0a8e303d18075fc1d473727cb6375080ea16"},
] ]
[package.dependencies] [package.dependencies]
...@@ -655,18 +655,6 @@ files = [ ...@@ -655,18 +655,6 @@ files = [
marshmallow = ">=3.18.0,<4.0.0" marshmallow = ">=3.18.0,<4.0.0"
typing-inspect = ">=0.4.0,<1" typing-inspect = ">=0.4.0,<1"
[[package]]
name = "defusedxml"
version = "0.7.1"
description = "XML bomb protection for Python stdlib modules"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
groups = ["main"]
files = [
{file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"},
{file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"},
]
[[package]] [[package]]
name = "dependency-injector" name = "dependency-injector"
version = "4.45.0" version = "4.45.0"
...@@ -1024,14 +1012,14 @@ tqdm = ["tqdm"] ...@@ -1024,14 +1012,14 @@ tqdm = ["tqdm"]
[[package]] [[package]]
name = "gitlab-cloud-connector" name = "gitlab-cloud-connector"
version = "2.0.2" version = "2.1.0"
description = "Cloud Connector library for Python backends" description = "Cloud Connector library for Python backends"
optional = false optional = false
python-versions = ">=3.11,<4.0" python-versions = ">=3.11,<4.0"
groups = ["main"] groups = ["main"]
files = [ files = [
{file = "gitlab_cloud_connector-2.0.2-py3-none-any.whl", hash = "sha256:3a20dbe9ca841571c09bc0f4aa946df03c94ef6e83ab29a6fc2f5bc90b206ddf"}, {file = "gitlab_cloud_connector-2.1.0-py3-none-any.whl", hash = "sha256:2548d5b5715f0cdc44213e654e7ffbbe0fc7b8f8f2123c4b88d0196b45f09381"},
{file = "gitlab_cloud_connector-2.0.2.tar.gz", hash = "sha256:2ea36c28d648b5b2763d41a5c78ebdb46ee72fafc04419f121cb9b7c46b29568"}, {file = "gitlab_cloud_connector-2.1.0.tar.gz", hash = "sha256:0a4ea0079a6bbd77e1e455846985a704593dc455d42c41193559c45f59ddde3a"},
] ]
[package.dependencies] [package.dependencies]
...@@ -1231,14 +1219,14 @@ grpc = ["grpcio (>=1.38.0,<2.0dev)", "grpcio-status (>=1.38.0,<2.0.dev0)"] ...@@ -1231,14 +1219,14 @@ grpc = ["grpcio (>=1.38.0,<2.0dev)", "grpcio-status (>=1.38.0,<2.0.dev0)"]
[[package]] [[package]]
name = "google-cloud-discoveryengine" name = "google-cloud-discoveryengine"
version = "0.13.5" version = "0.13.6"
description = "Google Cloud Discoveryengine API client library" description = "Google Cloud Discoveryengine API client library"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
groups = ["main"] groups = ["main"]
files = [ files = [
{file = "google_cloud_discoveryengine-0.13.5-py3-none-any.whl", hash = "sha256:93f466edc1842abcf66339a40cbdfc7f3928c57c403ce0a38e20af01059948c0"}, {file = "google_cloud_discoveryengine-0.13.6-py3-none-any.whl", hash = "sha256:5c67ca0d5d14ee509756e97ab59ba4b59003a5173d38b427e4019763ec521503"},
{file = "google_cloud_discoveryengine-0.13.5.tar.gz", hash = "sha256:4a3bc9cc625ddd015922cbb5b5e1f2eaa79a82b737a9783c4e927852ab5c1e07"}, {file = "google_cloud_discoveryengine-0.13.6.tar.gz", hash = "sha256:f1c942d51c87d585947f56647c8768faa0ec07bb57af06224c23eb4475f7d095"},
] ]
[package.dependencies] [package.dependencies]
...@@ -2036,20 +2024,19 @@ together = ["langchain-together"] ...@@ -2036,20 +2024,19 @@ together = ["langchain-together"]
[[package]] [[package]]
name = "langchain-anthropic" name = "langchain-anthropic"
version = "0.3.5" version = "0.3.7"
description = "An integration package connecting AnthropicMessages and LangChain" description = "An integration package connecting AnthropicMessages and LangChain"
optional = false optional = false
python-versions = "<4.0,>=3.9" python-versions = "<4.0,>=3.9"
groups = ["main"] groups = ["main"]
files = [ files = [
{file = "langchain_anthropic-0.3.5-py3-none-any.whl", hash = "sha256:bad34b02d7b4bdca9a9471bc391b01269fd8dc4600b83ca2a3e76925b7c27fe6"}, {file = "langchain_anthropic-0.3.7-py3-none-any.whl", hash = "sha256:adec0a1daabd3c25249753c6cd625654917fb9e3feee68e72c7dc3f4449c0f3c"},
{file = "langchain_anthropic-0.3.5.tar.gz", hash = "sha256:2aa1673511056061680492871f386d68a8b62947e0eb1f15303ef10db16c8357"}, {file = "langchain_anthropic-0.3.7.tar.gz", hash = "sha256:534cd1867bc41711cd8c3d0a0bc055e6c5a4215953c87260209a90dc5816f30d"},
] ]
[package.dependencies] [package.dependencies]
anthropic = ">=0.41.0,<1" anthropic = ">=0.45.0,<1"
defusedxml = ">=0.7.1,<0.8.0" langchain-core = ">=0.3.34,<1.0.0"
langchain-core = ">=0.3.33,<0.4.0"
pydantic = ">=2.7.4,<3.0.0" pydantic = ">=2.7.4,<3.0.0"
[[package]] [[package]]
...@@ -4050,14 +4037,14 @@ telegram = ["requests"] ...@@ -4050,14 +4037,14 @@ telegram = ["requests"]
[[package]] [[package]]
name = "transformers" name = "transformers"
version = "4.48.2" version = "4.48.3"
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
optional = false optional = false
python-versions = ">=3.9.0" python-versions = ">=3.9.0"
groups = ["main"] groups = ["main"]
files = [ files = [
{file = "transformers-4.48.2-py3-none-any.whl", hash = "sha256:493bc5b0268b116eff305edf6656367fc89cf570e7a9d5891369e04751db698a"}, {file = "transformers-4.48.3-py3-none-any.whl", hash = "sha256:78697f990f5ef350c23b46bf86d5081ce96b49479ab180b2de7687267de8fd36"},
{file = "transformers-4.48.2.tar.gz", hash = "sha256:dcfb73473e61f22fb3366fe2471ed2e42779ecdd49527a1bdf1937574855d516"}, {file = "transformers-4.48.3.tar.gz", hash = "sha256:a5e8f1e9a6430aa78215836be70cecd3f872d99eeda300f41ad6cc841724afdb"},
] ]
[package.dependencies] [package.dependencies]
...@@ -4743,4 +4730,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", ...@@ -4743,4 +4730,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
[metadata] [metadata]
lock-version = "2.1" lock-version = "2.1"
python-versions = "~3.11.0" python-versions = "~3.11.0"
content-hash = "e2e30a6cc057f5c256b5abaf8708f42b798819c98bbdd4c24dd8cf83a0f4fb42" content-hash = "847ec2d01317a9fdb8352a681f3f417dd6d23352fc026fcc6e91218e8676d06d"
...@@ -75,7 +75,7 @@ optional = true ...@@ -75,7 +75,7 @@ optional = true
[tool.poetry.group.lint.dependencies] [tool.poetry.group.lint.dependencies]
flake8 = "^7.0.0" flake8 = "^7.0.0"
isort = "^5.12.0" isort = ">=5.12.0, <6.0.0"
black = "^25.0.0" black = "^25.0.0"
pylint = "^3.0.3" pylint = "^3.0.3"
astroid = "^3.0.2" astroid = "^3.0.2"
......
...@@ -17,5 +17,12 @@ ...@@ -17,5 +17,12 @@
"stanhu", "stanhu",
"tle_gitlab" "tle_gitlab"
], ],
"ignoreDeps": ["errata-ai/vale", "DavidAnson/markdownlint-cli2", "node", "ruby"] "ignoreDeps": ["errata-ai/vale", "DavidAnson/markdownlint-cli2", "node", "ruby"],
"packageRules": [
{
"matchPackageNames": ["isort"],
"matchUpdateTypes": ["major"],
"enabled": false
}
]
} }
...@@ -6,11 +6,15 @@ from starlette.requests import Request ...@@ -6,11 +6,15 @@ from starlette.requests import Request
from starlette_context import context, request_cycle_context from starlette_context import context, request_cycle_context
from ai_gateway.api.middleware import ( from ai_gateway.api.middleware import (
X_GITLAB_CLIENT_NAME,
X_GITLAB_CLIENT_TYPE,
X_GITLAB_CLIENT_VERSION,
X_GITLAB_FEATURE_ENABLED_BY_NAMESPACE_IDS_HEADER, X_GITLAB_FEATURE_ENABLED_BY_NAMESPACE_IDS_HEADER,
X_GITLAB_FEATURE_ENABLEMENT_TYPE_HEADER, X_GITLAB_FEATURE_ENABLEMENT_TYPE_HEADER,
X_GITLAB_GLOBAL_USER_ID_HEADER, X_GITLAB_GLOBAL_USER_ID_HEADER,
X_GITLAB_HOST_NAME_HEADER, X_GITLAB_HOST_NAME_HEADER,
X_GITLAB_INSTANCE_ID_HEADER, X_GITLAB_INSTANCE_ID_HEADER,
X_GITLAB_INTERFACE,
X_GITLAB_REALM_HEADER, X_GITLAB_REALM_HEADER,
X_GITLAB_SAAS_DUO_PRO_NAMESPACE_IDS_HEADER, X_GITLAB_SAAS_DUO_PRO_NAMESPACE_IDS_HEADER,
X_GITLAB_TEAM_MEMBER_HEADER, X_GITLAB_TEAM_MEMBER_HEADER,
...@@ -123,6 +127,10 @@ async def test_middleware_set_context(internal_event_middleware): ...@@ -123,6 +127,10 @@ async def test_middleware_set_context(internal_event_middleware):
(X_GITLAB_DUO_SEAT_COUNT_HEADER.lower().encode(), b"100"), (X_GITLAB_DUO_SEAT_COUNT_HEADER.lower().encode(), b"100"),
(X_GITLAB_TEAM_MEMBER_HEADER.lower().encode(), b"true"), (X_GITLAB_TEAM_MEMBER_HEADER.lower().encode(), b"true"),
(X_GITLAB_FEATURE_ENABLEMENT_TYPE_HEADER.lower().encode(), b"add_on"), (X_GITLAB_FEATURE_ENABLEMENT_TYPE_HEADER.lower().encode(), b"add_on"),
(X_GITLAB_CLIENT_NAME.lower().encode(), b"vscode"),
(X_GITLAB_CLIENT_TYPE.lower().encode(), b"ide"),
(X_GITLAB_CLIENT_VERSION.lower().encode(), b"1.97.0"),
(X_GITLAB_INTERFACE.lower().encode(), b"duo_chat"),
], ],
} }
) )
...@@ -148,6 +156,10 @@ async def test_middleware_set_context(internal_event_middleware): ...@@ -148,6 +156,10 @@ async def test_middleware_set_context(internal_event_middleware):
duo_seat_count="100", duo_seat_count="100",
is_gitlab_team_member="true", is_gitlab_team_member="true",
feature_enablement_type="add_on", feature_enablement_type="add_on",
client_name="vscode",
client_version="1.97.0",
client_type="ide",
interface="duo_chat",
feature_enabled_by_namespace_ids=[], feature_enabled_by_namespace_ids=[],
context_generated_at=mock_event_context.set.call_args[0][ context_generated_at=mock_event_context.set.call_args[0][
0 0
......
...@@ -60,7 +60,6 @@ class TestAgentSuccessfulRequest: ...@@ -60,7 +60,6 @@ class TestAgentSuccessfulRequest:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("content_fixture", "provider", "model", "params"), ("content_fixture", "provider", "model", "params"),
[ [
("text_content", "anthropic", "claude-2.0", None),
( (
"text_content", "text_content",
"anthropic", "anthropic",
...@@ -206,7 +205,7 @@ class TestAgentSuccessfulStream: ...@@ -206,7 +205,7 @@ class TestAgentSuccessfulStream:
"payload": { "payload": {
"content": content, "content": content,
"provider": "anthropic", "provider": "anthropic",
"model": KindAnthropicModel.CLAUDE_2_0.value, "model": KindAnthropicModel.CLAUDE_2_1.value,
}, },
}, },
], ],
...@@ -256,7 +255,7 @@ class TestAgentUnsupportedProvider: ...@@ -256,7 +255,7 @@ class TestAgentUnsupportedProvider:
"payload": { "payload": {
"content": text_content, "content": text_content,
"provider": "UNSUPPORTED_PROVIDER", "provider": "UNSUPPORTED_PROVIDER",
"model": "claude-2.0", "model": "claude-2.1",
}, },
}, },
] ]
...@@ -333,7 +332,7 @@ class TestAnthropicInvalidScope: ...@@ -333,7 +332,7 @@ class TestAnthropicInvalidScope:
"payload": { "payload": {
"content": text_content, "content": text_content,
"provider": "anthropic", "provider": "anthropic",
"model": "claude-2.0", "model": "claude-2.1",
}, },
} }
] ]
...@@ -364,7 +363,7 @@ class TestAgentInvalidRequestMissingFields: ...@@ -364,7 +363,7 @@ class TestAgentInvalidRequestMissingFields:
"metadata": {"source": "gitlab-rails-sm"}, "metadata": {"source": "gitlab-rails-sm"},
"payload": { "payload": {
"provider": "anthropic", "provider": "anthropic",
"model": "claude-2.0", "model": "claude-2.1",
}, },
}, },
] ]
...@@ -384,7 +383,7 @@ class TestAgentInvalidRequestMissingFields: ...@@ -384,7 +383,7 @@ class TestAgentInvalidRequestMissingFields:
"type": "missing", "type": "missing",
"loc": ["body", "prompt_components", 0, "payload", "content"], "loc": ["body", "prompt_components", 0, "payload", "content"],
"msg": "Field required", "msg": "Field required",
"input": {"provider": "anthropic", "model": "claude-2.0"}, "input": {"provider": "anthropic", "model": "claude-2.1"},
}, },
] ]
} }
...@@ -415,7 +414,7 @@ class TestAgentInvalidRequestManyPromptComponents: ...@@ -415,7 +414,7 @@ class TestAgentInvalidRequestManyPromptComponents:
"payload": { "payload": {
"content": text_content, "content": text_content,
"provider": "anthropic", "provider": "anthropic",
"model": "claude-2.0", "model": "claude-2.1",
}, },
}, },
{ {
...@@ -427,7 +426,7 @@ class TestAgentInvalidRequestManyPromptComponents: ...@@ -427,7 +426,7 @@ class TestAgentInvalidRequestManyPromptComponents:
"payload": { "payload": {
"content": "SECOND PROMPT COMPONENT (NOT EXPECTED)", "content": "SECOND PROMPT COMPONENT (NOT EXPECTED)",
"provider": "anthropic", "provider": "anthropic",
"model": "claude-2.0", "model": "claude-2.1",
}, },
}, },
] ]
...@@ -451,7 +450,7 @@ class TestAgentInvalidRequestManyPromptComponents: ...@@ -451,7 +450,7 @@ class TestAgentInvalidRequestManyPromptComponents:
"payload": { "payload": {
"content": text_content, "content": text_content,
"provider": "anthropic", "provider": "anthropic",
"model": "claude-2.0", "model": "claude-2.1",
}, },
}, },
{ {
...@@ -460,7 +459,7 @@ class TestAgentInvalidRequestManyPromptComponents: ...@@ -460,7 +459,7 @@ class TestAgentInvalidRequestManyPromptComponents:
"payload": { "payload": {
"content": "SECOND PROMPT COMPONENT (NOT EXPECTED)", "content": "SECOND PROMPT COMPONENT (NOT EXPECTED)",
"provider": "anthropic", "provider": "anthropic",
"model": "claude-2.0", "model": "claude-2.1",
}, },
}, },
], ],
...@@ -525,7 +524,7 @@ class TestAgentUnsuccessfulAnthropicRequest: ...@@ -525,7 +524,7 @@ class TestAgentUnsuccessfulAnthropicRequest:
"payload": { "payload": {
"content": request.getfixturevalue(content_fixture), "content": request.getfixturevalue(content_fixture),
"provider": "anthropic", "provider": "anthropic",
"model": "claude-2.0", "model": "claude-2.1",
}, },
} }
] ]
......
from typing import Any, List, Optional, Type from typing import List, Optional, Type
from unittest.mock import patch from unittest.mock import ANY, patch
import pytest import pytest
from fastapi import HTTPException from fastapi import HTTPException
...@@ -12,8 +12,7 @@ from pydantic import AnyUrl ...@@ -12,8 +12,7 @@ from pydantic import AnyUrl
from ai_gateway.api.v1 import api_router from ai_gateway.api.v1 import api_router
from ai_gateway.config import Config from ai_gateway.config import Config
from ai_gateway.prompts import Prompt from ai_gateway.prompts import Prompt
from ai_gateway.prompts.config.base import PromptConfig from ai_gateway.prompts.typing import ModelMetadata, ModelMetadataType
from ai_gateway.prompts.typing import ModelMetadata, TypeModelFactory
class FakeModel(SimpleChatModel): class FakeModel(SimpleChatModel):
...@@ -25,7 +24,7 @@ class FakeModel(SimpleChatModel): ...@@ -25,7 +24,7 @@ class FakeModel(SimpleChatModel):
return "fake-provider" return "fake-provider"
@property @property
def _identifying_params(self) -> dict[str, Any]: def _identifying_params(self) -> dict[str, ANY]:
return {"model": "fake-model"} return {"model": "fake-model"}
def _call( def _call(
...@@ -33,7 +32,7 @@ class FakeModel(SimpleChatModel): ...@@ -33,7 +32,7 @@ class FakeModel(SimpleChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: ANY,
) -> str: ) -> str:
assert self.expected_message == messages[0].content assert self.expected_message == messages[0].content
...@@ -118,7 +117,7 @@ class TestPrompt: ...@@ -118,7 +117,7 @@ class TestPrompt:
{"name": "John", "age": 20}, {"name": "John", "age": 20},
None, None,
None, None,
("test", "^1.0.0", None), ("test", "^1.0.0", ANY, None),
200, 200,
"Hi John!", "Hi John!",
["1.0.0"], ["1.0.0"],
...@@ -128,7 +127,7 @@ class TestPrompt: ...@@ -128,7 +127,7 @@ class TestPrompt:
{"name": "John", "age": 20}, {"name": "John", "age": 20},
"^2.0.0", "^2.0.0",
None, None,
("test", "^2.0.0", None), ("test", "^2.0.0", ANY, None),
200, 200,
"Hi John!", "Hi John!",
["2.0.0"], ["2.0.0"],
...@@ -146,6 +145,7 @@ class TestPrompt: ...@@ -146,6 +145,7 @@ class TestPrompt:
( (
"test", "test",
"^1.0.0", "^1.0.0",
ANY,
ModelMetadata( ModelMetadata(
name="mistral", name="mistral",
provider="litellm", provider="litellm",
...@@ -162,7 +162,7 @@ class TestPrompt: ...@@ -162,7 +162,7 @@ class TestPrompt:
{"name": "John", "age": 20}, {"name": "John", "age": 20},
"^2.0.0", "^2.0.0",
None, None,
("test", "^2.0.0", None), ("test", "^2.0.0", ANY, None),
400, 400,
{"detail": "No prompt version found matching the query"}, {"detail": "No prompt version found matching the query"},
[], [],
...@@ -172,7 +172,7 @@ class TestPrompt: ...@@ -172,7 +172,7 @@ class TestPrompt:
{"name": "John", "age": 20}, {"name": "John", "age": 20},
None, None,
None, None,
("test", "^1.0.0", None), ("test", "^1.0.0", ANY, None),
404, 404,
{"detail": "Prompt 'test' not found"}, {"detail": "Prompt 'test' not found"},
None, None,
...@@ -182,7 +182,7 @@ class TestPrompt: ...@@ -182,7 +182,7 @@ class TestPrompt:
{"name": "John"}, {"name": "John"},
None, None,
None, None,
("test", "^1.0.0", None), ("test", "^1.0.0", ANY, None),
422, 422,
{ {
"detail": "\"Input to ChatPromptTemplate is missing variables {'age'}. Expected: ['age', 'name'] Received: ['name']" "detail": "\"Input to ChatPromptTemplate is missing variables {'age'}. Expected: ['age', 'name'] Received: ['name']"
...@@ -199,10 +199,10 @@ class TestPrompt: ...@@ -199,10 +199,10 @@ class TestPrompt:
mock_track_internal_event, mock_track_internal_event,
inputs: dict[str, str], inputs: dict[str, str],
prompt_version: Optional[str], prompt_version: Optional[str],
model_metadata: Optional[ModelMetadata], model_metadata: Optional[ModelMetadataType],
expected_get_args: dict, expected_get_args: dict,
expected_status: int, expected_status: int,
expected_response: Any, expected_response: ANY,
compatible_versions: Optional[List[str]], compatible_versions: Optional[List[str]],
): ):
response = mock_client.post( response = mock_client.post(
...@@ -256,7 +256,7 @@ class TestPrompt: ...@@ -256,7 +256,7 @@ class TestPrompt:
}, },
) )
mock_registry_get.assert_called_with("test", "^2.0.0", None) mock_registry_get.assert_called_with("test", "^2.0.0", ANY, None)
assert response.status_code == 200 assert response.status_code == 200
assert response.text == "Hi John!" assert response.text == "Hi John!"
assert response.headers["content-type"] == "text/event-stream; charset=utf-8" assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
...@@ -311,6 +311,6 @@ class TestMisdirectedRequest: ...@@ -311,6 +311,6 @@ class TestMisdirectedRequest:
and model_metadata.model_dump(mode="json"), and model_metadata.model_dump(mode="json"),
}, },
) )
mock_registry_get.assert_called_with("test", "^1.0.0", model_metadata) mock_registry_get.assert_called_with("test", "^1.0.0", ANY, model_metadata)
assert response.status_code == 421 assert response.status_code == 421
assert response.json() == {"detail": "401: Unauthorized"} assert response.json() == {"detail": "401: Unauthorized"}
import json import json
from datetime import datetime from datetime import datetime
from typing import AsyncIterator from typing import AsyncIterator
from unittest.mock import Mock, PropertyMock, call, patch from unittest.mock import ANY, Mock, PropertyMock, call, patch
import pytest import pytest
from gitlab_cloud_connector import CloudConnectorUser, UserClaims from gitlab_cloud_connector import CloudConnectorUser, UserClaims
...@@ -47,7 +47,7 @@ def mock_date(mocker): ...@@ -47,7 +47,7 @@ def mock_date(mocker):
def auth_user(): def auth_user():
return CloudConnectorUser( return CloudConnectorUser(
authenticated=True, authenticated=True,
claims=UserClaims(scopes=["duo_chat"]), claims=UserClaims(scopes=["duo_chat", "amazon_q_integration"]),
) )
...@@ -345,7 +345,7 @@ class TestReActAgentStream: ...@@ -345,7 +345,7 @@ class TestReActAgentStream:
mocked_stream.assert_called_once_with(inputs=agent_inputs) mocked_stream.assert_called_once_with(inputs=agent_inputs)
mock_track_internal_event.assert_called_once_with( mock_track_internal_event.assert_called_once_with(
"request_duo_chat", "request_amazon_q_integration",
category="ai_gateway.api.v2.chat.agent", category="ai_gateway.api.v2.chat.agent",
) )
...@@ -355,12 +355,18 @@ class TestReActAgentStream: ...@@ -355,12 +355,18 @@ class TestReActAgentStream:
[ [
( (
CloudConnectorUser( CloudConnectorUser(
authenticated=True, claims=UserClaims(scopes=["duo_chat"]) authenticated=True,
claims=UserClaims(scopes=["duo_chat", "amazon_q_integration"]),
), ),
AgentRequest(messages=[Message(role=Role.USER, content="Hi")]), AgentRequest(messages=[Message(role=Role.USER, content="Hi")]),
200, 200,
"", "",
[call("request_duo_chat", category="ai_gateway.api.v2.chat.agent")], [
call(
"request_amazon_q_integration",
category="ai_gateway.api.v2.chat.agent",
)
],
), ),
( (
CloudConnectorUser( CloudConnectorUser(
...@@ -374,7 +380,13 @@ class TestReActAgentStream: ...@@ -374,7 +380,13 @@ class TestReActAgentStream:
( (
CloudConnectorUser( CloudConnectorUser(
authenticated=True, authenticated=True,
claims=UserClaims(scopes=["duo_chat", "include_file_context"]), claims=UserClaims(
scopes=[
"duo_chat",
"include_file_context",
"amazon_q_integration",
]
),
), ),
AgentRequest( AgentRequest(
messages=[ messages=[
...@@ -396,7 +408,8 @@ class TestReActAgentStream: ...@@ -396,7 +408,8 @@ class TestReActAgentStream:
), ),
( (
CloudConnectorUser( CloudConnectorUser(
authenticated=True, claims=UserClaims(scopes=["duo_chat"]) authenticated=True,
claims=UserClaims(scopes=["duo_chat", "amazon_q_integration"]),
), ),
AgentRequest( AgentRequest(
messages=[ messages=[
...@@ -409,11 +422,17 @@ class TestReActAgentStream: ...@@ -409,11 +422,17 @@ class TestReActAgentStream:
), ),
200, 200,
"", "",
[call("request_duo_chat", category="ai_gateway.api.v2.chat.agent")], [
call(
"request_amazon_q_integration",
category="ai_gateway.api.v2.chat.agent",
)
],
), ),
( (
CloudConnectorUser( CloudConnectorUser(
authenticated=True, claims=UserClaims(scopes=["duo_chat"]) authenticated=True,
claims=UserClaims(scopes=["duo_chat", "amazon_q_integration"]),
), ),
AgentRequest( AgentRequest(
messages=[ messages=[
...@@ -426,11 +445,17 @@ class TestReActAgentStream: ...@@ -426,11 +445,17 @@ class TestReActAgentStream:
), ),
200, 200,
"", "",
[call("request_duo_chat", category="ai_gateway.api.v2.chat.agent")], [
call(
"request_amazon_q_integration",
category="ai_gateway.api.v2.chat.agent",
)
],
), ),
( (
CloudConnectorUser( CloudConnectorUser(
authenticated=True, claims=UserClaims(scopes=["duo_chat"]) authenticated=True,
claims=UserClaims(scopes=["duo_chat", "amazon_q_integration"]),
), ),
AgentRequest( AgentRequest(
messages=[ messages=[
...@@ -554,6 +579,6 @@ class TestChatAgent: ...@@ -554,6 +579,6 @@ class TestChatAgent:
assert actual_actions == expected_actions assert actual_actions == expected_actions
mock_track_internal_event.assert_called_once_with( mock_track_internal_event.assert_called_once_with(
"request_duo_chat", "request_amazon_q_integration",
category="ai_gateway.api.v2.chat.agent", category="ai_gateway.api.v2.chat.agent",
) )
...@@ -1317,7 +1317,7 @@ class TestCodeGenerations: ...@@ -1317,7 +1317,7 @@ class TestCodeGenerations:
"foo", "foo",
"bar", "bar",
"anthropic", "anthropic",
"claude-2.0", "claude-2.1",
None, None,
None, None,
"foo", "foo",
...@@ -1343,7 +1343,7 @@ class TestCodeGenerations: ...@@ -1343,7 +1343,7 @@ class TestCodeGenerations:
"foo", "foo",
None, None,
"anthropic", "anthropic",
"claude-2.0", "claude-2.1",
None, None,
None, None,
"foo", "foo",
......