refactor: optimize ClaudeSession context trimming with tag compression, fix tool_use parsing to use last match
This commit is contained in:
85
sidercall.py
85
sidercall.py
@@ -27,46 +27,29 @@ class SiderLLMSession:
|
|||||||
if stream: return iter([full_text]) # gen有奇怪的空回复或死循环行为,sider足够快
|
if stream: return iter([full_text]) # gen有奇怪的空回复或死循环行为,sider足够快
|
||||||
return full_text
|
return full_text
|
||||||
|
|
||||||
class GeminiSession:
|
|
||||||
def __init__(self, api_key=None, default_model="gemini-2.0-flash-001", proxy=proxy):
|
|
||||||
self.api_key = api_key or google_api_key
|
|
||||||
if not self.api_key: raise ValueError("google_api_key 未配置或为空,请在 mykey.py 中设置")
|
|
||||||
self.default_model = default_model
|
|
||||||
self.proxies = {"http":proxy, "https":proxy} if proxy else None
|
|
||||||
def ask(self, prompt, model=None, stream=False):
|
|
||||||
if model is None: model = self.default_model
|
|
||||||
url = f"https://generativelanguage.googleapis.com/v1/models/{model}:generateContent?key={self.api_key}"
|
|
||||||
headers = {"Content-Type":"application/json"}
|
|
||||||
data = {"contents":[{"role":"user","parts":[{"text":prompt}]}]}
|
|
||||||
try:
|
|
||||||
kw = {"headers":headers, "json":data, "timeout":60, 'proxies': self.proxies}
|
|
||||||
r = requests.post(url, **kw)
|
|
||||||
except Exception as e:
|
|
||||||
return f"[GeminiError] request failed: {e}"
|
|
||||||
if r.status_code != 200:
|
|
||||||
body = r.text[:500].replace("\n"," ")
|
|
||||||
return f"[GeminiError] HTTP {r.status_code}: {body}"
|
|
||||||
try:
|
|
||||||
obj = r.json(); cands = obj.get("candidates") or []
|
|
||||||
if not cands: return "[GeminiError] empty candidates"
|
|
||||||
parts = (cands[0].get("content") or {}).get("parts") or []
|
|
||||||
full_text = "".join(p.get("text","") for p in parts)
|
|
||||||
except Exception as e:
|
|
||||||
return f"[GeminiError] invalid response format: {e}"
|
|
||||||
return iter([full_text]) if stream else full_text
|
|
||||||
|
|
||||||
class ClaudeSession:
|
class ClaudeSession:
|
||||||
def __init__(self, api_key, api_base, model="claude-opus", context_win=24000):
|
def __init__(self, api_key, api_base, model="claude-opus", context_win=12000):
|
||||||
self.api_key, self.api_base, self.default_model, self.context_win = api_key, api_base.rstrip('/'), model, context_win
|
self.api_key, self.api_base, self.default_model, self.context_win = api_key, api_base.rstrip('/'), model, context_win
|
||||||
self.raw_msgs, self.lock = [], threading.Lock()
|
self.raw_msgs, self.lock = [], threading.Lock()
|
||||||
def _trim_messages(self, messages):
|
def _trim_messages(self, messages):
|
||||||
total = sum(len(m['prompt'])//4 for m in messages)
|
# 压缩4轮前的assistant消息:truncate <thinking>/<tool_use> 块
|
||||||
if total <= self.context_win: return messages
|
for i, msg in enumerate(messages):
|
||||||
target, current, result = self.context_win * 0.9, 0, []
|
if i < len(messages) - 4 and 'orig' not in msg:
|
||||||
|
msg['orig'] = msg['prompt']
|
||||||
|
for tag in ('thinking', 'tool_use', 'tool_result'):
|
||||||
|
msg['prompt'] = re.sub(
|
||||||
|
rf'(<{tag}>)([\s\S]*?)(</{tag}>)',
|
||||||
|
lambda m: m.group(1) + (m.group(2)[:200] + '...') + m.group(3) if len(m.group(2)) > 200 else m.group(0),
|
||||||
|
msg['prompt']
|
||||||
|
)
|
||||||
|
total = sum(len(m['prompt']) for m in messages)
|
||||||
|
if total <= self.context_win * 4: return messages
|
||||||
|
target, current, result = self.context_win * 4 * 0.9, 0, []
|
||||||
for msg in reversed(messages):
|
for msg in reversed(messages):
|
||||||
if (msg_len := len(msg['prompt'])//4) + current <= target:
|
if (msg_len := len(msg['prompt'])) + current <= target:
|
||||||
result.append(msg); current += msg_len
|
result.append(msg); current += msg_len
|
||||||
else: break
|
else: break
|
||||||
|
if current > 10000 * 4: print(f'[DEBUG] Whole context length {current//4}.')
|
||||||
return result[::-1] or messages[-2:]
|
return result[::-1] or messages[-2:]
|
||||||
def raw_ask(self, messages, model=None, temperature=0.5, max_tokens=4096):
|
def raw_ask(self, messages, model=None, temperature=0.5, max_tokens=4096):
|
||||||
model = model or self.default_model
|
model = model or self.default_model
|
||||||
@@ -196,6 +179,34 @@ class LLMSession:
|
|||||||
return ''.join(list(_ask_gen()))
|
return ''.join(list(_ask_gen()))
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiSession:
|
||||||
|
def __init__(self, api_key=None, default_model="gemini-2.0-flash-001", proxy=proxy):
|
||||||
|
self.api_key = api_key or google_api_key
|
||||||
|
if not self.api_key: raise ValueError("google_api_key 未配置或为空,请在 mykey.py 中设置")
|
||||||
|
self.default_model = default_model
|
||||||
|
self.proxies = {"http":proxy, "https":proxy} if proxy else None
|
||||||
|
def ask(self, prompt, model=None, stream=False):
|
||||||
|
if model is None: model = self.default_model
|
||||||
|
url = f"https://generativelanguage.googleapis.com/v1/models/{model}:generateContent?key={self.api_key}"
|
||||||
|
headers = {"Content-Type":"application/json"}
|
||||||
|
data = {"contents":[{"role":"user","parts":[{"text":prompt}]}]}
|
||||||
|
try:
|
||||||
|
kw = {"headers":headers, "json":data, "timeout":60, 'proxies': self.proxies}
|
||||||
|
r = requests.post(url, **kw)
|
||||||
|
except Exception as e:
|
||||||
|
return f"[GeminiError] request failed: {e}"
|
||||||
|
if r.status_code != 200:
|
||||||
|
body = r.text[:500].replace("\n"," ")
|
||||||
|
return f"[GeminiError] HTTP {r.status_code}: {body}"
|
||||||
|
try:
|
||||||
|
obj = r.json(); cands = obj.get("candidates") or []
|
||||||
|
if not cands: return "[GeminiError] empty candidates"
|
||||||
|
parts = (cands[0].get("content") or {}).get("parts") or []
|
||||||
|
full_text = "".join(p.get("text","") for p in parts)
|
||||||
|
except Exception as e:
|
||||||
|
return f"[GeminiError] invalid response format: {e}"
|
||||||
|
return iter([full_text]) if stream else full_text
|
||||||
|
|
||||||
class MockFunction:
|
class MockFunction:
|
||||||
def __init__(self, name, arguments):
|
def __init__(self, name, arguments):
|
||||||
self.name = name
|
self.name = name
|
||||||
@@ -291,11 +302,11 @@ class ToolClient:
|
|||||||
|
|
||||||
tool_calls = None
|
tool_calls = None
|
||||||
tool_pattern = r"<tool_use>(.*?)</tool_use>"
|
tool_pattern = r"<tool_use>(.*?)</tool_use>"
|
||||||
tool_match = re.search(tool_pattern, remaining_text, re.DOTALL)
|
tool_all = re.findall(tool_pattern, remaining_text, re.DOTALL)
|
||||||
|
|
||||||
json_str = ""
|
json_str = ""
|
||||||
if tool_match:
|
if tool_all:
|
||||||
json_str = tool_match.group(1).strip()
|
json_str = tool_all[-1].strip()
|
||||||
remaining_text = re.sub(tool_pattern, "", remaining_text, flags=re.DOTALL)
|
remaining_text = re.sub(tool_pattern, "", remaining_text, flags=re.DOTALL)
|
||||||
elif '<tool_use>' in remaining_text:
|
elif '<tool_use>' in remaining_text:
|
||||||
weaktoolstr = remaining_text.split('<tool_use>')[-1].strip()
|
weaktoolstr = remaining_text.split('<tool_use>')[-1].strip()
|
||||||
@@ -317,7 +328,7 @@ class ToolClient:
|
|||||||
if func_name: tool_calls = [MockToolCall(func_name, args)]
|
if func_name: tool_calls = [MockToolCall(func_name, args)]
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
print("[Warn] Failed to parse tool_use JSON:", json_str)
|
print("[Warn] Failed to parse tool_use JSON:", json_str)
|
||||||
tool_calls = [MockToolCall('bad_json', {'msg': f'Failed to parse tool_use JSON: {str(e)}'})]
|
tool_calls = [MockToolCall('bad_json', {'msg': f'Failed to parse tool_use JSON: {json_str[:200]}'})]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("[Error] Exception during tool_use parsing:", str(e), data)
|
print("[Error] Exception during tool_use parsing:", str(e), data)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user