Skip to content
Snippets Groups Projects
Verified Commit c11ec96c authored by Gosia Ksionek's avatar Gosia Ksionek :palm_tree: Committed by GitLab
Browse files

Handle errors in AI gateway client

parent b5994f2c
No related branches found
No related tags found
1 merge request!148088Handle errors in AI gateway client
......@@ -22,6 +22,8 @@ class Client
ALLOWED_PAYLOAD_PARAM_KEYS = %i[temperature max_tokens_to_sample stop_sequences].freeze
ConnectionError = Class.new(StandardError)
def self.access_token(scopes:)
::CloudConnector::AccessService.new.access_token(audience: JWT_AUDIENCE, scopes: scopes)
end
......@@ -41,6 +43,8 @@ def complete(prompt:, **options)
perform_completion_request(prompt: prompt, options: options.except(:stream))
end
logger.info_or_debug(user, message: "Received response from AI Gateway", response: response["response"])
track_prompt_size(token_size(prompt))
track_response_size(token_size(response["response"]))
......@@ -52,16 +56,24 @@ def stream(prompt:, **options)
response_body = ""
perform_completion_request(prompt: prompt, options: options.merge(stream: true)) do |chunk|
response = perform_completion_request(prompt: prompt, options: options.merge(stream: true)) do |chunk|
response_body += chunk
yield chunk if block_given?
end
track_prompt_size(token_size(prompt))
track_response_size(token_size(response_body))
if response.success?
logger.info_or_debug(user, message: "Received response from AI Gateway", response: response_body)
track_prompt_size(token_size(prompt))
track_response_size(token_size(response_body))
response_body
else
logger.error(message: "Received error from AI gateway", response: response_body)
response_body
raise ConnectionError, 'AI gateway not reachable'
end
end
traceable :stream, name: 'Request to AI Gateway', run_type: 'llm'
......@@ -73,7 +85,7 @@ def perform_completion_request(prompt:, options:)
logger.info(message: "Performing request to AI Gateway", options: options)
timeout = options.delete(:timeout) || DEFAULT_TIMEOUT
response = Gitlab::HTTP.post(
Gitlab::HTTP.post(
"#{Gitlab::AiGateway.url}#{endpoint_url(options)}",
headers: request_headers,
body: request_body(prompt: prompt, options: options).to_json,
......@@ -83,10 +95,6 @@ def perform_completion_request(prompt:, options:)
) do |fragment|
yield fragment if block_given?
end
logger.info_or_debug(user, message: "Received response from AI Gateway", response: response)
response
end
def enabled?
......
......@@ -75,6 +75,12 @@ def execute
context: context,
content: _("GitLab Duo didn't respond. Try again? If it fails again, your request might be too large.")
)
rescue Gitlab::Llm::AiGateway::Client::ConnectionError => error
Gitlab::ErrorTracking.track_exception(error)
Answer.error_answer(
context: context,
content: _("GitLab Duo could not connect to the AI provider.")
)
end
traceable :execute, name: 'Run ReAct'
......
......@@ -227,9 +227,19 @@
context 'when response contains multiple events' do
let(:expected_response) { "Hello World" }
let(:success) do
instance_double(HTTParty::Response,
code: 200,
success?: true,
parsed_response: response_body,
headers: response_headers,
body: response_body
)
end
before do
allow(Gitlab::HTTP).to receive(:post).and_yield("Hello").and_yield(" ").and_yield("World")
allow(Gitlab::HTTP).to receive(:post).and_return(success)
.and_yield("Hello").and_yield(" ").and_yield("World")
end
it 'provides parsed streamed response' do
......@@ -240,44 +250,66 @@
it 'returns response' do
expect(described_class.new(user).stream(prompt: 'anything', **options)).to eq(expected_response)
end
end
context 'when additional params are passed in as options' do
let(:options) do
{ temperature: 1, stop_sequences: %W[\n\nHuman Observation:], max_tokens_to_sample: 1024,
disallowed_param: 1 }
end
context 'when additional params are passed in as options' do
let(:options) do
{ temperature: 1, stop_sequences: %W[\n\nHuman Observation:], max_tokens_to_sample: 1024,
disallowed_param: 1 }
end
let(:expected_response) { "Hello World" }
let(:expected_response) { "Hello World" }
before do
allow(Gitlab::HTTP).to receive(:post).and_yield("Hello").and_yield(" ").and_yield("World")
end
before do
allow(Gitlab::HTTP).to receive(:post).and_return(success)
.and_yield("Hello").and_yield(" ").and_yield("World")
end
it 'passes the allowed options as params' do
expect(described_class.new(user).stream(prompt: 'anything', **options)).to eq(expected_response)
it 'passes the allowed options as params' do
expect(described_class.new(user).stream(prompt: 'anything', **options)).to eq(expected_response)
expect(Gitlab::HTTP).to have_received(:post).with(
anything,
hash_including(
body: including(
'"temperature":1',
'"stop_sequences":["\n\nHuman","Observation:"]',
'"max_tokens_to_sample":1024'
)
)
)
end
it 'does not pass the disallowed options as params' do
expect(described_class.new(user).stream(prompt: 'anything', **options)).to eq(expected_response)
expect(Gitlab::HTTP).to have_received(:post).with(
anything,
hash_including(
body: including(
'"temperature":1',
'"stop_sequences":["\n\nHuman","Observation:"]',
'"max_tokens_to_sample":1024'
expect(Gitlab::HTTP).to have_received(:post).with(
anything,
hash_excluding(
body: include('disallowed_param')
)
)
end
end
end
context 'when response is not successful' do
let(:response_body) { expected_response.to_json }
let(:failure) do
instance_double(HTTParty::Response,
code: 400,
success?: false,
parsed_response: response_body,
headers: response_headers
)
end
it 'does not pass the disallowed options as params' do
expect(described_class.new(user).stream(prompt: 'anything', **options)).to eq(expected_response)
before do
allow(Gitlab::HTTP).to receive(:post).and_return(failure)
end
expect(Gitlab::HTTP).to have_received(:post).with(
anything,
hash_excluding(
body: include('disallowed_param')
)
)
it 'raises error' do
expect { described_class.new(user).stream(prompt: 'anything', **options) }
.to raise_error(Gitlab::Llm::AiGateway::Client::ConnectionError)
end
end
end
......
......@@ -500,6 +500,26 @@
end
end
end
context 'when connection error is raised' do
let(:error) { ::Gitlab::Llm::AiGateway::Client::ConnectionError.new }
before do
allow(Gitlab::ErrorTracking).to receive(:track_exception)
end
context 'when streamed request times out' do
it 'returns an error' do
allow(ai_request_double).to receive(:request).and_raise(error)
answer = agent.execute
expect(answer.is_final).to eq(true)
expect(answer.content).to include("GitLab Duo could not connect to the AI provider")
expect(Gitlab::ErrorTracking).to have_received(:track_exception).with(error)
end
end
end
end
def claude_3_system_prompt(agent)
......
......@@ -23046,6 +23046,9 @@ msgstr ""
msgid "GitLab Community Edition"
msgstr ""
 
msgid "GitLab Duo could not connect to the AI provider."
msgstr ""
msgid "GitLab Duo didn't respond. Try again? If it fails again, your request might be too large."
msgstr ""
 
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment