Skip to content
Snippets Groups Projects
Commit 31701e4e authored by Eva Kadlecová's avatar Eva Kadlecová :two: Committed by Tetiana Chupryna
Browse files

fix: support agent_scratchpad attribute in Chat messages

parent 87710e86
No related branches found
No related tags found
1 merge request!50Q integration bt
......@@ -171,7 +171,15 @@ class ReActPromptTemplate(Runnable[ReActAgentInputs, PromptValue]):
)
)
elif m.role is Role.ASSISTANT:
messages.append(AIMessage(m.content))
messages.append(
AIMessage(
jinja2_formatter(
self.prompt_template["assistant"],
agent_scratchpad=m.agent_scratchpad,
final_answer=m.content,
)
)
)
else:
raise ValueError("Unsupported message")
......
import json
from typing import Literal, Optional, TypeVar
from typing import Literal, Optional, Self, TypeVar
from pydantic import BaseModel
import fastapi
from pydantic import BaseModel, model_validator
from ai_gateway.chat.context.current_page import CurrentPageContext
from ai_gateway.models.base_chat import Role
......@@ -59,7 +60,7 @@ TypeAgentInputs = TypeVar("TypeAgentInputs")
class AgentStep(BaseModel):
action: AgentToolAction
action: Optional[AgentToolAction] = None
observation: str
......@@ -85,4 +86,13 @@ class Message(BaseModel):
context: Optional[CurrentPageContext] = None
current_file: Optional[CurrentFile] = None
additional_context: Optional[list[AdditionalContext]] = None
resource_content: Optional[str] = None
agent_scratchpad: Optional[list[AgentStep]] = None
@model_validator(mode="after")
def validate_agent_scratchpad_role(self) -> Self:
if self.agent_scratchpad is not None and self.role != Role.ASSISTANT:
raise fastapi.HTTPException(
status_code=400,
detail="agent_scratchpad can only be present when role is ASSISTANT",
)
return self
{% include 'chat/react/partials/agent_scratchpad.jinja' %}
Thought:
\ No newline at end of file
{%- if final_answer -%}
{% include 'chat/react/partials/assistant_with_final_answer.jinja' %}
{%- else -%}
{% include 'chat/react/partials/assistant.jinja' %}
{%- endif -%}
\ No newline at end of file
{% include 'chat/react/partials/agent_scratchpad.jinja' %}
Thought:
{# Stuffing the assistant message as a chat history with the raw context retrieved at that time #}
{%- if agent_scratchpad %}
{%- for pad in agent_scratchpad -%}
Observation: {{ pad.observation }}
{% endfor -%}
Final Answer: {{ final_answer }}
{%- else -%}
{{ final_answer }}
{%- endif %}
\ No newline at end of file
......@@ -127,13 +127,11 @@ class TestReActAgentStream:
Message(
role=Role.USER,
content="chat history",
resource_content="Please use this information about identified issue",
),
Message(role=Role.ASSISTANT, content="chat history"),
Message(
role=Role.USER,
content="What's the title of this issue?",
resource_content="Please use this information about identified issue",
context=Context(type="issue", content="issue content"),
current_file=CurrentFile(
file_path="main.py",
......@@ -230,13 +228,11 @@ class TestReActAgentStream:
Message(
role=Role.USER,
content="chat history",
resource_content="Please use this information about identified issue",
),
Message(role=Role.ASSISTANT, content="chat history"),
Message(
role=Role.USER,
content="What's the title of this issue?",
resource_content="Please use this information about identified issue",
context=Context(type="issue", content="issue content"),
current_file=CurrentFile(
file_path="main.py",
......
import fastapi
import pytest
from langchain_core.messages import SystemMessage
from pydantic import AnyUrl
......@@ -123,7 +124,6 @@ class TestReActAgent:
Message(
role=Role.USER,
content="What's the title of this issue?",
resource_content="Please use this information about identified issue",
),
],
agent_scratchpad=[],
......@@ -144,7 +144,6 @@ class TestReActAgent:
Message(
role=Role.USER,
content="Summarize this Merge request",
resource_content="Please use this information about identified issue",
),
],
agent_scratchpad=[],
......@@ -163,7 +162,20 @@ class TestReActAgent:
ReActAgentInputs(
messages=[
Message(role=Role.USER, content="How can I log output?"),
Message(role=Role.ASSISTANT, content="Use print function"),
Message(
role=Role.ASSISTANT,
content="Use print function",
agent_scratchpad=[
AgentStep(
action=AgentToolAction(
thought="thought",
tool="tool",
tool_input="tool_input",
),
observation="observation",
)
],
),
Message(
role=Role.USER,
content="Can you explain the print function?",
......@@ -182,9 +194,21 @@ class TestReActAgent:
Message(
role=Role.USER,
content="what's the description of this issue",
resource_content="Please use this information about identified issue",
),
Message(role=Role.ASSISTANT, content="PoC ReAct"),
Message(
role=Role.ASSISTANT,
content="PoC ReAct",
agent_scratchpad=[
AgentStep(
action=AgentToolAction(
thought="thought",
tool="tool",
tool_input="tool_input",
),
observation="observation",
)
],
),
Message(role=Role.USER, content="What's your name?"),
],
agent_scratchpad=[
......@@ -213,7 +237,6 @@ class TestReActAgent:
Message(
role=Role.USER,
content="Explain this issue",
resource_content="Please use this information about identified issue",
context=Context(
type="issue", content="this issue is about Duo Chat"
),
......@@ -233,7 +256,6 @@ class TestReActAgent:
Message(
role=Role.USER,
content="Explain this issue",
resource_content="Please use this information about identified issue",
context=IssueContext(type="issue", title="Duo Chat issue"),
),
],
......@@ -345,7 +367,6 @@ class TestReActAgent:
Message(
role=Role.USER,
content="What's the title of this issue?",
resource_content="Please use this information about identified issue",
),
],
agent_scratchpad=[],
......@@ -359,7 +380,6 @@ class TestReActAgent:
Message(
role=Role.USER,
content="What's the title of this issue?",
resource_content="Please use this information about identified issue",
),
],
agent_scratchpad=[],
......@@ -373,7 +393,6 @@ class TestReActAgent:
Message(
role=Role.USER,
content="What's the title of this issue?",
resource_content="Please use this information about identified issue",
),
],
agent_scratchpad=[],
......@@ -394,7 +413,6 @@ class TestReActAgent:
Message(
role=Role.USER,
content="What's the title of this issue?",
resource_content="Please use this information about identified issue",
),
],
agent_scratchpad=[],
......@@ -415,7 +433,6 @@ class TestReActAgent:
Message(
role=Role.USER,
content="What's the title of this issue?",
resource_content="Please use this information about identified issue",
),
],
agent_scratchpad=[],
......@@ -506,3 +523,18 @@ class TestReActAgent:
assert actual_events == expected_events
assert str(exc_info.value) == error_message
@pytest.mark.asyncio
async def test_message_agent_scratchpad_validation(self):
with pytest.raises(fastapi.HTTPException) as exc_info:
Message(
role=Role.USER,
content="test",
agent_scratchpad=[AgentStep(action=None, observation="test")],
)
assert exc_info.value.status_code == 400
assert (
exc_info.value.detail
== "agent_scratchpad can only be present when role is ASSISTANT"
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment