refactor: optimize ClaudeSession context trimming with tag compression, fix tool_use parsing to use last match
This commit is contained in:
87
sidercall.py
87
sidercall.py
@@ -26,47 +26,30 @@ class SiderLLMSession:
|
||||
full_text = self._core.chat(prompt, model, stream=False)
|
||||
if stream: return iter([full_text]) # gen有奇怪的空回复或死循环行为,sider足够快
|
||||
return full_text
|
||||
|
||||
class GeminiSession:
|
||||
def __init__(self, api_key=None, default_model="gemini-2.0-flash-001", proxy=proxy):
|
||||
self.api_key = api_key or google_api_key
|
||||
if not self.api_key: raise ValueError("google_api_key 未配置或为空,请在 mykey.py 中设置")
|
||||
self.default_model = default_model
|
||||
self.proxies = {"http":proxy, "https":proxy} if proxy else None
|
||||
def ask(self, prompt, model=None, stream=False):
|
||||
if model is None: model = self.default_model
|
||||
url = f"https://generativelanguage.googleapis.com/v1/models/{model}:generateContent?key={self.api_key}"
|
||||
headers = {"Content-Type":"application/json"}
|
||||
data = {"contents":[{"role":"user","parts":[{"text":prompt}]}]}
|
||||
try:
|
||||
kw = {"headers":headers, "json":data, "timeout":60, 'proxies': self.proxies}
|
||||
r = requests.post(url, **kw)
|
||||
except Exception as e:
|
||||
return f"[GeminiError] request failed: {e}"
|
||||
if r.status_code != 200:
|
||||
body = r.text[:500].replace("\n"," ")
|
||||
return f"[GeminiError] HTTP {r.status_code}: {body}"
|
||||
try:
|
||||
obj = r.json(); cands = obj.get("candidates") or []
|
||||
if not cands: return "[GeminiError] empty candidates"
|
||||
parts = (cands[0].get("content") or {}).get("parts") or []
|
||||
full_text = "".join(p.get("text","") for p in parts)
|
||||
except Exception as e:
|
||||
return f"[GeminiError] invalid response format: {e}"
|
||||
return iter([full_text]) if stream else full_text
|
||||
|
||||
|
||||
class ClaudeSession:
|
||||
def __init__(self, api_key, api_base, model="claude-opus", context_win=24000):
|
||||
def __init__(self, api_key, api_base, model="claude-opus", context_win=12000):
|
||||
self.api_key, self.api_base, self.default_model, self.context_win = api_key, api_base.rstrip('/'), model, context_win
|
||||
self.raw_msgs, self.lock = [], threading.Lock()
|
||||
def _trim_messages(self, messages):
|
||||
total = sum(len(m['prompt'])//4 for m in messages)
|
||||
if total <= self.context_win: return messages
|
||||
target, current, result = self.context_win * 0.9, 0, []
|
||||
# 压缩4轮前的assistant消息:truncate <thinking>/<tool_use> 块
|
||||
for i, msg in enumerate(messages):
|
||||
if i < len(messages) - 4 and 'orig' not in msg:
|
||||
msg['orig'] = msg['prompt']
|
||||
for tag in ('thinking', 'tool_use', 'tool_result'):
|
||||
msg['prompt'] = re.sub(
|
||||
rf'(<{tag}>)([\s\S]*?)(</{tag}>)',
|
||||
lambda m: m.group(1) + (m.group(2)[:200] + '...') + m.group(3) if len(m.group(2)) > 200 else m.group(0),
|
||||
msg['prompt']
|
||||
)
|
||||
total = sum(len(m['prompt']) for m in messages)
|
||||
if total <= self.context_win * 4: return messages
|
||||
target, current, result = self.context_win * 4 * 0.9, 0, []
|
||||
for msg in reversed(messages):
|
||||
if (msg_len := len(msg['prompt'])//4) + current <= target:
|
||||
if (msg_len := len(msg['prompt'])) + current <= target:
|
||||
result.append(msg); current += msg_len
|
||||
else: break
|
||||
if current > 10000 * 4: print(f'[DEBUG] Whole context length {current//4}.')
|
||||
return result[::-1] or messages[-2:]
|
||||
def raw_ask(self, messages, model=None, temperature=0.5, max_tokens=4096):
|
||||
model = model or self.default_model
|
||||
@@ -195,6 +178,34 @@ class LLMSession:
|
||||
if stream: return _ask_gen()
|
||||
return ''.join(list(_ask_gen()))
|
||||
|
||||
|
||||
class GeminiSession:
|
||||
def __init__(self, api_key=None, default_model="gemini-2.0-flash-001", proxy=proxy):
|
||||
self.api_key = api_key or google_api_key
|
||||
if not self.api_key: raise ValueError("google_api_key 未配置或为空,请在 mykey.py 中设置")
|
||||
self.default_model = default_model
|
||||
self.proxies = {"http":proxy, "https":proxy} if proxy else None
|
||||
def ask(self, prompt, model=None, stream=False):
|
||||
if model is None: model = self.default_model
|
||||
url = f"https://generativelanguage.googleapis.com/v1/models/{model}:generateContent?key={self.api_key}"
|
||||
headers = {"Content-Type":"application/json"}
|
||||
data = {"contents":[{"role":"user","parts":[{"text":prompt}]}]}
|
||||
try:
|
||||
kw = {"headers":headers, "json":data, "timeout":60, 'proxies': self.proxies}
|
||||
r = requests.post(url, **kw)
|
||||
except Exception as e:
|
||||
return f"[GeminiError] request failed: {e}"
|
||||
if r.status_code != 200:
|
||||
body = r.text[:500].replace("\n"," ")
|
||||
return f"[GeminiError] HTTP {r.status_code}: {body}"
|
||||
try:
|
||||
obj = r.json(); cands = obj.get("candidates") or []
|
||||
if not cands: return "[GeminiError] empty candidates"
|
||||
parts = (cands[0].get("content") or {}).get("parts") or []
|
||||
full_text = "".join(p.get("text","") for p in parts)
|
||||
except Exception as e:
|
||||
return f"[GeminiError] invalid response format: {e}"
|
||||
return iter([full_text]) if stream else full_text
|
||||
|
||||
class MockFunction:
|
||||
def __init__(self, name, arguments):
|
||||
@@ -291,11 +302,11 @@ class ToolClient:
|
||||
|
||||
tool_calls = None
|
||||
tool_pattern = r"<tool_use>(.*?)</tool_use>"
|
||||
tool_match = re.search(tool_pattern, remaining_text, re.DOTALL)
|
||||
tool_all = re.findall(tool_pattern, remaining_text, re.DOTALL)
|
||||
|
||||
json_str = ""
|
||||
if tool_match:
|
||||
json_str = tool_match.group(1).strip()
|
||||
if tool_all:
|
||||
json_str = tool_all[-1].strip()
|
||||
remaining_text = re.sub(tool_pattern, "", remaining_text, flags=re.DOTALL)
|
||||
elif '<tool_use>' in remaining_text:
|
||||
weaktoolstr = remaining_text.split('<tool_use>')[-1].strip()
|
||||
@@ -317,7 +328,7 @@ class ToolClient:
|
||||
if func_name: tool_calls = [MockToolCall(func_name, args)]
|
||||
except json.JSONDecodeError as e:
|
||||
print("[Warn] Failed to parse tool_use JSON:", json_str)
|
||||
tool_calls = [MockToolCall('bad_json', {'msg': f'Failed to parse tool_use JSON: {str(e)}'})]
|
||||
tool_calls = [MockToolCall('bad_json', {'msg': f'Failed to parse tool_use JSON: {json_str[:200]}'})]
|
||||
except Exception as e:
|
||||
print("[Error] Exception during tool_use parsing:", str(e), data)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user