Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • ggillies/ai-assist
  • gitlab-org/modelops/applied-ml/code-suggestions/ai-assist
  • golmohamadirasuol139401/ai-assist
  • alejandro/ai-assist
  • JackSkylark/ai-assist
  • knottos/ai-assist
  • nanmu42/model-gateway
  • shinya.maeda/ai-assist
  • abitrolly/ai-assist
  • gitlab-community/modelops/applied-ml/code-suggestions/ai-assist
  • tanducque/ai-assist
  • gitlab-renovate-forks/ai-assist
  • armbiant/gitlab-ai-gateway
  • mike22664/ai-assist
14 results
Show changes
Commits on Source (63)
Showing
with 283 additions and 115 deletions
...@@ -61,13 +61,13 @@ include: ...@@ -61,13 +61,13 @@ include:
inputs: inputs:
runway_service_id: ai-gateway runway_service_id: ai-gateway
image: "$CI_REGISTRY_IMAGE/model-gateway:${CI_COMMIT_SHORT_SHA}" image: "$CI_REGISTRY_IMAGE/model-gateway:${CI_COMMIT_SHORT_SHA}"
runway_version: v3.55.6 runway_version: v3.58.3
- project: "gitlab-com/gl-infra/platform/runway/runwayctl" - project: "gitlab-com/gl-infra/platform/runway/runwayctl"
file: "ci-tasks/service-project/runway.yml" file: "ci-tasks/service-project/runway.yml"
inputs: inputs:
runway_service_id: ai-gateway-custom runway_service_id: ai-gateway-custom
image: "$SELF_HOSTED_TARGET_IMAGE" image: "$SELF_HOSTED_TARGET_IMAGE"
runway_version: v3.55.6 runway_version: v3.58.3
- component: ${CI_SERVER_FQDN}/gitlab-org/components/danger-review/danger-review@2.0.0 - component: ${CI_SERVER_FQDN}/gitlab-org/components/danger-review/danger-review@2.0.0
rules: rules:
- if: $CI_SERVER_HOST == "gitlab.com" - if: $CI_SERVER_HOST == "gitlab.com"
......
include: include:
# see https://gitlab.com/gitlab-com/gl-infra/common-ci-tasks/-/blob/main/oidc.md # see https://gitlab.com/gitlab-com/gl-infra/common-ci-tasks/-/blob/main/oidc.md
- project: 'gitlab-com/gl-infra/common-ci-tasks' - project: 'gitlab-com/gl-infra/common-ci-tasks'
ref: v2.61 # renovate:managed ref: v2.62 # renovate:managed
file: 'oidc.yml' file: 'oidc.yml'
.ingest-base: .ingest-base:
......
lefthook 1.10.10 # datasource=github-releases depName=evilmartians/lefthook lefthook 1.11.0 # datasource=github-releases depName=evilmartians/lefthook
python 3.11.11 # datasource=github-tags depName=python/cpython python 3.11.11 # datasource=github-tags depName=python/cpython
gcloud 428.0.0 # datasource=github-tags depName=GoogleCloudPlatform/cloud-sdk-docker gcloud 428.0.0 # datasource=github-tags depName=GoogleCloudPlatform/cloud-sdk-docker
poetry 2.0.1 # datasource=pypi depName=poetry poetry 2.0.1 # datasource=pypi depName=poetry
......
...@@ -5,7 +5,10 @@ from gitlab_cloud_connector import GitLabFeatureCategory, GitLabUnitPrimitive ...@@ -5,7 +5,10 @@ from gitlab_cloud_connector import GitLabFeatureCategory, GitLabUnitPrimitive
from ai_gateway.api.auth_utils import StarletteUser, get_current_user from ai_gateway.api.auth_utils import StarletteUser, get_current_user
from ai_gateway.api.feature_category import feature_category from ai_gateway.api.feature_category import feature_category
from ai_gateway.api.v1.amazon_q.typing import ApplicationRequest from ai_gateway.api.v1.amazon_q.typing import (
ApplicationDeleteRequest,
ApplicationRequest,
)
from ai_gateway.async_dependency_resolver import ( from ai_gateway.async_dependency_resolver import (
get_amazon_q_client_factory, get_amazon_q_client_factory,
get_internal_event_client, get_internal_event_client,
...@@ -54,3 +57,38 @@ async def oauth_create_application( ...@@ -54,3 +57,38 @@ async def oauth_create_application(
raise e.to_http_exception() raise e.to_http_exception()
return Response(status_code=status.HTTP_204_NO_CONTENT) return Response(status_code=status.HTTP_204_NO_CONTENT)
@router.post("/application/delete")
@feature_category(GitLabFeatureCategory.DUO_CHAT)
async def oauth_delete_application(
request: Request,
application_request: ApplicationDeleteRequest,
current_user: Annotated[StarletteUser, Depends(get_current_user)],
internal_event_client: InternalEventsClient = Depends(get_internal_event_client),
amazon_q_client_factory: AmazonQClientFactory = Depends(
get_amazon_q_client_factory
),
):
if not current_user.can(GitLabUnitPrimitive.AMAZON_Q_INTEGRATION):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Unauthorized to perform action",
)
internal_event_client.track_event(
f"request_{GitLabUnitPrimitive.AMAZON_Q_INTEGRATION}",
category=__name__,
)
try:
q_client = amazon_q_client_factory.get_client(
current_user=current_user,
role_arn=application_request.role_arn,
)
q_client.delete_o_auth_app_connection()
except AWSException as e:
raise e.to_http_exception()
return Response(status_code=status.HTTP_204_NO_CONTENT)
...@@ -4,6 +4,7 @@ from pydantic import BaseModel, Field, StringConstraints ...@@ -4,6 +4,7 @@ from pydantic import BaseModel, Field, StringConstraints
__all__ = [ __all__ = [
"ApplicationRequest", "ApplicationRequest",
"ApplicationDeleteRequest",
"EventRequest", "EventRequest",
] ]
...@@ -16,6 +17,10 @@ class ApplicationRequest(BaseModel): ...@@ -16,6 +17,10 @@ class ApplicationRequest(BaseModel):
role_arn: Annotated[str, StringConstraints(max_length=2048)] role_arn: Annotated[str, StringConstraints(max_length=2048)]
class ApplicationDeleteRequest(BaseModel):
role_arn: Annotated[str, StringConstraints(max_length=2048)]
class EventRequestPayload(BaseModel): class EventRequestPayload(BaseModel):
command: Annotated[str, StringConstraints(max_length=255)] command: Annotated[str, StringConstraints(max_length=255)]
source: Annotated[str, StringConstraints(max_length=255)] source: Annotated[str, StringConstraints(max_length=255)]
......
...@@ -9,7 +9,7 @@ from starlette.responses import StreamingResponse ...@@ -9,7 +9,7 @@ from starlette.responses import StreamingResponse
from ai_gateway.api.auth_utils import StarletteUser, get_current_user from ai_gateway.api.auth_utils import StarletteUser, get_current_user
from ai_gateway.api.feature_category import feature_category from ai_gateway.api.feature_category import feature_category
from ai_gateway.async_dependency_resolver import get_container_application from ai_gateway.async_dependency_resolver import get_prompt_registry
from ai_gateway.prompts import BasePromptRegistry, Prompt from ai_gateway.prompts import BasePromptRegistry, Prompt
from ai_gateway.prompts.typing import TypeModelMetadata from ai_gateway.prompts.typing import TypeModelMetadata
...@@ -28,10 +28,6 @@ class PromptRequest(BaseModel): ...@@ -28,10 +28,6 @@ class PromptRequest(BaseModel):
router = APIRouter() router = APIRouter()
async def get_prompt_registry():
yield get_container_application().pkg_prompts.prompt_registry()
@router.post( @router.post(
"/{prompt_id:path}", "/{prompt_id:path}",
response_model=str, response_model=str,
......
...@@ -2,7 +2,11 @@ from datetime import datetime ...@@ -2,7 +2,11 @@ from datetime import datetime
from typing import Annotated, AsyncIterator from typing import Annotated, AsyncIterator
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from gitlab_cloud_connector import GitLabFeatureCategory, GitLabUnitPrimitive from gitlab_cloud_connector import (
GitLabFeatureCategory,
GitLabUnitPrimitive,
WrongUnitPrimitives,
)
from starlette.responses import StreamingResponse from starlette.responses import StreamingResponse
from ai_gateway.api.auth_utils import StarletteUser, get_current_user from ai_gateway.api.auth_utils import StarletteUser, get_current_user
...@@ -12,16 +16,18 @@ from ai_gateway.api.v2.chat.typing import AgentRequest ...@@ -12,16 +16,18 @@ from ai_gateway.api.v2.chat.typing import AgentRequest
from ai_gateway.async_dependency_resolver import ( from ai_gateway.async_dependency_resolver import (
get_container_application, get_container_application,
get_internal_event_client, get_internal_event_client,
get_prompt_registry,
) )
from ai_gateway.chat.agents import ( from ai_gateway.chat.agents import (
AdditionalContext,
AgentStep, AgentStep,
AgentToolAction, AgentToolAction,
ReActAgent,
ReActAgentInputs, ReActAgentInputs,
TypeAgentEvent, TypeAgentEvent,
) )
from ai_gateway.chat.executor import GLAgentRemoteExecutor from ai_gateway.chat.executor import GLAgentRemoteExecutor
from ai_gateway.internal_events import InternalEventsClient from ai_gateway.internal_events import InternalEventsClient
from ai_gateway.prompts import BasePromptRegistry
__all__ = [ __all__ = [
"router", "router",
...@@ -34,53 +40,57 @@ request_log = get_request_logger("chat") ...@@ -34,53 +40,57 @@ request_log = get_request_logger("chat")
router = APIRouter() router = APIRouter()
async def get_gl_agent_remote_executor(): async def get_gl_agent_remote_executor_factory():
yield get_container_application().chat.gl_agent_remote_executor() yield get_container_application().chat.gl_agent_remote_executor_factory
def authorize_additional_context( def authorize_additional_context(
current_user: StarletteUser, current_user: StarletteUser,
additional_context: AdditionalContext, agent_request: AgentRequest,
internal_event_client: InternalEventsClient, internal_event_client: InternalEventsClient,
): ):
unit_primitive = GitLabUnitPrimitive[ if agent_request.messages:
f"include_{additional_context.category}_context".upper() for message in agent_request.messages:
] if message.additional_context:
if current_user.can(unit_primitive): for additional_context in message.additional_context:
internal_event_client.track_event( unit_primitive = GitLabUnitPrimitive[
f"request_{unit_primitive}", f"include_{additional_context.category}_context".upper()
category=__name__, ]
) if current_user.can(unit_primitive):
else: internal_event_client.track_event(
raise HTTPException( f"request_{unit_primitive}",
status_code=status.HTTP_403_FORBIDDEN, category=__name__,
detail=f"Unauthorized to access {unit_primitive}", )
) else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
def authorize_agent_request( detail=f"Unauthorized to access {unit_primitive}",
)
def get_agent(
current_user: StarletteUser, current_user: StarletteUser,
agent_request: AgentRequest, agent_request: AgentRequest,
internal_event_client: InternalEventsClient, prompt_registry: BasePromptRegistry,
): ) -> ReActAgent:
if current_user.can(GitLabUnitPrimitive.DUO_CHAT): try:
internal_event_client.track_event( if agent_request.model_metadata:
f"request_{GitLabUnitPrimitive.DUO_CHAT}", agent_request.model_metadata.add_user(current_user)
category=__name__,
prompt = prompt_registry.get_on_behalf(
current_user,
"chat/react",
None,
agent_request.model_metadata,
__name__,
) )
else: except WrongUnitPrimitives:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Unauthorized to access duo chat", detail="Unauthorized to access duo chat",
) )
if agent_request.messages: return prompt
for message in agent_request.messages:
if message.additional_context:
for ctx in message.additional_context:
authorize_additional_context(
current_user, ctx, internal_event_client
)
@router.post("/agent") @router.post("/agent")
...@@ -89,12 +99,15 @@ async def chat( ...@@ -89,12 +99,15 @@ async def chat(
request: Request, request: Request,
agent_request: AgentRequest, agent_request: AgentRequest,
current_user: Annotated[StarletteUser, Depends(get_current_user)], current_user: Annotated[StarletteUser, Depends(get_current_user)],
gl_agent_remote_executor: GLAgentRemoteExecutor[ prompt_registry: Annotated[BasePromptRegistry, Depends(get_prompt_registry)],
gl_agent_remote_executor_factory: GLAgentRemoteExecutor[
ReActAgentInputs, TypeAgentEvent ReActAgentInputs, TypeAgentEvent
] = Depends(get_gl_agent_remote_executor), ] = Depends(get_gl_agent_remote_executor_factory),
internal_event_client: InternalEventsClient = Depends(get_internal_event_client), internal_event_client: InternalEventsClient = Depends(get_internal_event_client),
): ):
authorize_agent_request(current_user, agent_request, internal_event_client) agent = get_agent(current_user, agent_request, prompt_registry)
authorize_additional_context(current_user, agent_request, internal_event_client)
async def _stream_handler(stream_events: AsyncIterator[TypeAgentEvent]): async def _stream_handler(stream_events: AsyncIterator[TypeAgentEvent]):
async for event in stream_events: async for event in stream_events:
...@@ -118,12 +131,12 @@ async def chat( ...@@ -118,12 +131,12 @@ async def chat(
inputs = ReActAgentInputs( inputs = ReActAgentInputs(
messages=agent_request.messages, messages=agent_request.messages,
agent_scratchpad=scratchpad, agent_scratchpad=scratchpad,
model_metadata=agent_request.model_metadata,
unavailable_resources=agent_request.unavailable_resources, unavailable_resources=agent_request.unavailable_resources,
current_date=datetime.now().strftime("%A, %B %d, %Y"), current_date=datetime.now().strftime("%A, %B %d, %Y"),
) )
gl_version = request.headers.get(X_GITLAB_VERSION_HEADER, "") gl_version = request.headers.get(X_GITLAB_VERSION_HEADER, "")
gl_agent_remote_executor = gl_agent_remote_executor_factory(agent=agent)
gl_agent_remote_executor.on_behalf(current_user, gl_version) gl_agent_remote_executor.on_behalf(current_user, gl_version)
request_log.info("Request to V2 Chat Agent", source=__name__, inputs=inputs) request_log.info("Request to V2 Chat Agent", source=__name__, inputs=inputs)
......
...@@ -33,6 +33,7 @@ from ai_gateway.async_dependency_resolver import ( ...@@ -33,6 +33,7 @@ from ai_gateway.async_dependency_resolver import (
get_code_suggestions_completions_anthropic_provider, get_code_suggestions_completions_anthropic_provider,
get_code_suggestions_completions_fireworks_qwen_factory_provider, get_code_suggestions_completions_fireworks_qwen_factory_provider,
get_code_suggestions_completions_litellm_factory_provider, get_code_suggestions_completions_litellm_factory_provider,
get_code_suggestions_completions_litellm_vertex_codestral_factory_provider,
get_code_suggestions_completions_vertex_legacy_provider, get_code_suggestions_completions_vertex_legacy_provider,
get_code_suggestions_generations_agent_factory_provider, get_code_suggestions_generations_agent_factory_provider,
get_code_suggestions_generations_anthropic_chat_factory_provider, get_code_suggestions_generations_anthropic_chat_factory_provider,
...@@ -54,11 +55,16 @@ from ai_gateway.code_suggestions.base import CodeSuggestionsOutput ...@@ -54,11 +55,16 @@ from ai_gateway.code_suggestions.base import CodeSuggestionsOutput
from ai_gateway.code_suggestions.processing.base import ModelEngineOutput from ai_gateway.code_suggestions.processing.base import ModelEngineOutput
from ai_gateway.code_suggestions.processing.ops import lang_from_filename from ai_gateway.code_suggestions.processing.ops import lang_from_filename
from ai_gateway.config import Config from ai_gateway.config import Config
from ai_gateway.feature_flags.context import current_feature_flag_context from ai_gateway.feature_flags.context import (
FeatureFlag,
current_feature_flag_context,
is_feature_enabled,
)
from ai_gateway.instrumentators.base import TelemetryInstrumentator from ai_gateway.instrumentators.base import TelemetryInstrumentator
from ai_gateway.internal_events import InternalEventsClient from ai_gateway.internal_events import InternalEventsClient
from ai_gateway.models import KindAnthropicModel, KindModelProvider from ai_gateway.models import KindAnthropicModel, KindModelProvider
from ai_gateway.models.base import TokensConsumptionMetadata from ai_gateway.models.base import TokensConsumptionMetadata
from ai_gateway.models.vertex_text import KindVertexTextModel
from ai_gateway.prompts import BasePromptRegistry from ai_gateway.prompts import BasePromptRegistry
from ai_gateway.prompts.typing import ModelMetadata from ai_gateway.prompts.typing import ModelMetadata
from ai_gateway.structured_logging import get_request_logger from ai_gateway.structured_logging import get_request_logger
...@@ -108,6 +114,9 @@ async def completions( ...@@ -108,6 +114,9 @@ async def completions(
completions_amazon_q_factory: Factory[CodeCompletions] = Depends( completions_amazon_q_factory: Factory[CodeCompletions] = Depends(
get_code_suggestions_completions_amazon_q_factory_provider get_code_suggestions_completions_amazon_q_factory_provider
), ),
completions_litellm_vertex_codestral_factory: Factory[CodeCompletions] = Depends(
get_code_suggestions_completions_litellm_vertex_codestral_factory_provider
),
completions_agent_factory: Factory[CodeCompletions] = Depends( completions_agent_factory: Factory[CodeCompletions] = Depends(
get_code_suggestions_completions_agent_factory_provider get_code_suggestions_completions_agent_factory_provider
), ),
...@@ -127,6 +136,7 @@ async def completions( ...@@ -127,6 +136,7 @@ async def completions(
completions_fireworks_qwen_factory, completions_fireworks_qwen_factory,
completions_agent_factory, completions_agent_factory,
completions_amazon_q_factory, completions_amazon_q_factory,
completions_litellm_vertex_codestral_factory,
internal_event_client, internal_event_client,
) )
...@@ -438,6 +448,7 @@ def _build_code_completions( ...@@ -438,6 +448,7 @@ def _build_code_completions(
completions_fireworks_qwen_factory: Factory[CodeCompletions], completions_fireworks_qwen_factory: Factory[CodeCompletions],
completions_agent_factory: Factory[CodeCompletions], completions_agent_factory: Factory[CodeCompletions],
completions_amazon_q_factory: Factory[CodeCompletions], completions_amazon_q_factory: Factory[CodeCompletions],
completions_litellm_vertex_codestral_factory: Factory[CodeCompletions],
internal_event_client: InternalEventsClient, internal_event_client: InternalEventsClient,
) -> tuple[CodeCompletions | CodeCompletionsLegacy, dict]: ) -> tuple[CodeCompletions | CodeCompletionsLegacy, dict]:
kwargs = {} kwargs = {}
...@@ -464,7 +475,10 @@ def _build_code_completions( ...@@ -464,7 +475,10 @@ def _build_code_completions(
) )
return code_completions, kwargs return code_completions, kwargs
elif payload.model_provider == KindModelProvider.FIREWORKS: elif payload.model_provider == KindModelProvider.FIREWORKS or (
not _allow_vertex_codestral()
and is_feature_enabled(FeatureFlag.DISABLE_CODE_GECKO_DEFAULT)
):
FireworksHandler(payload, request, kwargs).update_completion_params() FireworksHandler(payload, request, kwargs).update_completion_params()
code_completions = _resolve_code_completions_litellm( code_completions = _resolve_code_completions_litellm(
payload=payload, payload=payload,
...@@ -482,6 +496,35 @@ def _build_code_completions( ...@@ -482,6 +496,35 @@ def _build_code_completions(
model__current_user=current_user, model__current_user=current_user,
model__role_arn=payload.role_arn, model__role_arn=payload.role_arn,
) )
elif (
(
(
payload.model_provider == KindModelProvider.VERTEX_AI
and payload.model_name == KindVertexTextModel.CODESTRAL_2501
)
or is_feature_enabled(FeatureFlag.DISABLE_CODE_GECKO_DEFAULT)
)
# Codestral is currently not supported in asia-* locations
and _allow_vertex_codestral()
):
code_completions = _resolve_code_completions_vertex_codestral(
payload=payload,
completions_litellm_vertex_codestral_factory=completions_litellm_vertex_codestral_factory,
)
# We need to pass this here since litellm.LiteLlmTextGenModel
# sets the default temperature and max_output_tokens in the `generate` function signature
# To override those values, the kwargs passed to `generate` is updated here
# For further details, see:
# https://gitlab.com/gitlab-org/modelops/applied-ml/code-suggestions/ai-assist/-/merge_requests/1172#note_2060587592
#
# The temperature value is taken from Mistral's docs: https://docs.mistral.ai/api/#operation/createFIMCompletion
# context_max_percent is set to 0.3 to limit the amount of context right now because latency increases with larger context
kwargs.update(
{"temperature": 0.7, "max_output_tokens": 64, "context_max_percent": 0.3}
)
if payload.context:
kwargs.update({"code_context": [ctx.content for ctx in payload.context]})
else: else:
code_completions = completions_legacy_factory() code_completions = completions_legacy_factory()
LegacyHandler(payload, request, kwargs).update_completion_params() LegacyHandler(payload, request, kwargs).update_completion_params()
...@@ -502,6 +545,19 @@ def _build_code_completions( ...@@ -502,6 +545,19 @@ def _build_code_completions(
return code_completions, kwargs return code_completions, kwargs
def _resolve_code_completions_vertex_codestral(
payload: SuggestionsRequest,
completions_litellm_vertex_codestral_factory: Factory[CodeCompletions],
) -> CodeCompletions:
if payload.prompt_version == 2 and payload.prompt is not None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="You cannot specify a prompt with the given provider and model combination",
)
return completions_litellm_vertex_codestral_factory()
def _resolve_agent_code_completions( def _resolve_agent_code_completions(
model_metadata: ModelMetadata, model_metadata: ModelMetadata,
current_user: StarletteUser, current_user: StarletteUser,
...@@ -560,6 +616,14 @@ def _generation_suggestion_choices(text: str) -> list: ...@@ -560,6 +616,14 @@ def _generation_suggestion_choices(text: str) -> list:
return [SuggestionsResponse.Choice(text=text)] if text else [] return [SuggestionsResponse.Choice(text=text)] if text else []
def _allow_vertex_codestral():
return not _get_gcp_location().startswith("asia-")
def _get_gcp_location():
return Config().google_cloud_platform.location
async def _handle_stream( async def _handle_stream(
response: AsyncIterator[CodeSuggestionsChunk], response: AsyncIterator[CodeSuggestionsChunk],
) -> StreamSuggestionsResponse: ) -> StreamSuggestionsResponse:
......
...@@ -3,6 +3,7 @@ from fastapi import Request ...@@ -3,6 +3,7 @@ from fastapi import Request
from ai_gateway.api.middleware import X_GITLAB_LANGUAGE_SERVER_VERSION from ai_gateway.api.middleware import X_GITLAB_LANGUAGE_SERVER_VERSION
from ai_gateway.api.v2.code.typing import CompletionsRequestWithVersion from ai_gateway.api.v2.code.typing import CompletionsRequestWithVersion
from ai_gateway.code_suggestions.language_server import LanguageServerVersion from ai_gateway.code_suggestions.language_server import LanguageServerVersion
from ai_gateway.models.base import KindModelProvider
class BaseModelProviderHandler: class BaseModelProviderHandler:
...@@ -47,6 +48,12 @@ class FireworksHandler(BaseModelProviderHandler): ...@@ -47,6 +48,12 @@ class FireworksHandler(BaseModelProviderHandler):
if self.payload.context: if self.payload.context:
self._update_code_context() self._update_code_context()
if not self.payload.model_provider:
self.payload.model_provider = KindModelProvider.FIREWORKS
if not self.payload.model_name:
self.payload.model_name = "qwen2p5-coder-7b"
class LegacyHandler(BaseModelProviderHandler): class LegacyHandler(BaseModelProviderHandler):
def update_completion_params(self): def update_completion_params(self):
......
...@@ -64,6 +64,10 @@ async def get_code_suggestions_completions_amazon_q_factory_provider(): ...@@ -64,6 +64,10 @@ async def get_code_suggestions_completions_amazon_q_factory_provider():
yield get_container_application().code_suggestions.completions.amazon_q_factory yield get_container_application().code_suggestions.completions.amazon_q_factory
async def get_code_suggestions_completions_litellm_vertex_codestral_factory_provider():
yield get_container_application().code_suggestions.completions.litellm_vertex_codestral_factory
async def get_code_suggestions_completions_agent_factory_provider(): async def get_code_suggestions_completions_agent_factory_provider():
yield get_container_application().code_suggestions.completions.agent_factory yield get_container_application().code_suggestions.completions.agent_factory
...@@ -150,3 +154,7 @@ async def get_amazon_q_client_factory( ...@@ -150,3 +154,7 @@ async def get_amazon_q_client_factory(
], ],
): ):
return amazon_q_client_factory return amazon_q_client_factory
async def get_prompt_registry():
yield get_container_application().pkg_prompts.prompt_registry()
...@@ -23,7 +23,7 @@ from ai_gateway.chat.tools.base import BaseTool ...@@ -23,7 +23,7 @@ from ai_gateway.chat.tools.base import BaseTool
from ai_gateway.feature_flags import FeatureFlag, is_feature_enabled from ai_gateway.feature_flags import FeatureFlag, is_feature_enabled
from ai_gateway.models.base_chat import Role from ai_gateway.models.base_chat import Role
from ai_gateway.prompts import Prompt, jinja2_formatter from ai_gateway.prompts import Prompt, jinja2_formatter
from ai_gateway.prompts.typing import TypeModelMetadata from ai_gateway.prompts.config import ModelClassProvider, ModelConfig
__all__ = [ __all__ = [
"ReActAgentInputs", "ReActAgentInputs",
...@@ -41,7 +41,6 @@ request_log = get_request_logger("react") ...@@ -41,7 +41,6 @@ request_log = get_request_logger("react")
class ReActAgentInputs(BaseModel): class ReActAgentInputs(BaseModel):
messages: list[Message] messages: list[Message]
agent_scratchpad: Optional[list[AgentStep]] = None agent_scratchpad: Optional[list[AgentStep]] = None
model_metadata: Optional[TypeModelMetadata] = None
unavailable_resources: Optional[list[str]] = None unavailable_resources: Optional[list[str]] = None
tools: Optional[list[BaseTool]] = None tools: Optional[list[BaseTool]] = None
current_date: Optional[str] = None current_date: Optional[str] = None
...@@ -131,8 +130,9 @@ class ReActPlainTextParser(BaseCumulativeTransformOutputParser): ...@@ -131,8 +130,9 @@ class ReActPlainTextParser(BaseCumulativeTransformOutputParser):
class ReActPromptTemplate(Runnable[ReActAgentInputs, PromptValue]): class ReActPromptTemplate(Runnable[ReActAgentInputs, PromptValue]):
def __init__(self, prompt_template: dict[str, str]): def __init__(self, prompt_template: dict[str, str], model_config: ModelConfig):
self.prompt_template = prompt_template self.prompt_template = prompt_template
self.model_config = model_config
def invoke( def invoke(
self, self,
...@@ -151,7 +151,8 @@ class ReActPromptTemplate(Runnable[ReActAgentInputs, PromptValue]): ...@@ -151,7 +151,8 @@ class ReActPromptTemplate(Runnable[ReActAgentInputs, PromptValue]):
) )
if ( if (
is_feature_enabled(FeatureFlag.ENABLE_ANTHROPIC_PROMPT_CACHING) is_feature_enabled(FeatureFlag.ENABLE_ANTHROPIC_PROMPT_CACHING)
and input.model_metadata is None and self.model_config.params.model_class_provider
== ModelClassProvider.ANTHROPIC
): ):
content = [ content = [
{ {
...@@ -171,7 +172,15 @@ class ReActPromptTemplate(Runnable[ReActAgentInputs, PromptValue]): ...@@ -171,7 +172,15 @@ class ReActPromptTemplate(Runnable[ReActAgentInputs, PromptValue]):
) )
) )
elif m.role is Role.ASSISTANT: elif m.role is Role.ASSISTANT:
messages.append(AIMessage(m.content)) messages.append(
AIMessage(
jinja2_formatter(
self.prompt_template["assistant"],
agent_scratchpad=m.agent_scratchpad,
final_answer=m.content,
)
)
)
else: else:
raise ValueError("Unsupported message") raise ValueError("Unsupported message")
...@@ -201,8 +210,10 @@ class ReActAgent(Prompt[ReActAgentInputs, TypeAgentEvent]): ...@@ -201,8 +210,10 @@ class ReActAgent(Prompt[ReActAgentInputs, TypeAgentEvent]):
return chain | ReActPlainTextParser() return chain | ReActPlainTextParser()
@classmethod @classmethod
def _build_prompt_template(cls, prompt_template: dict[str, str]) -> Runnable: def _build_prompt_template(
return ReActPromptTemplate(prompt_template) cls, prompt_template: dict[str, str], model_config: ModelConfig
) -> Runnable:
return ReActPromptTemplate(prompt_template, model_config)
async def astream( async def astream(
self, self,
......
import json import json
from typing import Literal, Optional, TypeVar from typing import Literal, Optional, Self, TypeVar
from pydantic import BaseModel import fastapi
from pydantic import BaseModel, model_validator
from ai_gateway.chat.context.current_page import CurrentPageContext from ai_gateway.chat.context.current_page import CurrentPageContext
from ai_gateway.models.base_chat import Role from ai_gateway.models.base_chat import Role
...@@ -59,7 +60,7 @@ TypeAgentInputs = TypeVar("TypeAgentInputs") ...@@ -59,7 +60,7 @@ TypeAgentInputs = TypeVar("TypeAgentInputs")
class AgentStep(BaseModel): class AgentStep(BaseModel):
action: AgentToolAction action: Optional[AgentToolAction] = None
observation: str observation: str
...@@ -85,4 +86,13 @@ class Message(BaseModel): ...@@ -85,4 +86,13 @@ class Message(BaseModel):
context: Optional[CurrentPageContext] = None context: Optional[CurrentPageContext] = None
current_file: Optional[CurrentFile] = None current_file: Optional[CurrentFile] = None
additional_context: Optional[list[AdditionalContext]] = None additional_context: Optional[list[AdditionalContext]] = None
resource_content: Optional[str] = None agent_scratchpad: Optional[list[AgentStep]] = None
@model_validator(mode="after")
def validate_agent_scratchpad_role(self) -> Self:
if self.agent_scratchpad is not None and self.role != Role.ASSISTANT:
raise fastapi.HTTPException(
status_code=400,
detail="agent_scratchpad can only be present when role is ASSISTANT",
)
return self
from typing import TYPE_CHECKING
from dependency_injector import containers, providers from dependency_injector import containers, providers
from ai_gateway.chat.agents import ReActAgent, TypeAgentEvent from ai_gateway.chat.executor import GLAgentRemoteExecutor
from ai_gateway.chat.executor import GLAgentRemoteExecutor, TypeAgentFactory
from ai_gateway.chat.toolset import DuoChatToolsRegistry from ai_gateway.chat.toolset import DuoChatToolsRegistry
if TYPE_CHECKING:
from ai_gateway.prompts import BasePromptRegistry
__all__ = [ __all__ = [
"ContainerChat", "ContainerChat",
] ]
def _react_agent_factory(
prompt_registry: "BasePromptRegistry",
) -> TypeAgentFactory[TypeAgentEvent]:
def _fn(**kwargs) -> ReActAgent:
return prompt_registry.get("chat/react", "^1.0.0", **kwargs)
return _fn
class ContainerChat(containers.DeclarativeContainer): class ContainerChat(containers.DeclarativeContainer):
prompts = providers.DependenciesContainer() prompts = providers.DependenciesContainer()
models = providers.DependenciesContainer() models = providers.DependenciesContainer()
...@@ -34,11 +19,6 @@ class ContainerChat(containers.DeclarativeContainer): ...@@ -34,11 +19,6 @@ class ContainerChat(containers.DeclarativeContainer):
_anthropic_claude_llm_factory = providers.Factory(models.anthropic_claude) _anthropic_claude_llm_factory = providers.Factory(models.anthropic_claude)
_anthropic_claude_chat_factory = providers.Factory(models.anthropic_claude_chat) _anthropic_claude_chat_factory = providers.Factory(models.anthropic_claude_chat)
_react_agent_factory = providers.Factory(
_react_agent_factory,
prompt_registry=prompts.prompt_registry,
)
# We need to resolve the model based on model name provided in request payload # We need to resolve the model based on model name provided in request payload
# Hence, `models._anthropic_claude` and `models._anthropic_claude_chat_factory` are only partially applied here. # Hence, `models._anthropic_claude` and `models._anthropic_claude_chat_factory` are only partially applied here.
anthropic_claude_factory = providers.FactoryAggregate( anthropic_claude_factory = providers.FactoryAggregate(
...@@ -52,9 +32,8 @@ class ContainerChat(containers.DeclarativeContainer): ...@@ -52,9 +32,8 @@ class ContainerChat(containers.DeclarativeContainer):
self_hosted_documentation_enabled=config.custom_models.enabled, self_hosted_documentation_enabled=config.custom_models.enabled,
) )
gl_agent_remote_executor = providers.Factory( gl_agent_remote_executor_factory = providers.Factory(
GLAgentRemoteExecutor, GLAgentRemoteExecutor,
agent_factory=_react_agent_factory,
tools_registry=_tools_registry, tools_registry=_tools_registry,
internal_event_client=internal_event.client, internal_event_client=internal_event.client,
) )
from typing import AsyncIterator, Generic, Protocol from typing import AsyncIterator, Generic
import starlette_context import starlette_context
from langchain_core.runnables import Runnable
from ai_gateway.api.auth_utils import StarletteUser from ai_gateway.api.auth_utils import StarletteUser
from ai_gateway.chat.agents import ( from ai_gateway.chat.agents import (
...@@ -14,10 +13,8 @@ from ai_gateway.chat.agents import ( ...@@ -14,10 +13,8 @@ from ai_gateway.chat.agents import (
from ai_gateway.chat.base import BaseToolsRegistry from ai_gateway.chat.base import BaseToolsRegistry
from ai_gateway.chat.tools import BaseTool from ai_gateway.chat.tools import BaseTool
from ai_gateway.internal_events import InternalEventsClient from ai_gateway.internal_events import InternalEventsClient
from ai_gateway.prompts.typing import TypeModelMetadata
__all__ = [ __all__ = [
"TypeAgentFactory",
"GLAgentRemoteExecutor", "GLAgentRemoteExecutor",
] ]
...@@ -28,23 +25,15 @@ _REACT_AGENT_AVAILABLE_TOOL_NAMES_CONTEXT_KEY = "duo_chat.agent_available_tools" ...@@ -28,23 +25,15 @@ _REACT_AGENT_AVAILABLE_TOOL_NAMES_CONTEXT_KEY = "duo_chat.agent_available_tools"
log = get_request_logger("gl_agent_remote_executor") log = get_request_logger("gl_agent_remote_executor")
class TypeAgentFactory(Protocol[TypeAgentEvent]):
def __call__(
self,
*,
model_metadata: TypeModelMetadata,
) -> Runnable[TypeAgentInputs, TypeAgentEvent]: ...
class GLAgentRemoteExecutor(Generic[TypeAgentInputs, TypeAgentEvent]): class GLAgentRemoteExecutor(Generic[TypeAgentInputs, TypeAgentEvent]):
def __init__( def __init__(
self, self,
*, *,
agent_factory: TypeAgentFactory, agent: ReActAgent,
tools_registry: BaseToolsRegistry, tools_registry: BaseToolsRegistry,
internal_event_client: InternalEventsClient, internal_event_client: InternalEventsClient,
): ):
self.agent_factory = agent_factory self.agent = agent
self.tools_registry = tools_registry self.tools_registry = tools_registry
self.internal_event_client = internal_event_client self.internal_event_client = internal_event_client
self._tools: list[BaseTool] | None = None self._tools: list[BaseTool] | None = None
...@@ -69,7 +58,6 @@ class GLAgentRemoteExecutor(Generic[TypeAgentInputs, TypeAgentEvent]): ...@@ -69,7 +58,6 @@ class GLAgentRemoteExecutor(Generic[TypeAgentInputs, TypeAgentEvent]):
async def stream(self, *, inputs: TypeAgentInputs) -> AsyncIterator[TypeAgentEvent]: async def stream(self, *, inputs: TypeAgentInputs) -> AsyncIterator[TypeAgentEvent]:
inputs.tools = self.tools inputs.tools = self.tools
agent: ReActAgent = self.agent_factory(model_metadata=inputs.model_metadata)
tools_by_name = self.tools_by_name tools_by_name = self.tools_by_name
...@@ -79,7 +67,7 @@ class GLAgentRemoteExecutor(Generic[TypeAgentInputs, TypeAgentEvent]): ...@@ -79,7 +67,7 @@ class GLAgentRemoteExecutor(Generic[TypeAgentInputs, TypeAgentEvent]):
log.info("Processed inputs", source=__name__, inputs=inputs) log.info("Processed inputs", source=__name__, inputs=inputs)
async for event in agent.astream(inputs): async for event in self.agent.astream(inputs):
if isinstance(event, AgentToolAction): if isinstance(event, AgentToolAction):
if event.tool in tools_by_name: if event.tool in tools_by_name:
tool = tools_by_name[event.tool] tool = tools_by_name[event.tool]
......
...@@ -7,6 +7,7 @@ from ai_gateway.code_suggestions.processing.base import LANGUAGE_COUNTER ...@@ -7,6 +7,7 @@ from ai_gateway.code_suggestions.processing.base import LANGUAGE_COUNTER
from ai_gateway.code_suggestions.processing.ops import ( from ai_gateway.code_suggestions.processing.ops import (
lang_from_editor_lang, lang_from_editor_lang,
lang_from_filename, lang_from_filename,
lang_name_from_filename,
) )
from ai_gateway.experimentation import ExperimentTelemetry from ai_gateway.experimentation import ExperimentTelemetry
from ai_gateway.models import ( from ai_gateway.models import (
...@@ -55,6 +56,7 @@ USE_CASES_MODELS_MAP = { ...@@ -55,6 +56,7 @@ USE_CASES_MODELS_MAP = {
KindAnthropicModel.CLAUDE_3_5_SONNET, KindAnthropicModel.CLAUDE_3_5_SONNET,
KindAnthropicModel.CLAUDE_2_1, KindAnthropicModel.CLAUDE_2_1,
KindVertexTextModel.CODE_GECKO_002, KindVertexTextModel.CODE_GECKO_002,
KindVertexTextModel.CODESTRAL_2501,
KindLiteLlmModel.CODEGEMMA, KindLiteLlmModel.CODEGEMMA,
KindLiteLlmModel.CODELLAMA, KindLiteLlmModel.CODELLAMA,
KindLiteLlmModel.CODESTRAL, KindLiteLlmModel.CODESTRAL,
...@@ -96,6 +98,10 @@ SAAS_PROMPT_MODEL_MAP = { ...@@ -96,6 +98,10 @@ SAAS_PROMPT_MODEL_MAP = {
"model_provider": ModelProvider.ANTHROPIC, "model_provider": ModelProvider.ANTHROPIC,
"model_version": KindAnthropicModel.CLAUDE_3_5_SONNET, "model_version": KindAnthropicModel.CLAUDE_3_5_SONNET,
}, },
"3.0.2-dev": {
"model_provider": ModelProvider.ANTHROPIC,
"model_version": KindAnthropicModel.CLAUDE_3_7_SONNET,
},
"2.0.0": { "2.0.0": {
"model_provider": ModelProvider.VERTEX_AI, "model_provider": ModelProvider.VERTEX_AI,
"model_version": KindAnthropicModel.CLAUDE_3_5_SONNET, "model_version": KindAnthropicModel.CLAUDE_3_5_SONNET,
...@@ -138,6 +144,12 @@ def resolve_lang_id( ...@@ -138,6 +144,12 @@ def resolve_lang_id(
return lang_id return lang_id
def resolve_lang_name(file_name: str) -> Optional[str]:
lang_name = lang_name_from_filename(file_name)
return lang_name
# TODO: https://gitlab.com/gitlab-org/modelops/applied-ml/code-suggestions/ai-assist/-/issues/292 # TODO: https://gitlab.com/gitlab-org/modelops/applied-ml/code-suggestions/ai-assist/-/issues/292
def increment_lang_counter( def increment_lang_counter(
filename: str, filename: str,
......
...@@ -9,6 +9,7 @@ from ai_gateway.code_suggestions.base import ( ...@@ -9,6 +9,7 @@ from ai_gateway.code_suggestions.base import (
LanguageId, LanguageId,
increment_lang_counter, increment_lang_counter,
resolve_lang_id, resolve_lang_id,
resolve_lang_name,
) )
from ai_gateway.code_suggestions.processing import ( from ai_gateway.code_suggestions.processing import (
ModelEngineCompletions, ModelEngineCompletions,
...@@ -201,13 +202,16 @@ class CodeCompletions: ...@@ -201,13 +202,16 @@ class CodeCompletions:
prompt.prefix, stream=stream, **kwargs prompt.prefix, stream=stream, **kwargs
) )
elif isinstance(self.model, AmazonQModel): elif isinstance(self.model, AmazonQModel):
res = await self.model.generate( if lang := (editor_lang or resolve_lang_name(file_name)):
prompt.prefix, res = await self.model.generate(
prompt.suffix, prompt.prefix,
file_name, prompt.suffix,
lang_id.name.lower(), file_name,
**kwargs, lang.lower(),
) **kwargs,
)
else:
res = None
else: else:
res = await self.model.generate( res = await self.model.generate(
prompt.prefix, prompt.suffix, stream, **kwargs prompt.prefix, prompt.suffix, stream, **kwargs
......
...@@ -178,6 +178,23 @@ class ContainerCodeCompletions(containers.DeclarativeContainer): ...@@ -178,6 +178,23 @@ class ContainerCodeCompletions(containers.DeclarativeContainer):
).provider, ).provider,
) )
litellm_vertex_codestral_factory = providers.Factory(
CodeCompletions,
model=providers.Factory(
litellm,
name=KindVertexTextModel.CODESTRAL_2501,
provider=KindModelProvider.VERTEX_AI,
),
tokenization_strategy=providers.Factory(
TokenizerTokenStrategy, tokenizer=tokenizer
),
post_processor=providers.Factory(
PostProcessorCompletions,
extras=[PostProcessorOperation.STRIP_ASTERISKS],
exclude=config.excl_post_process,
).provider,
)
agent_factory = providers.Factory( agent_factory = providers.Factory(
CodeCompletions, CodeCompletions,
model=providers.Factory(agent_model), model=providers.Factory(agent_model),
......
...@@ -7,6 +7,7 @@ from ai_gateway.code_suggestions.base import ( ...@@ -7,6 +7,7 @@ from ai_gateway.code_suggestions.base import (
ModelProvider, ModelProvider,
increment_lang_counter, increment_lang_counter,
resolve_lang_id, resolve_lang_id,
resolve_lang_name,
) )
from ai_gateway.code_suggestions.processing import LanguageId, Prompt, TokenStrategyBase from ai_gateway.code_suggestions.processing import LanguageId, Prompt, TokenStrategyBase
from ai_gateway.code_suggestions.processing.post.generations import ( from ai_gateway.code_suggestions.processing.post.generations import (
...@@ -133,13 +134,17 @@ class CodeGenerations: ...@@ -133,13 +134,17 @@ class CodeGenerations:
prompt.prefix, stream=stream, **kwargs prompt.prefix, stream=stream, **kwargs
) )
elif isinstance(self.model, AmazonQModel): elif isinstance(self.model, AmazonQModel):
res = await self.model.generate( if lang := (editor_lang or resolve_lang_name(file_name)):
prefix, res = await self.model.generate(
suffix if suffix else "", prefix,
file_name, suffix if suffix else "",
lang_id.name.lower(), file_name,
**kwargs, lang.lower(),
) **kwargs,
)
else:
res = None
else: else:
res = await self.model.generate( res = await self.model.generate(
prompt.prefix, "", stream=stream, **kwargs prompt.prefix, "", stream=stream, **kwargs
......
...@@ -102,6 +102,10 @@ _EDITOR_LANG_TO_LANG_ID = { ...@@ -102,6 +102,10 @@ _EDITOR_LANG_TO_LANG_ID = {
name: language.lang_id for language in _ALL_LANGS for name in language.editor_names name: language.lang_id for language in _ALL_LANGS for name in language.editor_names
} }
_EXTENSION_TO_LANG_NAME = {
ext: language.grammar_name for language in _ALL_LANGS for ext in language.extensions
}
# A new line with a non-indented letter or comment (/*, #, //) # A new line with a non-indented letter or comment (/*, #, //)
_END_OF_CODE_BLOCK_REGEX = re.compile(r"\n([a-zA-Z]|(\/\*)|(#)|(\/\/))") _END_OF_CODE_BLOCK_REGEX = re.compile(r"\n([a-zA-Z]|(\/\*)|(#)|(\/\/))")
...@@ -168,6 +172,11 @@ def lang_from_filename(file_name: Union[str, Path]) -> Optional[LanguageId]: ...@@ -168,6 +172,11 @@ def lang_from_filename(file_name: Union[str, Path]) -> Optional[LanguageId]:
return _EXTENSION_TO_LANG_ID.get(ext, None) return _EXTENSION_TO_LANG_ID.get(ext, None)
def lang_name_from_filename(file_name: Union[str, Path]) -> Optional[str]:
ext = Path(file_name).suffix.replace(".", "")
return _EXTENSION_TO_LANG_NAME.get(ext, None)
def lang_from_editor_lang(editor_lang: str) -> Optional[LanguageId]: def lang_from_editor_lang(editor_lang: str) -> Optional[LanguageId]:
return _EDITOR_LANG_TO_LANG_ID.get(editor_lang, None) return _EDITOR_LANG_TO_LANG_ID.get(editor_lang, None)
......
...@@ -10,6 +10,8 @@ class FeatureFlag(StrEnum): ...@@ -10,6 +10,8 @@ class FeatureFlag(StrEnum):
# Definition: https://gitlab.com/gitlab-org/gitlab/-/blob/master/config/feature_flags/ops/expanded_ai_logging.yml # Definition: https://gitlab.com/gitlab-org/gitlab/-/blob/master/config/feature_flags/ops/expanded_ai_logging.yml
EXPANDED_AI_LOGGING = "expanded_ai_logging" EXPANDED_AI_LOGGING = "expanded_ai_logging"
ENABLE_ANTHROPIC_PROMPT_CACHING = "enable_anthropic_prompt_caching" ENABLE_ANTHROPIC_PROMPT_CACHING = "enable_anthropic_prompt_caching"
DISABLE_CODE_GECKO_DEFAULT = "disable_code_gecko_default"
DUO_CHAT_REACT_AGENT_CLAUDE_3_7 = "duo_chat_react_agent_claude_3_7"
def is_feature_enabled(feature_name: FeatureFlag | str) -> bool: def is_feature_enabled(feature_name: FeatureFlag | str) -> bool:
......