feat: add XaiSession backend, improve handler switch & tg resilience

This commit is contained in:
Liang Jiaqing
2026-02-17 13:18:09 +08:00
parent 67c7b3fa71
commit f8e501a27a
5 changed files with 81 additions and 25 deletions

View File

@@ -5,7 +5,7 @@ if sys.stderr is None: sys.stderr = open(os.devnull, "w")
elif hasattr(sys.stderr, 'reconfigure'): sys.stderr.reconfigure(errors='replace')
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from sidercall import SiderLLMSession, LLMSession, ToolClient, ClaudeSession
from sidercall import SiderLLMSession, LLMSession, ToolClient, ClaudeSession, XaiSession
from agent_loop import agent_runner_loop, StepOutcome, BaseHandler
from ga import GenericAgentHandler, smart_format, get_global_memory, format_error
@@ -28,14 +28,15 @@ def get_system_prompt():
class GeneraticAgent:
def __init__(self):
if not os.path.exists('temp'): os.makedirs('temp')
from sidercall import sider_cookie, oai_configs, claude_configs
from sidercall import sider_cookie, oai_configs, claude_configs, xai_api_key, proxy
llm_sessions = []
for cfg in claude_configs.values():
llm_sessions += [ClaudeSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model'])]
if sider_cookie: llm_sessions += [SiderLLMSession(default_model=x) for x in \
["gemini-3.0-flash", "claude-haiku-4.5", "kimi-k2"]]
if xai_api_key: llm_sessions += [XaiSession(xai_api_key, proxy)]
for cfg in oai_configs.values():
llm_sessions += [LLMSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model'])]
llm_sessions += [LLMSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model'], proxy=cfg.get('proxy'))]
if len(llm_sessions) > 0: self.llmclient = ToolClient(llm_sessions, auto_save_tokens=True)
else: self.llmclient = None
self.lock = threading.Lock()
@@ -51,7 +52,9 @@ class GeneraticAgent:
self.llm_no = ((self.llm_no + 1) if n < 0 else n) % len(self.llmclient.backends)
self.llmclient.last_tools = ''
def list_llms(self): return [(i, b.default_model, i == self.llm_no) for i, b in enumerate(self.llmclient.backends)]
def get_llm_name(self): return self.llmclient.backends[self.llm_no].default_model
def get_llm_name(self):
b = self.llmclient.backends[self.llm_no]
return f"{type(b).__name__}/{b.default_model}"
def abort(self):
print('Abort current task...')
@@ -75,6 +78,9 @@ class GeneraticAgent:
sys_prompt = get_system_prompt()
handler = GenericAgentHandler(None, self.history, './temp')
if self.handler and self.handler.key_info:
handler.key_info = self.handler.key_info
handler.key_info += '\n如你确信任务已经改变,请先更新工作记忆去除无用部分\n'
self.handler = handler
self.llmclient.backend = self.llmclient.backends[self.llm_no]
gen = agent_runner_loop(self.llmclient, sys_prompt, raw_query,

View File

@@ -30,3 +30,13 @@
## 导航避坑
- `web_scan` 仅读当前页,不会导航。
- 切换网站用 `web_execute_js` + `location.href = 'url'`
## Google图片搜索操作
- **class名不可靠**Google的class均为混淆名(如F0uyec),随版本变化,禁止硬编码
- 点击图片结果:找搜索结果区内 `[role=button]` 的div而非外层容器或内部a/img
- `web_scan` 会过滤边栏内容边栏弹出后用JS提取
- 文本:`document.body.innerText`
- 大图遍历所有img`naturalWidth` 最大的那个取src通常>600px
- "访问"链接:遍历所有`a``textContent.includes('访问')`的href
- 缩略图base64结果中`img[src^="data:image"]`可直接提取保存
- 下载大图时注意JS返回的src可能被截断`return img.src`获取完整URL

View File

@@ -9,6 +9,7 @@ sider_cookie = mykeys.get("sider_cookie")
oai_configs = {k: v for k, v in vars(mykey).items() if k.startswith("oai_config") and v}
claude_configs = {k: v for k, v in vars(mykey).items() if k.startswith("claude_config") and v}
google_api_key = mykeys.get("google_api_key")
xai_api_key = mykeys.get("xai_api_key")
proxy = mykeys.get("proxy", 'http://127.0.0.1:2082')
proxies = {"http": proxy, "https": proxy} if proxy else None
@@ -90,13 +91,10 @@ class ClaudeSession:
return _ask_gen() if stream else ''.join(list(_ask_gen()))
class LLMSession:
def __init__(self, api_key, api_base, model, context_win=16000):
self.api_key = api_key
self.api_base = api_base
self.raw_msgs = []
self.messages = []
self.context_win = context_win
self.default_model = model
def __init__(self, api_key, api_base, model, context_win=12000, proxy=None):
self.api_key = api_key; self.api_base = api_base; self.default_model = model
self.context_win = context_win; self.raw_msgs = []; self.messages = []
self.proxies = {"http": proxy, "https": proxy} if proxy else None
self.lock = threading.Lock()
def raw_ask(self, messages, model=None, temperature=0.5):
@@ -104,8 +102,8 @@ class LLMSession:
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", "Accept": "text/event-stream"}
payload = {"model": model, "messages": messages, "temperature": temperature, "stream": True}
try:
with requests.post(f"{self.api_base}/chat/completions",
headers=headers, json=payload, stream=True, timeout=(5, 60)) as r:
with requests.post(f"{self.api_base}/chat/completions", headers=headers,
json=payload, stream=True, timeout=(5, 60), proxies=self.proxies) as r:
r.raise_for_status()
buffer = ''
for line in r.iter_lines():
@@ -209,6 +207,44 @@ class GeminiSession:
return f"[GeminiError] invalid response format: {e}"
return iter([full_text]) if stream else full_text
class XaiSession:
def __init__(self, api_key, proxy="http://127.0.0.1:2082", default_model="grok-4-1-fast-non-reasoning"):
import xai_sdk
from xai_sdk.chat import user, system
self._user, self._system = user, system
self.default_model = default_model
self._last_response_id = None # 多轮对话链
os.environ["XAI_API_KEY"] = api_key
if not proxy.startswith("http"): proxy = f"http://{proxy}"
os.environ.setdefault("grpc_proxy", proxy)
self._client = xai_sdk.Client()
def ask(self, prompt, model=None, system_prompt=None, stream=False):
"""发送消息自动串联多轮对话stream=True返回生成器"""
mdl = model or self.default_model
try:
kw = dict(model=mdl, store_messages=True)
if self._last_response_id: kw["previous_response_id"] = self._last_response_id
chat = self._client.chat.create(**kw)
if system_prompt: chat.append(self._system(system_prompt))
chat.append(self._user(prompt))
if stream: return self._stream(chat)
resp = chat.sample()
self._last_response_id = resp.id
return resp.content
except Exception as e:
err = f"[XaiError] {e}"
return iter([err]) if stream else err
def _stream(self, chat):
try:
last_resp = None
for resp, chunk in chat.stream():
last_resp = resp
if chunk and chunk.content: yield chunk.content
if last_resp and hasattr(last_resp, 'id'): self._last_response_id = last_resp.id
except Exception as e:
yield f"[XaiError] {e}"
def reset(self): self._last_response_id = None
class MockFunction:
def __init__(self, name, arguments): self.name, self.arguments = name, arguments

View File

@@ -42,6 +42,7 @@ def render_sidebar():
if st.button("强行停止任务"):
agent.abort()
st.toast("已发送停止信号")
st.rerun()
if st.button("重新注入System Prompt"):
agent.llmclient.last_tools = ''
st.toast("下次将重新注入System Prompt")

View File

@@ -97,23 +97,25 @@ if __name__ == '__main__':
threading.Thread(target=agent.run, daemon=True).start()
proxy = vars(mykey).get('proxy', 'http://127.0.0.1:2082')
print('proxy:', proxy)
request = HTTPXRequest(proxy=proxy, read_timeout=30, write_timeout=30, connect_timeout=30, pool_timeout=30)
app = (ApplicationBuilder()
.token(mykey.tg_bot_token)
.request(request)
.get_updates_request(request)
.build())
app.add_handler(CommandHandler("stop", cmd_abort))
app.add_handler(CommandHandler("llm", cmd_llm))
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_msg))
async def _error_handler(update, context: ContextTypes.DEFAULT_TYPE):
print(f"[{time.strftime('%m-%d %H:%M')}] TG error: {context.error}", flush=True)
app.add_error_handler(_error_handler)
print(f"TG bot starting... {time.strftime('%m-%d %H:%M')}")
while True:
try:
print(f"TG bot starting... {time.strftime('%m-%d %H:%M')}")
# Recreate request and app objects on each restart to avoid stale connections
request = HTTPXRequest(proxy=proxy, read_timeout=30, write_timeout=30, connect_timeout=30, pool_timeout=30)
app = (ApplicationBuilder()
.token(mykey.tg_bot_token)
.request(request)
.get_updates_request(request)
.build())
app.add_handler(CommandHandler("stop", cmd_abort))
app.add_handler(CommandHandler("llm", cmd_llm))
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_msg))
app.add_error_handler(_error_handler)
app.run_polling(
drop_pending_updates=True,
poll_interval=1.0,
@@ -122,3 +124,4 @@ if __name__ == '__main__':
except Exception as e:
print(f"[{time.strftime('%m-%d %H:%M')}] polling crashed: {e}", flush=True)
time.sleep(10)
asyncio.set_event_loop(asyncio.new_event_loop())