Skip to content
Snippets Groups Projects
Commit 58f4792e authored by Terri Chu's avatar Terri Chu :nail_care:
Browse files

Merge branch 'tchu-bot-create-new-api' into 'master'

Draft: Create initial Tanuki bot api endpoint

See merge request gitlab-org/gitlab!117695



Merged-by: Terri Chu's avatarTerri Chu <tchu@gitlab.com>
Co-authored-by: default avatarDmitry Gruzd <dgruzd@gitlab.com>
Co-authored-by: default avatarMadelein van Niekerk <mvanniekerk@gitlab.com>
parents 143b6f53 6effe0ec
No related branches found
No related tags found
No related merge requests found
---
name: tanuki_bot
introduced_by_url: https://gitlab.com/gitlab-org/gitlab/-/merge_requests/117585
rollout_issue_url: https://gitlab.com/gitlab-org/gitlab/-/issues/407337
milestone: '15.11'
type: development
group: group::global search
default_enabled: false
......@@ -171,6 +171,7 @@
draw :country
draw :country_state
draw :subscription
draw :llm
scope '/push_from_secondary/:geo_node_id' do
draw :git_http
......
# frozen_string_literal: true
module Llm
class TanukiBotController < ApplicationController
skip_before_action :verify_authenticity_token
wrap_parameters format: []
feature_category :global_search
def ask
respond_to do |format|
format.json { render json: service.execute, status: :ok }
format.any { head :bad_request }
end
end
private
def service
::Gitlab::Llm::TanukiBot.new(current_user: current_user, question: params.require(:q))
end
end
end
......@@ -312,11 +312,12 @@ def owns_group_without_trial?
# Returns true if the user is a Reporter or higher on any namespace
# currently on a paid plan
def has_paid_namespace?
def has_paid_namespace?(plans: ::Plan::PAID_HOSTED_PLANS)
paid_hosted_plans = ::Plan::PAID_HOSTED_PLANS & plans
::Namespace
.from("(#{namespace_union_for_reporter_developer_maintainer_owned}) #{::Namespace.table_name}")
.include_gitlab_subscription
.where(gitlab_subscriptions: { hosted_plan: ::Plan.where(name: ::Plan::PAID_HOSTED_PLANS) })
.where(gitlab_subscriptions: { hosted_plan: ::Plan.where(name: paid_hosted_plans) })
.any?
end
......
......@@ -247,6 +247,7 @@ class Features
unique_project_download_limit
vulnerability_auto_fix
vulnerability_finding_signatures
tanuki_bot
].freeze
STARTER_FEATURES_WITH_USAGE_PING = %i[
......
# frozen_string_literal: true
namespace :llm do
post 'tanuki_bot/ask' => 'tanuki_bot#ask', as: :tanuki_bot_ask, constraints: { format: :json }
end
# frozen_string_literal: true
module Gitlab
module Llm
class TanukiBot
include ::Gitlab::Loggable
DEFAULT_OPTIONS = {
max_tokens: 256,
top_p: 1,
n: 1,
best_of: 1
}.freeze
CONTENT_ID_FIELD = 'ATTRS'
CONTENT_ID_REGEX = /CNT-IDX-(?<id>\d+)/
EMBEDDINGS_LIMIT = Rails.env.production? ? 7 : 2
# Example usage (dev)
# execute(current_user: User.first, question: 'What is advanced search?', logger: Logger.new($stdout))
def self.execute(current_user:, question:, logger: nil)
new(current_user: current_user, question: question, logger: logger).execute
end
def initialize(current_user:, question:, logger: nil)
@current_user = current_user
@question = question
@logger = logger
end
def execute
return {} unless question.present?
return {} unless ::License.feature_available?(:tanuki_bot)
return {} if ::Gitlab.com? && !current_user&.has_paid_namespace?(plans: [::Plan::ULTIMATE])
return {} unless Feature.enabled?(:openai_experimentation) && Feature.enabled?(:tanuki_bot, current_user)
search_documents = query_search_documents
result = get_completions(search_documents)
build_response(result, search_documents)
end
private
attr_reader :current_user, :question, :logger
def client
@client ||= ::Gitlab::Llm::OpenAi::Client.new(current_user)
end
def get_completions(search_documents)
documents = search_documents.map do |doc|
initial_prompt = <<~PROMPT.strip
Use the following portion of a long document to see if any of the text is relevant to answer the question.
Return any relevant text verbatim.
#{doc[:content]}
Question: #{question}
Relevant text, if any:
PROMPT
result = client.completions(prompt: initial_prompt, **DEFAULT_OPTIONS)
debug(
document: doc[:id],
prompt: initial_prompt,
result: result.parsed_response,
message: 'Initial prompt request'
)
raise result.dig('error', 'message') if result['error']
doc.merge(extracted_text: result['choices'].first['text'])
end
content = build_content(documents)
final_prompt = <<~PROMPT.strip
Given the following extracted parts of a long document and a question,
create a final answer with references "#{CONTENT_ID_FIELD}".
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
At the end of your answer ALWAYS return a "#{CONTENT_ID_FIELD}" part and
ALWAYS name it #{CONTENT_ID_FIELD}.
QUESTION: #{question}
=========
#{content}
=========
FINAL ANSWER:
PROMPT
final_prompt_result = client.completions(prompt: final_prompt, **DEFAULT_OPTIONS)
debug(prompt: final_prompt, result: final_prompt_result.parsed_response, message: 'Final prompt request')
final_prompt_result
end
# TODO - this method will be re-written when the database is in place to use ActiveRecord
def query_search_documents
embeddings_result = client.embeddings(input: question)
question_embeddings = "[#{embeddings_result['data'].first['embedding'].join(',')}]"
# the database must be setup manually to test locally
# the SQL below can be substituted in place of the embeddings
# SELECT id, content, metadata FROM tanuki_bot WHERE ID IN (25776, 24193, 8748, 19869, 25783, 4218, 35058);
conn = PG.connect(dbname: 'tanuki_bot', host: 'localhost', port: 5433, user: 'postgres', password: 'password')
sql_input = <<~SQL
SELECT id, content, metadata
FROM tanuki_bot
ORDER BY embedding <-> '#{question_embeddings}'
LIMIT #{EMBEDDINGS_LIMIT};
SQL
db_results_tmp = conn.exec(sql_input).to_a
db_results_tmp.map do |x|
metadata = ::Gitlab::Json.parse(x['metadata'])
metadata['source_url'] = case metadata['source_type']
when 'doc'
page = metadata['source'].gsub('doc/', '').gsub('.md', '')
::Gitlab::Routing.url_helpers.help_page_url(page)
when 'handbook'
page = metadata['source'].gsub('handbook/', '').gsub(/.(md|erb)/, '')
"https://about.gitlab.com/#{page}"
else # blog
"https://about.gitlab.com/#{metadata['source'].gsub(/.(md|erb)/, '')}"
end
{
id: x['id'].to_i,
content: x['content'],
metadata: metadata
}
end
end
def build_content(search_documents)
search_documents.map do |document|
<<~PROMPT.strip
CONTENT: #{document[:extracted_text]}
#{CONTENT_ID_FIELD}: CNT-IDX-#{document[:id]}
PROMPT
end.join("\n\n")
end
def build_response(result, search_documents)
output = result['choices'][0]['text'].split("#{CONTENT_ID_FIELD}:")
msg = output[0].strip
content_idx = output[1].scan(CONTENT_ID_REGEX).flatten.map(&:to_i)
documents = search_documents.filter { |doc| content_idx.include?(doc[:id]) }
sources = documents.pluck(:metadata) # rubocop:disable CodeReuse/ActiveRecord
{
msg: msg,
sources: sources
}
end
def debug(payload)
return unless logger
logger.debug(build_structured_payload(**payload))
end
end
end
end
# frozen_string_literal: true
require 'spec_helper'
RSpec.describe Llm::TanukiBotController, feature_category: :global_search do
describe 'GET #ask' do
let(:question) { 'Some question' }
subject { post :ask, params: { q: question }, format: :json }
before do
allow(Gitlab::Llm::TanukiBot).to receive_message_chain(:new, :execute).and_return({})
end
it 'responds with a 401' do
subject
expect(response).to have_gitlab_http_status(:unauthorized)
end
context 'when the user is signed in' do
before do
sign_in(create(:user))
end
it 'responds with :bad_request if the request is not json' do
post :ask, params: { q: question }
expect(response).to have_gitlab_http_status(:bad_request)
end
it 'responds with :ok' do
subject
expect(response).to have_gitlab_http_status(:ok)
end
it 'calls TanukiBot service' do
expect(Gitlab::Llm::TanukiBot).to receive_message_chain(:new, :execute)
subject
end
context 'when question is not provided' do
let(:question) { nil }
it 'raises an error' do
expect { subject }.to raise_error(ActionController::ParameterMissing)
end
end
end
end
end
# frozen_string_literal: true
require 'spec_helper'
RSpec.describe Gitlab::Llm::TanukiBot, feature_category: :global_search do
describe '#execute' do
let_it_be(:user) { create(:user) }
let(:question) { 'A question' }
let(:answer) { 'The answer.' }
let(:instance) { described_class.new(current_user: user, question: question, logger: nil) }
let(:openai_client) { ::Gitlab::Llm::OpenAi::Client.new(user) }
let(:embedding_response) { { "data" => [{ "embedding" => Array.new(1536, 0.5) }] } }
let(:completion_response) do
{ "choices" => [{ "text" => "#{answer} ATTRS: CNT-IDX-111, CNT-IDX-222, CNT-IDX-333" }] }
end
let(:connection) { instance_double(PG::Connection) }
let(:database_response) do
[
{
"id" => "111",
"content" => "Documentation content",
"metadata" => "{\"source\": \"doc/folder/index.md\", \"source_type\": \"doc\"}"
},
{
"id" => "222",
"content" => "Blog content",
"metadata" => "{\"source\": \"blog/folder/index.html.md.erb\", \"source_type\": \"blog\"}"
},
{
"id" => "333",
"content" => "Handbook content",
"metadata" => "{\"source\": \"handbook/folder/index.html.md\", \"source_type\": \"handbook\"}"
}
]
end
subject(:execute) { instance.execute }
before do
allow(License).to receive(:feature_available?).and_return(true)
end
context 'with the tanuki_bot license not available' do
before do
allow(License).to receive(:feature_available?).with(:tanuki_bot).and_return(false)
end
it 'returns an empty hash' do
expect(execute).to eq({})
end
end
context 'with the tanuki_bot license available' do
context 'when on Gitlab.com' do
before do
allow(::Gitlab).to receive(:com?).and_return(true)
end
context 'when no user is provided' do
let(:user) { nil }
it 'returns an empty hash' do
expect(execute).to eq({})
end
end
context 'when the user does not have a paid namespace' do
before do
allow(user).to receive(:has_paid_namespace?).and_return(false)
end
it 'returns an empty hash' do
expect(execute).to eq({})
end
end
context 'when the user has a paid namespace' do
before do
allow(::Gitlab::Llm::OpenAi::Client).to receive(:new).and_return(openai_client)
allow(user).to receive(:has_paid_namespace?).and_return(true)
end
it 'executes calls through to open ai' do
allow(PG).to receive(:connect).and_return(connection)
allow(connection).to receive(:exec).and_return(database_response)
expect(openai_client).to receive(:completions).exactly(4).times.and_return(completion_response)
expect(openai_client).to receive(:embeddings).and_return(embedding_response)
allow(completion_response).to receive(:parsed_response).and_return(completion_response)
execute
end
end
end
context 'when the feature flags are disabled' do
using RSpec::Parameterized::TableSyntax
where(:openai_experimentation, :tanuki_bot) do
true | false
false | true
false | false
end
with_them do
before do
stub_feature_flags(openai_experimentation: openai_experimentation)
stub_feature_flags(tanuki_bot: tanuki_bot)
end
it 'returns an empty hash' do
expect(execute).to eq({})
end
end
end
context 'when the feature flags are enabled' do
before do
allow(completion_response).to receive(:parsed_response).and_return(completion_response)
allow(::Gitlab::Llm::OpenAi::Client).to receive(:new).and_return(openai_client)
end
context 'when the question is not provided' do
let(:question) { nil }
it 'returns an empty hash' do
expect(execute).to eq({})
end
end
describe 'getting matching documents' do
before do
allow(openai_client).to receive(:completions).and_return(completion_response)
end
it 'creates an embedding for the question' do
expect(openai_client).to receive(:embeddings).with(input: question).and_return(embedding_response)
allow(PG).to receive(:connect).and_return(connection)
allow(connection).to receive(:exec).and_return(database_response)
execute
end
it 'queries the database for nearest neighbor embeddings' do
expect(openai_client).to receive(:embeddings).with(input: question).and_return(embedding_response)
expect(PG).to receive(:connect).and_return(connection)
expect(connection).to receive(:exec)
.with(/embedding <-> '\[#{Array.new(1536, 0.5).join(",")}\]'/)
.and_return(database_response)
execute
end
end
describe 'checking documents for relevance and summarizing' do
before do
allow(openai_client).to receive(:embeddings).and_return(embedding_response)
allow(PG).to receive(:connect).and_return(connection)
allow(connection).to receive(:exec).and_return(database_response)
end
# NOTE: my local database matches on 2 documents and will have to be set up when moving to ActiveRecord
it 'calls the completions API once for each document and once for summarizing' do
expect(openai_client).to receive(:completions)
.with(hash_including(prompt: /see if any of the text is relevant to answer the question/))
.and_return(completion_response).exactly(3).times
expect(openai_client).to receive(:completions)
.with(hash_including(prompt: /create a final answer with references/))
.and_return(completion_response).once
execute
end
end
describe 'building the response' do
before do
allow(openai_client).to receive(:embeddings).and_return(embedding_response).once
allow(openai_client).to receive(:completions).and_return(completion_response).exactly(4).times
end
it 'returns the message and sources' do
expect(PG).to receive(:connect).and_return(connection)
expect(connection).to receive(:exec)
.with(/embedding <-> '\[#{Array.new(1536, 0.5).join(",")}\]'/)
.and_return(database_response)
expected_response = {
msg: answer,
sources: [
{
"source" => "doc/folder/index.md",
"source_type" => "doc",
"source_url" => "http://localhost/help/folder/index"
},
{
"source" => "blog/folder/index.html.md.erb",
"source_type" => "blog",
"source_url" => "https://about.gitlab.com/blog/folder/index.html"
},
{
"source" => "handbook/folder/index.html.md",
"source_type" => "handbook",
"source_url" => "https://about.gitlab.com/folder/index.html"
}
]
}
expect(execute).to eq(expected_response)
end
end
end
end
end
end
......@@ -1480,6 +1480,23 @@
expect(user.has_paid_namespace?).to eq(false)
end
end
context 'when passed a subset of plans' do
it 'returns true', :aggregate_failures do
bronze_group.add_reporter(user)
expect(user.has_paid_namespace?(plans: [::Plan::BRONZE])).to eq(true)
expect(user.has_paid_namespace?(plans: [::Plan::ULTIMATE])).to eq(false)
end
end
context 'when passed a non-paid plan' do
it 'returns false' do
free_group.add_owner(user)
expect(user.has_paid_namespace?(plans: [::Plan::ULTIMATE, ::Plan::FREE])).to eq(false)
end
end
end
describe '#owns_paid_namespace?', :saas do
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment