跳到内容

构建和评估用于获取金属价格的ReAct Agent

AI Agent在金融、电子商务和客户支持等领域变得越来越有价值。这些Agent可以自主地与API交互,检索实时数据,并执行与用户目标一致的任务。评估这些Agent对于确保它们有效、准确并能响应不同的输入至关重要。

在本教程中,我们将

  1. 构建一个ReAct Agent来获取金属价格。
  2. 设置一个评估管道来跟踪关键性能指标。
  3. 运行并评估Agent在不同查询下的有效性。

点击此链接在Google Colab中打开notebook。

先决条件

  • Python 3.8+
  • 对LangGraph、LangChain和LLM有基本了解

安装Ragas及其他依赖项

使用pip安装Ragas和Langgraph

%pip install langgraph==0.2.44
%pip install ragas
%pip install nltk

构建ReAct Agent

初始化外部组件

首先,您有两种设置外部组件的选项

  1. 使用实时API密钥

    • metals.dev上注册账户以获取API密钥。
  2. 模拟API响应

    • 或者,您可以使用预定义的JSON对象来模拟API响应。这使您无需实时API密钥即可更快地开始。

选择最适合您需求的方法进行设置。

预定义JSON对象模拟API响应

如果您想快速开始而无需创建账户,可以跳过设置过程,使用下面提供的预定义JSON对象来模拟API响应。

metal_price = {
    "gold": 88.1553,
    "silver": 1.0523,
    "platinum": 32.169,
    "palladium": 35.8252,
    "lbma_gold_am": 88.3294,
    "lbma_gold_pm": 88.2313,
    "lbma_silver": 1.0545,
    "lbma_platinum_am": 31.99,
    "lbma_platinum_pm": 32.2793,
    "lbma_palladium_am": 36.0088,
    "lbma_palladium_pm": 36.2017,
    "mcx_gold": 93.2689,
    "mcx_gold_am": 94.281,
    "mcx_gold_pm": 94.1764,
    "mcx_silver": 1.125,
    "mcx_silver_am": 1.1501,
    "mcx_silver_pm": 1.1483,
    "ibja_gold": 93.2713,
    "copper": 0.0098,
    "aluminum": 0.0026,
    "lead": 0.0021,
    "nickel": 0.0159,
    "zinc": 0.0031,
    "lme_copper": 0.0096,
    "lme_aluminum": 0.0026,
    "lme_lead": 0.002,
    "lme_nickel": 0.0158,
    "lme_zinc": 0.0031,
}

定义get_metal_price工具

get_metal_price工具将被Agent用于获取指定金属的价格。我们将使用LangChain的@tool装饰器创建此工具。

如果您想使用来自metals.dev API的实时数据,可以修改函数以向API发起实时请求。

from langchain_core.tools import tool


# Define the tools for the agent to use
@tool
def get_metal_price(metal_name: str) -> float:
    """Fetches the current per gram price of the specified metal.

    Args:
        metal_name : The name of the metal (e.g., 'gold', 'silver', 'platinum').

    Returns:
        float: The current price of the metal in dollars per gram.

    Raises:
        KeyError: If the specified metal is not found in the data source.
    """
    try:
        metal_name = metal_name.lower().strip()
        if metal_name not in metal_price:
            raise KeyError(
                f"Metal '{metal_name}' not found. Available metals: {', '.join(metal_price['metals'].keys())}"
            )
        return metal_price[metal_name]
    except Exception as e:
        raise Exception(f"Error fetching metal price: {str(e)}")

将工具绑定到LLM

定义get_metal_price工具后,下一步是将其绑定到ChatOpenAI模型。这使得Agent能够在执行过程中根据用户的请求调用该工具,从而使其能够与外部数据交互并执行超出其原生能力的操作。

from langchain_openai import ChatOpenAI

tools = [get_metal_price]
llm = ChatOpenAI(model="gpt-4o-mini")
llm_with_tools = llm.bind_tools(tools)

在LangGraph中,状态(state)在跟踪和更新图执行过程中的信息方面起着至关重要的作用。随着图的不同部分运行,状态会随之演变,反映出变化,并包含在节点之间传递的信息。

例如,在这样一个对话系统中,状态用于跟踪交换的消息。每次生成新消息时,都会将其添加到状态中,更新后的状态通过节点传递,确保对话逻辑地进行。

定义状态

为了在LangGraph中实现这一点,我们定义了一个状态类,该类维护一个消息列表。每当生成新消息时,都会将其追加到此列表中,确保对话历史不断更新。

from langgraph.graph import END
from langchain_core.messages import AnyMessage
from langgraph.graph.message import add_messages
from typing import Annotated
from typing_extensions import TypedDict


class GraphState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]

定义should_continue函数

should_continue函数确定对话是否应继续进行进一步的工具交互或结束。具体来说,它检查最后一条消息是否包含任何工具调用(例如,获取金属价格的请求)。

  • 如果最后一条消息包含工具调用,表明Agent已调用外部工具,则对话继续并进入“tools”节点。
  • 如果没有工具调用,则对话结束,由END状态表示。
# Define the function that determines whether to continue or not
def should_continue(state: GraphState):
    messages = state["messages"]
    last_message = messages[-1]
    if last_message.tool_calls:
        return "tools"
    return END

调用模型

call_model函数与大语言模型(LLM)交互,根据对话的当前状态生成响应。它将更新后的状态作为输入,进行处理并返回模型生成的响应。

# Define the function that calls the model
def call_model(state: GraphState):
    messages = state["messages"]
    response = llm_with_tools.invoke(messages)
    return {"messages": [response]}

创建助手节点

assistant节点是一个关键组件,负责处理对话的当前状态,并使用大语言模型(LLM)生成相关响应。它评估状态,确定适当的操作流程,并调用LLM生成与当前对话一致的响应。

# Node
def assistant(state: GraphState):
    response = llm_with_tools.invoke(state["messages"])
    return {"messages": [response]}

创建工具节点

tool_node负责管理与外部工具的交互,例如获取金属价格或执行超出LLM原生能力的其他操作。工具本身在代码前面已定义,tool_node根据当前状态和对话需求调用这些工具。

from langgraph.prebuilt import ToolNode

# Node
tools = [get_metal_price]
tool_node = ToolNode(tools)

构建图

图结构是Agent工作流的骨干,由相互连接的节点和边组成。为了构建这个图,我们使用StateGraph构建器,它允许我们定义和连接各种节点。每个节点代表过程中的一个步骤(例如,助手节点、工具节点),边则规定了这些步骤之间的执行流程。

from langgraph.graph import START, StateGraph
from IPython.display import Image, display

# Define a new graph for the agent
builder = StateGraph(GraphState)

# Define the two nodes we will cycle between
builder.add_node("assistant", assistant)
builder.add_node("tools", tool_node)

# Set the entrypoint as `agent`
builder.add_edge(START, "assistant")

# Making a conditional edge
# should_continue will determine which node is called next.
builder.add_conditional_edges("assistant", should_continue, ["tools", END])

# Making a normal edge from `tools` to `agent`.
# The `agent` node will be called after the `tool`.
builder.add_edge("tools", "assistant")

# Compile and display the graph for a visual overview
react_graph = builder.compile()
display(Image(react_graph.get_graph(xray=True).draw_mermaid_png()))

jpeg

为了测试我们的设置,我们将使用一个查询运行Agent。Agent将使用metals.dev API获取铜的价格。

from langchain_core.messages import HumanMessage

messages = [HumanMessage(content="What is the price of copper?")]
result = react_graph.invoke({"messages": messages})
result["messages"]
[HumanMessage(content='What is the price of copper?', id='4122f5d4-e298-49e8-a0e0-c98adda78c6c'),
 AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_DkVQBK4UMgiXrpguUS2qC4mA', 'function': {'arguments': '{"metal_name":"copper"}', 'name': 'get_metal_price'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 18, 'prompt_tokens': 116, 'total_tokens': 134, 'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0}, 'completion_tokens_details': {'reasoning_tokens': 0, 'audio_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_0ba0d124f1', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-0f77b156-e43e-4c1e-bd3a-307333eefb68-0', tool_calls=[{'name': 'get_metal_price', 'args': {'metal_name': 'copper'}, 'id': 'call_DkVQBK4UMgiXrpguUS2qC4mA', 'type': 'tool_call'}], usage_metadata={'input_tokens': 116, 'output_tokens': 18, 'total_tokens': 134}),
 ToolMessage(content='0.0098', name='get_metal_price', id='422c089a-6b76-4e48-952f-8925c3700ae3', tool_call_id='call_DkVQBK4UMgiXrpguUS2qC4mA'),
 AIMessage(content='The price of copper is $0.0098 per gram.', response_metadata={'token_usage': {'completion_tokens': 14, 'prompt_tokens': 148, 'total_tokens': 162, 'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0}, 'completion_tokens_details': {'reasoning_tokens': 0, 'audio_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_0ba0d124f1', 'finish_reason': 'stop', 'logprobs': None}, id='run-67cbf98b-4fa6-431e-9ce4-58697a76c36e-0', usage_metadata={'input_tokens': 148, 'output_tokens': 14, 'total_tokens': 162})]

将消息转换为Ragas评估格式

在当前实现中,GraphState将人类用户、AI(LLM的响应)以及任何外部工具(AI使用的API或服务)之间交换的消息存储在一个列表中。每条消息都是LangChain格式的对象

# Implementation of Graph State
class GraphState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]

Agent执行期间每次交换消息时,都会将其添加到GraphState的消息列表中。然而,Ragas需要特定的消息格式来评估交互。

Ragas使用自己的格式来评估Agent交互。因此,如果您正在使用LangGraph,您需要将LangChain消息对象转换为Ragas消息对象。这使您可以使用Ragas内置的评估工具来评估您的AI Agent。

目标:将LangChain消息列表(例如,HumanMessage、AIMessage和ToolMessage)转换为Ragas期望的格式,以便评估框架能够正确理解和处理它们。

为了将LangChain消息列表转换为适合Ragas评估的格式,Ragas提供了函数convert_to_ragas_messages,该函数可用于将LangChain消息转换为Ragas期望的格式。

以下是如何使用该函数

from ragas.integrations.langgraph import convert_to_ragas_messages

# Assuming 'result["messages"]' contains the list of LangChain messages
ragas_trace = convert_to_ragas_messages(result["messages"])
ragas_trace  # List of Ragas messages
[HumanMessage(content='What is the price of copper?', metadata=None, type='human'),
 AIMessage(content='', metadata=None, type='ai', tool_calls=[ToolCall(name='get_metal_price', args={'metal_name': 'copper'})]),
 ToolMessage(content='0.0098', metadata=None, type='tool'),
 AIMessage(content='The price of copper is $0.0098 per gram.', metadata=None, type='ai', tool_calls=None)]

评估Agent的性能

在本教程中,我们将使用以下指标评估Agent

  • 工具调用准确性:ToolCallAccuracy是一个指标,可用于评估LLM识别和调用所需工具以完成给定任务的性能。

  • Agent目标准确性:Agent goal accuracy是一个指标,可用于评估LLM识别和实现用户目标的性能。这是一个二元指标,1表示AI已实现目标,0表示AI未实现目标。

首先,让我们用几个查询实际运行我们的Agent,并确保我们有这些查询的真实标签。

工具调用准确性

from ragas.metrics import ToolCallAccuracy
from ragas.dataset_schema import MultiTurnSample
from ragas.integrations.langgraph import convert_to_ragas_messages
import ragas.messages as r


ragas_trace = convert_to_ragas_messages(
    messages=result["messages"]
)  # List of Ragas messages converted using the Ragas function

sample = MultiTurnSample(
    user_input=ragas_trace,
    reference_tool_calls=[
        r.ToolCall(name="get_metal_price", args={"metal_name": "copper"})
    ],
)

tool_accuracy_scorer = ToolCallAccuracy()
await tool_accuracy_scorer.multi_turn_ascore(sample)
1.0

工具调用准确性:1,因为LLM正确识别并使用了必要的工具(get_metal_price)以及正确的参数(即金属名称为“copper”)。

Agent 目标准确性

messages = [HumanMessage(content="What is the price of 10 grams of silver?")]

result = react_graph.invoke({"messages": messages})
result["messages"]  # List of Langchain messages
[HumanMessage(content='What is the price of 10 grams of silver?', id='51a469de-5b7c-4d01-ab71-f8db64c8da49'),
 AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_rdplOo95CRwo3mZcPu4dmNxG', 'function': {'arguments': '{"metal_name":"silver"}', 'name': 'get_metal_price'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 17, 'prompt_tokens': 120, 'total_tokens': 137, 'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0}, 'completion_tokens_details': {'reasoning_tokens': 0, 'audio_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_0ba0d124f1', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-3bb60e27-1275-41f1-a46e-03f77984c9d8-0', tool_calls=[{'name': 'get_metal_price', 'args': {'metal_name': 'silver'}, 'id': 'call_rdplOo95CRwo3mZcPu4dmNxG', 'type': 'tool_call'}], usage_metadata={'input_tokens': 120, 'output_tokens': 17, 'total_tokens': 137}),
 ToolMessage(content='1.0523', name='get_metal_price', id='0b5f9260-df26-4164-b042-6df2e869adfb', tool_call_id='call_rdplOo95CRwo3mZcPu4dmNxG'),
 AIMessage(content='The current price of silver is approximately $1.0523 per gram. Therefore, the price of 10 grams of silver would be about $10.52.', response_metadata={'token_usage': {'completion_tokens': 34, 'prompt_tokens': 151, 'total_tokens': 185, 'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0}, 'completion_tokens_details': {'reasoning_tokens': 0, 'audio_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_0ba0d124f1', 'finish_reason': 'stop', 'logprobs': None}, id='run-93e38f71-cc9d-41d6-812a-bfad9f9231b2-0', usage_metadata={'input_tokens': 151, 'output_tokens': 34, 'total_tokens': 185})]
from ragas.integrations.langgraph import convert_to_ragas_messages

ragas_trace = convert_to_ragas_messages(
    result["messages"]
)  # List of Ragas messages converted using the Ragas function
ragas_trace
[HumanMessage(content='What is the price of 10 grams of silver?', metadata=None, type='human'),
 AIMessage(content='', metadata=None, type='ai', tool_calls=[ToolCall(name='get_metal_price', args={'metal_name': 'silver'})]),
 ToolMessage(content='1.0523', metadata=None, type='tool'),
 AIMessage(content='The current price of silver is approximately $1.0523 per gram. Therefore, the price of 10 grams of silver would be about $10.52.', metadata=None, type='ai', tool_calls=None)]
from ragas.dataset_schema import MultiTurnSample
from ragas.metrics import AgentGoalAccuracyWithReference
from ragas.llms import LangchainLLMWrapper


sample = MultiTurnSample(
    user_input=ragas_trace,
    reference="Price of 10 grams of silver",
)

scorer = AgentGoalAccuracyWithReference()

evaluator_llm = LangchainLLMWrapper(ChatOpenAI(model="gpt-4o-mini"))
scorer.llm = evaluator_llm
await scorer.multi_turn_ascore(sample)
1.0

Agent目标准确性:1,因为LLM正确地实现了用户获取10克白银价格的目标。

下一步是什么

🎉 恭喜!我们已经学会了如何使用Ragas评估框架来评估Agent。