Skip to content
Snippets Groups Projects
Commit 629087ff authored by Dmitry Gruzd's avatar Dmitry Gruzd :two:
Browse files

Rebase changes and address reviewer's feedback

parent 1a40621c
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !117695. Comments created here will be created in the context of that merge request.
...@@ -5,6 +5,8 @@ class TanukiBotController < ApplicationController ...@@ -5,6 +5,8 @@ class TanukiBotController < ApplicationController
wrap_parameters format: [] wrap_parameters format: []
feature_category :global_search feature_category :global_search
before_action :verify_tanuki_bot_enabled
def ask def ask
respond_to do |format| respond_to do |format|
format.json { render json: generate_response, status: :ok } format.json { render json: generate_response, status: :ok }
...@@ -17,5 +19,11 @@ def ask ...@@ -17,5 +19,11 @@ def ask
def generate_response def generate_response
::Gitlab::Llm::TanukiBot.execute(current_user: current_user, question: params.require(:q)) ::Gitlab::Llm::TanukiBot.execute(current_user: current_user, question: params.require(:q))
end end
def verify_tanuki_bot_enabled
return if ::Gitlab::Llm::TanukiBot.enabled_for?(user: current_user)
head :unauthorized
end
end end
end end
...@@ -18,8 +18,9 @@ class Client ...@@ -18,8 +18,9 @@ class Client
include ExponentialBackoff include ExponentialBackoff
def initialize(user) def initialize(user, request_timeout: nil)
@user = user @user = user
@request_timeout = request_timeout
end end
def chat(content:, **options) def chat(content:, **options)
...@@ -73,10 +74,10 @@ def embeddings(input:, **options) ...@@ -73,10 +74,10 @@ def embeddings(input:, **options)
retry_methods_with_exponential_backoff :chat, :completions, :edits, :embeddings retry_methods_with_exponential_backoff :chat, :completions, :edits, :embeddings
attr_reader :user attr_reader :user, :request_timeout
def client def client
@client ||= OpenAI::Client.new(access_token: access_token) @client ||= OpenAI::Client.new(access_token: access_token, request_timeout: request_timeout)
end end
def enabled? def enabled?
......
...@@ -11,6 +11,7 @@ class TanukiBot ...@@ -11,6 +11,7 @@ class TanukiBot
n: 1, n: 1,
best_of: 1 best_of: 1
}.freeze }.freeze
REQUEST_TIMEOUT = 30
CONTENT_ID_FIELD = 'ATTRS' CONTENT_ID_FIELD = 'ATTRS'
CONTENT_ID_REGEX = /CNT-IDX-(?<id>\d+)/ CONTENT_ID_REGEX = /CNT-IDX-(?<id>\d+)/
RECORD_LIMIT = Rails.env.production? ? 7 : 2 RECORD_LIMIT = Rails.env.production? ? 7 : 2
...@@ -19,17 +20,24 @@ def self.execute(current_user:, question:, logger: nil) ...@@ -19,17 +20,24 @@ def self.execute(current_user:, question:, logger: nil)
new(current_user: current_user, question: question, logger: logger).execute new(current_user: current_user, question: question, logger: logger).execute
end end
def self.enabled_for?(user:)
return false unless user
return false unless ::License.feature_available?(:ai_tanuki_bot)
return false unless Feature.enabled?(:openai_experimentation) && Feature.enabled?(:tanuki_bot, user)
return false if ::Gitlab.com? && !user&.has_paid_namespace?(plans: [::Plan::ULTIMATE])
true
end
def initialize(current_user:, question:, logger: nil) def initialize(current_user:, question:, logger: nil)
@current_user = current_user @current_user = current_user
@question = question @question = question
@logger = logger @logger = logger || Gitlab::AppJsonLogger.build
end end
def execute def execute
return {} unless question.present? return {} unless question.present?
return {} unless ::License.feature_available?(:ai_tanuki_bot) return {} unless self.class.enabled_for?(user: current_user)
return {} if ::Gitlab.com? && !current_user&.has_paid_namespace?(plans: [::Plan::ULTIMATE])
return {} unless Feature.enabled?(:openai_experimentation) && Feature.enabled?(:tanuki_bot, current_user)
search_documents = query_search_documents search_documents = query_search_documents
return empty_response if search_documents.empty? return empty_response if search_documents.empty?
...@@ -44,7 +52,7 @@ def execute ...@@ -44,7 +52,7 @@ def execute
attr_reader :current_user, :question, :logger attr_reader :current_user, :question, :logger
def client def client
@client ||= ::Gitlab::Llm::OpenAi::Client.new(current_user) @client ||= ::Gitlab::Llm::OpenAi::Client.new(current_user, request_timeout: REQUEST_TIMEOUT)
end end
def build_initial_prompts(search_documents) def build_initial_prompts(search_documents)
...@@ -64,14 +72,15 @@ def build_initial_prompts(search_documents) ...@@ -64,14 +72,15 @@ def build_initial_prompts(search_documents)
def send_initial_prompt(doc:, prompt:) def send_initial_prompt(doc:, prompt:)
result = client.completions(prompt: prompt, **DEFAULT_OPTIONS) result = client.completions(prompt: prompt, **DEFAULT_OPTIONS)
debug( info(
document: doc[:id], document: doc[:id],
prompt: prompt, prompt: prompt,
code: result.code,
result: result.parsed_response, result: result.parsed_response,
message: 'Initial prompt request' message: 'Initial prompt request'
) )
raise result.dig('error', 'message') if result['error'] raise result.dig('error', 'message') || "Initial prompt request failed with '#{result}'" unless result.success?
doc.merge(extracted_text: result['choices'].first['text']) doc.merge(extracted_text: result['choices'].first['text'])
end end
...@@ -119,7 +128,16 @@ def get_completions(search_documents) ...@@ -119,7 +128,16 @@ def get_completions(search_documents)
PROMPT PROMPT
final_prompt_result = client.completions(prompt: final_prompt, **DEFAULT_OPTIONS) final_prompt_result = client.completions(prompt: final_prompt, **DEFAULT_OPTIONS)
debug(prompt: final_prompt, result: final_prompt_result.parsed_response, message: 'Final prompt request') info(
prompt: final_prompt,
code: final_prompt_result.code,
result: final_prompt_result.parsed_response,
message: 'Final prompt request'
)
unless final_prompt_result.success?
raise final_prompt_result.dig('error', 'message') || "Final prompt request failed with '#{result}'"
end
final_prompt_result final_prompt_result
end end
...@@ -128,7 +146,7 @@ def query_search_documents ...@@ -128,7 +146,7 @@ def query_search_documents
embeddings_result = client.embeddings(input: question) embeddings_result = client.embeddings(input: question)
question_embedding = embeddings_result['data'].first['embedding'] question_embedding = embeddings_result['data'].first['embedding']
nearest_neighbors = ::Embedding::TanukiBot.neighbor_for(question_embedding).limit(RECORD_LIMIT) nearest_neighbors = ::Embedding::TanukiBotMvc.neighbor_for(question_embedding).limit(RECORD_LIMIT)
nearest_neighbors.map do |item| nearest_neighbors.map do |item|
item.metadata['source_url'] = item.url item.metadata['source_url'] = item.url
...@@ -169,10 +187,10 @@ def empty_response ...@@ -169,10 +187,10 @@ def empty_response
} }
end end
def debug(payload) def info(payload)
return unless logger return unless logger
logger.debug(build_structured_payload(**payload)) logger.info(build_structured_payload(**payload))
end end
end end
end end
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
before do before do
allow(Gitlab::Llm::TanukiBot).to receive_message_chain(:new, :execute).and_return({}) allow(Gitlab::Llm::TanukiBot).to receive_message_chain(:new, :execute).and_return({})
allow(Gitlab::Llm::TanukiBot).to receive(:enabled_for?).and_return(true)
end end
it 'responds with a 401' do it 'responds with a 401' do
...@@ -23,6 +24,18 @@ ...@@ -23,6 +24,18 @@
sign_in(create(:user)) sign_in(create(:user))
end end
context 'when user does not have access to the feature' do
before do
allow(Gitlab::Llm::TanukiBot).to receive(:enabled_for?).and_return(false)
end
it 'responds with a 401' do
subject
expect(response).to have_gitlab_http_status(:unauthorized)
end
end
it 'responds with :bad_request if the request is not json' do it 'responds with :bad_request if the request is not json' do
post :ask, params: { q: question } post :ask, params: { q: question }
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
RSpec.describe Gitlab::Llm::TanukiBot, feature_category: :global_search do RSpec.describe Gitlab::Llm::TanukiBot, feature_category: :global_search do
describe '#execute' do describe '#execute' do
let_it_be(:user) { create(:user) } let_it_be(:user) { create(:user) }
let_it_be(:embeddings) { create_list(:tanuki_bot, 2) } let_it_be(:embeddings) { create_list(:tanuki_bot_mvc, 2) }
let(:question) { 'A question' } let(:question) { 'A question' }
let(:answer) { 'The answer.' } let(:answer) { 'The answer.' }
...@@ -14,15 +14,17 @@ ...@@ -14,15 +14,17 @@
let(:openai_client) { ::Gitlab::Llm::OpenAi::Client.new(user) } let(:openai_client) { ::Gitlab::Llm::OpenAi::Client.new(user) }
let(:embedding_response) { { "data" => [{ "embedding" => Array.new(1536, 0.5) }] } } let(:embedding_response) { { "data" => [{ "embedding" => Array.new(1536, 0.5) }] } }
let(:attrs) { embeddings.map(&:id).map { |x| "CNT-IDX-#{x}" }.join(", ") } let(:attrs) { embeddings.map(&:id).map { |x| "CNT-IDX-#{x}" }.join(", ") }
let(:completion_response) do let(:completion_response) { { "choices" => [{ "text" => "#{answer} ATTRS: #{attrs}" }] } }
{ "choices" => [{ "text" => "#{answer} ATTRS: #{attrs}" }] } let(:status_code) { 200 }
end let(:success) { true }
subject(:execute) { instance.execute } subject(:execute) { instance.execute }
before do before do
allow(License).to receive(:feature_available?).and_return(true) allow(License).to receive(:feature_available?).and_return(true)
allow(logger).to receive(:debug) allow(logger).to receive(:info)
allow(completion_response).to receive(:code).and_return(status_code)
allow(completion_response).to receive(:success?).and_return(success)
end end
context 'with the ai_tanuki_bot license not available' do context 'with the ai_tanuki_bot license not available' do
...@@ -66,7 +68,7 @@ ...@@ -66,7 +68,7 @@
end end
it 'executes calls through to open ai' do it 'executes calls through to open ai' do
create_list(:tanuki_bot, 2) create_list(:tanuki_bot_mvc, 2)
expect(openai_client).to receive(:completions).exactly(3).times.and_return(completion_response) expect(openai_client).to receive(:completions).exactly(3).times.and_return(completion_response)
expect(openai_client).to receive(:embeddings).and_return(embedding_response) expect(openai_client).to receive(:embeddings).and_return(embedding_response)
...@@ -114,7 +116,7 @@ ...@@ -114,7 +116,7 @@
context 'when no neighbors are found' do context 'when no neighbors are found' do
before do before do
allow(Embedding::TanukiBot).to receive(:neighbor_for).and_return(Embedding::TanukiBot.none) allow(Embedding::TanukiBotMvc).to receive(:neighbor_for).and_return(Embedding::TanukiBotMvc.none)
allow(openai_client).to receive(:embeddings).with(input: question).and_return(embedding_response) allow(openai_client).to receive(:embeddings).with(input: question).and_return(embedding_response)
end end
...@@ -151,8 +153,8 @@ ...@@ -151,8 +153,8 @@
expect(result[:msg]).to eq(answer) expect(result[:msg]).to eq(answer)
expect(result[:sources].count).to eq(2) expect(result[:sources].count).to eq(2)
expected_sources = ::Embedding::TanukiBot.pluck(:metadata).pluck('source') expected_sources = ::Embedding::TanukiBotMvc.pluck(:metadata).pluck('source')
expected_source_urls = ::Embedding::TanukiBot.pluck(:url) expected_source_urls = ::Embedding::TanukiBotMvc.pluck(:url)
expect(result[:sources].pluck('source')).to match_array(expected_sources) expect(result[:sources].pluck('source')).to match_array(expected_sources)
expect(result[:sources].pluck('source_url')).to match_array(expected_source_urls) expect(result[:sources].pluck('source_url')).to match_array(expected_source_urls)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment