工作记忆
工作记忆是 Agent 推理过程中的"草稿纸",存储当前任务执行的中间状态和临时计算结果。它在单次任务执行期间存在,支持复杂的推理和规划流程。
一、什么是工作记忆?
1.1 定义
工作记忆(Working Memory) 是 Agent 在执行单次任务时维护的临时状态存储,用于保存推理过程中的中间结果、任务状态和临时数据。生命周期仅限于单次任务执行期间。
1.2 核心特征
┌─────────────────────────────────────────────────────────────┐
│ 工作记忆核心特征 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 存储位置:内存变量、Scratchpad │ │
│ │ 生命周期:单次任务执行期间 │ │
│ │ 容量限制:无固定限制,取决于系统内存 │ │
│ │ 访问速度:最快,直接内存访问 │ │
│ │ 持久性:任务结束后消失 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ 典型内容: │
│ • 当前推理步骤的中间结果 │
│ • 任务执行状态和进度 │
│ • 工具调用的输入输出 │
│ • 临时计算和决策记录 │
│ • 执行轨迹(Thought/Observation/Action) │
│ │
└─────────────────────────────────────────────────────────────┘1.3 与短期/长期记忆的区别
| 维度 | 工作记忆 | 短期记忆 | 长期记忆 |
|---|---|---|---|
| 存储位置 | 内存变量 | LLM Context Window | 外部存储 |
| 生命周期 | 单次任务执行 | 会话期间 | 永久 |
| 访问速度 | 最快 | 快 | 较慢 |
| 用途 | 推理过程暂存 | 对话上下文 | 知识积累 |
| 示例 | ReAct 的 Thought/Obs | 对话历史 | 用户偏好 |
| 是否参与推理 | 直接参与 | 直接参与 | 需检索后参与 |
二、工作记忆的应用场景
2.1 ReAct 框架中的工作记忆
┌─────────────────────────────────────────────────────────────┐
│ ReAct 框架中的工作记忆 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 用户问题:北京今天天气如何?我应该穿什么衣服? │
│ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 工作记忆内容: │ │
│ │ │ │
│ │ Step 1: │ │
│ │ Thought: 需要先查询北京今天的天气 │ │
│ │ Action: search("北京天气 今天") │ │
│ │ Observation: 北京今天晴,气温 15-25°C │ │
│ │ ✓ 存入工作记忆:天气信息 │ │
│ │ │ │
│ │ Step 2: │ │
│ │ Thought: 根据天气信息,给出穿衣建议 │ │
│ │ Action: 从工作记忆读取天气信息 │ │
│ │ Final Answer: 北京今天晴天,气温15-25°C... │ │
│ │ │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ 工作记忆作用: │
│ • 存储 Step 1 的天气查询结果 │
│ • 供 Step 2 使用,避免重复查询 │
│ • 任务完成后自动清除 │
│ │
└─────────────────────────────────────────────────────────────┘2.2 多步骤任务中的工作记忆
┌─────────────────────────────────────────────────────────────┐
│ 多步骤任务中的工作记忆示例 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 任务:分析销售数据并生成报告 │
│ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 工作记忆状态: │ │
│ │ │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ 任务进度 │ │ │
│ │ │ current_step: 3 │ │ │
│ │ │ total_steps: 5 │ │ │
│ │ │ status: "in_progress" │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ │ │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ 中间结果 │ │ │
│ │ │ step_1_result: 数据读取成功 │ │ │
│ │ │ - 文件: sales_2024.csv │ │ │
│ │ │ - 行数: 10000 │ │ │
│ │ │ step_2_result: 数据清洗完成 │ │ │
│ │ │ - 有效数据: 9800 │ │ │
│ │ │ - 异常数据: 200 │ │ │
│ │ │ step_3_result: (进行中) │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ │ │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ 临时变量 │ │ │
│ │ │ analysis_context: {...} │ │ │
│ │ │ error_log: [] │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ │ │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘2.3 任务中断与恢复
┌─────────────────────────────────────────────────────────────┐
│ 任务中断与恢复机制 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 场景:长时间任务被中断后需要恢复 │
│ │
│ 正常执行: │
│ ┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐ │
│ │ Step 1 │ → │ Step 2 │ → │ Step 3 │ → │ Step 4 │ │
│ └────────┘ └────────┘ └────────┘ └────────┘ │
│ ↓ ↓ ↓ ↓ │
│ [工作记忆] [工作记忆] [中断!] [恢复点] │
│ │
│ 恢复策略: │
│ 1. 定期将工作记忆快照保存到长期记忆 │
│ 2. 中断后从最近的快照恢复 │
│ 3. 继续执行未完成的步骤 │
│ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 快照保存到长期记忆: │ │
│ │ { │ │
│ │ "task_id": "task_123", │ │
│ │ "current_step": 2, │ │
│ │ "completed_steps": [1, 2], │ │
│ │ "intermediate_results": {...}, │ │
│ │ "checkpoint_time": "2024-01-15T10:30:00" │ │
│ │ } │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘三、工作记忆的实现
3.1 基础实现
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from datetime import datetime
import json
@dataclass
class WorkingMemory:
"""工作记忆基础实现"""
# 任务信息
task_id: str = ""
task_description: str = ""
# 执行状态
current_step: int = 0
total_steps: int = 0
status: str = "idle" # idle, running, completed, failed
# 推理轨迹
thoughts: List[str] = field(default_factory=list)
actions: List[str] = field(default_factory=list)
observations: List[str] = field(default_factory=list)
# 中间结果
intermediate_results: Dict[int, Any] = field(default_factory=dict)
# 临时变量
variables: Dict[str, Any] = field(default_factory=dict)
# 错误记录
errors: List[str] = field(default_factory=list)
def start_task(self, task_id: str, description: str, total_steps: int = 0):
"""开始任务"""
self.task_id = task_id
self.task_description = description
self.total_steps = total_steps
self.current_step = 0
self.status = "running"
self._clear()
def _clear(self):
"""清空工作记忆"""
self.thoughts.clear()
self.actions.clear()
self.observations.clear()
self.intermediate_results.clear()
self.variables.clear()
self.errors.clear()
def record_thought(self, thought: str):
"""记录推理思考"""
self.thoughts.append(thought)
def record_action(self, action: str):
"""记录执行动作"""
self.actions.append(action)
def record_observation(self, observation: str):
"""记录观察结果"""
self.observations.append(observation)
def save_result(self, step: int, result: Any):
"""保存步骤结果"""
self.intermediate_results[step] = {
"result": result,
"timestamp": datetime.now().isoformat()
}
def get_result(self, step: int) -> Optional[Any]:
"""获取步骤结果"""
if step in self.intermediate_results:
return self.intermediate_results[step]["result"]
return None
def set_variable(self, key: str, value: Any):
"""设置临时变量"""
self.variables[key] = value
def get_variable(self, key: str, default: Any = None) -> Any:
"""获取临时变量"""
return self.variables.get(key, default)
def record_error(self, error: str):
"""记录错误"""
self.errors.append(error)
def advance_step(self):
"""推进步骤"""
self.current_step += 1
def complete_task(self):
"""完成任务"""
self.status = "completed"
def fail_task(self, reason: str = ""):
"""任务失败"""
self.status = "failed"
if reason:
self.errors.append(reason)
def get_context(self) -> str:
"""获取当前上下文摘要"""
context = f"任务: {self.task_description}\n"
context += f"进度: {self.current_step}/{self.total_steps}\n"
context += f"状态: {self.status}\n\n"
if self.thoughts:
context += "推理轨迹:\n"
for i, thought in enumerate(self.thoughts[-5:], 1): # 最近5条
context += f" Thought {i}: {thought}\n"
if self.variables:
context += f"\n临时变量: {json.dumps(self.variables, ensure_ascii=False)}\n"
return context
def to_dict(self) -> Dict:
"""导出为字典(用于持久化)"""
return {
"task_id": self.task_id,
"task_description": self.task_description,
"current_step": self.current_step,
"total_steps": self.total_steps,
"status": self.status,
"thoughts": self.thoughts,
"actions": self.actions,
"observations": self.observations,
"intermediate_results": self.intermediate_results,
"variables": self.variables,
"errors": self.errors
}
@classmethod
def from_dict(cls, data: Dict) -> 'WorkingMemory':
"""从字典恢复"""
memory = cls()
for key, value in data.items():
setattr(memory, key, value)
return memory3.2 LangChain Agent 中的工作记忆
from langchain.agents import AgentExecutor, create_react_agent
from langchain_openai import OpenAI
from langchain.tools import Tool
from langchain import hub
class AgentWithWorkingMemory:
"""带工作记忆的 Agent"""
def __init__(self, llm, tools):
self.llm = llm
self.tools = tools
self.working_memory = WorkingMemory()
# 创建 ReAct Agent
prompt = hub.pull("hwchase17/react")
agent = create_react_agent(llm, tools, prompt)
self.agent_executor = AgentExecutor(
agent=agent,
tools=tools,
verbose=True,
handle_parsing_errors=True
)
def run(self, task: str) -> str:
"""执行任务"""
# 初始化工作记忆
self.working_memory.start_task(
task_id=f"task_{hash(task)}",
description=task
)
try:
# 执行 Agent
result = self.agent_executor.invoke({"input": task})
# 保存结果
self.working_memory.save_result(
step=1,
result=result["output"]
)
self.working_memory.complete_task()
return result["output"]
except Exception as e:
self.working_memory.fail_task(str(e))
raise
def get_execution_trace(self) -> List[Dict]:
"""获取执行轨迹"""
return list(zip(
self.working_memory.thoughts,
self.working_memory.actions,
self.working_memory.observations
))
# 使用示例
def search_tool(query: str) -> str:
return f"搜索结果: {query} 的相关信息..."
def calculator_tool(expression: str) -> str:
try:
return str(eval(expression))
except:
return "计算错误"
tools = [
Tool(name="Search", func=search_tool, description="搜索工具"),
Tool(name="Calculator", func=calculator_tool, description="计算器")
]
llm = OpenAI(temperature=0)
agent = AgentWithWorkingMemory(llm, tools)3.3 高级实现:支持中断恢复
import json
from typing import Optional, Callable
from datetime import datetime
class PersistentWorkingMemory(WorkingMemory):
"""支持持久化的工作记忆"""
def __init__(self, persistence_callback: Optional[Callable] = None):
super().__init__()
self.persistence_callback = persistence_callback
self.checkpoint_interval = 5 # 每5步检查点
def advance_step(self):
"""推进步骤,自动创建检查点"""
super().advance_step()
if self.current_step % self.checkpoint_interval == 0:
self.create_checkpoint()
def create_checkpoint(self):
"""创建检查点"""
checkpoint = {
"task_id": self.task_id,
"timestamp": datetime.now().isoformat(),
"memory_state": self.to_dict()
}
if self.persistence_callback:
self.persistence_callback(checkpoint)
return checkpoint
@classmethod
def restore_from_checkpoint(cls, checkpoint: Dict, persistence_callback: Callable = None) -> 'PersistentWorkingMemory':
"""从检查点恢复"""
memory = cls(persistence_callback)
memory_data = checkpoint["memory_state"]
for key, value in memory_data.items():
setattr(memory, key, value)
return memory
class ResumableAgent:
"""支持中断恢复的 Agent"""
def __init__(self, llm, tools, storage_client):
self.llm = llm
self.tools = tools
self.storage = storage_client # 长期记忆存储
def execute_with_resume(self, task_id: str, task_fn: Callable):
"""执行任务,支持中断恢复"""
# 尝试从检查点恢复
checkpoint = self.storage.get_checkpoint(task_id)
if checkpoint:
print(f"从检查点恢复: {checkpoint['timestamp']}")
working_memory = PersistentWorkingMemory.restore_from_checkpoint(
checkpoint,
persistence_callback=lambda cp: self.storage.save_checkpoint(task_id, cp)
)
else:
working_memory = PersistentWorkingMemory(
persistence_callback=lambda cp: self.storage.save_checkpoint(task_id, cp)
)
# 执行任务
try:
result = task_fn(working_memory)
working_memory.complete_task()
# 清理检查点
self.storage.delete_checkpoint(task_id)
return result
except Exception as e:
working_memory.record_error(str(e))
working_memory.fail_task(str(e))
raise四、工作记忆与推理框架
4.1 ReAct 框架
┌─────────────────────────────────────────────────────────────┐
│ ReAct 框架中的工作记忆流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 用户问题:Python 如何读取 CSV 文件? │
│ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ Iteration 1 │ │
│ │ ┌────────────────────────────────────────────────┐ │ │
│ │ │ Thought: 用户想知道 Python 读取 CSV 的方法 │ │ │
│ │ │ → 存入工作记忆 │ │ │
│ │ └────────────────────────────────────────────────┘ │ │
│ │ ┌────────────────────────────────────────────────┐ │ │
│ │ │ Action: search("Python read CSV") │ │ │
│ │ │ → 存入工作记忆 │ │ │
│ │ └────────────────────────────────────────────────┘ │ │
│ │ ┌────────────────────────────────────────────────┐ │ │
│ │ │ Observation: pandas.read_csv() 是最常用方法... │ │ │
│ │ │ → 存入工作记忆 │ │ │
│ │ └────────────────────────────────────────────────┘ │ │
│ └──────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ Iteration 2 │ │
│ │ ┌────────────────────────────────────────────────┐ │ │
│ │ │ Thought: 从工作记忆读取上次搜索结果 │ │ │
│ │ │ 整理代码示例 │ │ │
│ │ └────────────────────────────────────────────────┘ │ │
│ │ ┌────────────────────────────────────────────────┐ │ │
│ │ │ Action: Finish │ │ │
│ │ └────────────────────────────────────────────────┘ │ │
│ │ ┌────────────────────────────────────────────────┐ │ │
│ │ │ Final Answer: Python 读取 CSV 最常用的方法是 │ │ │
│ │ │ 使用 pandas.read_csv(),示例代码如下: │ │ │
│ │ │ import pandas as pd │ │ │
│ │ │ df = pd.read_csv('file.csv') │ │ │
│ │ └────────────────────────────────────────────────┘ │ │
│ └──────────────────────────────────────────────────────┘ │
│ │
│ 工作记忆状态(迭代过程中): │
│ thoughts: ["用户想知道...", "整理代码示例"] │
│ actions: ["search(...)", "Finish"] │
│ observations: ["pandas.read_csv() 是最常用方法..."] │
│ │
└─────────────────────────────────────────────────────────────┘4.2 Plan-and-Execute 框架
from typing import List, Dict, Any
from dataclasses import dataclass, field
@dataclass
class PlanStep:
"""计划步骤"""
id: int
description: str
status: str = "pending" # pending, running, completed, failed
result: Any = None
dependencies: List[int] = field(default_factory=list)
class PlanExecuteWorkingMemory(WorkingMemory):
"""Plan-Execute 框架的工作记忆"""
def __init__(self):
super().__init__()
self.plan: List[PlanStep] = []
self.execution_log: List[Dict] = []
def set_plan(self, steps: List[str]):
"""设置执行计划"""
self.plan = [
PlanStep(id=i+1, description=step)
for i, step in enumerate(steps)
]
self.total_steps = len(steps)
def get_next_step(self) -> Optional[PlanStep]:
"""获取下一个可执行的步骤"""
for step in self.plan:
if step.status == "pending":
# 检查依赖是否完成
deps_completed = all(
self.get_step_by_id(dep_id).status == "completed"
for dep_id in step.dependencies
)
if deps_completed:
return step
return None
def get_step_by_id(self, step_id: int) -> Optional[PlanStep]:
"""根据 ID 获取步骤"""
for step in self.plan:
if step.id == step_id:
return step
return None
def start_step(self, step_id: int):
"""开始执行步骤"""
step = self.get_step_by_id(step_id)
if step:
step.status = "running"
self.current_step = step_id
self.execution_log.append({
"step_id": step_id,
"action": "start",
"timestamp": datetime.now().isoformat()
})
def complete_step(self, step_id: int, result: Any):
"""完成步骤"""
step = self.get_step_by_id(step_id)
if step:
step.status = "completed"
step.result = result
self.save_result(step_id, result)
self.execution_log.append({
"step_id": step_id,
"action": "complete",
"result": result,
"timestamp": datetime.now().isoformat()
})
def fail_step(self, step_id: int, error: str):
"""步骤失败"""
step = self.get_step_by_id(step_id)
if step:
step.status = "failed"
self.record_error(f"Step {step_id} failed: {error}")
self.execution_log.append({
"step_id": step_id,
"action": "fail",
"error": error,
"timestamp": datetime.now().isoformat()
})
def get_plan_status(self) -> str:
"""获取计划状态摘要"""
completed = sum(1 for s in self.plan if s.status == "completed")
running = sum(1 for s in self.plan if s.status == "running")
pending = sum(1 for s in self.plan if s.status == "pending")
failed = sum(1 for s in self.plan if s.status == "failed")
return f"计划进度: {completed}完成 / {running}执行中 / {pending}待执行 / {failed}失败"五、工作记忆的优化技巧
5.1 减少内存占用
class EfficientWorkingMemory(WorkingMemory):
"""高效工作记忆"""
def __init__(self, max_thoughts: int = 20, max_observations: int = 10):
super().__init__()
self.max_thoughts = max_thoughts
self.max_observations = max_observations
def record_thought(self, thought: str):
"""记录思考,超出限制时压缩"""
self.thoughts.append(thought)
if len(self.thoughts) > self.max_thoughts:
# 压缩旧思考为摘要
old_thoughts = self.thoughts[:-self.max_thoughts//2]
summary = self._summarize_thoughts(old_thoughts)
self.thoughts = [f"[历史摘要] {summary}"] + self.thoughts[-self.max_thoughts//2:]
def record_observation(self, observation: str):
"""记录观察,超出限制时丢弃旧的"""
self.observations.append(observation)
if len(self.observations) > self.max_observations:
self.observations = self.observations[-self.max_observations:]
def _summarize_thoughts(self, thoughts: List[str]) -> str:
"""压缩思考为摘要"""
# 简化实现,实际可用 LLM 生成摘要
return "; ".join(thoughts[:3]) + "..." if len(thoughts) > 3 else "; ".join(thoughts)5.2 执行轨迹可视化
class TraceableWorkingMemory(WorkingMemory):
"""支持执行轨迹追踪的工作记忆"""
def __init__(self):
super().__init__()
self.trace_events: List[Dict] = []
def _record_event(self, event_type: str, data: Dict):
"""记录事件"""
self.trace_events.append({
"type": event_type,
"data": data,
"timestamp": datetime.now().isoformat(),
"step": self.current_step
})
def record_thought(self, thought: str):
super().record_thought(thought)
self._record_event("thought", {"content": thought})
def record_action(self, action: str):
super().record_action(action)
self._record_event("action", {"content": action})
def record_observation(self, observation: str):
super().record_observation(observation)
self._record_event("observation", {"content": observation})
def export_trace(self) -> List[Dict]:
"""导出执行轨迹"""
return self.trace_events.copy()
def visualize_trace(self) -> str:
"""可视化执行轨迹"""
output = "执行轨迹:\n" + "=" * 50 + "\n"
for event in self.trace_events:
event_type = event["type"]
content = event["data"]["content"]
step = event["step"]
icon = {
"thought": "💭",
"action": "🎬",
"observation": "👁️"
}.get(event_type, "•")
output += f"{icon} Step {step} [{event_type}]: {content[:100]}...\n"
return output六、面试高频问题
Q1: 工作记忆和短期记忆有什么区别?
答案要点:
| 维度 | 工作记忆 | 短期记忆 |
|---|---|---|
| 定义 | 推理过程的临时存储 | 当前对话的上下文 |
| 存储位置 | 内存变量 | LLM Context Window |
| 生命周期 | 单次任务执行 | 整个会话 |
| 内容 | Thought/Action/Observation | 对话历史 |
| 用途 | 支持推理和规划 | 保持对话连贯 |
关系:工作记忆的内容最终会变成短期记忆的一部分(如 ReAct 轨迹)。
Q2: 为什么 ReAct 框架需要工作记忆?
答案要点:
- 推理连续性:下一步推理需要上一步的结果
- 避免重复:存储已执行的动作和结果
- 错误恢复:记录错误便于调试和重试
- 执行追踪:完整的执行轨迹便于分析
Q3: 如何实现任务的中断恢复?
答案要点:
实现步骤:
1. 检查点机制
- 定期将工作记忆状态保存
- 保存到长期记忆存储
2. 恢复机制
- 从最近的检查点恢复
- 继续执行未完成的步骤
3. 状态管理
- 记录每个步骤的状态(pending/running/completed/failed)
- 记录依赖关系
4. 幂等性设计
- 每个步骤可重复执行
- 结果一致性Q4: 工作记忆的内容会占用 Context Window 吗?
答案要点:
会的,但有两种处理方式:
-
作为 Prompt 一部分:
- ReAct 轨迹会加入 Prompt
- 占用 Context Window
- LangChain 默认做法
-
仅内存存储:
- 只在内存中保存
- 不加入 Prompt
- 用于后续分析和持久化
最佳实践:关键轨迹加入 Prompt,次要信息仅内存保存。
Q5: 如何设计一个高效的 Agent 工作记忆系统?
答案要点:
设计要点:
1. 分层存储
- 热数据:内存(当前步骤)
- 温数据:内存(最近N步)
- 冷数据:压缩摘要
2. 容量管理
- 限制轨迹长度
- 自动压缩旧记录
- 丢弃无关信息
3. 结构化存储
- 使用 dataclass 定义结构
- 便于序列化和恢复
4. 可观测性
- 支持导出执行轨迹
- 便于调试和分析
5. 持久化支持
- 检查点机制
- 支持中断恢复七、总结
核心概念回顾
| 概念 | 定义 | 关键要点 |
|---|---|---|
| 工作记忆 | 推理过程的临时存储 | 单次任务生命周期 |
| 推理轨迹 | Thought/Action/Observation | ReAct 框架核心 |
| 检查点 | 任务状态的快照 | 支持中断恢复 |
| 执行状态 | 当前步骤和进度 | 任务跟踪 |
一句话总结
工作记忆是 Agent 推理过程的"草稿纸",存储执行轨迹和中间结果,支持 ReAct 等推理框架的实现,并可配合检查点机制实现任务中断恢复。
设计口诀
工作记忆设计口诀:
推理过程暂存处,生命周期单任务
轨迹记录 TAO,中间结果步步留
检查点定期存,中断恢复不用愁
容量管理要做好,压缩丢弃要讲究最后更新:2026年3月18日