From 01555c4fee7e8712fb07675019bde7c88daa0e0f Mon Sep 17 00:00:00 2001 From: Terri Chu <tchu@gitlab.com> Date: Thu, 13 Apr 2023 13:11:51 -0400 Subject: [PATCH 1/3] Add Tanuki Bot backend service and API Changelog: added EE: true --- config/routes.rb | 1 + .../controllers/llm/tanuki_bot_controller.rb | 21 ++ .../development/tanuki_bot_parallel.yml | 8 + ee/config/routes/llm.rb | 5 + ee/lib/gitlab/llm/tanuki_bot.rb | 179 +++++++++++++++++ .../llm/tanuki_bot_controller_spec.rb | 53 +++++ .../{tanuki_bot.rb => tanuki_bots.rb} | 6 +- ee/spec/lib/gitlab/llm/tanuki_bot_spec.rb | 184 ++++++++++++++++++ ee/spec/models/ee/user_spec.rb | 17 ++ locale/gitlab.pot | 3 + 10 files changed, 474 insertions(+), 3 deletions(-) create mode 100644 ee/app/controllers/llm/tanuki_bot_controller.rb create mode 100644 ee/config/feature_flags/development/tanuki_bot_parallel.yml create mode 100644 ee/config/routes/llm.rb create mode 100644 ee/lib/gitlab/llm/tanuki_bot.rb create mode 100644 ee/spec/controllers/llm/tanuki_bot_controller_spec.rb rename ee/spec/factories/embedding/{tanuki_bot.rb => tanuki_bots.rb} (72%) create mode 100644 ee/spec/lib/gitlab/llm/tanuki_bot_spec.rb diff --git a/config/routes.rb b/config/routes.rb index ebb0984a008ed933..dfca81c00cbcdeef 100644 --- a/config/routes.rb +++ b/config/routes.rb @@ -171,6 +171,7 @@ draw :country draw :country_state draw :subscription + draw :llm scope '/push_from_secondary/:geo_node_id' do draw :git_http diff --git a/ee/app/controllers/llm/tanuki_bot_controller.rb b/ee/app/controllers/llm/tanuki_bot_controller.rb new file mode 100644 index 0000000000000000..0a9d0d50ecc02c6f --- /dev/null +++ b/ee/app/controllers/llm/tanuki_bot_controller.rb @@ -0,0 +1,21 @@ +# frozen_string_literal: true + +module Llm + class TanukiBotController < ApplicationController + wrap_parameters format: [] + feature_category :global_search + + def ask + respond_to do |format| + format.json { render json: generate_response, status: :ok } + format.any { head :bad_request } + end + end + + private + + def generate_response + ::Gitlab::Llm::TanukiBot.execute(current_user: current_user, question: params.require(:q)) + end + end +end diff --git a/ee/config/feature_flags/development/tanuki_bot_parallel.yml b/ee/config/feature_flags/development/tanuki_bot_parallel.yml new file mode 100644 index 0000000000000000..78c5915f0ccfee77 --- /dev/null +++ b/ee/config/feature_flags/development/tanuki_bot_parallel.yml @@ -0,0 +1,8 @@ +--- +name: tanuki_bot_parallel +introduced_by_url: https://gitlab.com/gitlab-org/gitlab/-/merge_requests/117695 +rollout_issue_url: https://gitlab.com/gitlab-org/gitlab/-/issues/407555 +milestone: '16.0' +type: development +group: group::global search +default_enabled: true diff --git a/ee/config/routes/llm.rb b/ee/config/routes/llm.rb new file mode 100644 index 0000000000000000..8627b1838e3cc6ef --- /dev/null +++ b/ee/config/routes/llm.rb @@ -0,0 +1,5 @@ +# frozen_string_literal: true + +namespace :llm do + post 'tanuki_bot/ask' => 'tanuki_bot#ask', as: :tanuki_bot_ask, constraints: { format: :json } +end diff --git a/ee/lib/gitlab/llm/tanuki_bot.rb b/ee/lib/gitlab/llm/tanuki_bot.rb new file mode 100644 index 0000000000000000..532dce64b8271312 --- /dev/null +++ b/ee/lib/gitlab/llm/tanuki_bot.rb @@ -0,0 +1,179 @@ +# frozen_string_literal: true + +module Gitlab + module Llm + class TanukiBot + include ::Gitlab::Loggable + + DEFAULT_OPTIONS = { + max_tokens: 256, + top_p: 1, + n: 1, + best_of: 1 + }.freeze + CONTENT_ID_FIELD = 'ATTRS' + CONTENT_ID_REGEX = /CNT-IDX-(?<id>\d+)/ + RECORD_LIMIT = Rails.env.production? ? 7 : 2 + + def self.execute(current_user:, question:, logger: nil) + new(current_user: current_user, question: question, logger: logger).execute + end + + def initialize(current_user:, question:, logger: nil) + @current_user = current_user + @question = question + @logger = logger + end + + def execute + return {} unless question.present? + return {} unless ::License.feature_available?(:ai_tanuki_bot) + return {} if ::Gitlab.com? && !current_user&.has_paid_namespace?(plans: [::Plan::ULTIMATE]) + return {} unless Feature.enabled?(:openai_experimentation) && Feature.enabled?(:tanuki_bot, current_user) + + search_documents = query_search_documents + return empty_response if search_documents.empty? + + result = get_completions(search_documents) + + build_response(result, search_documents) + end + + private + + attr_reader :current_user, :question, :logger + + def client + @client ||= ::Gitlab::Llm::OpenAi::Client.new(current_user) + end + + def build_initial_prompts(search_documents) + search_documents.to_h do |doc| + prompt = <<~PROMPT.strip + Use the following portion of a long document to see if any of the text is relevant to answer the question. + Return any relevant text verbatim. + #{doc[:content]} + Question: #{question} + Relevant text, if any: + PROMPT + + [doc, prompt] + end + end + + def send_initial_prompt(doc:, prompt:) + result = client.completions(prompt: prompt, **DEFAULT_OPTIONS) + + debug( + document: doc[:id], + prompt: prompt, + result: result.parsed_response, + message: 'Initial prompt request' + ) + + raise result.dig('error', 'message') if result['error'] + + doc.merge(extracted_text: result['choices'].first['text']) + end + + def sequential_competion(search_documents) + prompts = build_initial_prompts(search_documents) + + prompts.map do |doc, prompt| + send_initial_prompt(doc: doc, prompt: prompt) + end + end + + def parallel_completion(search_documents) + prompts = build_initial_prompts(search_documents) + + threads = prompts.map do |doc, prompt| + Thread.new do + send_initial_prompt(doc: doc, prompt: prompt) + end + end + + threads.map(&:value) + end + + def get_completions(search_documents) + documents = if Feature.enabled?(:tanuki_bot_parallel) + parallel_completion(search_documents) + else + sequential_competion(search_documents) + end + + content = build_content(documents) + final_prompt = <<~PROMPT.strip + Given the following extracted parts of a long document and a question, + create a final answer with references "#{CONTENT_ID_FIELD}". + If you don't know the answer, just say that you don't know. Don't try to make up an answer. + At the end of your answer ALWAYS return a "#{CONTENT_ID_FIELD}" part and + ALWAYS name it #{CONTENT_ID_FIELD}. + + QUESTION: #{question} + ========= + #{content} + ========= + FINAL ANSWER: + PROMPT + + final_prompt_result = client.completions(prompt: final_prompt, **DEFAULT_OPTIONS) + debug(prompt: final_prompt, result: final_prompt_result.parsed_response, message: 'Final prompt request') + + final_prompt_result + end + + def query_search_documents + embeddings_result = client.embeddings(input: question) + question_embedding = embeddings_result['data'].first['embedding'] + + nearest_neighbors = ::Embedding::TanukiBot.neighbor_for(question_embedding).limit(RECORD_LIMIT) + nearest_neighbors.map do |item| + item.metadata['source_url'] = item.url + + { + id: item.id, + content: item.content, + metadata: item.metadata + } + end + end + + def build_content(search_documents) + search_documents.map do |document| + <<~PROMPT.strip + CONTENT: #{document[:extracted_text]} + #{CONTENT_ID_FIELD}: CNT-IDX-#{document[:id]} + PROMPT + end.join("\n\n") + end + + def build_response(result, search_documents) + output = result['choices'][0]['text'].split("#{CONTENT_ID_FIELD}:") + msg = output[0].strip + content_idx = output[1].scan(CONTENT_ID_REGEX).flatten.map(&:to_i) + documents = search_documents.filter { |doc| content_idx.include?(doc[:id]) } + sources = documents.pluck(:metadata).uniq # rubocop:disable CodeReuse/ActiveRecord + + { + msg: msg, + sources: sources + } + end + + def empty_response + { + msg: _("I do not know."), # TODO namespace this? + sources: [] + } + end + + def debug(payload) + return unless logger + + logger.debug(build_structured_payload(**payload)) + end + end + end +end diff --git a/ee/spec/controllers/llm/tanuki_bot_controller_spec.rb b/ee/spec/controllers/llm/tanuki_bot_controller_spec.rb new file mode 100644 index 0000000000000000..81ef6d0912f68c2e --- /dev/null +++ b/ee/spec/controllers/llm/tanuki_bot_controller_spec.rb @@ -0,0 +1,53 @@ +# frozen_string_literal: true + +require 'spec_helper' + +RSpec.describe Llm::TanukiBotController, feature_category: :global_search do + describe 'GET #ask' do + let(:question) { 'Some question' } + + subject { post :ask, params: { q: question }, format: :json } + + before do + allow(Gitlab::Llm::TanukiBot).to receive_message_chain(:new, :execute).and_return({}) + end + + it 'responds with a 401' do + subject + + expect(response).to have_gitlab_http_status(:unauthorized) + end + + context 'when the user is signed in' do + before do + sign_in(create(:user)) + end + + it 'responds with :bad_request if the request is not json' do + post :ask, params: { q: question } + + expect(response).to have_gitlab_http_status(:bad_request) + end + + it 'responds with :ok' do + subject + + expect(response).to have_gitlab_http_status(:ok) + end + + it 'calls TanukiBot service' do + expect(Gitlab::Llm::TanukiBot).to receive_message_chain(:new, :execute) + + subject + end + + context 'when question is not provided' do + let(:question) { nil } + + it 'raises an error' do + expect { subject }.to raise_error(ActionController::ParameterMissing) + end + end + end + end +end diff --git a/ee/spec/factories/embedding/tanuki_bot.rb b/ee/spec/factories/embedding/tanuki_bots.rb similarity index 72% rename from ee/spec/factories/embedding/tanuki_bot.rb rename to ee/spec/factories/embedding/tanuki_bots.rb index 0f3fe6a3d83e2061..26e81d42974c7005 100644 --- a/ee/spec/factories/embedding/tanuki_bot.rb +++ b/ee/spec/factories/embedding/tanuki_bots.rb @@ -4,10 +4,10 @@ factory :tanuki_bot_mvc, class: 'Embedding::TanukiBotMvc' do url { 'http://example.com/path/to/a/doc' } - metadata do + sequence(:metadata) do |n| { - info: 'A description', - source: 'path/to/a/doc.md', + info: "Description for #{n}", + source: "path/to/a/doc_#{n}.md", source_type: 'doc' } end diff --git a/ee/spec/lib/gitlab/llm/tanuki_bot_spec.rb b/ee/spec/lib/gitlab/llm/tanuki_bot_spec.rb new file mode 100644 index 0000000000000000..f665a54b23a80772 --- /dev/null +++ b/ee/spec/lib/gitlab/llm/tanuki_bot_spec.rb @@ -0,0 +1,184 @@ +# frozen_string_literal: true + +require 'spec_helper' + +RSpec.describe Gitlab::Llm::TanukiBot, feature_category: :global_search do + describe '#execute' do + let_it_be(:user) { create(:user) } + let_it_be(:embeddings) { create_list(:tanuki_bot, 2) } + + let(:question) { 'A question' } + let(:answer) { 'The answer.' } + let(:logger) { instance_double('Logger') } + let(:instance) { described_class.new(current_user: user, question: question, logger: logger) } + let(:openai_client) { ::Gitlab::Llm::OpenAi::Client.new(user) } + let(:embedding_response) { { "data" => [{ "embedding" => Array.new(1536, 0.5) }] } } + let(:attrs) { embeddings.map(&:id).map { |x| "CNT-IDX-#{x}" }.join(", ") } + let(:completion_response) do + { "choices" => [{ "text" => "#{answer} ATTRS: #{attrs}" }] } + end + + subject(:execute) { instance.execute } + + before do + allow(License).to receive(:feature_available?).and_return(true) + allow(logger).to receive(:debug) + end + + context 'with the ai_tanuki_bot license not available' do + before do + allow(License).to receive(:feature_available?).with(:ai_tanuki_bot).and_return(false) + end + + it 'returns an empty hash' do + expect(execute).to eq({}) + end + end + + context 'with the tanuki_bot license available' do + context 'when on Gitlab.com' do + before do + allow(::Gitlab).to receive(:com?).and_return(true) + end + + context 'when no user is provided' do + let(:user) { nil } + + it 'returns an empty hash' do + expect(execute).to eq({}) + end + end + + context 'when the user does not have a paid namespace' do + before do + allow(user).to receive(:has_paid_namespace?).and_return(false) + end + + it 'returns an empty hash' do + expect(execute).to eq({}) + end + end + + context 'when the user has a paid namespace' do + before do + allow(::Gitlab::Llm::OpenAi::Client).to receive(:new).and_return(openai_client) + allow(user).to receive(:has_paid_namespace?).and_return(true) + end + + it 'executes calls through to open ai' do + create_list(:tanuki_bot, 2) + + expect(openai_client).to receive(:completions).exactly(3).times.and_return(completion_response) + expect(openai_client).to receive(:embeddings).and_return(embedding_response) + allow(completion_response).to receive(:parsed_response).and_return(completion_response) + + execute + end + end + end + + context 'when the feature flags are disabled' do + using RSpec::Parameterized::TableSyntax + + where(:openai_experimentation, :tanuki_bot) do + true | false + false | true + false | false + end + + with_them do + before do + stub_feature_flags(openai_experimentation: openai_experimentation) + stub_feature_flags(tanuki_bot: tanuki_bot) + end + + it 'returns an empty hash' do + expect(execute).to eq({}) + end + end + end + + context 'when the feature flags are enabled' do + before do + allow(completion_response).to receive(:parsed_response).and_return(completion_response) + allow(::Gitlab::Llm::OpenAi::Client).to receive(:new).and_return(openai_client) + end + + context 'when the question is not provided' do + let(:question) { nil } + + it 'returns an empty hash' do + expect(execute).to eq({}) + end + end + + context 'when no neighbors are found' do + before do + allow(Embedding::TanukiBot).to receive(:neighbor_for).and_return(Embedding::TanukiBot.none) + allow(openai_client).to receive(:embeddings).with(input: question).and_return(embedding_response) + end + + it 'returns an i do not know' do + expect(execute).to eq({ + msg: 'I do not know.', + sources: [] + }) + end + end + + [true, false].each do |parallel_bot| + context "with tanuki_bot_parallel set to #{parallel_bot}" do + before do + stub_feature_flags(tanuki_bot_parallel: parallel_bot) + end + + describe 'getting matching documents' do + before do + allow(openai_client).to receive(:completions).and_return(completion_response) + end + + it 'creates an embedding for the question' do + expect(openai_client).to receive(:embeddings).with(input: question).and_return(embedding_response) + + execute + end + + it 'queries the embedding database for nearest neighbors' do + allow(openai_client).to receive(:embeddings).with(input: question).and_return(embedding_response) + + result = execute + + expect(result[:msg]).to eq(answer) + expect(result[:sources].count).to eq(2) + + expected_sources = ::Embedding::TanukiBot.pluck(:metadata).pluck('source') + expected_source_urls = ::Embedding::TanukiBot.pluck(:url) + + expect(result[:sources].pluck('source')).to match_array(expected_sources) + expect(result[:sources].pluck('source_url')).to match_array(expected_source_urls) + end + end + + describe 'checking documents for relevance and summarizing' do + before do + allow(openai_client).to receive(:embeddings).and_return(embedding_response) + end + + it 'calls the completions API once for each document and once for summarizing' do + expect(openai_client).to receive(:completions) + .with(hash_including(prompt: /see if any of the text is relevant to answer the question/)) + .and_return(completion_response).twice + + expect(openai_client).to receive(:completions) + .with(hash_including(prompt: /create a final answer with references/)) + .and_return(completion_response).once + + execute + end + end + end + end + end + end + end +end diff --git a/ee/spec/models/ee/user_spec.rb b/ee/spec/models/ee/user_spec.rb index ec1480176c4d06ef..32040529dc266775 100644 --- a/ee/spec/models/ee/user_spec.rb +++ b/ee/spec/models/ee/user_spec.rb @@ -1480,6 +1480,23 @@ expect(user.has_paid_namespace?).to eq(false) end end + + context 'when passed a subset of plans' do + it 'returns true', :aggregate_failures do + bronze_group.add_reporter(user) + + expect(user.has_paid_namespace?(plans: [::Plan::BRONZE])).to eq(true) + expect(user.has_paid_namespace?(plans: [::Plan::ULTIMATE])).to eq(false) + end + end + + context 'when passed a non-paid plan' do + it 'returns false' do + free_group.add_owner(user) + + expect(user.has_paid_namespace?(plans: [::Plan::ULTIMATE, ::Plan::FREE])).to eq(false) + end + end end context 'when passed a plan' do diff --git a/locale/gitlab.pot b/locale/gitlab.pot index 62e14eebc9a02a69..a0361e04a1185b22 100644 --- a/locale/gitlab.pot +++ b/locale/gitlab.pot @@ -21737,6 +21737,9 @@ msgstr "" msgid "I accept the %{terms_link}" msgstr "" +msgid "I do not know." +msgstr "" + msgid "I forgot my password" msgstr "" -- GitLab From 39c5377f50bfea2d044c92d3df9d288959e4d905 Mon Sep 17 00:00:00 2001 From: Dmitry Gruzd <dgruzd@gitlab.com> Date: Tue, 25 Apr 2023 20:48:22 +0200 Subject: [PATCH 2/3] Rebase changes and address reviewer's feedback --- .../controllers/llm/tanuki_bot_controller.rb | 8 +++ ee/lib/gitlab/llm/open_ai/client.rb | 7 +-- ee/lib/gitlab/llm/tanuki_bot.rb | 51 +++++++++++++------ .../llm/tanuki_bot_controller_spec.rb | 13 +++++ ee/spec/lib/gitlab/llm/tanuki_bot_spec.rb | 20 ++++---- 5 files changed, 71 insertions(+), 28 deletions(-) diff --git a/ee/app/controllers/llm/tanuki_bot_controller.rb b/ee/app/controllers/llm/tanuki_bot_controller.rb index 0a9d0d50ecc02c6f..ac5db552c78ee594 100644 --- a/ee/app/controllers/llm/tanuki_bot_controller.rb +++ b/ee/app/controllers/llm/tanuki_bot_controller.rb @@ -5,6 +5,8 @@ class TanukiBotController < ApplicationController wrap_parameters format: [] feature_category :global_search + before_action :verify_tanuki_bot_enabled + def ask respond_to do |format| format.json { render json: generate_response, status: :ok } @@ -14,6 +16,12 @@ def ask private + def verify_tanuki_bot_enabled + return if ::Gitlab::Llm::TanukiBot.enabled_for?(user: current_user) + + head :unauthorized + end + def generate_response ::Gitlab::Llm::TanukiBot.execute(current_user: current_user, question: params.require(:q)) end diff --git a/ee/lib/gitlab/llm/open_ai/client.rb b/ee/lib/gitlab/llm/open_ai/client.rb index 207ec3918b01b1c8..c926b9c29bab017b 100644 --- a/ee/lib/gitlab/llm/open_ai/client.rb +++ b/ee/lib/gitlab/llm/open_ai/client.rb @@ -18,8 +18,9 @@ class Client include ExponentialBackoff - def initialize(user) + def initialize(user, request_timeout: nil) @user = user + @request_timeout = request_timeout end def chat(content:, **options) @@ -73,10 +74,10 @@ def embeddings(input:, **options) retry_methods_with_exponential_backoff :chat, :completions, :edits, :embeddings - attr_reader :user + attr_reader :user, :request_timeout def client - @client ||= OpenAI::Client.new(access_token: access_token) + @client ||= OpenAI::Client.new(access_token: access_token, request_timeout: request_timeout) end def enabled? diff --git a/ee/lib/gitlab/llm/tanuki_bot.rb b/ee/lib/gitlab/llm/tanuki_bot.rb index 532dce64b8271312..6e562e1634d69ee5 100644 --- a/ee/lib/gitlab/llm/tanuki_bot.rb +++ b/ee/lib/gitlab/llm/tanuki_bot.rb @@ -11,25 +11,33 @@ class TanukiBot n: 1, best_of: 1 }.freeze + REQUEST_TIMEOUT = 30 CONTENT_ID_FIELD = 'ATTRS' CONTENT_ID_REGEX = /CNT-IDX-(?<id>\d+)/ - RECORD_LIMIT = Rails.env.production? ? 7 : 2 + RECORD_LIMIT = 7 def self.execute(current_user:, question:, logger: nil) new(current_user: current_user, question: question, logger: logger).execute end + def self.enabled_for?(user:) + return false unless user + return false unless ::License.feature_available?(:ai_tanuki_bot) + return false unless Feature.enabled?(:openai_experimentation) && Feature.enabled?(:tanuki_bot, user) + return false if ::Gitlab.com? && !user&.has_paid_namespace?(plans: [::Plan::ULTIMATE]) + + true + end + def initialize(current_user:, question:, logger: nil) @current_user = current_user @question = question - @logger = logger + @logger = logger || Gitlab::AppJsonLogger.build end def execute return {} unless question.present? - return {} unless ::License.feature_available?(:ai_tanuki_bot) - return {} if ::Gitlab.com? && !current_user&.has_paid_namespace?(plans: [::Plan::ULTIMATE]) - return {} unless Feature.enabled?(:openai_experimentation) && Feature.enabled?(:tanuki_bot, current_user) + return {} unless self.class.enabled_for?(user: current_user) search_documents = query_search_documents return empty_response if search_documents.empty? @@ -44,7 +52,7 @@ def execute attr_reader :current_user, :question, :logger def client - @client ||= ::Gitlab::Llm::OpenAi::Client.new(current_user) + @client ||= ::Gitlab::Llm::OpenAi::Client.new(current_user, request_timeout: REQUEST_TIMEOUT) end def build_initial_prompts(search_documents) @@ -64,14 +72,15 @@ def build_initial_prompts(search_documents) def send_initial_prompt(doc:, prompt:) result = client.completions(prompt: prompt, **DEFAULT_OPTIONS) - debug( + info( document: doc[:id], - prompt: prompt, + openai_completions_response: prompt, + status_code: result.code, result: result.parsed_response, message: 'Initial prompt request' ) - raise result.dig('error', 'message') if result['error'] + raise result.dig('error', 'message') || "Initial prompt request failed with '#{result}'" unless result.success? doc.merge(extracted_text: result['choices'].first['text']) end @@ -118,17 +127,24 @@ def get_completions(search_documents) FINAL ANSWER: PROMPT - final_prompt_result = client.completions(prompt: final_prompt, **DEFAULT_OPTIONS) - debug(prompt: final_prompt, result: final_prompt_result.parsed_response, message: 'Final prompt request') + result = client.completions(prompt: final_prompt, **DEFAULT_OPTIONS) + info( + openai_completions_response: result, + status_code: result.code, + result: result.parsed_response, + message: 'Final prompt request' + ) + + raise result.dig('error', 'message') || "Final prompt request failed with '#{result}'" unless result.success? - final_prompt_result + result end def query_search_documents embeddings_result = client.embeddings(input: question) question_embedding = embeddings_result['data'].first['embedding'] - nearest_neighbors = ::Embedding::TanukiBot.neighbor_for(question_embedding).limit(RECORD_LIMIT) + nearest_neighbors = ::Embedding::TanukiBotMvc.neighbor_for(question_embedding).limit(RECORD_LIMIT) nearest_neighbors.map do |item| item.metadata['source_url'] = item.url @@ -151,6 +167,9 @@ def build_content(search_documents) def build_response(result, search_documents) output = result['choices'][0]['text'].split("#{CONTENT_ID_FIELD}:") + + raise 'Failed to parse the response' if output.length != 2 + msg = output[0].strip content_idx = output[1].scan(CONTENT_ID_REGEX).flatten.map(&:to_i) documents = search_documents.filter { |doc| content_idx.include?(doc[:id]) } @@ -164,15 +183,15 @@ def build_response(result, search_documents) def empty_response { - msg: _("I do not know."), # TODO namespace this? + msg: _("I do not know."), sources: [] } end - def debug(payload) + def info(payload) return unless logger - logger.debug(build_structured_payload(**payload)) + logger.info(build_structured_payload(**payload)) end end end diff --git a/ee/spec/controllers/llm/tanuki_bot_controller_spec.rb b/ee/spec/controllers/llm/tanuki_bot_controller_spec.rb index 81ef6d0912f68c2e..37ab8734d7bacf25 100644 --- a/ee/spec/controllers/llm/tanuki_bot_controller_spec.rb +++ b/ee/spec/controllers/llm/tanuki_bot_controller_spec.rb @@ -10,6 +10,7 @@ before do allow(Gitlab::Llm::TanukiBot).to receive_message_chain(:new, :execute).and_return({}) + allow(Gitlab::Llm::TanukiBot).to receive(:enabled_for?).and_return(true) end it 'responds with a 401' do @@ -23,6 +24,18 @@ sign_in(create(:user)) end + context 'when user does not have access to the feature' do + before do + allow(Gitlab::Llm::TanukiBot).to receive(:enabled_for?).and_return(false) + end + + it 'responds with a 401' do + subject + + expect(response).to have_gitlab_http_status(:unauthorized) + end + end + it 'responds with :bad_request if the request is not json' do post :ask, params: { q: question } diff --git a/ee/spec/lib/gitlab/llm/tanuki_bot_spec.rb b/ee/spec/lib/gitlab/llm/tanuki_bot_spec.rb index f665a54b23a80772..87ade6fc1b865c02 100644 --- a/ee/spec/lib/gitlab/llm/tanuki_bot_spec.rb +++ b/ee/spec/lib/gitlab/llm/tanuki_bot_spec.rb @@ -5,7 +5,7 @@ RSpec.describe Gitlab::Llm::TanukiBot, feature_category: :global_search do describe '#execute' do let_it_be(:user) { create(:user) } - let_it_be(:embeddings) { create_list(:tanuki_bot, 2) } + let_it_be(:embeddings) { create_list(:tanuki_bot_mvc, 2) } let(:question) { 'A question' } let(:answer) { 'The answer.' } @@ -14,15 +14,17 @@ let(:openai_client) { ::Gitlab::Llm::OpenAi::Client.new(user) } let(:embedding_response) { { "data" => [{ "embedding" => Array.new(1536, 0.5) }] } } let(:attrs) { embeddings.map(&:id).map { |x| "CNT-IDX-#{x}" }.join(", ") } - let(:completion_response) do - { "choices" => [{ "text" => "#{answer} ATTRS: #{attrs}" }] } - end + let(:completion_response) { { "choices" => [{ "text" => "#{answer} ATTRS: #{attrs}" }] } } + let(:status_code) { 200 } + let(:success) { true } subject(:execute) { instance.execute } before do allow(License).to receive(:feature_available?).and_return(true) - allow(logger).to receive(:debug) + allow(logger).to receive(:info) + allow(completion_response).to receive(:code).and_return(status_code) + allow(completion_response).to receive(:success?).and_return(success) end context 'with the ai_tanuki_bot license not available' do @@ -66,7 +68,7 @@ end it 'executes calls through to open ai' do - create_list(:tanuki_bot, 2) + embeddings expect(openai_client).to receive(:completions).exactly(3).times.and_return(completion_response) expect(openai_client).to receive(:embeddings).and_return(embedding_response) @@ -114,7 +116,7 @@ context 'when no neighbors are found' do before do - allow(Embedding::TanukiBot).to receive(:neighbor_for).and_return(Embedding::TanukiBot.none) + allow(Embedding::TanukiBotMvc).to receive(:neighbor_for).and_return(Embedding::TanukiBotMvc.none) allow(openai_client).to receive(:embeddings).with(input: question).and_return(embedding_response) end @@ -151,8 +153,8 @@ expect(result[:msg]).to eq(answer) expect(result[:sources].count).to eq(2) - expected_sources = ::Embedding::TanukiBot.pluck(:metadata).pluck('source') - expected_source_urls = ::Embedding::TanukiBot.pluck(:url) + expected_sources = ::Embedding::TanukiBotMvc.pluck(:metadata).pluck('source') + expected_source_urls = ::Embedding::TanukiBotMvc.pluck(:url) expect(result[:sources].pluck('source')).to match_array(expected_sources) expect(result[:sources].pluck('source_url')).to match_array(expected_source_urls) -- GitLab From 3bf3b0bdfd243b800d84c814e6c180debade793e Mon Sep 17 00:00:00 2001 From: Dmitry Gruzd <dgruzd@gitlab.com> Date: Tue, 25 Apr 2023 21:45:00 +0200 Subject: [PATCH 3/3] Rename document to document_id --- ee/lib/gitlab/llm/tanuki_bot.rb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ee/lib/gitlab/llm/tanuki_bot.rb b/ee/lib/gitlab/llm/tanuki_bot.rb index 6e562e1634d69ee5..1073d743ab094871 100644 --- a/ee/lib/gitlab/llm/tanuki_bot.rb +++ b/ee/lib/gitlab/llm/tanuki_bot.rb @@ -73,7 +73,7 @@ def send_initial_prompt(doc:, prompt:) result = client.completions(prompt: prompt, **DEFAULT_OPTIONS) info( - document: doc[:id], + document_id: doc[:id], openai_completions_response: prompt, status_code: result.code, result: result.parsed_response, -- GitLab