import json, re, os from dataclasses import dataclass from typing import Any, Optional @dataclass class StepOutcome: data: Any next_prompt: Optional[str] = None should_exit: bool = False def try_call_generator(func, *args, **kwargs): ret = func(*args, **kwargs) if hasattr(ret, '__iter__') and not isinstance(ret, (str, bytes, dict, list)): ret = yield from ret return ret class BaseHandler: def tool_before_callback(self, tool_name, args, response): pass def tool_after_callback(self, tool_name, args, response, ret): pass def next_prompt_patcher(self, next_prompt, outcome, turn): return next_prompt def dispatch(self, tool_name, args, response, index=0): method_name = f"do_{tool_name}" if hasattr(self, method_name): args['_index'] = index prer = yield from try_call_generator(self.tool_before_callback, tool_name, args, response) ret = yield from try_call_generator(getattr(self, method_name), args, response) _ = yield from try_call_generator(self.tool_after_callback, tool_name, args, response, ret) return ret elif tool_name == 'bad_json': return StepOutcome(None, next_prompt=args.get('msg', 'bad_json'), should_exit=False) else: yield f"未知工具: {tool_name}\n" return StepOutcome(None, next_prompt=f"未知工具 {tool_name}", should_exit=False) def json_default(o): if isinstance(o, set): return list(o) return str(o) def exhaust(g): try: while True: next(g) except StopIteration as e: return e.value def get_pretty_json(data): if isinstance(data, dict) and "script" in data: data = data.copy(); data["script"] = data["script"].replace("; ", ";\n ") return json.dumps(data, indent=2, ensure_ascii=False).replace('\\n', '\n') def agent_runner_loop(client, system_prompt, user_input, handler, tools_schema, max_turns=40, verbose=True, initial_user_content=None): messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": initial_user_content if initial_user_content is not None else user_input} ] turn = 0; handler._done_hooks = []; handler.max_turns = max_turns while turn < handler.max_turns: turn += 1; md = '**' if verbose else '' yield f"{md}LLM Running (Turn {turn}) ...{md}\n\n" if turn%10 == 0: client.last_tools = '' # 每10轮重置一次工具描述,避免上下文过大导致的模型性能下降 response_gen = client.chat(messages=messages, tools=tools_schema) if verbose: response = yield from response_gen yield '\n\n' else: response = exhaust(response_gen) cleaned = _clean_content(response.content) if cleaned: yield cleaned + '\n' if not response.tool_calls: tool_calls = [{'tool_name': 'no_tool', 'args': {}}] else: tool_calls = [{'tool_name': tc.function.name, 'args': json.loads(tc.function.arguments), 'id': tc.id} for tc in response.tool_calls] tool_results = []; next_prompts = set(); should_exit = None for ii, tc in enumerate(tool_calls): tool_name, args, tid = tc['tool_name'], tc['args'], tc.get('id', '') if tool_name == 'no_tool': pass else: if verbose: yield f"🛠️ 正在调用工具: `{tool_name}` 📥参数:\n````text\n{get_pretty_json(args)}\n````\n" else: yield f"🛠️ {tool_name}({_compact_tool_args(tool_name, args)})\n\n\n" handler.current_turn = turn gen = handler.dispatch(tool_name, args, response, index=ii) try: v = next(gen) def proxy(): yield v; return (yield from gen) if verbose: yield '`````\n' outcome = (yield from proxy()) if verbose else exhaust(proxy()) if verbose: yield '`````\n' except StopIteration as e: outcome = e.value if outcome.should_exit: return {'result': 'EXITED', 'data': outcome.data} # should_exit is only used for immediate exit if not outcome.next_prompt: should_exit = {'result': 'CURRENT_TASK_DONE', 'data': outcome.data}; break if outcome.next_prompt.startswith('未知工具'): client.last_tools = '' if outcome.data is not None: datastr = json.dumps(outcome.data, ensure_ascii=False, default=json_default) if type(outcome.data) in [dict, list] else str(outcome.data) tool_results.append({'tool_use_id': tid, 'content': datastr}) next_prompts.add(outcome.next_prompt) if len(next_prompts) == 0: if len(handler._done_hooks) == 0: return should_exit next_prompts.add(handler._done_hooks.pop(0)) next_prompt = handler.next_prompt_patcher("\n".join(next_prompts), None, turn) messages = [{"role": "user", "content": next_prompt, "tool_results": tool_results}] # just new message, history is kept in *Session return {'result': 'MAX_TURNS_EXCEEDED'} def _clean_content(text): if not text: return '' def _shrink_code(m): lines = m.group(0).split('\n') lang = lines[0].replace('```','').strip() body = [l for l in lines[1:-1] if l.strip()] # 去掉```行和空行 if len(body) <= 6: return m.group(0) # 短代码保留 preview = '\n'.join(body[:5]) return f'```{lang}\n{preview}\n ... ({len(body)} lines)\n```' text = re.sub(r'```[\s\S]*?```', _shrink_code, text) for p in [r'[\s\S]*?', r'[\s\S]*?', r'(\r?\n){3,}']: text = re.sub(p, '\n\n' if '\\n' in p else '', text) return text.strip() def _compact_tool_args(name, args): a = {k: v for k, v in args.items() if k != '_index'} for k in ('path',): # 只缩短路径 if k in a: a[k] = os.path.basename(a[k]) if name == 'update_working_checkpoint': s = a.get('key_info', ''); return (s[:60]+'...') if len(s)>60 else s s = json.dumps(a, ensure_ascii=False); return (s[:120]+'...') if len(s)>120 else s