Skip to content
Snippets Groups Projects
Commit 629087ff authored by Dmitry Gruzd's avatar Dmitry Gruzd :two:
Browse files

Rebase changes and address reviewer's feedback

parent 1a40621c
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !117695. Comments created here will be created in the context of that merge request.
......@@ -5,6 +5,8 @@ class TanukiBotController < ApplicationController
wrap_parameters format: []
feature_category :global_search
before_action :verify_tanuki_bot_enabled
def ask
respond_to do |format|
format.json { render json: generate_response, status: :ok }
......@@ -17,5 +19,11 @@ def ask
def generate_response
::Gitlab::Llm::TanukiBot.execute(current_user: current_user, question: params.require(:q))
end
def verify_tanuki_bot_enabled
return if ::Gitlab::Llm::TanukiBot.enabled_for?(user: current_user)
head :unauthorized
end
end
end
......@@ -18,8 +18,9 @@ class Client
include ExponentialBackoff
def initialize(user)
def initialize(user, request_timeout: nil)
@user = user
@request_timeout = request_timeout
end
def chat(content:, **options)
......@@ -73,10 +74,10 @@ def embeddings(input:, **options)
retry_methods_with_exponential_backoff :chat, :completions, :edits, :embeddings
attr_reader :user
attr_reader :user, :request_timeout
def client
@client ||= OpenAI::Client.new(access_token: access_token)
@client ||= OpenAI::Client.new(access_token: access_token, request_timeout: request_timeout)
end
def enabled?
......
......@@ -11,6 +11,7 @@ class TanukiBot
n: 1,
best_of: 1
}.freeze
REQUEST_TIMEOUT = 30
CONTENT_ID_FIELD = 'ATTRS'
CONTENT_ID_REGEX = /CNT-IDX-(?<id>\d+)/
RECORD_LIMIT = Rails.env.production? ? 7 : 2
......@@ -19,17 +20,24 @@ def self.execute(current_user:, question:, logger: nil)
new(current_user: current_user, question: question, logger: logger).execute
end
def self.enabled_for?(user:)
return false unless user
return false unless ::License.feature_available?(:ai_tanuki_bot)
return false unless Feature.enabled?(:openai_experimentation) && Feature.enabled?(:tanuki_bot, user)
return false if ::Gitlab.com? && !user&.has_paid_namespace?(plans: [::Plan::ULTIMATE])
true
end
def initialize(current_user:, question:, logger: nil)
@current_user = current_user
@question = question
@logger = logger
@logger = logger || Gitlab::AppJsonLogger.build
end
def execute
return {} unless question.present?
return {} unless ::License.feature_available?(:ai_tanuki_bot)
return {} if ::Gitlab.com? && !current_user&.has_paid_namespace?(plans: [::Plan::ULTIMATE])
return {} unless Feature.enabled?(:openai_experimentation) && Feature.enabled?(:tanuki_bot, current_user)
return {} unless self.class.enabled_for?(user: current_user)
search_documents = query_search_documents
return empty_response if search_documents.empty?
......@@ -44,7 +52,7 @@ def execute
attr_reader :current_user, :question, :logger
def client
@client ||= ::Gitlab::Llm::OpenAi::Client.new(current_user)
@client ||= ::Gitlab::Llm::OpenAi::Client.new(current_user, request_timeout: REQUEST_TIMEOUT)
end
def build_initial_prompts(search_documents)
......@@ -64,14 +72,15 @@ def build_initial_prompts(search_documents)
def send_initial_prompt(doc:, prompt:)
result = client.completions(prompt: prompt, **DEFAULT_OPTIONS)
debug(
info(
document: doc[:id],
prompt: prompt,
code: result.code,
result: result.parsed_response,
message: 'Initial prompt request'
)
raise result.dig('error', 'message') if result['error']
raise result.dig('error', 'message') || "Initial prompt request failed with '#{result}'" unless result.success?
doc.merge(extracted_text: result['choices'].first['text'])
end
......@@ -119,7 +128,16 @@ def get_completions(search_documents)
PROMPT
final_prompt_result = client.completions(prompt: final_prompt, **DEFAULT_OPTIONS)
debug(prompt: final_prompt, result: final_prompt_result.parsed_response, message: 'Final prompt request')
info(
prompt: final_prompt,
code: final_prompt_result.code,
result: final_prompt_result.parsed_response,
message: 'Final prompt request'
)
unless final_prompt_result.success?
raise final_prompt_result.dig('error', 'message') || "Final prompt request failed with '#{result}'"
end
final_prompt_result
end
......@@ -128,7 +146,7 @@ def query_search_documents
embeddings_result = client.embeddings(input: question)
question_embedding = embeddings_result['data'].first['embedding']
nearest_neighbors = ::Embedding::TanukiBot.neighbor_for(question_embedding).limit(RECORD_LIMIT)
nearest_neighbors = ::Embedding::TanukiBotMvc.neighbor_for(question_embedding).limit(RECORD_LIMIT)
nearest_neighbors.map do |item|
item.metadata['source_url'] = item.url
......@@ -169,10 +187,10 @@ def empty_response
}
end
def debug(payload)
def info(payload)
return unless logger
logger.debug(build_structured_payload(**payload))
logger.info(build_structured_payload(**payload))
end
end
end
......
......@@ -10,6 +10,7 @@
before do
allow(Gitlab::Llm::TanukiBot).to receive_message_chain(:new, :execute).and_return({})
allow(Gitlab::Llm::TanukiBot).to receive(:enabled_for?).and_return(true)
end
it 'responds with a 401' do
......@@ -23,6 +24,18 @@
sign_in(create(:user))
end
context 'when user does not have access to the feature' do
before do
allow(Gitlab::Llm::TanukiBot).to receive(:enabled_for?).and_return(false)
end
it 'responds with a 401' do
subject
expect(response).to have_gitlab_http_status(:unauthorized)
end
end
it 'responds with :bad_request if the request is not json' do
post :ask, params: { q: question }
......
......@@ -5,7 +5,7 @@
RSpec.describe Gitlab::Llm::TanukiBot, feature_category: :global_search do
describe '#execute' do
let_it_be(:user) { create(:user) }
let_it_be(:embeddings) { create_list(:tanuki_bot, 2) }
let_it_be(:embeddings) { create_list(:tanuki_bot_mvc, 2) }
let(:question) { 'A question' }
let(:answer) { 'The answer.' }
......@@ -14,15 +14,17 @@
let(:openai_client) { ::Gitlab::Llm::OpenAi::Client.new(user) }
let(:embedding_response) { { "data" => [{ "embedding" => Array.new(1536, 0.5) }] } }
let(:attrs) { embeddings.map(&:id).map { |x| "CNT-IDX-#{x}" }.join(", ") }
let(:completion_response) do
{ "choices" => [{ "text" => "#{answer} ATTRS: #{attrs}" }] }
end
let(:completion_response) { { "choices" => [{ "text" => "#{answer} ATTRS: #{attrs}" }] } }
let(:status_code) { 200 }
let(:success) { true }
subject(:execute) { instance.execute }
before do
allow(License).to receive(:feature_available?).and_return(true)
allow(logger).to receive(:debug)
allow(logger).to receive(:info)
allow(completion_response).to receive(:code).and_return(status_code)
allow(completion_response).to receive(:success?).and_return(success)
end
context 'with the ai_tanuki_bot license not available' do
......@@ -66,7 +68,7 @@
end
it 'executes calls through to open ai' do
create_list(:tanuki_bot, 2)
create_list(:tanuki_bot_mvc, 2)
expect(openai_client).to receive(:completions).exactly(3).times.and_return(completion_response)
expect(openai_client).to receive(:embeddings).and_return(embedding_response)
......@@ -114,7 +116,7 @@
context 'when no neighbors are found' do
before do
allow(Embedding::TanukiBot).to receive(:neighbor_for).and_return(Embedding::TanukiBot.none)
allow(Embedding::TanukiBotMvc).to receive(:neighbor_for).and_return(Embedding::TanukiBotMvc.none)
allow(openai_client).to receive(:embeddings).with(input: question).and_return(embedding_response)
end
......@@ -151,8 +153,8 @@
expect(result[:msg]).to eq(answer)
expect(result[:sources].count).to eq(2)
expected_sources = ::Embedding::TanukiBot.pluck(:metadata).pluck('source')
expected_source_urls = ::Embedding::TanukiBot.pluck(:url)
expected_sources = ::Embedding::TanukiBotMvc.pluck(:metadata).pluck('source')
expected_source_urls = ::Embedding::TanukiBotMvc.pluck(:url)
expect(result[:sources].pluck('source')).to match_array(expected_sources)
expect(result[:sources].pluck('source_url')).to match_array(expected_source_urls)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment