Skip to content
Snippets Groups Projects
Commit 84e0a960 authored by Igor Drozdov's avatar Igor Drozdov :two:
Browse files

Merge branch 'id-refactor-auth-token' into 'main'

chore(q): take cloud connector token from current user

See merge request !1939
parents 3add107e a2860ba9
No related branches found
No related tags found
4 merge requests!24test: multi-choice for code assistance,!19Draft: Q integration,!14Vivekp feature main,!10Vivekp feature main
......@@ -9,8 +9,10 @@ class StarletteUser(BaseUser):
def __init__(
self,
cloud_connector_user: CloudConnectorUser,
cloud_connector_token: Optional[str] = None,
):
self.cloud_connector_user = cloud_connector_user
self.cloud_connector_token = cloud_connector_token
# overriding starlette BaseUser methods
@property
......
......@@ -14,6 +14,7 @@ from gitlab_cloud_connector import (
CloudConnectorUser,
)
from gitlab_cloud_connector import authenticate as cloud_connector_authenticate
from gitlab_cloud_connector.auth import AUTH_HEADER
from langsmith.run_helpers import tracing_context
from starlette.authentication import (
AuthCredentials,
......@@ -272,8 +273,10 @@ class MiddlewareAuthentication(Middleware):
if cloud_connector_error:
raise AuthenticationError(cloud_connector_error.error_message)
_, _, cloud_connector_token = conn.headers.get(AUTH_HEADER).partition(" ")
return AuthCredentials(cloud_connector_user.claims.scopes), StarletteUser(
cloud_connector_user
cloud_connector_user,
cloud_connector_token,
)
@timing("auth_duration_s")
......
......@@ -2,7 +2,6 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from gitlab_cloud_connector import GitLabFeatureCategory, GitLabUnitPrimitive
from gitlab_cloud_connector.auth import AUTH_HEADER
from ai_gateway.api.auth_utils import StarletteUser, get_current_user
from ai_gateway.api.feature_category import feature_category
......@@ -47,7 +46,6 @@ async def oauth_create_application(
try:
q_client = amazon_q_client_factory.get_client(
current_user=current_user,
auth_header=request.headers.get(AUTH_HEADER),
role_arn=application_request.role_arn,
)
......
......@@ -2,7 +2,6 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from gitlab_cloud_connector import GitLabFeatureCategory, GitLabUnitPrimitive
from gitlab_cloud_connector.auth import AUTH_HEADER
from ai_gateway.api.auth_utils import StarletteUser, get_current_user
from ai_gateway.api.feature_category import feature_category
......@@ -47,7 +46,6 @@ async def events(
try:
q_client = amazon_q_client_factory.get_client(
current_user=current_user,
auth_header=request.headers.get(AUTH_HEADER),
role_arn=event_request.role_arn,
)
......
......@@ -9,7 +9,6 @@ from gitlab_cloud_connector import (
GitLabFeatureCategory,
GitLabUnitPrimitive,
)
from gitlab_cloud_connector.auth import AUTH_HEADER
from ai_gateway.api.auth_utils import StarletteUser, get_current_user
from ai_gateway.api.error_utils import capture_validation_errors
......@@ -481,7 +480,6 @@ def _build_code_completions(
tracking_event = f"request_{unit_primitive}_complete_code"
code_completions = completions_amazon_q_factory(
model__current_user=current_user,
model__auth_header=request.headers.get(AUTH_HEADER),
model__role_arn=payload.role_arn,
)
else:
......
......@@ -3,6 +3,7 @@ import hashlib
import json
import uuid
from datetime import datetime, timedelta, timezone
from typing import Optional
import requests
from gitlab_cloud_connector import CompositeProvider
......@@ -24,7 +25,7 @@ class GlgoAuthority:
self.kid = self._build_kid(signing_key)
self.token_endpoint = f"{glgo_base_url}/cc/token"
def token(self, user_id: str, cloud_connector_token: str):
def token(self, user_id: str, cloud_connector_token: Optional[str]):
token = self._build_token(user_id, cloud_connector_token)
headers = {
......@@ -44,7 +45,7 @@ class GlgoAuthority:
return response.json().get("token")
def _build_token(self, user_id: str, cloud_connector_token: str):
def _build_token(self, user_id: str, cloud_connector_token: Optional[str]):
now = datetime.now(timezone.utc)
payload = {
"iss": self.ISSUER,
......
......@@ -33,8 +33,8 @@ class AmazonQClientFactory:
self.endpoint_url = endpoint_url
self.region = region
def get_client(self, current_user: StarletteUser, auth_header: str, role_arn: str):
token = self._get_glgo_token(current_user, auth_header)
def get_client(self, current_user: StarletteUser, role_arn: str):
token = self._get_glgo_token(current_user)
credentials = self._get_aws_credentials(current_user, token, role_arn)
return AmazonQClient(
......@@ -46,7 +46,6 @@ class AmazonQClientFactory:
def _get_glgo_token(
self,
current_user: StarletteUser,
auth_header: str,
):
user_id = current_user.global_user_id
if not user_id:
......@@ -55,10 +54,9 @@ class AmazonQClientFactory:
)
try:
_, _, cloud_connector_token = auth_header.partition(" ")
token = self.glgo_authority.token(
user_id=user_id,
cloud_connector_token=cloud_connector_token,
cloud_connector_token=current_user.cloud_connector_token,
)
request_log.info("Obtained Glgo token", source=__name__, user_id=user_id)
return token
......
......@@ -21,11 +21,9 @@ class AmazonQModel(TextGenModelBase):
self,
current_user: StarletteUser,
role_arn: str,
auth_header: str,
client_factory: AmazonQClientFactory,
):
self._current_user = current_user
self._auth_header = auth_header
self._role_arn = role_arn
self._client_factory = client_factory
self._metadata = ModelMetadata(
......@@ -47,7 +45,6 @@ class AmazonQModel(TextGenModelBase):
) -> TextGenModelOutput:
q_client = self._client_factory.get_client(
current_user=self._current_user,
auth_header=self._auth_header,
role_arn=self._role_arn,
)
......
......@@ -29,13 +29,18 @@ def homepage(
detail="Unauthorized to access homepage",
)
content = {
"authenticated": request.user.is_authenticated,
"is_debug": request.user.is_debug,
"scopes": request.auth.scopes,
}
if request.user.cloud_connector_token:
content["token"] = request.user.cloud_connector_token
return JSONResponse(
status_code=200,
content={
"authenticated": request.user.is_authenticated,
"is_debug": request.user.is_debug,
"scopes": request.auth.scopes,
},
content=content,
)
......@@ -149,6 +154,7 @@ invalid_authentication_token_type_error = {
"authenticated": True,
"is_debug": False,
"scopes": ["feature1", "feature3"],
"token": "12345",
},
["auth_duration_s", "token_issuer"],
),
......@@ -174,6 +180,7 @@ invalid_authentication_token_type_error = {
"authenticated": True,
"is_debug": False,
"scopes": ["feature2", "feature3"],
"token": "12345",
},
["auth_duration_s", "token_issuer"],
),
......@@ -199,6 +206,7 @@ invalid_authentication_token_type_error = {
"authenticated": True,
"is_debug": False,
"scopes": ["feature1", "feature2", "feature3"],
"token": "12345",
},
["auth_duration_s", "token_issuer"],
),
......@@ -386,6 +394,7 @@ invalid_authentication_token_type_error = {
"authenticated": True,
"is_debug": False,
"scopes": ["feature1", "feature3"],
"token": "12345",
},
["auth_duration_s", "token_issuer"],
),
......@@ -440,6 +449,7 @@ invalid_authentication_token_type_error = {
"authenticated": True,
"is_debug": False,
"scopes": ["feature1", "feature3"],
"token": "12345",
},
["auth_duration_s", "token_issuer"],
),
......@@ -470,6 +480,7 @@ invalid_authentication_token_type_error = {
"authenticated": True,
"is_debug": False,
"scopes": ["feature1", "feature3"],
"token": "12345",
},
["auth_duration_s", "token_issuer"],
),
......@@ -561,6 +572,7 @@ invalid_authentication_token_type_error = {
"authenticated": True,
"is_debug": False,
"scopes": ["feature1", "feature3"],
"token": "12345",
},
[
"auth_duration_s",
......@@ -621,6 +633,7 @@ def test_failed_authorization_logging(
"authenticated": True,
"is_debug": False,
"scopes": ["feature1", "feature3"],
"token": "12345",
},
"Header is missing: 'X-Gitlab-Duo-Seat-Count'",
[
......@@ -655,6 +668,7 @@ def test_failed_authorization_logging(
"authenticated": True,
"is_debug": False,
"scopes": ["feature1", "feature3"],
"token": "12345",
},
"Header mismatch 'X-Gitlab-Duo-Seat-Count'",
[
......
......@@ -40,6 +40,7 @@ class TestAmazonQClientFactory:
def mock_user(self):
user = MagicMock(spec=StarletteUser)
user.global_user_id = "test-user-id"
user.cloud_connector_token = "mock-cloud-connector-token"
user.claims = MagicMock(subject="test-session")
return user
......@@ -47,9 +48,7 @@ class TestAmazonQClientFactory:
self, amazon_q_client_factory, mock_user, mock_glgo_authority
):
mock_glgo_authority.token.return_value = "mock-token"
token = amazon_q_client_factory._get_glgo_token(
mock_user, "Bearer mock-cloud-connector-token"
)
token = amazon_q_client_factory._get_glgo_token(mock_user)
mock_glgo_authority.token.assert_called_once_with(
user_id="test-user-id", cloud_connector_token="mock-cloud-connector-token"
......@@ -62,9 +61,7 @@ class TestAmazonQClientFactory:
mock_user.global_user_id = None
with pytest.raises(HTTPException) as exc:
amazon_q_client_factory._get_glgo_token(
mock_user, "Bearer mock-cloud-connector-token"
)
amazon_q_client_factory._get_glgo_token(mock_user)
assert exc.value.status_code == 400
assert exc.value.detail == "User Id is missing"
......@@ -74,9 +71,7 @@ class TestAmazonQClientFactory:
mock_glgo_authority.token.side_effect = KeyError()
with pytest.raises(HTTPException) as exc:
amazon_q_client_factory._get_glgo_token(
mock_user, "Bearer mock-cloud-connector-token"
)
amazon_q_client_factory._get_glgo_token(mock_user)
assert exc.value.status_code == 500
assert exc.value.detail == "Cannot obtain OIDC token"
......@@ -158,7 +153,6 @@ class TestAmazonQClientFactory:
client = amazon_q_client_factory.get_client(
current_user=mock_user,
auth_header="Bearer mock-cloud-connector-token",
role_arn="mock-role-arn",
)
......
......@@ -13,11 +13,10 @@ def test_amazon_q_model_init():
mock_user = MagicMock(spec=StarletteUser)
mock_factory = MagicMock(spec=AmazonQClientFactory)
model = AmazonQModel(mock_user, "test-role", "test-header", mock_factory)
model = AmazonQModel(mock_user, "test-role", mock_factory)
assert model._current_user == mock_user
assert model._role_arn == "test-role"
assert model._auth_header == "test-header"
assert model.metadata.name == KindAmazonQModel.AMAZON_Q
assert model.metadata.engine == KindAmazonQModel.AMAZON_Q
......@@ -29,7 +28,7 @@ async def test_amazon_q_model_generate():
mock_client = MagicMock()
mock_factory.get_client.return_value = mock_client
model = AmazonQModel(mock_user, "test-role", "test-header", mock_factory)
model = AmazonQModel(mock_user, "test-role", mock_factory)
mock_client.generate_code_recommendations.return_value = {
"CodeRecommendations": [{"content": "Generated Code"}]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment