refactor llmcore: extract _parse_claude_sse, simplify ToolClient, enhance logging and context trimming
This commit is contained in:
348
llmcore.py
348
llmcore.py
@@ -14,20 +14,19 @@ mykeys = _load_mykeys()
|
|||||||
proxy = mykeys.get("proxy", 'http://127.0.0.1:2082')
|
proxy = mykeys.get("proxy", 'http://127.0.0.1:2082')
|
||||||
proxies = {"http": proxy, "https": proxy} if proxy else None
|
proxies = {"http": proxy, "https": proxy} if proxy else None
|
||||||
|
|
||||||
def compress_history_tags(messages, keep_recent=10, max_len=1000):
|
def compress_history_tags(messages, keep_recent=10, max_len=800):
|
||||||
"""Compress <thinking>/<tool_use>/<tool_result> tags in older messages to save tokens.
|
"""Compress <thinking>/<tool_use>/<tool_result> tags in older messages to save tokens.
|
||||||
Supports both prompt-style (ClaudeSession/LLMSession) and content-style (NativeClaudeSession) messages."""
|
Supports both prompt-style (ClaudeSession/LLMSession) and content-style (NativeClaudeSession) messages."""
|
||||||
compress_history_tags._cd = getattr(compress_history_tags, '_cd', 0) + 1
|
compress_history_tags._cd = getattr(compress_history_tags, '_cd', 0) + 1
|
||||||
if compress_history_tags._cd % 5 != 0: return messages
|
if compress_history_tags._cd % 5 != 0: return messages
|
||||||
|
_before = sum(len(json.dumps(m)) for m in messages)
|
||||||
_pats = {tag: re.compile(rf'(<{tag}>)([\s\S]*?)(</{tag}>)') for tag in ('thinking', 'tool_use', 'tool_result')}
|
_pats = {tag: re.compile(rf'(<{tag}>)([\s\S]*?)(</{tag}>)') for tag in ('thinking', 'tool_use', 'tool_result')}
|
||||||
def _trunc(text):
|
def _trunc(text):
|
||||||
for pat in _pats.values(): text = pat.sub(lambda m: m.group(1) + m.group(2)[:max_len] + '...' + m.group(3) if len(m.group(2)) > max_len else m.group(0), text)
|
for pat in _pats.values(): text = pat.sub(lambda m: m.group(1) + m.group(2)[:max_len] + '...' + m.group(3) if len(m.group(2)) > max_len else m.group(0), text)
|
||||||
return text
|
return text
|
||||||
for i, msg in enumerate(messages):
|
for i, msg in enumerate(messages):
|
||||||
if i >= len(messages) - keep_recent: break
|
if i >= len(messages) - keep_recent: break
|
||||||
if 'prompt' in msg and 'orig' not in msg:
|
if 'prompt' in msg: msg['prompt'] = _trunc(msg['prompt'])
|
||||||
msg['orig'] = msg['prompt']
|
|
||||||
msg['prompt'] = _trunc(msg['prompt'])
|
|
||||||
elif 'content' in msg and 'prompt' not in msg:
|
elif 'content' in msg and 'prompt' not in msg:
|
||||||
c = msg['content']
|
c = msg['content']
|
||||||
if isinstance(c, str): msg['content'] = _trunc(c)
|
if isinstance(c, str): msg['content'] = _trunc(c)
|
||||||
@@ -35,6 +34,7 @@ def compress_history_tags(messages, keep_recent=10, max_len=1000):
|
|||||||
for block in c:
|
for block in c:
|
||||||
if isinstance(block, dict) and block.get('type') == 'text' and isinstance(block.get('text'), str):
|
if isinstance(block, dict) and block.get('type') == 'text' and isinstance(block.get('text'), str):
|
||||||
block['text'] = _trunc(block['text'])
|
block['text'] = _trunc(block['text'])
|
||||||
|
print(f"[Cut] {_before} -> {sum(len(json.dumps(m)) for m in messages)}")
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def auto_make_url(base, path):
|
def auto_make_url(base, path):
|
||||||
@@ -75,62 +75,171 @@ class SiderLLMSession:
|
|||||||
if stream: return iter([full_text]) # gen有奇怪的空回复或死循环行为,sider足够快
|
if stream: return iter([full_text]) # gen有奇怪的空回复或死循环行为,sider足够快
|
||||||
return full_text
|
return full_text
|
||||||
|
|
||||||
|
def _parse_claude_sse(resp_lines):
|
||||||
|
"""Parse Anthropic SSE stream. Yields text chunks, returns list[content_block]."""
|
||||||
|
content_blocks = []; current_block = None; tool_json_buf = ""
|
||||||
|
for line in resp_lines:
|
||||||
|
if not line: continue
|
||||||
|
line = line.decode('utf-8') if isinstance(line, bytes) else line
|
||||||
|
if not line.startswith("data:"): continue
|
||||||
|
data_str = line[5:].lstrip()
|
||||||
|
if data_str == "[DONE]": break
|
||||||
|
try: evt = json.loads(data_str)
|
||||||
|
except: continue
|
||||||
|
evt_type = evt.get("type", "")
|
||||||
|
if evt_type == "message_start":
|
||||||
|
usage = evt.get("message", {}).get("usage", {})
|
||||||
|
ci, cr, inp = usage.get("cache_creation_input_tokens", 0), usage.get("cache_read_input_tokens", 0), usage.get("input_tokens", 0)
|
||||||
|
print(f"[Cache] input={inp} creation={ci} read={cr}")
|
||||||
|
elif evt_type == "content_block_start":
|
||||||
|
block = evt.get("content_block", {})
|
||||||
|
if block.get("type") == "text": current_block = {"type": "text", "text": ""}
|
||||||
|
elif block.get("type") == "tool_use":
|
||||||
|
current_block = {"type": "tool_use", "id": block.get("id", ""), "name": block.get("name", ""), "input": {}}
|
||||||
|
tool_json_buf = ""
|
||||||
|
elif evt_type == "content_block_delta":
|
||||||
|
delta = evt.get("delta", {})
|
||||||
|
if delta.get("type") == "text_delta":
|
||||||
|
text = delta.get("text", "")
|
||||||
|
if current_block and current_block.get("type") == "text": current_block["text"] += text
|
||||||
|
if text: yield text
|
||||||
|
elif delta.get("type") == "input_json_delta": tool_json_buf += delta.get("partial_json", "")
|
||||||
|
elif evt_type == "content_block_stop":
|
||||||
|
if current_block:
|
||||||
|
if current_block["type"] == "tool_use":
|
||||||
|
try: current_block["input"] = json.loads(tool_json_buf) if tool_json_buf else {}
|
||||||
|
except: current_block["input"] = {"_raw": tool_json_buf}
|
||||||
|
content_blocks.append(current_block)
|
||||||
|
current_block = None
|
||||||
|
return content_blocks
|
||||||
|
|
||||||
|
def _parse_openai_sse(resp_lines, api_mode="chat_completions"):
|
||||||
|
"""Parse OpenAI SSE stream (chat_completions or responses API).
|
||||||
|
Yields text chunks, returns list[content_block].
|
||||||
|
content_block: {type:'text', text:str} | {type:'tool_use', id:str, name:str, input:dict}
|
||||||
|
"""
|
||||||
|
content_text = ""
|
||||||
|
if api_mode == "responses":
|
||||||
|
seen_delta = False; fc_buf = {}; current_fc_idx = None
|
||||||
|
for line in resp_lines:
|
||||||
|
if not line: continue
|
||||||
|
line = line.decode('utf-8', errors='replace') if isinstance(line, bytes) else line
|
||||||
|
if not line.startswith("data:"): continue
|
||||||
|
data_str = line[5:].lstrip()
|
||||||
|
if data_str == "[DONE]": break
|
||||||
|
try: evt = json.loads(data_str)
|
||||||
|
except: continue
|
||||||
|
etype = evt.get("type", "")
|
||||||
|
if etype == "response.output_text.delta":
|
||||||
|
delta = evt.get("delta", "")
|
||||||
|
if delta: seen_delta = True; content_text += delta; yield delta
|
||||||
|
elif etype == "response.output_text.done" and not seen_delta:
|
||||||
|
text = evt.get("text", "")
|
||||||
|
if text: content_text += text; yield text
|
||||||
|
elif etype == "response.output_item.added":
|
||||||
|
item = evt.get("item", {})
|
||||||
|
if item.get("type") == "function_call":
|
||||||
|
idx = evt.get("output_index", 0)
|
||||||
|
fc_buf[idx] = {"id": item.get("call_id", item.get("id", "")), "name": item.get("name", ""), "args": ""}
|
||||||
|
current_fc_idx = idx
|
||||||
|
elif etype == "response.function_call_arguments.delta":
|
||||||
|
idx = evt.get("output_index", current_fc_idx or 0)
|
||||||
|
if idx in fc_buf: fc_buf[idx]["args"] += evt.get("delta", "")
|
||||||
|
elif etype == "response.function_call_arguments.done":
|
||||||
|
idx = evt.get("output_index", current_fc_idx or 0)
|
||||||
|
if idx in fc_buf: fc_buf[idx]["args"] = evt.get("arguments", fc_buf[idx]["args"])
|
||||||
|
elif etype == "error":
|
||||||
|
err = evt.get("error", {})
|
||||||
|
emsg = err.get("message", str(err)) if isinstance(err, dict) else str(err)
|
||||||
|
if emsg: content_text += f"Error: {emsg}"; yield f"Error: {emsg}"
|
||||||
|
break
|
||||||
|
elif etype == "response.completed":
|
||||||
|
usage = evt.get("response", {}).get("usage", {})
|
||||||
|
cached = (usage.get("input_tokens_details") or {}).get("cached_tokens", 0)
|
||||||
|
inp = usage.get("input_tokens", 0)
|
||||||
|
if inp: print(f"[Cache] input={inp} cached={cached}")
|
||||||
|
break
|
||||||
|
blocks = []
|
||||||
|
if content_text: blocks.append({"type": "text", "text": content_text})
|
||||||
|
for idx in sorted(fc_buf):
|
||||||
|
fc = fc_buf[idx]
|
||||||
|
try: inp = json.loads(fc["args"]) if fc["args"] else {}
|
||||||
|
except: inp = {"_raw": fc["args"]}
|
||||||
|
blocks.append({"type": "tool_use", "id": fc["id"], "name": fc["name"], "input": inp})
|
||||||
|
return blocks
|
||||||
|
else:
|
||||||
|
tc_buf = {} # index -> {id, name, args}
|
||||||
|
for line in resp_lines:
|
||||||
|
if not line: continue
|
||||||
|
line = line.decode('utf-8', errors='replace') if isinstance(line, bytes) else line
|
||||||
|
if not line.startswith("data:"): continue
|
||||||
|
data_str = line[5:].lstrip()
|
||||||
|
if data_str == "[DONE]": break
|
||||||
|
try: evt = json.loads(data_str)
|
||||||
|
except: continue
|
||||||
|
ch = (evt.get("choices") or [{}])[0]
|
||||||
|
delta = ch.get("delta", {})
|
||||||
|
if delta.get("content"):
|
||||||
|
text = delta["content"]; content_text += text; yield text
|
||||||
|
for tc in delta.get("tool_calls", []):
|
||||||
|
idx = tc.get("index", 0)
|
||||||
|
if idx not in tc_buf: tc_buf[idx] = {"id": tc.get("id", ""), "name": "", "args": ""}
|
||||||
|
if tc.get("function", {}).get("name"): tc_buf[idx]["name"] = tc["function"]["name"]
|
||||||
|
if tc.get("function", {}).get("arguments"): tc_buf[idx]["args"] += tc["function"]["arguments"]
|
||||||
|
usage = evt.get("usage")
|
||||||
|
if usage:
|
||||||
|
cached = (usage.get("prompt_tokens_details") or {}).get("cached_tokens", 0)
|
||||||
|
print(f"[Cache] input={usage.get('prompt_tokens',0)} cached={cached}")
|
||||||
|
blocks = []
|
||||||
|
if content_text: blocks.append({"type": "text", "text": content_text})
|
||||||
|
for idx in sorted(tc_buf):
|
||||||
|
tc = tc_buf[idx]
|
||||||
|
try: inp = json.loads(tc["args"]) if tc["args"] else {}
|
||||||
|
except: inp = {"_raw": tc["args"]}
|
||||||
|
blocks.append({"type": "tool_use", "id": tc["id"], "name": tc["name"], "input": inp})
|
||||||
|
return blocks
|
||||||
|
|
||||||
class ClaudeSession:
|
class ClaudeSession:
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
self.api_key = cfg['apikey']; self.api_base = cfg['apibase'].rstrip('/')
|
self.api_key = cfg['apikey']; self.api_base = cfg['apibase'].rstrip('/')
|
||||||
self.default_model = cfg.get('model', 'claude-opus')
|
self.default_model = cfg.get('model', 'claude-opus')
|
||||||
self.context_win = cfg.get('context_win', 18000)
|
self.context_win = cfg.get('context_win', 18000)
|
||||||
self.raw_msgs, self.lock = [], threading.Lock()
|
self.raw_msgs, self.lock = [], threading.Lock()
|
||||||
self.prompt_cache = cfg.get('prompt_cache', False)
|
self.system = ""
|
||||||
def _trim_messages(self, messages):
|
def _trim_messages(self, raw_msgs):
|
||||||
compress_history_tags(messages)
|
compress_history_tags(raw_msgs)
|
||||||
total = sum(len(m['prompt']) for m in messages)
|
total = sum(len(m['prompt']) for m in raw_msgs)
|
||||||
if total <= self.context_win * 3: return messages
|
print(f'[Debug] Current context: {total} chars, {len(raw_msgs)} messages.')
|
||||||
|
if total <= self.context_win * 3: return raw_msgs
|
||||||
target, current, result = self.context_win * 3 * 0.6, 0, []
|
target, current, result = self.context_win * 3 * 0.6, 0, []
|
||||||
for msg in reversed(messages):
|
for msg in reversed(raw_msgs):
|
||||||
if (msg_len := len(msg['prompt'])) + current <= target:
|
if (msg_len := len(msg['prompt'])) + current <= target:
|
||||||
result.append(msg); current += msg_len
|
result.append(msg); current += msg_len
|
||||||
else: break
|
else: break
|
||||||
if current > self.context_win * 2.7: print(f'[DEBUG] {len(result)} contexts, whole length {current//3} tokens.')
|
print(f'[Debug] Trimmed context, current: {current} chars, {len(result)} messages.')
|
||||||
return result[::-1] or messages[-2:]
|
return result[::-1] or raw_msgs[-2:]
|
||||||
def raw_ask(self, messages, model=None, temperature=0.5, max_tokens=6144):
|
def raw_ask(self, messages, model=None, temperature=0.5, max_tokens=6144):
|
||||||
model = model or self.default_model
|
model = model or self.default_model
|
||||||
if 'kimi' in model.lower() or 'moonshot' in model.lower(): temperature = 1.0 # kimi/moonshot only accepts temp 1.0
|
if 'kimi' in model.lower() or 'moonshot' in model.lower(): temperature = 1.0 # kimi/moonshot only accepts temp 1.0
|
||||||
headers = {"x-api-key": self.api_key, "Content-Type": "application/json", "anthropic-version": "2023-06-01", "anthropic-beta": "prompt-caching-2024-07-31"}
|
headers = {"x-api-key": self.api_key, "Content-Type": "application/json", "anthropic-version": "2023-06-01", "anthropic-beta": "prompt-caching-2024-07-31"}
|
||||||
payload = {"model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "stream": True}
|
payload = {"model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "stream": True}
|
||||||
|
if self.system: payload["system"] = [{"type": "text", "text": self.system, "cache_control": {"type": "persistent"}}]
|
||||||
try:
|
try:
|
||||||
with requests.post(auto_make_url(self.api_base, "messages"), headers=headers, json=payload, stream=True, timeout=(5,30)) as r:
|
with requests.post(auto_make_url(self.api_base, "messages"), headers=headers, json=payload, stream=True, timeout=(5,30)) as r:
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
for line in r.iter_lines():
|
yield from _parse_claude_sse(r.iter_lines())
|
||||||
if not line: continue
|
|
||||||
line = line.decode("utf-8") if isinstance(line, bytes) else line
|
|
||||||
if not line.startswith("data:"): continue
|
|
||||||
data = line[5:].lstrip()
|
|
||||||
if data == "[DONE]": break
|
|
||||||
try:
|
|
||||||
obj = json.loads(data)
|
|
||||||
if obj.get("type") == "message_start":
|
|
||||||
usage = obj.get("message", {}).get("usage", {})
|
|
||||||
ci, cr, inp = usage.get("cache_creation_input_tokens", 0), usage.get("cache_read_input_tokens", 0), usage.get("input_tokens", 0)
|
|
||||||
print(f"[Cache] input={inp} creation={ci} read={cr}")
|
|
||||||
elif obj.get("type") == "content_block_delta" and obj.get("delta", {}).get("type") == "text_delta":
|
|
||||||
text = obj["delta"].get("text", "")
|
|
||||||
if text: yield text
|
|
||||||
except: pass
|
|
||||||
except Exception as e: yield f"Error: {str(e)}"
|
except Exception as e: yield f"Error: {str(e)}"
|
||||||
def make_messages(self, raw_list):
|
def make_messages(self, raw_list):
|
||||||
trimmed = self._trim_messages(raw_list)
|
msgs = [{"role": m['role'], "content": [{"type": "text", "text": m['prompt']}]} for m in raw_list]
|
||||||
msgs = [{"role": m['role'], "content": [{"type": "text", "text": m['prompt']}] if m['role'] == "assistant" else m['prompt']} for m in trimmed]
|
c = msgs[-1]["content"]
|
||||||
for i in range(len(msgs)-1, -1, -1):
|
c[-1] = dict(c[-1], cache_control={"type": "ephemeral"})
|
||||||
if msgs[i]["role"] == "assistant":
|
|
||||||
msgs[i]["content"][-1]["cache_control"] = {"type": "ephemeral"}
|
|
||||||
break
|
|
||||||
return msgs
|
return msgs
|
||||||
def ask(self, prompt, model=None, stream=False):
|
def ask(self, prompt, model=None, stream=False):
|
||||||
def _ask_gen():
|
def _ask_gen():
|
||||||
content = ''
|
content = ''
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.raw_msgs.append({"role": "user", "prompt": prompt})
|
self.raw_msgs.append({"role": "user", "prompt": prompt})
|
||||||
|
self.raw_msgs = self._trim_messages(self.raw_msgs)
|
||||||
messages = self.make_messages(self.raw_msgs)
|
messages = self.make_messages(self.raw_msgs)
|
||||||
for chunk in self.raw_ask(messages, model):
|
for chunk in self.raw_ask(messages, model):
|
||||||
content += chunk; yield chunk
|
content += chunk; yield chunk
|
||||||
@@ -145,7 +254,6 @@ class LLMSession:
|
|||||||
self.raw_msgs, self.messages = [], []
|
self.raw_msgs, self.messages = [], []
|
||||||
proxy = cfg.get('proxy')
|
proxy = cfg.get('proxy')
|
||||||
self.proxies = {"http": proxy, "https": proxy} if proxy else None
|
self.proxies = {"http": proxy, "https": proxy} if proxy else None
|
||||||
self.prompt_cache = cfg.get('prompt_cache', False)
|
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
self.max_retries = max(0, int(cfg.get('max_retries', 2)))
|
self.max_retries = max(0, int(cfg.get('max_retries', 2)))
|
||||||
self.connect_timeout = max(1, int(cfg.get('connect_timeout', 10)))
|
self.connect_timeout = max(1, int(cfg.get('connect_timeout', 10)))
|
||||||
@@ -201,7 +309,7 @@ class LLMSession:
|
|||||||
if self.reasoning_effort: payload["reasoning"] = {"effort": self.reasoning_effort}
|
if self.reasoning_effort: payload["reasoning"] = {"effort": self.reasoning_effort}
|
||||||
else:
|
else:
|
||||||
url = auto_make_url(self.api_base, "chat/completions")
|
url = auto_make_url(self.api_base, "chat/completions")
|
||||||
payload = {"model": model, "messages": messages, "temperature": temperature, "stream": True}
|
payload = {"model": model, "messages": messages, "temperature": temperature, "stream": True, "stream_options": {"include_usage": True}}
|
||||||
if self.reasoning_effort: payload["reasoning_effort"] = self.reasoning_effort
|
if self.reasoning_effort: payload["reasoning_effort"] = self.reasoning_effort
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
streamed_any = False
|
streamed_any = False
|
||||||
@@ -216,42 +324,9 @@ class LLMSession:
|
|||||||
time.sleep(delay)
|
time.sleep(delay)
|
||||||
continue
|
continue
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
buffer = ''; seen_delta = False
|
for chunk in _parse_openai_sse(r.iter_lines(), self.api_mode):
|
||||||
for line in r.iter_lines():
|
streamed_any = True
|
||||||
line = line.decode("utf-8") if isinstance(line, bytes) else line
|
yield chunk
|
||||||
if not line or not line.startswith("data:"): continue
|
|
||||||
data = line[5:].lstrip()
|
|
||||||
if data == "[DONE]": break
|
|
||||||
try: obj = json.loads(data)
|
|
||||||
except: continue
|
|
||||||
if self.api_mode == "responses":
|
|
||||||
etype = obj.get("type", "")
|
|
||||||
delta = obj.get("delta", "") if etype == "response.output_text.delta" else ""
|
|
||||||
if delta:
|
|
||||||
streamed_any = True; seen_delta = True
|
|
||||||
yield delta; buffer += delta
|
|
||||||
elif etype == "response.output_text.done" and not seen_delta:
|
|
||||||
text = obj.get("text", "")
|
|
||||||
if text:
|
|
||||||
streamed_any = True
|
|
||||||
yield text; buffer += text
|
|
||||||
elif etype == "error":
|
|
||||||
err = obj.get("error", {})
|
|
||||||
emsg = err.get("message", str(err)) if isinstance(err, dict) else str(err)
|
|
||||||
if emsg:
|
|
||||||
yield f"Error: {emsg}"
|
|
||||||
return
|
|
||||||
elif etype == "response.completed":
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
ch = (obj.get("choices") or [{}])[0]
|
|
||||||
finish_reason = ch.get("finish_reason")
|
|
||||||
delta = (ch.get("delta") or {}).get("content")
|
|
||||||
if delta:
|
|
||||||
streamed_any = True
|
|
||||||
yield delta; buffer += delta
|
|
||||||
if finish_reason: break
|
|
||||||
#if '</tool_use>' in buffer[-30:]: break
|
|
||||||
return
|
return
|
||||||
except requests.HTTPError as e:
|
except requests.HTTPError as e:
|
||||||
resp = getattr(e, "response", None)
|
resp = getattr(e, "response", None)
|
||||||
@@ -420,64 +495,46 @@ class NativeOAISession:
|
|||||||
self.api_key = cfg['apikey']; self.api_base = cfg['apibase'].rstrip('/')
|
self.api_key = cfg['apikey']; self.api_base = cfg['apibase'].rstrip('/')
|
||||||
self.default_model = cfg.get('model', 'gpt-4o')
|
self.default_model = cfg.get('model', 'gpt-4o')
|
||||||
self.context_win = cfg.get('context_win', 28000)
|
self.context_win = cfg.get('context_win', 28000)
|
||||||
|
self.reasoning_effort = cfg.get('reasoning_effort')
|
||||||
proxy = cfg.get('proxy')
|
proxy = cfg.get('proxy')
|
||||||
self.proxies = {"http": proxy, "https": proxy} if proxy else None
|
self.proxies = {"http": proxy, "https": proxy} if proxy else None
|
||||||
self.history = []; self.system = None; self.lock = threading.Lock()
|
self.history = []; self.system = ''; self.lock = threading.Lock()
|
||||||
def set_system(self, system_text): self.system = system_text
|
|
||||||
|
|
||||||
def raw_ask(self, messages, tools=None, system=None, model=None, temperature=0.5, max_tokens=6144, **kw):
|
def raw_ask(self, messages, tools=None, system=None, model=None, temperature=0.5, max_tokens=6144, **kw):
|
||||||
"""OpenAI streaming. yields text chunks, generator return = list[content_block]"""
|
"""OpenAI streaming. yields text chunks, generator return = list[content_block]"""
|
||||||
model = model or self.default_model
|
model = model or self.default_model
|
||||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||||
msgs = ([{"role": "system", "content": system}] if system else []) + messages
|
msgs = ([{"role": "system", "content": system}] if system else []) + messages
|
||||||
payload = {"model": model, "messages": msgs, "temperature": temperature, "max_tokens": max_tokens, "stream": True}
|
payload = {"model": model, "messages": msgs, "temperature": temperature, "max_tokens": max_tokens, "stream": True, "stream_options": {"include_usage": True}}
|
||||||
if tools: payload["tools"] = tools
|
if tools: payload["tools"] = tools
|
||||||
|
if self.reasoning_effort: payload["reasoning_effort"] = self.reasoning_effort
|
||||||
try:
|
try:
|
||||||
resp = requests.post(auto_make_url(self.api_base, "chat/completions"), headers=headers, json=payload, stream=True, timeout=120, proxies=self.proxies)
|
resp = requests.post(auto_make_url(self.api_base, "chat/completions"), headers=headers, json=payload, stream=True, timeout=120, proxies=self.proxies)
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
err = f"Error: HTTP {resp.status_code} {resp.text[:500]}"; yield err; return [{"type": "text", "text": err}]
|
err = f"Error: HTTP {resp.status_code} {resp.text[:500]}"; yield err; return [{"type": "text", "text": err}]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err = f"Error: {e}"; yield err; return [{"type": "text", "text": err}]
|
err = f"Error: {e}"; yield err; return [{"type": "text", "text": err}]
|
||||||
content_text = ""; tc_buf = {} # index -> {id, name, args_str}
|
gen = _parse_openai_sse(resp.iter_lines(), "chat_completions")
|
||||||
for line in resp.iter_lines():
|
try:
|
||||||
if not line: continue
|
while True: yield next(gen)
|
||||||
line = line.decode('utf-8', errors='replace') if isinstance(line, bytes) else line
|
except StopIteration as e:
|
||||||
if not line.startswith("data: "): continue
|
return e.value or []
|
||||||
data_str = line[6:]
|
|
||||||
if data_str.strip() == "[DONE]": break
|
|
||||||
try: evt = json.loads(data_str)
|
|
||||||
except: continue
|
|
||||||
delta = evt.get("choices", [{}])[0].get("delta", {})
|
|
||||||
if delta.get("content"):
|
|
||||||
text = delta["content"]; content_text += text; yield text
|
|
||||||
for tc in delta.get("tool_calls", []):
|
|
||||||
idx = tc.get("index", 0)
|
|
||||||
if idx not in tc_buf: tc_buf[idx] = {"id": tc.get("id", ""), "name": "", "args": ""}
|
|
||||||
if tc.get("function", {}).get("name"): tc_buf[idx]["name"] = tc["function"]["name"]
|
|
||||||
if tc.get("function", {}).get("arguments"): tc_buf[idx]["args"] += tc["function"]["arguments"]
|
|
||||||
blocks = []
|
|
||||||
if content_text: blocks.append({"type": "text", "text": content_text})
|
|
||||||
for idx in sorted(tc_buf):
|
|
||||||
tc = tc_buf[idx]
|
|
||||||
try: inp = json.loads(tc["args"]) if tc["args"] else {}
|
|
||||||
except: inp = {"_raw": tc["args"]}
|
|
||||||
blocks.append({"type": "tool_use", "id": tc["id"], "name": tc["name"], "input": inp})
|
|
||||||
return blocks
|
|
||||||
|
|
||||||
def ask(self, msg, tools=None, model=None, **kw):
|
def ask(self, msg, tools=None, model=None, **kw):
|
||||||
"""Managed ask with history. yields text chunks, return MockResponse"""
|
assert type(msg) is dict
|
||||||
if isinstance(msg, str): msg = {"role": "user", "content": msg}
|
|
||||||
elif isinstance(msg, list): msg = {"role": "user", "content": msg}
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.history.append(msg)
|
self.history.append(msg)
|
||||||
compress_history_tags(self.history)
|
compress_history_tags(self.history)
|
||||||
cost = sum(len(json.dumps(m, ensure_ascii=False)) for m in self.history)
|
cost = sum(len(json.dumps(m, ensure_ascii=False)) for m in self.history)
|
||||||
|
print(f'[Debug] Current context: {cost} chars, {len(self.history)} messages.')
|
||||||
if cost > self.context_win * 3:
|
if cost > self.context_win * 3:
|
||||||
target = self.context_win * 3 * 0.6
|
target = self.context_win * 3 * 0.6
|
||||||
while len(self.history) > 2 and cost > target:
|
while len(self.history) > 2 and cost > target:
|
||||||
self.history.pop(0); self.history.pop(0)
|
self.history.pop(0); self.history.pop(0)
|
||||||
cost = sum(len(json.dumps(m, ensure_ascii=False)) for m in self.history)
|
cost = sum(len(json.dumps(m, ensure_ascii=False)) for m in self.history)
|
||||||
|
print(f'[Debug] Trimmed context, current: {cost} chars, {len(self.history)} messages.')
|
||||||
messages = list(self.history)
|
messages = list(self.history)
|
||||||
|
|
||||||
content_blocks = None
|
content_blocks = None
|
||||||
gen = self.raw_ask(messages, tools, self.system, model)
|
gen = self.raw_ask(messages, tools, self.system, model)
|
||||||
try:
|
try:
|
||||||
@@ -495,14 +552,10 @@ class NativeClaudeSession:
|
|||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
self.api_key = cfg['apikey']; self.api_base = cfg['apibase'].rstrip('/')
|
self.api_key = cfg['apikey']; self.api_base = cfg['apibase'].rstrip('/')
|
||||||
self.default_model = cfg.get('model', 'claude-opus')
|
self.default_model = cfg.get('model', 'claude-opus')
|
||||||
self.context_win = cfg.get('context_win', 32000)
|
self.context_win = cfg.get('context_win', 30000)
|
||||||
self.history = []
|
self.history = []; self.system = ''; self.lock = threading.Lock()
|
||||||
self.system = None
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
def set_system(self, system_text): self.system = system_text
|
|
||||||
|
|
||||||
def raw_ask(self, messages, tools=None, system=None, model=None, temperature=0.5, max_tokens=6144):
|
def raw_ask(self, messages, tools=None, system=None, model=None, temperature=0.5, max_tokens=6144):
|
||||||
"""底层API调用。yields text chunks,generator return = list[content_block]"""
|
|
||||||
model = model or self.default_model
|
model = model or self.default_model
|
||||||
headers = {"x-api-key": self.api_key, "Content-Type": "application/json", "anthropic-version": "2023-06-01", "anthropic-beta": "prompt-caching-2024-07-31"}
|
headers = {"x-api-key": self.api_key, "Content-Type": "application/json", "anthropic-version": "2023-06-01", "anthropic-beta": "prompt-caching-2024-07-31"}
|
||||||
payload = {"model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "stream": True}
|
payload = {"model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "stream": True}
|
||||||
@@ -510,12 +563,8 @@ class NativeClaudeSession:
|
|||||||
tools = [dict(t) for t in tools]; tools[-1]["cache_control"] = {"type": "ephemeral"}
|
tools = [dict(t) for t in tools]; tools[-1]["cache_control"] = {"type": "ephemeral"}
|
||||||
payload["tools"] = tools
|
payload["tools"] = tools
|
||||||
if system: payload["system"] = [{"type": "text", "text": system, "cache_control": {"type": "ephemeral"}}]
|
if system: payload["system"] = [{"type": "text", "text": system, "cache_control": {"type": "ephemeral"}}]
|
||||||
# 历史消息缓存:最后一个assistant消息加cache_control
|
messages[-1] = {**messages[-1], "content": list(messages[-1]["content"])}
|
||||||
for i in range(len(messages) - 1, -1, -1):
|
messages[-1]["content"][-1] = dict(messages[-1]["content"][-1], cache_control={"type": "ephemeral"})
|
||||||
if messages[i]["role"] == "assistant":
|
|
||||||
c = messages[i].get("content", [])
|
|
||||||
if isinstance(c, list) and c: messages[i] = {**messages[i], "content": [*c[:-1], {**c[-1], "cache_control": {"type": "ephemeral"}}]}
|
|
||||||
break
|
|
||||||
try:
|
try:
|
||||||
resp = requests.post(auto_make_url(self.api_base, "messages"), headers=headers, json=payload, stream=True, timeout=120)
|
resp = requests.post(auto_make_url(self.api_base, "messages"), headers=headers, json=payload, stream=True, timeout=120)
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
@@ -526,51 +575,22 @@ class NativeClaudeSession:
|
|||||||
error_msg = f"Error: {e}"
|
error_msg = f"Error: {e}"
|
||||||
yield error_msg
|
yield error_msg
|
||||||
return [{"type": "text", "text": error_msg}]
|
return [{"type": "text", "text": error_msg}]
|
||||||
|
content_blocks = yield from _parse_claude_sse(resp.iter_lines())
|
||||||
content_blocks = []; current_block = None; tool_json_buf = ""
|
return content_blocks or []
|
||||||
for line in resp.iter_lines():
|
|
||||||
if not line: continue
|
|
||||||
line = line.decode('utf-8') if isinstance(line, bytes) else line
|
|
||||||
data_str = line[6:]
|
|
||||||
if data_str.strip() == "[DONE]": break
|
|
||||||
try: evt = json.loads(data_str)
|
|
||||||
except: continue
|
|
||||||
evt_type = evt.get("type", "")
|
|
||||||
if evt_type == "content_block_start":
|
|
||||||
block = evt.get("content_block", {})
|
|
||||||
if block.get("type") == "text": current_block = {"type": "text", "text": ""}
|
|
||||||
elif block.get("type") == "tool_use":
|
|
||||||
current_block = {"type": "tool_use", "id": block.get("id", ""), "name": block.get("name", ""), "input": {}}
|
|
||||||
tool_json_buf = ""
|
|
||||||
elif evt_type == "content_block_delta":
|
|
||||||
delta = evt.get("delta", {})
|
|
||||||
if delta.get("type") == "text_delta":
|
|
||||||
text = delta.get("text", "")
|
|
||||||
if current_block: current_block["text"] += text
|
|
||||||
yield text
|
|
||||||
elif delta.get("type") == "input_json_delta": tool_json_buf += delta.get("partial_json", "")
|
|
||||||
elif evt_type == "content_block_stop":
|
|
||||||
if current_block:
|
|
||||||
if current_block["type"] == "tool_use":
|
|
||||||
try: current_block["input"] = json.loads(tool_json_buf) if tool_json_buf else {}
|
|
||||||
except: current_block["input"] = {"_raw": tool_json_buf}
|
|
||||||
content_blocks.append(current_block)
|
|
||||||
current_block = None
|
|
||||||
return content_blocks
|
|
||||||
|
|
||||||
def ask(self, msg, tools=None, model=None):
|
def ask(self, msg, tools=None, model=None):
|
||||||
"""增量ask。msg: str|list[content_block]|dict。yields text chunks, return MockResponse"""
|
assert type(msg) is dict
|
||||||
if isinstance(msg, str): msg = {"role": "user", "content": msg}
|
|
||||||
elif isinstance(msg, list): msg = {"role": "user", "content": msg}
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.history.append(msg)
|
self.history.append(msg)
|
||||||
compress_history_tags(self.history)
|
compress_history_tags(self.history)
|
||||||
cost = sum(len(json.dumps(m, ensure_ascii=False)) for m in self.history)
|
cost = sum(len(json.dumps(m, ensure_ascii=False)) for m in self.history)
|
||||||
|
print(f'[Debug] Current context: {cost} chars, {len(self.history)} messages.')
|
||||||
if cost > self.context_win * 3:
|
if cost > self.context_win * 3:
|
||||||
target = self.context_win * 3 * 0.6
|
target = self.context_win * 3 * 0.6
|
||||||
while len(self.history) > 2 and cost > target:
|
while len(self.history) > 2 and cost > target:
|
||||||
self.history.pop(0); self.history.pop(0)
|
self.history.pop(0); self.history.pop(0)
|
||||||
cost = sum(len(json.dumps(m, ensure_ascii=False)) for m in self.history)
|
cost = sum(len(json.dumps(m, ensure_ascii=False)) for m in self.history)
|
||||||
|
print(f'[Debug] Trimmed context, current: {cost} chars, {len(self.history)} messages.')
|
||||||
messages = list(self.history)
|
messages = list(self.history)
|
||||||
|
|
||||||
content_blocks = None
|
content_blocks = None
|
||||||
@@ -619,10 +639,8 @@ class MockResponse:
|
|||||||
return f"<MockResponse thinking={bool(self.thinking)}, content='{self.content}', tools={bool(self.tool_calls)}>"
|
return f"<MockResponse thinking={bool(self.thinking)}, content='{self.content}', tools={bool(self.tool_calls)}>"
|
||||||
|
|
||||||
class ToolClient:
|
class ToolClient:
|
||||||
def __init__(self, backends, auto_save_tokens=True):
|
def __init__(self, backend, auto_save_tokens=True):
|
||||||
if isinstance(backends, list): self.backends = backends
|
self.backend = backend
|
||||||
else: self.backends = [backends]
|
|
||||||
self.backend = self.backends[0]
|
|
||||||
self.auto_save_tokens = auto_save_tokens
|
self.auto_save_tokens = auto_save_tokens
|
||||||
self.last_tools = ''
|
self.last_tools = ''
|
||||||
self.total_cd_tokens = 0
|
self.total_cd_tokens = 0
|
||||||
@@ -726,19 +744,17 @@ class ToolClient:
|
|||||||
system_content = next((m['content'] for m in messages if m['role'].lower() == 'system'), "")
|
system_content = next((m['content'] for m in messages if m['role'].lower() == 'system'), "")
|
||||||
history_msgs = [m for m in messages if m['role'].lower() != 'system']
|
history_msgs = [m for m in messages if m['role'].lower() != 'system']
|
||||||
tool_instruction = self._prepare_tool_instruction(tools)
|
tool_instruction = self._prepare_tool_instruction(tools)
|
||||||
|
system = ""
|
||||||
prompt = ""
|
if system_content: system += f"{system_content}\n"
|
||||||
if system_content: prompt += f"=== SYSTEM ===\n{system_content}\n"
|
system += f"{tool_instruction}"
|
||||||
prompt += f"{tool_instruction}\n\n"
|
user = ""
|
||||||
for m in history_msgs:
|
for m in history_msgs:
|
||||||
role = "USER" if m['role'] == 'user' else "ASSISTANT"
|
role = "USER" if m['role'] == 'user' else "ASSISTANT"
|
||||||
prompt += f"=== {role} ===\n{m['content']}\n\n"
|
user += f"=== {role} ===\n{m['content']}\n\n"
|
||||||
self.total_cd_tokens += self._estimate_content_len(m['content'])
|
self.total_cd_tokens += self._estimate_content_len(m['content'])
|
||||||
|
|
||||||
if self.total_cd_tokens > 6000: self.last_tools = ''
|
if self.total_cd_tokens > 6000: self.last_tools = ''
|
||||||
|
user += "=== ASSISTANT ===\n"
|
||||||
prompt += "=== ASSISTANT ===\n"
|
return system + user
|
||||||
return prompt
|
|
||||||
|
|
||||||
def _parse_mixed_response(self, text):
|
def _parse_mixed_response(self, text):
|
||||||
remaining_text = text; thinking = ''
|
remaining_text = text; thinking = ''
|
||||||
@@ -830,8 +846,10 @@ class NativeToolClient:
|
|||||||
combined_content = []; resp = None
|
combined_content = []; resp = None
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
c = msg.get('content', '')
|
c = msg.get('content', '')
|
||||||
|
if msg['role'] == 'system':
|
||||||
|
self.set_system(c); continue
|
||||||
if isinstance(c, str): combined_content.append({"type": "text", "text": c})
|
if isinstance(c, str): combined_content.append({"type": "text", "text": c})
|
||||||
elif isinstance(c, list) or isinstance(c, dict): combined_content.extend(c)
|
elif isinstance(c, list): combined_content.extend(c)
|
||||||
if self._pending_tool_ids and isinstance(self.backend, NativeClaudeSession):
|
if self._pending_tool_ids and isinstance(self.backend, NativeClaudeSession):
|
||||||
tool_result_blocks = [{"type": "tool_result", "tool_use_id": tid, "content": ""} for tid in self._pending_tool_ids]
|
tool_result_blocks = [{"type": "tool_result", "tool_use_id": tid, "content": ""} for tid in self._pending_tool_ids]
|
||||||
combined_content = tool_result_blocks + combined_content
|
combined_content = tool_result_blocks + combined_content
|
||||||
|
|||||||
Reference in New Issue
Block a user