Skip to content
Commits on Source (38)
......@@ -62,7 +62,7 @@ include:
runway_service_id: ai-gateway
image: "$CI_REGISTRY_IMAGE/model-gateway:${CI_COMMIT_SHORT_SHA}"
runway_version: v2.43.0
- component: ${CI_SERVER_FQDN}/gitlab-org/components/danger-review/danger-review@1.2.0
- component: ${CI_SERVER_FQDN}/gitlab-org/components/danger-review/danger-review@1.4.1
rules:
- if: $CI_SERVER_HOST == "gitlab.com"
......
# flake8: noqa
from ai_gateway.agents import chat, container
from ai_gateway.agents import container
from ai_gateway.agents.base import *
from ai_gateway.agents.registry import *
from abc import ABC, abstractmethod
from typing import Any, Optional, Sequence, Tuple, TypeVar
from typing import Annotated, Any, Generic, Optional, Sequence, Tuple, TypeVar
from jinja2 import BaseLoader, Environment
from langchain_core.prompts.chat import MessageLikeRepresentation
from langchain_core.runnables import Runnable, RunnableBinding
from pydantic import BaseModel, Field
from ai_gateway.auth.user import GitLabUser
from ai_gateway.gitlab_features import GitLabUnitPrimitive, WrongUnitPrimitives
__all__ = ["Agent", "BaseAgentRegistry"]
__all__ = [
"Agent",
"BaseAgentRegistry",
"BaseAgentConfig",
"AgentConfig",
"Model",
]
Input = TypeVar("Input")
Output = TypeVar("Output")
# Agents may operate with unit primitives in various ways.
# Basic agents typically use plain strings as unit primitives.
# More sophisticated agents, like Duo Chat, assign unit primitives to specific tools.
# Creating a generic UnitPrimitiveType enables storage of unit primitives in any desired format.
TypeUnitPrimitive = TypeVar("TypeUnitPrimitive")
jinja_env = Environment(loader=BaseLoader())
......@@ -20,6 +33,35 @@ def _format_str(content: str, options: dict[str, Any]) -> str:
return jinja_env.from_string(content).render(options)
class Model(BaseModel):
class Params(BaseModel):
temperature: float
timeout: Annotated[
float | tuple[float, float] | None,
Field(serialization_alias="request_timeout"),
] = None
top_p: float | None = None
top_k: int | None = None
max_tokens: Optional[int] = 2_048
max_retries: Optional[int] = 1
name: str
provider: str
params: Params | None = None
class BaseAgentConfig(BaseModel, Generic[TypeUnitPrimitive]):
name: str
model: Model
unit_primitives: list[TypeUnitPrimitive]
prompt_template: dict[str, str]
stop: list[str] | None = None
class AgentConfig(BaseAgentConfig):
unit_primitives: list[GitLabUnitPrimitive]
class Agent(RunnableBinding[Input, Output]):
name: str
unit_primitives: list[GitLabUnitPrimitive]
......@@ -61,7 +103,8 @@ class BaseAgentRegistry(ABC):
) -> Agent:
agent = self.get(agent_id, options)
if not set(agent.unit_primitives).issubset(user.unit_primitives):
raise WrongUnitPrimitives
for unit_primitive in agent.unit_primitives:
if not user.can(unit_primitive):
raise WrongUnitPrimitives
return agent
# flake8: noqa
from ai_gateway.agents.chat.react import *
from ai_gateway.agents.chat.typing import *
from dependency_injector import containers, providers
from ai_gateway.agents import chat
from ai_gateway.agents.registry import LocalAgentRegistry, ModelProvider
from ai_gateway.agents.registry import LocalAgentRegistry
from ai_gateway.chat import agents as chat
__all__ = [
"ContainerAgents",
......@@ -15,6 +15,5 @@ class ContainerAgents(containers.DeclarativeContainer):
agent_registry = providers.Singleton(
LocalAgentRegistry.from_local_yaml,
model_factories={ModelProvider.ANTHROPIC: _anthropic_claude_fn},
class_overrides={"chat/react": chat.ReActAgent},
)
---
name: Claude 3 ReAct Chat agent
provider: anthropic
model: claude-3-sonnet-20240229
model:
name: claude-3-sonnet-20240229
provider: anthropic
params:
temperature: 0.0
timeout: 0.2
max_tokens: 2_048
max_retries: 1
unit_primitives:
- duo_chat
prompt_template:
......
from enum import Enum
from pathlib import Path
from typing import Any, NamedTuple, Optional, Protocol, Type
from typing import Any, NamedTuple, Optional, Type
import yaml
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable
from ai_gateway.agents.base import Agent, BaseAgentRegistry
from ai_gateway.agents.base import Agent, AgentConfig, BaseAgentRegistry, Model
__all__ = ["LocalAgentRegistry", "ModelProvider"]
class ModelProvider(str, Enum):
ANTHROPIC = "anthropic"
class ModelFactoryType(Protocol):
def __call__(
self, *, model: str, **model_kwargs: Optional[Any]
) -> BaseChatModel: ...
__all__ = ["LocalAgentRegistry", "AgentRegistered"]
class AgentRegistered(NamedTuple):
klass: Type[Agent]
config: dict
config: AgentConfig
class LocalAgentRegistry(BaseAgentRegistry):
......@@ -33,10 +23,8 @@ class LocalAgentRegistry(BaseAgentRegistry):
def __init__(
self,
agents_registered: dict[str, AgentRegistered],
model_factories: dict[ModelProvider, ModelFactoryType],
):
self.agents_registered = agents_registered
self.model_factories = model_factories
def _resolve_id(self, agent_id: str) -> str:
_, _, agent_type = agent_id.partition("/")
......@@ -46,41 +34,27 @@ class LocalAgentRegistry(BaseAgentRegistry):
return f"{agent_id}/{self.key_agent_type_base}"
def _get_model(
self, provider: str, name: str, **kwargs: Optional[Any]
) -> BaseChatModel:
if model_factory := self.model_factories.get(ModelProvider(provider), None):
return model_factory(model=name, **kwargs)
raise ValueError(f"unknown provider `{provider}`.")
def get(self, agent_id: str, options: Optional[dict[str, Any]] = None) -> Any:
agent_id = self._resolve_id(agent_id)
klass, config = self.agents_registered[agent_id]
# TODO: read model parameters such as `temperature`, `top_k`
# and pass them to the model factory via **kwargs.
model: Runnable = self._get_model(
provider=config["provider"],
name=config["model"],
)
model: Runnable = _get_model(config.model)
if "stop" in config:
model = model.bind(stop=config["stop"])
if config.stop:
model = model.bind(stop=config.stop)
messages = klass.build_messages(config["prompt_template"], options or {})
messages = klass.build_messages(config.prompt_template, options or {})
prompt = ChatPromptTemplate.from_messages(messages)
return klass(
name=config["name"],
name=config.name,
chain=prompt | model,
unit_primitives=config["unit_primitives"],
unit_primitives=config.unit_primitives,
)
@classmethod
def from_local_yaml(
cls,
model_factories: dict[ModelProvider, ModelFactoryType],
class_overrides: dict[str, Type[Agent]],
) -> "LocalAgentRegistry":
"""Iterate over all agent definition files matching [usecase]/[type].yml,
......@@ -99,7 +73,21 @@ class LocalAgentRegistry(BaseAgentRegistry):
with open(path, "r") as fp:
klass = class_overrides.get(agent_id, Agent)
agents_registered[agent_id] = AgentRegistered(
klass=klass, config=yaml.safe_load(fp)
klass=klass, config=AgentConfig(**yaml.safe_load(fp))
)
return cls(agents_registered, model_factories)
return cls(agents_registered)
def _get_model(model: Model) -> BaseChatModel:
model_params = (
model.params.model_dump(exclude_none=True, by_alias=True)
if model.params
else {}
)
return ChatLiteLLM(
model=model.name,
custom_llm_provider=model.provider,
**model_params,
)
......@@ -26,9 +26,9 @@ from uvicorn.protocols.utils import get_path_with_query_string
from ai_gateway.api.timing import timing
from ai_gateway.auth import AuthProvider, UserClaims
from ai_gateway.auth.self_signed_jwt import SELF_SIGNED_TOKEN_ISSUER
from ai_gateway.auth.user import GitLabUser
from ai_gateway.instrumentators.base import Telemetry, TelemetryInstrumentator
from ai_gateway.self_signed_jwt.token_authority import SELF_SIGNED_TOKEN_ISSUER
from ai_gateway.tracking.errors import log_exception
__all__ = [
......
......@@ -6,6 +6,7 @@ from dependency_injector.providers import Factory, FactoryAggregate
from fastapi import APIRouter, Depends, HTTPException, Request, status
from ai_gateway.api.feature_category import feature_category
from ai_gateway.api.v1.chat.auth import ChatInvokable, authorize_with_unit_primitive
from ai_gateway.api.v1.chat.typing import (
ChatRequest,
ChatResponse,
......@@ -36,24 +37,35 @@ log = structlog.stdlib.get_logger("chat")
router = APIRouter()
CHAT_INVOKABLES = [
ChatInvokable(name="explain_code", unit_primitive=GitLabUnitPrimitive.DUO_CHAT),
ChatInvokable(name="write_tests", unit_primitive=GitLabUnitPrimitive.DUO_CHAT),
ChatInvokable(name="refactor_code", unit_primitive=GitLabUnitPrimitive.DUO_CHAT),
ChatInvokable(
name="explain_vulnerability",
unit_primitive=GitLabUnitPrimitive.EXPLAIN_VULNERABILITY,
),
# Deprecated. Added for backward compatibility.
# Please, refer to `v2/chat/agent` for additional details.
ChatInvokable(name="agent", unit_primitive=GitLabUnitPrimitive.DUO_CHAT),
]
@router.post("/agent", response_model=ChatResponse, status_code=status.HTTP_200_OK)
@router.post(
"/{chat_invokable}", response_model=ChatResponse, status_code=status.HTTP_200_OK
)
@feature_category(GitLabFeatureCategory.DUO_CHAT)
@authorize_with_unit_primitive("chat_invokable", chat_invokables=CHAT_INVOKABLES)
async def chat(
request: Request,
chat_request: ChatRequest,
chat_invokable: str,
current_user: Annotated[GitLabUser, Depends(get_current_user)],
anthropic_claude_factory: FactoryAggregate = Depends(
get_chat_anthropic_claude_factory_provider
),
litellm_factory: Factory = Depends(get_chat_litellm_factory_provider),
):
if not current_user.can(GitLabUnitPrimitive.DUO_CHAT):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Unauthorized to access duo chat",
)
prompt_component = chat_request.prompt_components[0]
payload = prompt_component.payload
......
import functools
import typing
from fastapi import HTTPException, Request, status
from pydantic import BaseModel
from ai_gateway.gitlab_features import GitLabUnitPrimitive
__all__ = ["ChatInvokable", "authorize_with_unit_primitive"]
class ChatInvokable(BaseModel):
name: str
unit_primitive: GitLabUnitPrimitive
def authorize_with_unit_primitive(
request_param: str, *, chat_invokables: list[ChatInvokable]
):
def decorator(func: typing.Callable) -> typing.Callable:
chat_invokable_by_name = {ci.name: ci for ci in chat_invokables}
@functools.wraps(func)
async def wrapper(
request: Request, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
request_param_val = request.path_params[request_param]
chat_invokable = chat_invokable_by_name.get(request_param_val, None)
if not chat_invokable:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Not found"
)
current_user = request.user
unit_primitive = chat_invokable.unit_primitive
if not current_user.can(unit_primitive):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Unauthorized to access {unit_primitive}",
)
return await func(request, *args, **kwargs)
return wrapper
return decorator
from typing import Annotated
import structlog
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from ai_gateway.api.feature_category import feature_category
from ai_gateway.api.middleware import (
X_GITLAB_GLOBAL_USER_ID_HEADER,
X_GITLAB_REALM_HEADER,
)
from ai_gateway.api.v1.code.typing import Token
from ai_gateway.async_dependency_resolver import get_token_authority
from ai_gateway.auth.self_signed_jwt import SELF_SIGNED_TOKEN_ISSUER, TokenAuthority
from ai_gateway.auth.user import GitLabUser, get_current_user
from ai_gateway.gitlab_features import GitLabFeatureCategory, GitLabUnitPrimitive
from ai_gateway.self_signed_jwt import TokenAuthority
__all__ = [
"router",
]
from ai_gateway.self_signed_jwt.token_authority import SELF_SIGNED_TOKEN_ISSUER
log = structlog.stdlib.get_logger("user_access_token")
......@@ -31,6 +26,12 @@ async def user_access_token(
request: Request,
current_user: Annotated[GitLabUser, Depends(get_current_user)],
token_authority: TokenAuthority = Depends(get_token_authority),
x_gitlab_global_user_id: Annotated[
str, Header()
] = None, # This is the value of X_GITLAB_GLOBAL_USER_ID_HEADER
x_gitlab_realm: Annotated[
str, Header()
] = None, # This is the value of X_GITLAB_REALM_HEADER
):
if not current_user.can(
GitLabUnitPrimitive.CODE_SUGGESTIONS,
......@@ -41,22 +42,22 @@ async def user_access_token(
detail="Unauthorized to create user access token for code suggestions",
)
gitlab_user_id = request.headers.get(X_GITLAB_GLOBAL_USER_ID_HEADER)
if not gitlab_user_id:
if not x_gitlab_global_user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing X-Gitlab-Global-User-Id header",
)
gitlab_realm = request.headers.get(X_GITLAB_REALM_HEADER)
if not gitlab_realm:
if not x_gitlab_realm:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing X-Gitlab-Realm header",
)
try:
token, expires_at = token_authority.encode(gitlab_user_id, gitlab_realm)
token, expires_at = token_authority.encode(
x_gitlab_global_user_id, x_gitlab_realm
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to generate JWT")
......
......@@ -6,10 +6,10 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
from ai_gateway.api.feature_category import feature_category
from ai_gateway.api.v1.x_ray.typing import XRayRequest, XRayResponse
from ai_gateway.async_dependency_resolver import get_x_ray_anthropic_claude
from ai_gateway.auth.self_signed_jwt import SELF_SIGNED_TOKEN_ISSUER
from ai_gateway.auth.user import GitLabUser, get_current_user
from ai_gateway.gitlab_features import GitLabFeatureCategory, GitLabUnitPrimitive
from ai_gateway.models import AnthropicModel
from ai_gateway.self_signed_jwt.token_authority import SELF_SIGNED_TOKEN_ISSUER
__all__ = [
"router",
......
......@@ -4,17 +4,17 @@ import structlog
from fastapi import APIRouter, Depends, HTTPException, Request, status
from starlette.responses import StreamingResponse
from ai_gateway.agents.chat import (
from ai_gateway.api.feature_category import feature_category
from ai_gateway.api.v2.chat.typing import AgentRequest, AgentStreamResponseEvent
from ai_gateway.async_dependency_resolver import get_container_application
from ai_gateway.auth.user import GitLabUser, get_current_user
from ai_gateway.chat.agents import (
AgentStep,
AgentToolAction,
ReActAgentInputs,
ReActAgentToolAction,
TypeReActAgentAction,
)
from ai_gateway.api.feature_category import feature_category
from ai_gateway.api.v2.chat.typing import AgentRequest, AgentStreamResponseEvent
from ai_gateway.async_dependency_resolver import get_container_application
from ai_gateway.auth.user import GitLabUser, get_current_user
from ai_gateway.chat.executor import GLAgentRemoteExecutor
from ai_gateway.gitlab_features import GitLabFeatureCategory, WrongUnitPrimitives
......
......@@ -2,7 +2,7 @@ from typing import Literal, Optional
from pydantic import BaseModel, Field
from ai_gateway.agents.chat import Context, TypeReActAgentAction
from ai_gateway.chat.agents import Context, TypeReActAgentAction
__all__ = [
"ReActAgentScratchpad",
......
......@@ -36,6 +36,7 @@ from ai_gateway.async_dependency_resolver import (
get_code_suggestions_generations_vertex_provider,
get_snowplow_instrumentator,
)
from ai_gateway.auth.self_signed_jwt import SELF_SIGNED_TOKEN_ISSUER
from ai_gateway.auth.user import GitLabUser, get_current_user
from ai_gateway.code_suggestions import (
CodeCompletions,
......@@ -47,7 +48,6 @@ from ai_gateway.code_suggestions.processing.ops import lang_from_filename
from ai_gateway.gitlab_features import GitLabFeatureCategory, GitLabUnitPrimitive
from ai_gateway.instrumentators.base import TelemetryInstrumentator
from ai_gateway.models import KindAnthropicModel, KindModelProvider
from ai_gateway.self_signed_jwt.token_authority import SELF_SIGNED_TOKEN_ISSUER
from ai_gateway.tracking import SnowplowEvent, SnowplowEventContext
from ai_gateway.tracking.errors import log_exception
from ai_gateway.tracking.instrumentator import SnowplowInstrumentator
......
......@@ -18,6 +18,7 @@ from ai_gateway.api.v3.code.typing import (
ResponseMetadataBase,
StreamSuggestionsResponse,
)
from ai_gateway.auth.self_signed_jwt import SELF_SIGNED_TOKEN_ISSUER
from ai_gateway.auth.user import GitLabUser, get_current_user
from ai_gateway.code_suggestions import (
CodeCompletions,
......@@ -29,7 +30,6 @@ from ai_gateway.code_suggestions import (
from ai_gateway.container import ContainerApplication
from ai_gateway.gitlab_features import GitLabFeatureCategory, GitLabUnitPrimitive
from ai_gateway.models import KindModelProvider
from ai_gateway.self_signed_jwt.token_authority import SELF_SIGNED_TOKEN_ISSUER
__all__ = [
"router",
......
# flake8: noqa
from ai_gateway.auth import cache
from ai_gateway.auth import cache, container
from ai_gateway.auth.providers import *
from ai_gateway.auth.self_signed_jwt import *
from ai_gateway.auth.user import *
from dependency_injector import containers, providers
from ai_gateway.self_signed_jwt.token_authority import TokenAuthority
from ai_gateway.auth.self_signed_jwt import TokenAuthority
__all__ = ["ContainerSelfSignedJwt"]
class ContainerSelfSignedJwt(containers.DeclarativeContainer):
......
# flake8: noqa
from ai_gateway.auth.self_signed_jwt.token_authority import *
......@@ -8,6 +8,7 @@ from ai_gateway.gitlab_features import GitLabUnitPrimitive
from ai_gateway.tracking.errors import log_exception
__all__ = [
"SELF_SIGNED_TOKEN_ISSUER",
"TokenAuthority",
]
......@@ -38,7 +39,7 @@ class TokenAuthority:
token = jwt.encode(claims, self.signing_key, algorithm=self.ALGORITHM)
return (token, int(expires_at.timestamp()))
return token, int(expires_at.timestamp())
except JWTError as err:
log_exception(err)
raise