feat: stream LLM responses and improve agent UI
This commit is contained in:
99
sidercall.py
99
sidercall.py
@@ -1,4 +1,4 @@
|
||||
import os, json, re, time, requests
|
||||
import os, json, re, time, requests, sys
|
||||
|
||||
try: from mykey import sider_cookie
|
||||
except ImportError: sider_cookie = ""
|
||||
@@ -10,12 +10,14 @@ class SiderLLMSession:
|
||||
from sider_ai_api import Session
|
||||
self._core = Session(cookie=sider_cookie, proxies={'https':'127.0.0.1:2082'})
|
||||
self.default_model = default_model
|
||||
def ask(self, prompt, model=None):
|
||||
def ask(self, prompt, model=None, stream=False):
|
||||
if model is None: model = self.default_model
|
||||
if len(prompt) > 29000:
|
||||
print(f"[Warn] Prompt too long ({len(prompt)} chars), truncating.")
|
||||
prompt = prompt[-29000:]
|
||||
return ''.join(self._core.chat(prompt, model))
|
||||
gen = self._core.chat(prompt, model)
|
||||
if stream: return gen
|
||||
return ''.join(list(gen))
|
||||
|
||||
class LLMSession:
|
||||
def __init__(self, api_key=oai_apikey, api_base=oai_apibase, model=oai_model, context_win=16000):
|
||||
@@ -28,17 +30,29 @@ class LLMSession:
|
||||
|
||||
def raw_ask(self, messages, model=None, temperature=0.5):
|
||||
if model is None: model = self.model
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", "Accept": "text/event-stream"}
|
||||
payload = {"model": model, "messages": messages, "temperature": temperature, "stream": True}
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.api_base}/chat/completions", headers=headers, timeout=60,
|
||||
json={"model": model, "messages": messages, "temperature": temperature} )
|
||||
res_json = response.json()
|
||||
content = res_json["choices"][0]["message"]["content"]
|
||||
return content
|
||||
with requests.post(f"{self.api_base}/chat/completions",
|
||||
headers=headers, json=payload, stream=True, timeout=(5, 60)) as r:
|
||||
r.raise_for_status()
|
||||
buffer = ''
|
||||
for line in r.iter_lines():
|
||||
line = line.decode("utf-8")
|
||||
if not line or not line.startswith("data:"): continue
|
||||
data = line[5:].lstrip()
|
||||
if data == "[DONE]": break
|
||||
obj = json.loads(data)
|
||||
ch = (obj.get("choices") or [{}])[0]
|
||||
if ch.get("finish_reason") is not None: break
|
||||
delta = (ch.get("delta") or {}).get("content")
|
||||
if not delta: continue
|
||||
yield delta
|
||||
buffer += delta
|
||||
if '</tool_use>' in buffer[-30:]: break
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
yield f"Error: {str(e)}"
|
||||
|
||||
def make_messages(self, raw_list, omit_images=True):
|
||||
messages = []
|
||||
for msg in raw_list:
|
||||
@@ -60,22 +74,28 @@ class LLMSession:
|
||||
p = "Summarize prev summary and prev conversations into compact memory (facts/decisions/constraints/open questions). Do NOT restate long schemas. The new summary should less than 1000 tokens.\n"
|
||||
messages = self.make_messages(old, omit_images=True)
|
||||
messages += [{"role":"user", "content":p}]
|
||||
summary = self.raw_ask(messages, model, temperature=0.1)
|
||||
summary = ''.join(list(self.raw_ask(messages, model, temperature=0.1)))
|
||||
if not summary.startswith("Error:"):
|
||||
self.raw_msgs.insert(0, {"role":"system", "prompt":"Prev summary:\n"+summary, "image":None})
|
||||
else: self.raw_msgs = old + self.raw_msgs # 不做了,下次再做
|
||||
|
||||
def ask(self, prompt, model=None, image_base64=None):
|
||||
def ask(self, prompt, model=None, image_base64=None, stream=False):
|
||||
if model is None: model = self.model
|
||||
self.raw_msgs.append({"role": "user", "prompt": prompt, "image": image_base64})
|
||||
messages = self.make_messages(self.raw_msgs[:-1], omit_images=True)
|
||||
messages += self.make_messages([self.raw_msgs[-1]], omit_images=False)
|
||||
total_len = sum(2000 if isinstance(m["content"], list) else len(str(m["content"]))//4 for m in messages) # estimate token count
|
||||
content = self.raw_ask(messages, model)
|
||||
if not content.startswith("Error:"):
|
||||
self.raw_msgs.append({"role": "assistant", "prompt": content, "image": None})
|
||||
if total_len > self.context_win: self.summary_history()
|
||||
return content
|
||||
gen = self.raw_ask(messages, model)
|
||||
def _ask_gen():
|
||||
content = ''
|
||||
for chunk in gen:
|
||||
content += chunk; yield chunk
|
||||
if not content.startswith("Error:"):
|
||||
self.raw_msgs.append({"role": "assistant", "prompt": content, "image": None})
|
||||
if total_len > 5000: print(f"[Debug] Whole context length {total_len}.")
|
||||
if total_len > self.context_win: self.summary_history()
|
||||
if stream: return _ask_gen()
|
||||
return ''.join(list(_ask_gen()))
|
||||
|
||||
|
||||
class MockFunction:
|
||||
@@ -109,7 +129,10 @@ class ToolClient:
|
||||
def chat(self, messages, tools=None):
|
||||
full_prompt = self._build_protocol_prompt(messages, tools)
|
||||
print("Full prompt length:", len(full_prompt))
|
||||
raw_text = self.raw_api(full_prompt)
|
||||
gen = self.raw_api(full_prompt, stream=True)
|
||||
raw_text = ''
|
||||
for chunk in gen:
|
||||
raw_text += chunk; yield chunk
|
||||
with open('model_responses.txt', 'a', encoding='utf-8', errors="replace") as f:
|
||||
f.write(f"=== Prompt ===\n{full_prompt}\n=== Response ===\n{raw_text}\n\n")
|
||||
return self._parse_mixed_response(raw_text)
|
||||
@@ -127,7 +150,7 @@ class ToolClient:
|
||||
请按照以下步骤思考并行动:
|
||||
1. **思考**: 在 `<thinking>` 标签中先进行思考,分析现状和策略。
|
||||
2. **总结**: 在 `<summary>` 中输出*极为简短*的高度概括的单行(<30字)物理快照,包括上次工具调用结果获取的新信息+本次工具调用意图和预期。此内容将进入长期工作记忆,记录关键信息,严禁输出无实际信息增量的描述。
|
||||
3. **行动**: 如果需要调用工具,请紧接着输出一个 **<tool_use>块**,然后结束,我会稍后给你返回<tool_result>块。
|
||||
3. **行动**: 如果需要调用工具,请在回复正文之后输出一个 **<tool_use>块**,然后结束,我会稍后给你返回<tool_result>块。
|
||||
格式: ```<tool_use>\n{{"function": "工具名", "arguments": {{参数}}}}\n</tool_use>\n```
|
||||
|
||||
### 可用工具库
|
||||
@@ -164,7 +187,7 @@ class ToolClient:
|
||||
|
||||
tool_calls = None
|
||||
tool_pattern = r"<tool_use>(.*?)</tool_use>"
|
||||
tool_match = re.search(tool_pattern, text, re.DOTALL)
|
||||
tool_match = re.search(tool_pattern, remaining_text, re.DOTALL)
|
||||
|
||||
json_str = ""
|
||||
if tool_match:
|
||||
@@ -173,6 +196,8 @@ class ToolClient:
|
||||
elif '<tool_use>' in remaining_text:
|
||||
weaktoolstr = remaining_text.split('<tool_use>')[-1].strip()
|
||||
json_str = weaktoolstr if weaktoolstr.endswith('}') else ''
|
||||
if json_str == '' and '```' in weaktoolstr and weaktoolstr.split('```')[0].strip().endswith('}'):
|
||||
json_str = weaktoolstr.split('```')[0].strip()
|
||||
remaining_text = remaining_text.replace('<tool_use>'+weaktoolstr, "")
|
||||
|
||||
if json_str:
|
||||
@@ -184,7 +209,7 @@ class ToolClient:
|
||||
if func_name: tool_calls = [MockToolCall(func_name, args)]
|
||||
except json.JSONDecodeError:
|
||||
print("[Warn] Failed to parse tool_use JSON:", json_str)
|
||||
thinking += f"[Warn] JSON 解析失败,模型输出了无效的 JSON."
|
||||
remaining_text += f"[Warning] JSON 解析失败,模型输出了无效的 JSON."
|
||||
except Exception as e:
|
||||
print("[Error] Exception during tool_use parsing:", str(e), data)
|
||||
|
||||
@@ -198,20 +223,32 @@ def tryparse(json_str):
|
||||
return json.loads(json_str[:-1])
|
||||
|
||||
if __name__ == "__main__":
|
||||
llmclient = ToolClient(LLMSession().ask)
|
||||
response = llmclient.chat(
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
try: from mykey import sider_cookie
|
||||
except ImportError: sider_cookie = ""
|
||||
try: from mykey import oai_apikey, oai_apibase, oai_model
|
||||
except ImportError: oai_apikey = oai_apibase = oai_model = ""
|
||||
|
||||
llmclient = ToolClient(LLMSession(api_key=oai_apikey, api_base=oai_apibase, model=oai_model).ask)
|
||||
print(llmclient.raw_api("Hello, world!", stream=False))
|
||||
#llmclient = ToolClient(SiderLLMSession().ask)
|
||||
def get_final(gen):
|
||||
try:
|
||||
while True:
|
||||
print('mid:', next(gen))
|
||||
except StopIteration as e:
|
||||
return e.value
|
||||
|
||||
response = get_final(llmclient.chat(
|
||||
messages=[{"role": "user", "content": "我的IP是多少"}],
|
||||
tools=[{"name": "get_ip", "parameters": {}}]
|
||||
)
|
||||
# 4. 获取结果
|
||||
))
|
||||
print(f"思考: {response.thinking}")
|
||||
# -> 我需要查一下 IP。
|
||||
|
||||
if response.tool_calls:
|
||||
cmd = response.tool_calls[0]
|
||||
print(f"调用: {cmd.function.name} 参数: {cmd.function.arguments}")
|
||||
|
||||
response = llmclient.chat(
|
||||
response = get_final(llmclient.chat(
|
||||
messages=[{"role": "user", "content": "<tool_result>10.176.45.12</tool_result>"}]
|
||||
)
|
||||
))
|
||||
print(response.content)
|
||||
Reference in New Issue
Block a user