feat: add XaiSession backend, improve handler switch & tg resilience
This commit is contained in:
14
agentmain.py
14
agentmain.py
@@ -5,7 +5,7 @@ if sys.stderr is None: sys.stderr = open(os.devnull, "w")
|
||||
elif hasattr(sys.stderr, 'reconfigure'): sys.stderr.reconfigure(errors='replace')
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from sidercall import SiderLLMSession, LLMSession, ToolClient, ClaudeSession
|
||||
from sidercall import SiderLLMSession, LLMSession, ToolClient, ClaudeSession, XaiSession
|
||||
from agent_loop import agent_runner_loop, StepOutcome, BaseHandler
|
||||
from ga import GenericAgentHandler, smart_format, get_global_memory, format_error
|
||||
|
||||
@@ -28,14 +28,15 @@ def get_system_prompt():
|
||||
class GeneraticAgent:
|
||||
def __init__(self):
|
||||
if not os.path.exists('temp'): os.makedirs('temp')
|
||||
from sidercall import sider_cookie, oai_configs, claude_configs
|
||||
from sidercall import sider_cookie, oai_configs, claude_configs, xai_api_key, proxy
|
||||
llm_sessions = []
|
||||
for cfg in claude_configs.values():
|
||||
llm_sessions += [ClaudeSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model'])]
|
||||
if sider_cookie: llm_sessions += [SiderLLMSession(default_model=x) for x in \
|
||||
["gemini-3.0-flash", "claude-haiku-4.5", "kimi-k2"]]
|
||||
if xai_api_key: llm_sessions += [XaiSession(xai_api_key, proxy)]
|
||||
for cfg in oai_configs.values():
|
||||
llm_sessions += [LLMSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model'])]
|
||||
llm_sessions += [LLMSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model'], proxy=cfg.get('proxy'))]
|
||||
if len(llm_sessions) > 0: self.llmclient = ToolClient(llm_sessions, auto_save_tokens=True)
|
||||
else: self.llmclient = None
|
||||
self.lock = threading.Lock()
|
||||
@@ -51,7 +52,9 @@ class GeneraticAgent:
|
||||
self.llm_no = ((self.llm_no + 1) if n < 0 else n) % len(self.llmclient.backends)
|
||||
self.llmclient.last_tools = ''
|
||||
def list_llms(self): return [(i, b.default_model, i == self.llm_no) for i, b in enumerate(self.llmclient.backends)]
|
||||
def get_llm_name(self): return self.llmclient.backends[self.llm_no].default_model
|
||||
def get_llm_name(self):
|
||||
b = self.llmclient.backends[self.llm_no]
|
||||
return f"{type(b).__name__}/{b.default_model}"
|
||||
|
||||
def abort(self):
|
||||
print('Abort current task...')
|
||||
@@ -75,6 +78,9 @@ class GeneraticAgent:
|
||||
|
||||
sys_prompt = get_system_prompt()
|
||||
handler = GenericAgentHandler(None, self.history, './temp')
|
||||
if self.handler and self.handler.key_info:
|
||||
handler.key_info = self.handler.key_info
|
||||
handler.key_info += '\n如你确信任务已经改变,请先更新工作记忆去除无用部分\n'
|
||||
self.handler = handler
|
||||
self.llmclient.backend = self.llmclient.backends[self.llm_no]
|
||||
gen = agent_runner_loop(self.llmclient, sys_prompt, raw_query,
|
||||
|
||||
@@ -30,3 +30,13 @@
|
||||
## 导航避坑
|
||||
- `web_scan` 仅读当前页,不会导航。
|
||||
- 切换网站用 `web_execute_js` + `location.href = 'url'`。
|
||||
|
||||
## Google图片搜索操作
|
||||
- **class名不可靠**:Google的class均为混淆名(如F0uyec),随版本变化,禁止硬编码
|
||||
- 点击图片结果:找搜索结果区内 `[role=button]` 的div,而非外层容器或内部a/img
|
||||
- `web_scan` 会过滤边栏内容,边栏弹出后用JS提取:
|
||||
- 文本:`document.body.innerText`
|
||||
- 大图:遍历所有img,按 `naturalWidth` 最大的那个取src(通常>600px)
|
||||
- "访问"链接:遍历所有`a`找`textContent.includes('访问')`的href
|
||||
- 缩略图base64:结果中`img[src^="data:image"]`可直接提取保存
|
||||
- 下载大图时注意JS返回的src可能被截断,用`return img.src`获取完整URL
|
||||
|
||||
54
sidercall.py
54
sidercall.py
@@ -9,6 +9,7 @@ sider_cookie = mykeys.get("sider_cookie")
|
||||
oai_configs = {k: v for k, v in vars(mykey).items() if k.startswith("oai_config") and v}
|
||||
claude_configs = {k: v for k, v in vars(mykey).items() if k.startswith("claude_config") and v}
|
||||
google_api_key = mykeys.get("google_api_key")
|
||||
xai_api_key = mykeys.get("xai_api_key")
|
||||
|
||||
proxy = mykeys.get("proxy", 'http://127.0.0.1:2082')
|
||||
proxies = {"http": proxy, "https": proxy} if proxy else None
|
||||
@@ -90,13 +91,10 @@ class ClaudeSession:
|
||||
return _ask_gen() if stream else ''.join(list(_ask_gen()))
|
||||
|
||||
class LLMSession:
|
||||
def __init__(self, api_key, api_base, model, context_win=16000):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.raw_msgs = []
|
||||
self.messages = []
|
||||
self.context_win = context_win
|
||||
self.default_model = model
|
||||
def __init__(self, api_key, api_base, model, context_win=12000, proxy=None):
|
||||
self.api_key = api_key; self.api_base = api_base; self.default_model = model
|
||||
self.context_win = context_win; self.raw_msgs = []; self.messages = []
|
||||
self.proxies = {"http": proxy, "https": proxy} if proxy else None
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def raw_ask(self, messages, model=None, temperature=0.5):
|
||||
@@ -104,8 +102,8 @@ class LLMSession:
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", "Accept": "text/event-stream"}
|
||||
payload = {"model": model, "messages": messages, "temperature": temperature, "stream": True}
|
||||
try:
|
||||
with requests.post(f"{self.api_base}/chat/completions",
|
||||
headers=headers, json=payload, stream=True, timeout=(5, 60)) as r:
|
||||
with requests.post(f"{self.api_base}/chat/completions", headers=headers,
|
||||
json=payload, stream=True, timeout=(5, 60), proxies=self.proxies) as r:
|
||||
r.raise_for_status()
|
||||
buffer = ''
|
||||
for line in r.iter_lines():
|
||||
@@ -209,6 +207,44 @@ class GeminiSession:
|
||||
return f"[GeminiError] invalid response format: {e}"
|
||||
return iter([full_text]) if stream else full_text
|
||||
|
||||
class XaiSession:
|
||||
def __init__(self, api_key, proxy="http://127.0.0.1:2082", default_model="grok-4-1-fast-non-reasoning"):
|
||||
import xai_sdk
|
||||
from xai_sdk.chat import user, system
|
||||
self._user, self._system = user, system
|
||||
self.default_model = default_model
|
||||
self._last_response_id = None # 多轮对话链
|
||||
os.environ["XAI_API_KEY"] = api_key
|
||||
if not proxy.startswith("http"): proxy = f"http://{proxy}"
|
||||
os.environ.setdefault("grpc_proxy", proxy)
|
||||
self._client = xai_sdk.Client()
|
||||
def ask(self, prompt, model=None, system_prompt=None, stream=False):
|
||||
"""发送消息,自动串联多轮对话;stream=True返回生成器"""
|
||||
mdl = model or self.default_model
|
||||
try:
|
||||
kw = dict(model=mdl, store_messages=True)
|
||||
if self._last_response_id: kw["previous_response_id"] = self._last_response_id
|
||||
chat = self._client.chat.create(**kw)
|
||||
if system_prompt: chat.append(self._system(system_prompt))
|
||||
chat.append(self._user(prompt))
|
||||
if stream: return self._stream(chat)
|
||||
resp = chat.sample()
|
||||
self._last_response_id = resp.id
|
||||
return resp.content
|
||||
except Exception as e:
|
||||
err = f"[XaiError] {e}"
|
||||
return iter([err]) if stream else err
|
||||
def _stream(self, chat):
|
||||
try:
|
||||
last_resp = None
|
||||
for resp, chunk in chat.stream():
|
||||
last_resp = resp
|
||||
if chunk and chunk.content: yield chunk.content
|
||||
if last_resp and hasattr(last_resp, 'id'): self._last_response_id = last_resp.id
|
||||
except Exception as e:
|
||||
yield f"[XaiError] {e}"
|
||||
def reset(self): self._last_response_id = None
|
||||
|
||||
class MockFunction:
|
||||
def __init__(self, name, arguments): self.name, self.arguments = name, arguments
|
||||
|
||||
|
||||
1
stapp.py
1
stapp.py
@@ -42,6 +42,7 @@ def render_sidebar():
|
||||
if st.button("强行停止任务"):
|
||||
agent.abort()
|
||||
st.toast("已发送停止信号")
|
||||
st.rerun()
|
||||
if st.button("重新注入System Prompt"):
|
||||
agent.llmclient.last_tools = ''
|
||||
st.toast("下次将重新注入System Prompt")
|
||||
|
||||
25
tgapp.py
25
tgapp.py
@@ -97,23 +97,25 @@ if __name__ == '__main__':
|
||||
threading.Thread(target=agent.run, daemon=True).start()
|
||||
proxy = vars(mykey).get('proxy', 'http://127.0.0.1:2082')
|
||||
print('proxy:', proxy)
|
||||
request = HTTPXRequest(proxy=proxy, read_timeout=30, write_timeout=30, connect_timeout=30, pool_timeout=30)
|
||||
app = (ApplicationBuilder()
|
||||
.token(mykey.tg_bot_token)
|
||||
.request(request)
|
||||
.get_updates_request(request)
|
||||
.build())
|
||||
app.add_handler(CommandHandler("stop", cmd_abort))
|
||||
app.add_handler(CommandHandler("llm", cmd_llm))
|
||||
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_msg))
|
||||
|
||||
async def _error_handler(update, context: ContextTypes.DEFAULT_TYPE):
|
||||
print(f"[{time.strftime('%m-%d %H:%M')}] TG error: {context.error}", flush=True)
|
||||
app.add_error_handler(_error_handler)
|
||||
|
||||
print(f"TG bot starting... {time.strftime('%m-%d %H:%M')}")
|
||||
while True:
|
||||
try:
|
||||
print(f"TG bot starting... {time.strftime('%m-%d %H:%M')}")
|
||||
# Recreate request and app objects on each restart to avoid stale connections
|
||||
request = HTTPXRequest(proxy=proxy, read_timeout=30, write_timeout=30, connect_timeout=30, pool_timeout=30)
|
||||
app = (ApplicationBuilder()
|
||||
.token(mykey.tg_bot_token)
|
||||
.request(request)
|
||||
.get_updates_request(request)
|
||||
.build())
|
||||
app.add_handler(CommandHandler("stop", cmd_abort))
|
||||
app.add_handler(CommandHandler("llm", cmd_llm))
|
||||
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_msg))
|
||||
app.add_error_handler(_error_handler)
|
||||
|
||||
app.run_polling(
|
||||
drop_pending_updates=True,
|
||||
poll_interval=1.0,
|
||||
@@ -122,3 +124,4 @@ if __name__ == '__main__':
|
||||
except Exception as e:
|
||||
print(f"[{time.strftime('%m-%d %H:%M')}] polling crashed: {e}", flush=True)
|
||||
time.sleep(10)
|
||||
asyncio.set_event_loop(asyncio.new_event_loop())
|
||||
|
||||
Reference in New Issue
Block a user