93 lines
4.2 KiB
Python
93 lines
4.2 KiB
Python
import json, re
|
||
from dataclasses import dataclass
|
||
from typing import Any, Optional
|
||
@dataclass
|
||
class StepOutcome:
|
||
data: Any
|
||
next_prompt: Optional[str] = None
|
||
should_exit: bool = False
|
||
|
||
def try_call_generator(func, *args, **kwargs):
|
||
ret = func(*args, **kwargs)
|
||
if hasattr(ret, '__iter__') and not isinstance(ret, (str, bytes, dict, list)):
|
||
ret = yield from ret
|
||
return ret
|
||
|
||
class BaseHandler:
|
||
def tool_before_callback(self, tool_name, args, response): pass
|
||
def tool_after_callback(self, tool_name, args, response, ret): pass
|
||
def dispatch(self, tool_name, args, response):
|
||
method_name = f"do_{tool_name}"
|
||
if hasattr(self, method_name):
|
||
_ = yield from try_call_generator(self.tool_before_callback, tool_name, args, response)
|
||
ret = yield from try_call_generator(getattr(self, method_name), args, response)
|
||
_ = yield from try_call_generator(self.tool_after_callback, tool_name, args, response, ret)
|
||
return ret
|
||
elif tool_name == 'bad_json':
|
||
return StepOutcome(None, next_prompt=args.get('msg', 'bad_json'), should_exit=False)
|
||
else:
|
||
yield f"未知工具: {tool_name}\n"
|
||
return StepOutcome(None, next_prompt=f"未知工具 {tool_name}", should_exit=False)
|
||
|
||
def json_default(o):
|
||
if isinstance(o, set): return list(o)
|
||
return str(o)
|
||
|
||
def exhaust(g):
|
||
try:
|
||
while True: next(g)
|
||
except StopIteration as e: return e.value
|
||
|
||
def get_pretty_json(data):
|
||
if isinstance(data, dict) and "script" in data:
|
||
data = data.copy()
|
||
data["script"] = data["script"].replace("; ", ";\n ")
|
||
return json.dumps(data, indent=2, ensure_ascii=False).replace('\\n', '\n')
|
||
|
||
def agent_runner_loop(client, system_prompt, user_input, handler, tools_schema, max_turns=15, verbose=True):
|
||
messages = [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_input}
|
||
]
|
||
for turn in range(max_turns):
|
||
yield f"**LLM Running (Turn {turn+1}) ...**\n\n"
|
||
if (turn+1) % 10 == 0: client.last_tools = '' # 每10轮重置一次工具描述,避免上下文过大导致的模型性能下降
|
||
response_gen = client.chat(messages=messages, tools=tools_schema)
|
||
response = yield from response_gen
|
||
if verbose: yield '\n\n'
|
||
|
||
if not response.tool_calls:
|
||
tool_name, args = 'no_tool', {}
|
||
else:
|
||
tool_call = response.tool_calls[0]
|
||
tool_name = tool_call.function.name
|
||
args = json.loads(tool_call.function.arguments)
|
||
|
||
if tool_name == 'no_tool': pass
|
||
else:
|
||
showarg = get_pretty_json(args)
|
||
if not verbose and len(showarg) > 200: showarg = showarg[:200] + ' ...'
|
||
yield f"🛠️ **正在调用工具:** `{tool_name}` 📥**参数:**\n````text\n{showarg}\n````\n"
|
||
gen = handler.dispatch(tool_name, args, response)
|
||
if verbose:
|
||
yield '`````\n'
|
||
outcome = yield from gen
|
||
yield '`````\n'
|
||
else:
|
||
outcome = exhaust(gen)
|
||
|
||
if outcome.next_prompt is None: return {'result': 'CURRENT_TASK_DONE', 'data': outcome.data}
|
||
if outcome.should_exit: return {'result': 'EXITED', 'data': outcome.data}
|
||
if outcome.next_prompt.startswith('未知工具'): client.last_tools = ''
|
||
|
||
next_prompt = ""
|
||
if outcome.data is not None:
|
||
datastr = json.dumps(outcome.data, ensure_ascii=False, default=json_default) if type(outcome.data) in [dict, list] else str(outcome.data)
|
||
next_prompt += f"<tool_result>\n{datastr}\n</tool_result>\n\n"
|
||
next_prompt += outcome.next_prompt
|
||
if (turn+1) % 7 == 0:
|
||
next_prompt += f"\n\n[DANGER] 已连续执行第 {turn+1} 轮。禁止无效重试。若无有效进展,必须切换策略:1. 探测物理边界 2. 请求用户协助。如有需要,可调用 update_working_checkpoint 保存关键上下文。"
|
||
if (turn+1) % 30 == 0:
|
||
next_prompt += f"\n\n### [DANGER] 已连续执行第 {turn+1} 轮。你必须总结情况进行ask_user,不允许继续重试。"
|
||
messages = [{"role": "user", "content": next_prompt}]
|
||
return {'result': 'MAX_TURNS_EXCEEDED'} |