refactor: move tools ownership from NativeToolClient to Session layer

- tools state now held by Session (NativeClaudeSession.tools)
- MixinSession.__setattr__ broadcasts tools/system to all sub-sessions
- NativeToolClient no longer duplicates tools storage
- fix: use type(s) is instead of isinstance to avoid catching NativeOAISession subclass
This commit is contained in:
Liang Jiaqing
2026-04-03 23:05:28 +08:00
parent 555eeabf56
commit 14125ed57c

View File

@@ -485,7 +485,7 @@ class NativeClaudeSession(BaseSession):
self._session_id = str(uuid.uuid4()) self._session_id = str(uuid.uuid4())
self._account_uuid = str(uuid.uuid4()) self._account_uuid = str(uuid.uuid4())
self._device_id = uuid.uuid4().hex + uuid.uuid4().hex[:32] self._device_id = uuid.uuid4().hex + uuid.uuid4().hex[:32]
self.tools = None
def raw_ask(self, messages, tools=None, system=None, model=None, temperature=0.5, max_tokens=6144): def raw_ask(self, messages, tools=None, system=None, model=None, temperature=0.5, max_tokens=6144):
model = model or self.default_model model = model or self.default_model
headers = {"Content-Type": "application/json", "anthropic-version": "2023-06-01", headers = {"Content-Type": "application/json", "anthropic-version": "2023-06-01",
@@ -516,7 +516,7 @@ class NativeClaudeSession(BaseSession):
content_blocks = yield from _parse_claude_sse(resp.iter_lines()) content_blocks = yield from _parse_claude_sse(resp.iter_lines())
return content_blocks or [] return content_blocks or []
def ask(self, msg, tools=None, model=None): def ask(self, msg, model=None):
assert type(msg) is dict assert type(msg) is dict
with self.lock: with self.lock:
self.history.append(msg) self.history.append(msg)
@@ -524,7 +524,7 @@ class NativeClaudeSession(BaseSession):
messages = [{"role": m["role"], "content": list(m["content"])} for m in self.history] messages = [{"role": m["role"], "content": list(m["content"])} for m in self.history]
content_blocks = None content_blocks = None
gen = self.raw_ask(messages, tools, self.system, model) gen = self.raw_ask(messages, self.tools, self.system, model)
try: try:
while True: yield next(gen) while True: yield next(gen)
except StopIteration as e: content_blocks = e.value or [] except StopIteration as e: content_blocks = e.value or []
@@ -750,6 +750,12 @@ class MixinSession:
self.default_model = getattr(self._sessions[0], 'default_model', None) self.default_model = getattr(self._sessions[0], 'default_model', None)
self._cur_idx, self._switched_at = 0, 0.0 self._cur_idx, self._switched_at = 0, 0.0
def __getattr__(self, name): return getattr(self._sessions[0], name) def __getattr__(self, name): return getattr(self._sessions[0], name)
def __setattr__(self, name, value):
if name in ('system', 'tools'):
for s in self._sessions:
v = openai_tools_to_claude(value) if name == 'tools' and type(s) is NativeClaudeSession else value
setattr(s, name, v)
else: object.__setattr__(self, name, value)
@property @property
def primary(self): return self._sessions[0] def primary(self): return self._sessions[0]
def _pick(self): def _pick(self):
@@ -793,7 +799,6 @@ class NativeToolClient:
def __init__(self, backend): def __init__(self, backend):
self.backend = backend self.backend = backend
self.backend.system = self.THINKING_PROMPT self.backend.system = self.THINKING_PROMPT
self.tools = {}
self.name = self.backend.name self.name = self.backend.name
self._pending_tool_ids = [] self._pending_tool_ids = []
def set_system(self, extra_system): def set_system(self, extra_system):
@@ -801,7 +806,7 @@ class NativeToolClient:
if combined != self.backend.system: print(f"[Debug] Updated system prompt, length {len(combined)} chars.") if combined != self.backend.system: print(f"[Debug] Updated system prompt, length {len(combined)} chars.")
self.backend.system = combined self.backend.system = combined
def chat(self, messages, tools=None): def chat(self, messages, tools=None):
if tools: self.tools = openai_tools_to_claude(tools) if type(self.backend) is NativeClaudeSession else tools if tools: self.backend.tools = tools
combined_content = []; resp = None; tool_results = [] combined_content = []; resp = None; tool_results = []
for msg in messages: for msg in messages:
c = msg.get('content', '') c = msg.get('content', '')
@@ -821,7 +826,7 @@ class NativeToolClient:
self._pending_tool_ids = [] self._pending_tool_ids = []
merged = {"role": "user", "content": tool_result_blocks + combined_content} merged = {"role": "user", "content": tool_result_blocks + combined_content}
_write_llm_log('Prompt', json.dumps(merged, ensure_ascii=False, indent=2)) _write_llm_log('Prompt', json.dumps(merged, ensure_ascii=False, indent=2))
gen = self.backend.ask(merged, self.tools); gen = self.backend.ask(merged)
try: try:
while True: while True:
chunk = next(gen); yield chunk chunk = next(gen); yield chunk