Skip to content
Snippets Groups Projects
Commit 1e2065f5 authored by Britney Tong's avatar Britney Tong
Browse files

Merge remote-tracking branch 'public/id-merge-upstream-to-q-integration' into q-integration-bt

parents d1b02061 aa2cfebf
No related branches found
No related tags found
No related merge requests found
Showing
with 170 additions and 43 deletions
...@@ -61,13 +61,13 @@ include: ...@@ -61,13 +61,13 @@ include:
inputs: inputs:
runway_service_id: ai-gateway runway_service_id: ai-gateway
image: "$CI_REGISTRY_IMAGE/model-gateway:${CI_COMMIT_SHORT_SHA}" image: "$CI_REGISTRY_IMAGE/model-gateway:${CI_COMMIT_SHORT_SHA}"
runway_version: v3.55.3 runway_version: v3.58.0
- project: "gitlab-com/gl-infra/platform/runway/runwayctl" - project: "gitlab-com/gl-infra/platform/runway/runwayctl"
file: "ci-tasks/service-project/runway.yml" file: "ci-tasks/service-project/runway.yml"
inputs: inputs:
runway_service_id: ai-gateway-custom runway_service_id: ai-gateway-custom
image: "$SELF_HOSTED_TARGET_IMAGE" image: "$SELF_HOSTED_TARGET_IMAGE"
runway_version: v3.55.3 runway_version: v3.58.0
- component: ${CI_SERVER_FQDN}/gitlab-org/components/danger-review/danger-review@2.0.0 - component: ${CI_SERVER_FQDN}/gitlab-org/components/danger-review/danger-review@2.0.0
rules: rules:
- if: $CI_SERVER_HOST == "gitlab.com" - if: $CI_SERVER_HOST == "gitlab.com"
......
include: include:
# see https://gitlab.com/gitlab-com/gl-infra/common-ci-tasks/-/blob/main/oidc.md # see https://gitlab.com/gitlab-com/gl-infra/common-ci-tasks/-/blob/main/oidc.md
- project: 'gitlab-com/gl-infra/common-ci-tasks' - project: 'gitlab-com/gl-infra/common-ci-tasks'
ref: v2.57.3 # renovate:managed ref: v2.61 # renovate:managed
file: 'oidc.yml' file: 'oidc.yml'
.ingest-base: .ingest-base:
......
...@@ -16,6 +16,8 @@ spec: ...@@ -16,6 +16,8 @@ spec:
- europe-west3 - europe-west3
- asia-northeast1 - asia-northeast1
- europe-west9 - europe-west9
deployment:
strategy: "expedited"
request_timeout: 60 request_timeout: 60
observability: observability:
scrape_targets: scrape_targets:
......
lefthook 1.10.10 # datasource=github-releases depName=evilmartians/lefthook lefthook 1.11.0 # datasource=github-releases depName=evilmartians/lefthook
python 3.11.11 # datasource=github-tags depName=python/cpython python 3.11.11 # datasource=github-tags depName=python/cpython
gcloud 428.0.0 # datasource=github-tags depName=GoogleCloudPlatform/cloud-sdk-docker gcloud 428.0.0 # datasource=github-tags depName=GoogleCloudPlatform/cloud-sdk-docker
poetry 2.0.1 # datasource=pypi depName=poetry poetry 2.0.1 # datasource=pypi depName=poetry
......
# Guidelines for contributing to the project
WIP
...@@ -33,6 +33,7 @@ from ai_gateway.async_dependency_resolver import ( ...@@ -33,6 +33,7 @@ from ai_gateway.async_dependency_resolver import (
get_code_suggestions_completions_anthropic_provider, get_code_suggestions_completions_anthropic_provider,
get_code_suggestions_completions_fireworks_qwen_factory_provider, get_code_suggestions_completions_fireworks_qwen_factory_provider,
get_code_suggestions_completions_litellm_factory_provider, get_code_suggestions_completions_litellm_factory_provider,
get_code_suggestions_completions_litellm_vertex_codestral_factory_provider,
get_code_suggestions_completions_vertex_legacy_provider, get_code_suggestions_completions_vertex_legacy_provider,
get_code_suggestions_generations_agent_factory_provider, get_code_suggestions_generations_agent_factory_provider,
get_code_suggestions_generations_anthropic_chat_factory_provider, get_code_suggestions_generations_anthropic_chat_factory_provider,
...@@ -54,11 +55,16 @@ from ai_gateway.code_suggestions.base import CodeSuggestionsOutput ...@@ -54,11 +55,16 @@ from ai_gateway.code_suggestions.base import CodeSuggestionsOutput
from ai_gateway.code_suggestions.processing.base import ModelEngineOutput from ai_gateway.code_suggestions.processing.base import ModelEngineOutput
from ai_gateway.code_suggestions.processing.ops import lang_from_filename from ai_gateway.code_suggestions.processing.ops import lang_from_filename
from ai_gateway.config import Config from ai_gateway.config import Config
from ai_gateway.feature_flags.context import current_feature_flag_context from ai_gateway.feature_flags.context import (
FeatureFlag,
current_feature_flag_context,
is_feature_enabled,
)
from ai_gateway.instrumentators.base import TelemetryInstrumentator from ai_gateway.instrumentators.base import TelemetryInstrumentator
from ai_gateway.internal_events import InternalEventsClient from ai_gateway.internal_events import InternalEventsClient
from ai_gateway.models import KindAnthropicModel, KindModelProvider from ai_gateway.models import KindAnthropicModel, KindModelProvider
from ai_gateway.models.base import TokensConsumptionMetadata from ai_gateway.models.base import TokensConsumptionMetadata
from ai_gateway.models.vertex_text import KindVertexTextModel
from ai_gateway.prompts import BasePromptRegistry from ai_gateway.prompts import BasePromptRegistry
from ai_gateway.prompts.typing import ModelMetadata from ai_gateway.prompts.typing import ModelMetadata
from ai_gateway.structured_logging import get_request_logger from ai_gateway.structured_logging import get_request_logger
...@@ -108,6 +114,9 @@ async def completions( ...@@ -108,6 +114,9 @@ async def completions(
completions_amazon_q_factory: Factory[CodeCompletions] = Depends( completions_amazon_q_factory: Factory[CodeCompletions] = Depends(
get_code_suggestions_completions_amazon_q_factory_provider get_code_suggestions_completions_amazon_q_factory_provider
), ),
completions_litellm_vertex_codestral_factory: Factory[CodeCompletions] = Depends(
get_code_suggestions_completions_litellm_vertex_codestral_factory_provider
),
completions_agent_factory: Factory[CodeCompletions] = Depends( completions_agent_factory: Factory[CodeCompletions] = Depends(
get_code_suggestions_completions_agent_factory_provider get_code_suggestions_completions_agent_factory_provider
), ),
...@@ -127,6 +136,7 @@ async def completions( ...@@ -127,6 +136,7 @@ async def completions(
completions_fireworks_qwen_factory, completions_fireworks_qwen_factory,
completions_agent_factory, completions_agent_factory,
completions_amazon_q_factory, completions_amazon_q_factory,
completions_litellm_vertex_codestral_factory,
internal_event_client, internal_event_client,
) )
...@@ -465,6 +475,7 @@ def _build_code_completions( ...@@ -465,6 +475,7 @@ def _build_code_completions(
completions_fireworks_qwen_factory: Factory[CodeCompletions], completions_fireworks_qwen_factory: Factory[CodeCompletions],
completions_agent_factory: Factory[CodeCompletions], completions_agent_factory: Factory[CodeCompletions],
completions_amazon_q_factory: Factory[CodeCompletions], completions_amazon_q_factory: Factory[CodeCompletions],
completions_litellm_vertex_codestral_factory: Factory[CodeCompletions],
internal_event_client: InternalEventsClient, internal_event_client: InternalEventsClient,
) -> tuple[CodeCompletions | CodeCompletionsLegacy, dict]: ) -> tuple[CodeCompletions | CodeCompletionsLegacy, dict]:
kwargs = {} kwargs = {}
...@@ -491,7 +502,10 @@ def _build_code_completions( ...@@ -491,7 +502,10 @@ def _build_code_completions(
) )
return code_completions, kwargs return code_completions, kwargs
elif payload.model_provider == KindModelProvider.FIREWORKS: elif payload.model_provider == KindModelProvider.FIREWORKS or (
not _allow_vertex_codestral()
and is_feature_enabled(FeatureFlag.DISABLE_CODE_GECKO_DEFAULT)
):
FireworksHandler(payload, request, kwargs).update_completion_params() FireworksHandler(payload, request, kwargs).update_completion_params()
code_completions = _resolve_code_completions_litellm( code_completions = _resolve_code_completions_litellm(
payload=payload, payload=payload,
...@@ -509,6 +523,35 @@ def _build_code_completions( ...@@ -509,6 +523,35 @@ def _build_code_completions(
model__current_user=current_user, model__current_user=current_user,
model__role_arn=payload.role_arn, model__role_arn=payload.role_arn,
) )
elif (
(
(
payload.model_provider == KindModelProvider.VERTEX_AI
and payload.model_name == KindVertexTextModel.CODESTRAL_2501
)
or is_feature_enabled(FeatureFlag.DISABLE_CODE_GECKO_DEFAULT)
)
# Codestral is currently not supported in asia-* locations
and _allow_vertex_codestral()
):
code_completions = _resolve_code_completions_vertex_codestral(
payload=payload,
completions_litellm_vertex_codestral_factory=completions_litellm_vertex_codestral_factory,
)
# We need to pass this here since litellm.LiteLlmTextGenModel
# sets the default temperature and max_output_tokens in the `generate` function signature
# To override those values, the kwargs passed to `generate` is updated here
# For further details, see:
# https://gitlab.com/gitlab-org/modelops/applied-ml/code-suggestions/ai-assist/-/merge_requests/1172#note_2060587592
#
# The temperature value is taken from Mistral's docs: https://docs.mistral.ai/api/#operation/createFIMCompletion
# context_max_percent is set to 0.3 to limit the amount of context right now because latency increases with larger context
kwargs.update(
{"temperature": 0.7, "max_output_tokens": 64, "context_max_percent": 0.3}
)
if payload.context:
kwargs.update({"code_context": [ctx.content for ctx in payload.context]})
else: else:
code_completions = completions_legacy_factory() code_completions = completions_legacy_factory()
LegacyHandler(payload, request, kwargs).update_completion_params() LegacyHandler(payload, request, kwargs).update_completion_params()
...@@ -529,6 +572,19 @@ def _build_code_completions( ...@@ -529,6 +572,19 @@ def _build_code_completions(
return code_completions, kwargs return code_completions, kwargs
def _resolve_code_completions_vertex_codestral(
payload: SuggestionsRequest,
completions_litellm_vertex_codestral_factory: Factory[CodeCompletions],
) -> CodeCompletions:
if payload.prompt_version == 2 and payload.prompt is not None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="You cannot specify a prompt with the given provider and model combination",
)
return completions_litellm_vertex_codestral_factory()
def _resolve_agent_code_completions( def _resolve_agent_code_completions(
model_metadata: ModelMetadata, model_metadata: ModelMetadata,
current_user: StarletteUser, current_user: StarletteUser,
...@@ -587,6 +643,14 @@ def _generation_suggestion_choices(text: str) -> list: ...@@ -587,6 +643,14 @@ def _generation_suggestion_choices(text: str) -> list:
return [SuggestionsResponse.Choice(text=text)] if text else [] return [SuggestionsResponse.Choice(text=text)] if text else []
def _allow_vertex_codestral():
return not _get_gcp_location().startswith("asia-")
def _get_gcp_location():
return Config().google_cloud_platform.location
async def _handle_stream( async def _handle_stream(
response: AsyncIterator[CodeSuggestionsChunk], response: AsyncIterator[CodeSuggestionsChunk],
) -> StreamSuggestionsResponse: ) -> StreamSuggestionsResponse:
......
...@@ -3,6 +3,7 @@ from fastapi import Request ...@@ -3,6 +3,7 @@ from fastapi import Request
from ai_gateway.api.middleware import X_GITLAB_LANGUAGE_SERVER_VERSION from ai_gateway.api.middleware import X_GITLAB_LANGUAGE_SERVER_VERSION
from ai_gateway.api.v2.code.typing import CompletionsRequestWithVersion from ai_gateway.api.v2.code.typing import CompletionsRequestWithVersion
from ai_gateway.code_suggestions.language_server import LanguageServerVersion from ai_gateway.code_suggestions.language_server import LanguageServerVersion
from ai_gateway.models.base import KindModelProvider
class BaseModelProviderHandler: class BaseModelProviderHandler:
...@@ -47,6 +48,12 @@ class FireworksHandler(BaseModelProviderHandler): ...@@ -47,6 +48,12 @@ class FireworksHandler(BaseModelProviderHandler):
if self.payload.context: if self.payload.context:
self._update_code_context() self._update_code_context()
if not self.payload.model_provider:
self.payload.model_provider = KindModelProvider.FIREWORKS
if not self.payload.model_name:
self.payload.model_name = "qwen2p5-coder-7b"
class LegacyHandler(BaseModelProviderHandler): class LegacyHandler(BaseModelProviderHandler):
def update_completion_params(self): def update_completion_params(self):
......
from time import time from time import time
from typing import Annotated, AsyncIterator, Optional from typing import Annotated, AsyncIterator
from dependency_injector.providers import Factory from dependency_injector.providers import Factory
from dependency_injector.wiring import Provide, inject from dependency_injector.wiring import Provide, inject
...@@ -61,7 +61,7 @@ async def get_prompt_registry(): ...@@ -61,7 +61,7 @@ async def get_prompt_registry():
async def handle_stream( async def handle_stream(
stream: AsyncIterator[CodeSuggestionsChunk], stream: AsyncIterator[CodeSuggestionsChunk],
engine: StreamModelEngine, metadata: ResponseMetadataBase,
) -> StreamSuggestionsResponse: ) -> StreamSuggestionsResponse:
async def _stream_response_generator(): async def _stream_response_generator():
async for chunk in stream: async for chunk in stream:
...@@ -154,6 +154,7 @@ async def code_completion( ...@@ -154,6 +154,7 @@ async def code_completion(
payload: EditorContentCompletionPayload, payload: EditorContentCompletionPayload,
current_user: StarletteUser, current_user: StarletteUser,
stream_handler: StreamHandler, stream_handler: StreamHandler,
snowplow_event_context: SnowplowEventContext,
completions_legacy_factory: Factory[CodeCompletionsLegacy] = Provide[ completions_legacy_factory: Factory[CodeCompletionsLegacy] = Provide[
ContainerApplication.code_suggestions.completions.vertex_legacy.provider ContainerApplication.code_suggestions.completions.vertex_legacy.provider
], ],
...@@ -164,7 +165,6 @@ async def code_completion( ...@@ -164,7 +165,6 @@ async def code_completion(
ContainerApplication.code_suggestions.completions.amazon_q_factory.provider ContainerApplication.code_suggestions.completions.amazon_q_factory.provider
], ],
code_context: list[CodeContextPayload] = None, code_context: list[CodeContextPayload] = None,
snowplow_event_context: Optional[SnowplowEventContext] = None,
): ):
kwargs = {} kwargs = {}
...@@ -221,7 +221,8 @@ async def code_completion( ...@@ -221,7 +221,8 @@ async def code_completion(
suggestions = [suggestions] suggestions = [suggestions]
if isinstance(suggestions[0], AsyncIterator): if isinstance(suggestions[0], AsyncIterator):
return await stream_handler(suggestions[0], engine) stream_metadata = _get_stream_metadata(engine, snowplow_event_context)
return await stream_handler(suggestions[0], stream_metadata)
return CompletionResponse( return CompletionResponse(
choices=_completion_suggestion_choices(suggestions), choices=_completion_suggestion_choices(suggestions),
...@@ -264,6 +265,7 @@ async def code_generation( ...@@ -264,6 +265,7 @@ async def code_generation(
current_user: StarletteUser, current_user: StarletteUser,
prompt_registry: BasePromptRegistry, prompt_registry: BasePromptRegistry,
stream_handler: StreamHandler, stream_handler: StreamHandler,
snowplow_event_context: SnowplowEventContext,
generations_vertex_factory: Factory[CodeGenerations] = Provide[ generations_vertex_factory: Factory[CodeGenerations] = Provide[
ContainerApplication.code_suggestions.generations.vertex.provider ContainerApplication.code_suggestions.generations.vertex.provider
], ],
...@@ -277,7 +279,6 @@ async def code_generation( ...@@ -277,7 +279,6 @@ async def code_generation(
ContainerApplication.code_suggestions.generations.amazon_q_factory.provider ContainerApplication.code_suggestions.generations.amazon_q_factory.provider
], ],
code_context: list[CodeContextPayload] = None, code_context: list[CodeContextPayload] = None,
snowplow_event_context: Optional[SnowplowEventContext] = None,
): ):
model_provider = payload.model_provider model_provider = payload.model_provider
if model_provider == KindModelProvider.AMAZON_Q: if model_provider == KindModelProvider.AMAZON_Q:
...@@ -351,7 +352,8 @@ async def code_generation( ...@@ -351,7 +352,8 @@ async def code_generation(
# Handle streaming case # Handle streaming case
if isinstance(suggestions[0], AsyncIterator): if isinstance(suggestions[0], AsyncIterator):
return await stream_handler(suggestions[0], engine) stream_metadata = _get_stream_metadata(engine, snowplow_event_context)
return await stream_handler(suggestions[0], stream_metadata)
return CompletionResponse( return CompletionResponse(
choices=_completion_suggestion_choices(suggestions), choices=_completion_suggestion_choices(suggestions),
...@@ -371,3 +373,18 @@ def _create_response_metadata(model, lang, timestamp): ...@@ -371,3 +373,18 @@ def _create_response_metadata(model, lang, timestamp):
), ),
enabled_feature_flags=current_feature_flag_context.get(), enabled_feature_flags=current_feature_flag_context.get(),
) )
def _get_stream_metadata(
engine: StreamModelEngine,
snowplow_event_context: SnowplowEventContext,
) -> ResponseMetadataBase:
return ResponseMetadataBase(
timestamp=int(time()),
model=ModelMetadata(
engine=engine.model.metadata.engine,
name=engine.model.metadata.name,
),
enabled_feature_flags=current_feature_flag_context.get(),
region=snowplow_event_context.region,
)
...@@ -133,6 +133,7 @@ class ResponseMetadataBase(BaseModel): ...@@ -133,6 +133,7 @@ class ResponseMetadataBase(BaseModel):
model: Optional[ModelMetadata] = None model: Optional[ModelMetadata] = None
timestamp: int timestamp: int
enabled_feature_flags: Optional[list[str]] = None enabled_feature_flags: Optional[list[str]] = None
region: Optional[str] = None
class CompletionResponse(BaseModel): class CompletionResponse(BaseModel):
...@@ -157,6 +158,6 @@ class StreamHandler(Protocol): ...@@ -157,6 +158,6 @@ class StreamHandler(Protocol):
async def __call__( async def __call__(
self, self,
stream: AsyncIterator[CodeSuggestionsChunk], stream: AsyncIterator[CodeSuggestionsChunk],
engine: StreamModelEngine, metadata: ResponseMetadataBase,
) -> Union[StreamSuggestionsResponse, EventSourceResponse]: ) -> Union[StreamSuggestionsResponse, EventSourceResponse]:
pass pass
from time import time
from typing import Annotated, AsyncIterator from typing import Annotated, AsyncIterator
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request
...@@ -8,12 +7,7 @@ from sse_starlette.sse import EventSourceResponse ...@@ -8,12 +7,7 @@ from sse_starlette.sse import EventSourceResponse
from ai_gateway.api.auth_utils import StarletteUser, get_current_user from ai_gateway.api.auth_utils import StarletteUser, get_current_user
from ai_gateway.api.feature_category import feature_category from ai_gateway.api.feature_category import feature_category
from ai_gateway.api.v3.code.completions import code_suggestions as v3_code_suggestions from ai_gateway.api.v3.code.completions import code_suggestions as v3_code_suggestions
from ai_gateway.api.v3.code.typing import ( from ai_gateway.api.v3.code.typing import CompletionRequest, ResponseMetadataBase
CompletionRequest,
ModelMetadata,
ResponseMetadataBase,
StreamModelEngine,
)
from ai_gateway.api.v4.code.typing import ( from ai_gateway.api.v4.code.typing import (
StreamDelta, StreamDelta,
StreamEvent, StreamEvent,
...@@ -23,7 +17,6 @@ from ai_gateway.api.v4.code.typing import ( ...@@ -23,7 +17,6 @@ from ai_gateway.api.v4.code.typing import (
from ai_gateway.async_dependency_resolver import get_config, get_container_application from ai_gateway.async_dependency_resolver import get_config, get_container_application
from ai_gateway.code_suggestions import CodeSuggestionsChunk from ai_gateway.code_suggestions import CodeSuggestionsChunk
from ai_gateway.config import Config from ai_gateway.config import Config
from ai_gateway.feature_flags.context import current_feature_flag_context
from ai_gateway.prompts import BasePromptRegistry from ai_gateway.prompts import BasePromptRegistry
__all__ = [ __all__ = [
...@@ -39,23 +32,14 @@ async def get_prompt_registry(): ...@@ -39,23 +32,14 @@ async def get_prompt_registry():
async def handle_stream_sse( async def handle_stream_sse(
stream: AsyncIterator[CodeSuggestionsChunk], stream: AsyncIterator[CodeSuggestionsChunk],
engine: StreamModelEngine, metadata: ResponseMetadataBase,
) -> EventSourceResponse: ) -> EventSourceResponse:
async def _stream_response_generator(): async def _stream_response_generator():
def _start_message(): def _start_message():
# To minimize redundancy, we're only sending metadata in the first SSE message. # To minimize redundancy, we're only sending metadata in the first SSE message.
return StreamSSEMessage( return StreamSSEMessage(
event=StreamEvent.START, event=StreamEvent.START,
data={ data={"metadata": metadata.model_dump(exclude_none=True)},
"metadata": ResponseMetadataBase(
timestamp=int(time()),
model=ModelMetadata(
engine=engine.model.metadata.engine,
name=engine.model.metadata.name,
),
enabled_feature_flags=current_feature_flag_context.get(),
).model_dump(exclude_none=True)
},
).dump_with_json_data() ).dump_with_json_data()
def _content_message(chunk): def _content_message(chunk):
......
...@@ -64,6 +64,10 @@ async def get_code_suggestions_completions_amazon_q_factory_provider(): ...@@ -64,6 +64,10 @@ async def get_code_suggestions_completions_amazon_q_factory_provider():
yield get_container_application().code_suggestions.completions.amazon_q_factory yield get_container_application().code_suggestions.completions.amazon_q_factory
async def get_code_suggestions_completions_litellm_vertex_codestral_factory_provider():
yield get_container_application().code_suggestions.completions.litellm_vertex_codestral_factory
async def get_code_suggestions_completions_agent_factory_provider(): async def get_code_suggestions_completions_agent_factory_provider():
yield get_container_application().code_suggestions.completions.agent_factory yield get_container_application().code_suggestions.completions.agent_factory
......
...@@ -55,6 +55,7 @@ USE_CASES_MODELS_MAP = { ...@@ -55,6 +55,7 @@ USE_CASES_MODELS_MAP = {
KindAnthropicModel.CLAUDE_3_5_SONNET, KindAnthropicModel.CLAUDE_3_5_SONNET,
KindAnthropicModel.CLAUDE_2_1, KindAnthropicModel.CLAUDE_2_1,
KindVertexTextModel.CODE_GECKO_002, KindVertexTextModel.CODE_GECKO_002,
KindVertexTextModel.CODESTRAL_2501,
KindLiteLlmModel.CODEGEMMA, KindLiteLlmModel.CODEGEMMA,
KindLiteLlmModel.CODELLAMA, KindLiteLlmModel.CODELLAMA,
KindLiteLlmModel.CODESTRAL, KindLiteLlmModel.CODESTRAL,
......
...@@ -178,6 +178,23 @@ class ContainerCodeCompletions(containers.DeclarativeContainer): ...@@ -178,6 +178,23 @@ class ContainerCodeCompletions(containers.DeclarativeContainer):
).provider, ).provider,
) )
litellm_vertex_codestral_factory = providers.Factory(
CodeCompletions,
model=providers.Factory(
litellm,
name=KindVertexTextModel.CODESTRAL_2501,
provider=KindModelProvider.VERTEX_AI,
),
tokenization_strategy=providers.Factory(
TokenizerTokenStrategy, tokenizer=tokenizer
),
post_processor=providers.Factory(
PostProcessorCompletions,
extras=[PostProcessorOperation.STRIP_ASTERISKS],
exclude=config.excl_post_process,
).provider,
)
agent_factory = providers.Factory( agent_factory = providers.Factory(
CodeCompletions, CodeCompletions,
model=providers.Factory(agent_model), model=providers.Factory(agent_model),
......
...@@ -10,6 +10,7 @@ class FeatureFlag(StrEnum): ...@@ -10,6 +10,7 @@ class FeatureFlag(StrEnum):
# Definition: https://gitlab.com/gitlab-org/gitlab/-/blob/master/config/feature_flags/ops/expanded_ai_logging.yml # Definition: https://gitlab.com/gitlab-org/gitlab/-/blob/master/config/feature_flags/ops/expanded_ai_logging.yml
EXPANDED_AI_LOGGING = "expanded_ai_logging" EXPANDED_AI_LOGGING = "expanded_ai_logging"
ENABLE_ANTHROPIC_PROMPT_CACHING = "enable_anthropic_prompt_caching" ENABLE_ANTHROPIC_PROMPT_CACHING = "enable_anthropic_prompt_caching"
DISABLE_CODE_GECKO_DEFAULT = "disable_code_gecko_default"
def is_feature_enabled(feature_name: FeatureFlag | str) -> bool: def is_feature_enabled(feature_name: FeatureFlag | str) -> bool:
......
...@@ -144,7 +144,7 @@ class ChatAmazonQ(BaseChatModel): ...@@ -144,7 +144,7 @@ class ChatAmazonQ(BaseChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
user: StarletteUser, user: StarletteUser,
role_arn: str, role_arn: str,
conversation_id: str, conversation_id: Optional[str],
**kwargs: Any, **kwargs: Any,
): ):
""" """
......
...@@ -166,7 +166,7 @@ class AmazonQClient: ...@@ -166,7 +166,7 @@ class AmazonQClient:
def _send_message(self, payload): def _send_message(self, payload):
print("DEBUG [AmazonQClient]: _send_message payload", payload) print("DEBUG [AmazonQClient]: _send_message payload", payload)
return self.client.send_message( return self.client.send_message(
message=payload["message"], conversationId=payload["conversation_id"] message=payload["message"]
) )
def _retry_send_event(self, error, code, payload): def _retry_send_event(self, error, code, payload):
......
...@@ -121,7 +121,7 @@ class MessageProcessor: ...@@ -121,7 +121,7 @@ class MessageProcessor:
) )
if messages and system_message.content is not None: if messages and system_message.content is not None:
# Create new content by concatenating strings # Create new content by concatenating strings
new_content = f"{messages[0].content}" new_content = f"{system_message.content}\n{messages[0].content}"
messages[0].content = new_content messages[0].content = new_content
def _extract_content(self, messages: List[BaseMessage]) -> str: def _extract_content(self, messages: List[BaseMessage]) -> str:
......
...@@ -5,6 +5,7 @@ from litellm import CustomStreamWrapper, ModelResponse, acompletion ...@@ -5,6 +5,7 @@ from litellm import CustomStreamWrapper, ModelResponse, acompletion
from litellm.exceptions import APIConnectionError, InternalServerError from litellm.exceptions import APIConnectionError, InternalServerError
from openai import AsyncOpenAI from openai import AsyncOpenAI
from ai_gateway.config import Config
from ai_gateway.models.base import ( from ai_gateway.models.base import (
KindModelProvider, KindModelProvider,
ModelAPIError, ModelAPIError,
...@@ -17,6 +18,7 @@ from ai_gateway.models.base_text import ( ...@@ -17,6 +18,7 @@ from ai_gateway.models.base_text import (
TextGenModelChunk, TextGenModelChunk,
TextGenModelOutput, TextGenModelOutput,
) )
from ai_gateway.models.vertex_text import KindVertexTextModel
from ai_gateway.safety_attributes import SafetyAttributes from ai_gateway.safety_attributes import SafetyAttributes
from ai_gateway.tracking import SnowplowEventContext from ai_gateway.tracking import SnowplowEventContext
...@@ -99,6 +101,9 @@ MODEL_STOP_TOKENS = { ...@@ -99,6 +101,9 @@ MODEL_STOP_TOKENS = {
"<|fim_middle|>", "<|fim_middle|>",
"<|file_separator|>", "<|file_separator|>",
], ],
# Ref: https://docs.litellm.ai/docs/providers/vertex#mistral-api
# This model is served by Vertex AI but accessed through LiteLLM abstraction
KindVertexTextModel.CODESTRAL_2501: ["\n\n", "\n+++++"],
KindLiteLlmModel.QWEN_2_5: [ KindLiteLlmModel.QWEN_2_5: [
"<|fim_prefix|>", "<|fim_prefix|>",
"<|fim_suffix|>", "<|fim_suffix|>",
...@@ -113,6 +118,10 @@ MODEL_STOP_TOKENS = { ...@@ -113,6 +118,10 @@ MODEL_STOP_TOKENS = {
} }
MODEL_SPECIFICATIONS = { MODEL_SPECIFICATIONS = {
KindVertexTextModel.CODESTRAL_2501: {
"timeout": 60,
"completion_type": ModelCompletionType.TEXT,
},
KindLiteLlmModel.QWEN_2_5: { KindLiteLlmModel.QWEN_2_5: {
"timeout": 60, "timeout": 60,
"completion_type": ModelCompletionType.FIM, "completion_type": ModelCompletionType.FIM,
...@@ -245,7 +254,7 @@ class LiteLlmChatModel(ChatModelBase): ...@@ -245,7 +254,7 @@ class LiteLlmChatModel(ChatModelBase):
provider_endpoints: Optional[dict] = None, provider_endpoints: Optional[dict] = None,
async_fireworks_client: Optional[AsyncOpenAI] = None, async_fireworks_client: Optional[AsyncOpenAI] = None,
): ):
if not custom_models_enabled and provider == KindModelProvider.LITELLM: if not custom_models_enabled:
if endpoint is not None or api_key is not None: if endpoint is not None or api_key is not None:
raise ValueError("specifying custom models endpoint is disabled") raise ValueError("specifying custom models endpoint is disabled")
...@@ -402,10 +411,14 @@ class LiteLlmTextGenModel(TextGenModelBase): ...@@ -402,10 +411,14 @@ class LiteLlmTextGenModel(TextGenModelBase):
"top_p": top_p, "top_p": top_p,
"stream": stream, "stream": stream,
"timeout": self.specifications.get("timeout", 30.0), "timeout": self.specifications.get("timeout", 30.0),
"stop": self._get_stop_tokens(), "stop": self._get_stop_tokens(suffix),
} }
completion_args = completion_args | self.model_metadata_to_params() if self._is_vertex():
completion_args["vertex_ai_location"] = self._get_vertex_model_location()
completion_args["model"] = self.metadata.name
else:
completion_args = completion_args | self.model_metadata_to_params()
if self._completion_type() == ModelCompletionType.TEXT: if self._completion_type() == ModelCompletionType.TEXT:
completion_args["suffix"] = suffix completion_args["suffix"] = suffix
...@@ -449,9 +462,18 @@ class LiteLlmTextGenModel(TextGenModelBase): ...@@ -449,9 +462,18 @@ class LiteLlmTextGenModel(TextGenModelBase):
), ),
) )
def _get_stop_tokens(self): def _get_stop_tokens(self, suffix):
return self.stop_tokens return self.stop_tokens
def _is_vertex(self):
return self.provider == KindModelProvider.VERTEX_AI
def _get_vertex_model_location(self):
if Config().vertex_text_model.location.startswith("europe-"):
return "europe-west4"
return "us-central1"
@classmethod @classmethod
def from_model_name( def from_model_name(
cls, cls,
...@@ -466,8 +488,8 @@ class LiteLlmTextGenModel(TextGenModelBase): ...@@ -466,8 +488,8 @@ class LiteLlmTextGenModel(TextGenModelBase):
provider_endpoints: Optional[dict] = None, provider_endpoints: Optional[dict] = None,
async_fireworks_client: Optional[AsyncOpenAI] = None, async_fireworks_client: Optional[AsyncOpenAI] = None,
): ):
if endpoint is not None or api_key is not None: if not custom_models_enabled:
if not custom_models_enabled and provider == KindModelProvider.LITELLM: if endpoint is not None or api_key is not None:
raise ValueError("specifying custom models endpoint is disabled") raise ValueError("specifying custom models endpoint is disabled")
if provider == KindModelProvider.MISTRALAI: if provider == KindModelProvider.MISTRALAI:
...@@ -483,7 +505,10 @@ class LiteLlmTextGenModel(TextGenModelBase): ...@@ -483,7 +505,10 @@ class LiteLlmTextGenModel(TextGenModelBase):
identifier = f"text-completion-openai/{identifier}" identifier = f"text-completion-openai/{identifier}"
try: try:
kind_model = KindLiteLlmModel(name) if provider == KindModelProvider.VERTEX_AI:
kind_model = KindVertexTextModel(name)
else:
kind_model = KindLiteLlmModel(name)
except ValueError: except ValueError:
raise ValueError(f"no model found by the name '{name}'") raise ValueError(f"no model found by the name '{name}'")
......
...@@ -113,7 +113,7 @@ class KindVertexTextModel(StrEnum): ...@@ -113,7 +113,7 @@ class KindVertexTextModel(StrEnum):
TEXTEMBEDDING_GECKO_003 = "textembedding-gecko@003" TEXTEMBEDDING_GECKO_003 = "textembedding-gecko@003"
# Mistral AI # Mistral AI
CODESTRAL_2405 = "codestral@2405" CODESTRAL_2501 = "codestral-2501"
# This method handles the provider prefix transformation for # This method handles the provider prefix transformation for
# Vertex AI models # Vertex AI models
......
...@@ -10,6 +10,7 @@ __all__ = [ ...@@ -10,6 +10,7 @@ __all__ = [
"ChatLiteLLMParams", "ChatLiteLLMParams",
"ChatAmazonQParams", "ChatAmazonQParams",
"ChatAnthropicParams", "ChatAnthropicParams",
"ChatAmazonQParams",
] ]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment