Skip to content
Snippets Groups Projects
Commit 94bb5ce5 authored by Terri Chu's avatar Terri Chu :nail_care:
Browse files

Merge branch 'tchu-bot-create-new-api' into 'add-tanuki-bot-model'

Draft: Create initial Tanuki bot api endpoint

See merge request gitlab-org/gitlab!117695



Merged-by: Terri Chu's avatarTerri Chu <tchu@gitlab.com>
parents 1d98f59d 4ebace43
No related branches found
No related tags found
No related merge requests found
......@@ -171,6 +171,7 @@
draw :country
draw :country_state
draw :subscription
draw :llm
scope '/push_from_secondary/:geo_node_id' do
draw :git_http
......
# frozen_string_literal: true
module Llm
class TanukiBotController < ApplicationController
wrap_parameters format: []
feature_category :global_search
def ask
respond_to do |format|
format.json { render json: generate_response, status: :ok }
format.any { head :bad_request }
end
end
private
def generate_response
::Gitlab::Llm::TanukiBot.execute(current_user: current_user, question: params.require(:q))
end
end
end
---
name: tanuki_bot_parallel
introduced_by_url: https://gitlab.com/gitlab-org/gitlab/-/merge_requests/117695
rollout_issue_url: https://gitlab.com/gitlab-org/gitlab/-/issues/407555
milestone: '16.0'
type: development
group: group::global search
default_enabled: true
# frozen_string_literal: true
namespace :llm do
post 'tanuki_bot/ask' => 'tanuki_bot#ask', as: :tanuki_bot_ask, constraints: { format: :json }
end
# frozen_string_literal: true
module Gitlab
module Llm
class TanukiBot
include ::Gitlab::Loggable
DEFAULT_OPTIONS = {
max_tokens: 256,
top_p: 1,
n: 1,
best_of: 1
}.freeze
CONTENT_ID_FIELD = 'ATTRS'
CONTENT_ID_REGEX = /CNT-IDX-(?<id>\d+)/
RECORD_LIMIT = Rails.env.production? ? 7 : 2
def self.execute(current_user:, question:, logger: nil)
new(current_user: current_user, question: question, logger: logger).execute
end
def initialize(current_user:, question:, logger: nil)
@current_user = current_user
@question = question
@logger = logger
end
def execute
return {} unless question.present?
return {} unless ::License.feature_available?(:ai_tanuki_bot)
return {} if ::Gitlab.com? && !current_user&.has_paid_namespace?(plans: [::Plan::ULTIMATE])
return {} unless Feature.enabled?(:openai_experimentation) && Feature.enabled?(:tanuki_bot, current_user)
search_documents = query_search_documents
return empty_response if search_documents.empty?
result = get_completions(search_documents)
build_response(result, search_documents)
end
private
attr_reader :current_user, :question, :logger
def client
@client ||= ::Gitlab::Llm::OpenAi::Client.new(current_user)
end
def build_initial_prompts(search_documents)
search_documents.to_h do |doc|
prompt = <<~PROMPT.strip
Use the following portion of a long document to see if any of the text is relevant to answer the question.
Return any relevant text verbatim.
#{doc[:content]}
Question: #{question}
Relevant text, if any:
PROMPT
[doc, prompt]
end
end
def send_initial_prompt(doc:, prompt:)
result = client.completions(prompt: prompt, **DEFAULT_OPTIONS)
debug(
document: doc[:id],
prompt: prompt,
result: result.parsed_response,
message: 'Initial prompt request'
)
raise result.dig('error', 'message') if result['error']
doc.merge(extracted_text: result['choices'].first['text'])
end
def sequential_competion(search_documents)
prompts = build_initial_prompts(search_documents)
prompts.map do |doc, prompt|
send_initial_prompt(doc: doc, prompt: prompt)
end
end
def parallel_completion(search_documents)
prompts = build_initial_prompts(search_documents)
threads = prompts.map do |doc, prompt|
Thread.new do
send_initial_prompt(doc: doc, prompt: prompt)
end
end
threads.map(&:value)
end
def get_completions(search_documents)
documents = if Feature.enabled?(:tanuki_bot_parallel)
parallel_completion(search_documents)
else
sequential_competion(search_documents)
end
content = build_content(documents)
final_prompt = <<~PROMPT.strip
Given the following extracted parts of a long document and a question,
create a final answer with references "#{CONTENT_ID_FIELD}".
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
At the end of your answer ALWAYS return a "#{CONTENT_ID_FIELD}" part and
ALWAYS name it #{CONTENT_ID_FIELD}.
QUESTION: #{question}
=========
#{content}
=========
FINAL ANSWER:
PROMPT
final_prompt_result = client.completions(prompt: final_prompt, **DEFAULT_OPTIONS)
debug(prompt: final_prompt, result: final_prompt_result.parsed_response, message: 'Final prompt request')
final_prompt_result
end
def query_search_documents
embeddings_result = client.embeddings(input: question)
question_embedding = embeddings_result['data'].first['embedding']
nearest_neighbors = ::Embedding::TanukiBot.neighbor_for(question_embedding).limit(RECORD_LIMIT)
nearest_neighbors.map do |item|
item.metadata['source_url'] = item.url
{
id: item.id,
content: item.content,
metadata: item.metadata
}
end
end
def build_content(search_documents)
search_documents.map do |document|
<<~PROMPT.strip
CONTENT: #{document[:extracted_text]}
#{CONTENT_ID_FIELD}: CNT-IDX-#{document[:id]}
PROMPT
end.join("\n\n")
end
def build_response(result, search_documents)
output = result['choices'][0]['text'].split("#{CONTENT_ID_FIELD}:")
msg = output[0].strip
content_idx = output[1].scan(CONTENT_ID_REGEX).flatten.map(&:to_i)
documents = search_documents.filter { |doc| content_idx.include?(doc[:id]) }
sources = documents.pluck(:metadata).uniq # rubocop:disable CodeReuse/ActiveRecord
{
msg: msg,
sources: sources
}
end
def empty_response
{
msg: _("I do not know."), # TODO namespace this?
sources: []
}
end
def debug(payload)
return unless logger
logger.debug(build_structured_payload(**payload))
end
end
end
end
# frozen_string_literal: true
require 'spec_helper'
RSpec.describe Llm::TanukiBotController, feature_category: :global_search do
describe 'GET #ask' do
let(:question) { 'Some question' }
subject { post :ask, params: { q: question }, format: :json }
before do
allow(Gitlab::Llm::TanukiBot).to receive_message_chain(:new, :execute).and_return({})
end
it 'responds with a 401' do
subject
expect(response).to have_gitlab_http_status(:unauthorized)
end
context 'when the user is signed in' do
before do
sign_in(create(:user))
end
it 'responds with :bad_request if the request is not json' do
post :ask, params: { q: question }
expect(response).to have_gitlab_http_status(:bad_request)
end
it 'responds with :ok' do
subject
expect(response).to have_gitlab_http_status(:ok)
end
it 'calls TanukiBot service' do
expect(Gitlab::Llm::TanukiBot).to receive_message_chain(:new, :execute)
subject
end
context 'when question is not provided' do
let(:question) { nil }
it 'raises an error' do
expect { subject }.to raise_error(ActionController::ParameterMissing)
end
end
end
end
end
......@@ -2,12 +2,12 @@
FactoryBot.define do
factory :tanuki_bot, class: 'Embedding::TanukiBot' do
url { 'http://example.com/path/to/a/doc' }
sequence(:url) { |n| "https://example.com/path/to/a/doc_#{n}" }
metadata do
sequence(:metadata) do |n|
{
info: 'A description',
source: 'path/to/a/doc.md',
info: "Description for #{n}",
source: "path/to/a/doc_#{n}.md",
source_type: 'doc'
}
end
......
# frozen_string_literal: true
require 'spec_helper'
RSpec.describe Gitlab::Llm::TanukiBot, feature_category: :global_search do
describe '#execute' do
let_it_be(:user) { create(:user) }
let_it_be(:embeddings) { create_list(:tanuki_bot, 2) }
let(:question) { 'A question' }
let(:answer) { 'The answer.' }
let(:logger) { instance_double('Logger') }
let(:instance) { described_class.new(current_user: user, question: question, logger: logger) }
let(:openai_client) { ::Gitlab::Llm::OpenAi::Client.new(user) }
let(:embedding_response) { { "data" => [{ "embedding" => Array.new(1536, 0.5) }] } }
let(:attrs) { embeddings.map(&:id).map { |x| "CNT-IDX-#{x}" }.join(", ") }
let(:completion_response) do
{ "choices" => [{ "text" => "#{answer} ATTRS: #{attrs}" }] }
end
subject(:execute) { instance.execute }
before do
allow(License).to receive(:feature_available?).and_return(true)
allow(logger).to receive(:debug)
end
context 'with the ai_tanuki_bot license not available' do
before do
allow(License).to receive(:feature_available?).with(:ai_tanuki_bot).and_return(false)
end
it 'returns an empty hash' do
expect(execute).to eq({})
end
end
context 'with the tanuki_bot license available' do
context 'when on Gitlab.com' do
before do
allow(::Gitlab).to receive(:com?).and_return(true)
end
context 'when no user is provided' do
let(:user) { nil }
it 'returns an empty hash' do
expect(execute).to eq({})
end
end
context 'when the user does not have a paid namespace' do
before do
allow(user).to receive(:has_paid_namespace?).and_return(false)
end
it 'returns an empty hash' do
expect(execute).to eq({})
end
end
context 'when the user has a paid namespace' do
before do
allow(::Gitlab::Llm::OpenAi::Client).to receive(:new).and_return(openai_client)
allow(user).to receive(:has_paid_namespace?).and_return(true)
end
it 'executes calls through to open ai' do
create_list(:tanuki_bot, 2)
expect(openai_client).to receive(:completions).exactly(3).times.and_return(completion_response)
expect(openai_client).to receive(:embeddings).and_return(embedding_response)
allow(completion_response).to receive(:parsed_response).and_return(completion_response)
execute
end
end
end
context 'when the feature flags are disabled' do
using RSpec::Parameterized::TableSyntax
where(:openai_experimentation, :tanuki_bot) do
true | false
false | true
false | false
end
with_them do
before do
stub_feature_flags(openai_experimentation: openai_experimentation)
stub_feature_flags(tanuki_bot: tanuki_bot)
end
it 'returns an empty hash' do
expect(execute).to eq({})
end
end
end
context 'when the feature flags are enabled' do
before do
allow(completion_response).to receive(:parsed_response).and_return(completion_response)
allow(::Gitlab::Llm::OpenAi::Client).to receive(:new).and_return(openai_client)
end
context 'when the question is not provided' do
let(:question) { nil }
it 'returns an empty hash' do
expect(execute).to eq({})
end
end
context 'when no neighbors are found' do
before do
allow(Embedding::TanukiBot).to receive(:neighbor_for).and_return(Embedding::TanukiBot.none)
allow(openai_client).to receive(:embeddings).with(input: question).and_return(embedding_response)
end
it 'returns an i do not know' do
expect(execute).to eq({
msg: 'I do not know.',
sources: []
})
end
end
[true, false].each do |parallel_bot|
context "with tanuki_bot_parallel set to #{parallel_bot}" do
before do
stub_feature_flags(tanuki_bot_parallel: parallel_bot)
end
describe 'getting matching documents' do
before do
allow(openai_client).to receive(:completions).and_return(completion_response)
end
it 'creates an embedding for the question' do
expect(openai_client).to receive(:embeddings).with(input: question).and_return(embedding_response)
execute
end
it 'queries the embedding database for nearest neighbors' do
allow(openai_client).to receive(:embeddings).with(input: question).and_return(embedding_response)
result = execute
expect(result[:msg]).to eq(answer)
expect(result[:sources].count).to eq(2)
expected_sources = ::Embedding::TanukiBot.pluck(:metadata).pluck('source')
expected_source_urls = ::Embedding::TanukiBot.pluck(:url)
expect(result[:sources].pluck('source')).to match_array(expected_sources)
expect(result[:sources].pluck('source_url')).to match_array(expected_source_urls)
end
end
describe 'checking documents for relevance and summarizing' do
before do
allow(openai_client).to receive(:embeddings).and_return(embedding_response)
end
it 'calls the completions API once for each document and once for summarizing' do
expect(openai_client).to receive(:completions)
.with(hash_including(prompt: /see if any of the text is relevant to answer the question/))
.and_return(completion_response).twice
expect(openai_client).to receive(:completions)
.with(hash_including(prompt: /create a final answer with references/))
.and_return(completion_response).once
execute
end
end
end
end
end
end
end
end
......@@ -1480,6 +1480,23 @@
expect(user.has_paid_namespace?).to eq(false)
end
end
context 'when passed a subset of plans' do
it 'returns true', :aggregate_failures do
bronze_group.add_reporter(user)
expect(user.has_paid_namespace?(plans: [::Plan::BRONZE])).to eq(true)
expect(user.has_paid_namespace?(plans: [::Plan::ULTIMATE])).to eq(false)
end
end
context 'when passed a non-paid plan' do
it 'returns false' do
free_group.add_owner(user)
expect(user.has_paid_namespace?(plans: [::Plan::ULTIMATE, ::Plan::FREE])).to eq(false)
end
end
end
context 'when passed a plan' do
......
......@@ -21725,6 +21725,9 @@ msgstr ""
msgid "I accept the %{terms_link}"
msgstr ""
 
msgid "I do not know."
msgstr ""
msgid "I forgot my password"
msgstr ""
 
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment