refactor: optimize ClaudeSession context trimming with tag compression, fix tool_use parsing to use last match

This commit is contained in:
Liang Jiaqing
2026-02-12 22:52:47 +08:00
parent 4a2807f200
commit 39f6a851dd

View File

@@ -27,46 +27,29 @@ class SiderLLMSession:
if stream: return iter([full_text]) # gen有奇怪的空回复或死循环行为sider足够快
return full_text
class GeminiSession:
def __init__(self, api_key=None, default_model="gemini-2.0-flash-001", proxy=proxy):
self.api_key = api_key or google_api_key
if not self.api_key: raise ValueError("google_api_key 未配置或为空,请在 mykey.py 中设置")
self.default_model = default_model
self.proxies = {"http":proxy, "https":proxy} if proxy else None
def ask(self, prompt, model=None, stream=False):
if model is None: model = self.default_model
url = f"https://generativelanguage.googleapis.com/v1/models/{model}:generateContent?key={self.api_key}"
headers = {"Content-Type":"application/json"}
data = {"contents":[{"role":"user","parts":[{"text":prompt}]}]}
try:
kw = {"headers":headers, "json":data, "timeout":60, 'proxies': self.proxies}
r = requests.post(url, **kw)
except Exception as e:
return f"[GeminiError] request failed: {e}"
if r.status_code != 200:
body = r.text[:500].replace("\n"," ")
return f"[GeminiError] HTTP {r.status_code}: {body}"
try:
obj = r.json(); cands = obj.get("candidates") or []
if not cands: return "[GeminiError] empty candidates"
parts = (cands[0].get("content") or {}).get("parts") or []
full_text = "".join(p.get("text","") for p in parts)
except Exception as e:
return f"[GeminiError] invalid response format: {e}"
return iter([full_text]) if stream else full_text
class ClaudeSession:
def __init__(self, api_key, api_base, model="claude-opus", context_win=24000):
def __init__(self, api_key, api_base, model="claude-opus", context_win=12000):
self.api_key, self.api_base, self.default_model, self.context_win = api_key, api_base.rstrip('/'), model, context_win
self.raw_msgs, self.lock = [], threading.Lock()
def _trim_messages(self, messages):
total = sum(len(m['prompt'])//4 for m in messages)
if total <= self.context_win: return messages
target, current, result = self.context_win * 0.9, 0, []
# 压缩4轮前的assistant消息truncate <thinking>/<tool_use> 块
for i, msg in enumerate(messages):
if i < len(messages) - 4 and 'orig' not in msg:
msg['orig'] = msg['prompt']
for tag in ('thinking', 'tool_use', 'tool_result'):
msg['prompt'] = re.sub(
rf'(<{tag}>)([\s\S]*?)(</{tag}>)',
lambda m: m.group(1) + (m.group(2)[:200] + '...') + m.group(3) if len(m.group(2)) > 200 else m.group(0),
msg['prompt']
)
total = sum(len(m['prompt']) for m in messages)
if total <= self.context_win * 4: return messages
target, current, result = self.context_win * 4 * 0.9, 0, []
for msg in reversed(messages):
if (msg_len := len(msg['prompt'])//4) + current <= target:
if (msg_len := len(msg['prompt'])) + current <= target:
result.append(msg); current += msg_len
else: break
if current > 10000 * 4: print(f'[DEBUG] Whole context length {current//4}.')
return result[::-1] or messages[-2:]
def raw_ask(self, messages, model=None, temperature=0.5, max_tokens=4096):
model = model or self.default_model
@@ -196,6 +179,34 @@ class LLMSession:
return ''.join(list(_ask_gen()))
class GeminiSession:
def __init__(self, api_key=None, default_model="gemini-2.0-flash-001", proxy=proxy):
self.api_key = api_key or google_api_key
if not self.api_key: raise ValueError("google_api_key 未配置或为空,请在 mykey.py 中设置")
self.default_model = default_model
self.proxies = {"http":proxy, "https":proxy} if proxy else None
def ask(self, prompt, model=None, stream=False):
if model is None: model = self.default_model
url = f"https://generativelanguage.googleapis.com/v1/models/{model}:generateContent?key={self.api_key}"
headers = {"Content-Type":"application/json"}
data = {"contents":[{"role":"user","parts":[{"text":prompt}]}]}
try:
kw = {"headers":headers, "json":data, "timeout":60, 'proxies': self.proxies}
r = requests.post(url, **kw)
except Exception as e:
return f"[GeminiError] request failed: {e}"
if r.status_code != 200:
body = r.text[:500].replace("\n"," ")
return f"[GeminiError] HTTP {r.status_code}: {body}"
try:
obj = r.json(); cands = obj.get("candidates") or []
if not cands: return "[GeminiError] empty candidates"
parts = (cands[0].get("content") or {}).get("parts") or []
full_text = "".join(p.get("text","") for p in parts)
except Exception as e:
return f"[GeminiError] invalid response format: {e}"
return iter([full_text]) if stream else full_text
class MockFunction:
def __init__(self, name, arguments):
self.name = name
@@ -291,11 +302,11 @@ class ToolClient:
tool_calls = None
tool_pattern = r"<tool_use>(.*?)</tool_use>"
tool_match = re.search(tool_pattern, remaining_text, re.DOTALL)
tool_all = re.findall(tool_pattern, remaining_text, re.DOTALL)
json_str = ""
if tool_match:
json_str = tool_match.group(1).strip()
if tool_all:
json_str = tool_all[-1].strip()
remaining_text = re.sub(tool_pattern, "", remaining_text, flags=re.DOTALL)
elif '<tool_use>' in remaining_text:
weaktoolstr = remaining_text.split('<tool_use>')[-1].strip()
@@ -317,7 +328,7 @@ class ToolClient:
if func_name: tool_calls = [MockToolCall(func_name, args)]
except json.JSONDecodeError as e:
print("[Warn] Failed to parse tool_use JSON:", json_str)
tool_calls = [MockToolCall('bad_json', {'msg': f'Failed to parse tool_use JSON: {str(e)}'})]
tool_calls = [MockToolCall('bad_json', {'msg': f'Failed to parse tool_use JSON: {json_str[:200]}'})]
except Exception as e:
print("[Error] Exception during tool_use parsing:", str(e), data)