feat: add XaiSession backend, improve handler switch & tg resilience
This commit is contained in:
14
agentmain.py
14
agentmain.py
@@ -5,7 +5,7 @@ if sys.stderr is None: sys.stderr = open(os.devnull, "w")
|
|||||||
elif hasattr(sys.stderr, 'reconfigure'): sys.stderr.reconfigure(errors='replace')
|
elif hasattr(sys.stderr, 'reconfigure'): sys.stderr.reconfigure(errors='replace')
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
from sidercall import SiderLLMSession, LLMSession, ToolClient, ClaudeSession
|
from sidercall import SiderLLMSession, LLMSession, ToolClient, ClaudeSession, XaiSession
|
||||||
from agent_loop import agent_runner_loop, StepOutcome, BaseHandler
|
from agent_loop import agent_runner_loop, StepOutcome, BaseHandler
|
||||||
from ga import GenericAgentHandler, smart_format, get_global_memory, format_error
|
from ga import GenericAgentHandler, smart_format, get_global_memory, format_error
|
||||||
|
|
||||||
@@ -28,14 +28,15 @@ def get_system_prompt():
|
|||||||
class GeneraticAgent:
|
class GeneraticAgent:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if not os.path.exists('temp'): os.makedirs('temp')
|
if not os.path.exists('temp'): os.makedirs('temp')
|
||||||
from sidercall import sider_cookie, oai_configs, claude_configs
|
from sidercall import sider_cookie, oai_configs, claude_configs, xai_api_key, proxy
|
||||||
llm_sessions = []
|
llm_sessions = []
|
||||||
for cfg in claude_configs.values():
|
for cfg in claude_configs.values():
|
||||||
llm_sessions += [ClaudeSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model'])]
|
llm_sessions += [ClaudeSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model'])]
|
||||||
if sider_cookie: llm_sessions += [SiderLLMSession(default_model=x) for x in \
|
if sider_cookie: llm_sessions += [SiderLLMSession(default_model=x) for x in \
|
||||||
["gemini-3.0-flash", "claude-haiku-4.5", "kimi-k2"]]
|
["gemini-3.0-flash", "claude-haiku-4.5", "kimi-k2"]]
|
||||||
|
if xai_api_key: llm_sessions += [XaiSession(xai_api_key, proxy)]
|
||||||
for cfg in oai_configs.values():
|
for cfg in oai_configs.values():
|
||||||
llm_sessions += [LLMSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model'])]
|
llm_sessions += [LLMSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model'], proxy=cfg.get('proxy'))]
|
||||||
if len(llm_sessions) > 0: self.llmclient = ToolClient(llm_sessions, auto_save_tokens=True)
|
if len(llm_sessions) > 0: self.llmclient = ToolClient(llm_sessions, auto_save_tokens=True)
|
||||||
else: self.llmclient = None
|
else: self.llmclient = None
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
@@ -51,7 +52,9 @@ class GeneraticAgent:
|
|||||||
self.llm_no = ((self.llm_no + 1) if n < 0 else n) % len(self.llmclient.backends)
|
self.llm_no = ((self.llm_no + 1) if n < 0 else n) % len(self.llmclient.backends)
|
||||||
self.llmclient.last_tools = ''
|
self.llmclient.last_tools = ''
|
||||||
def list_llms(self): return [(i, b.default_model, i == self.llm_no) for i, b in enumerate(self.llmclient.backends)]
|
def list_llms(self): return [(i, b.default_model, i == self.llm_no) for i, b in enumerate(self.llmclient.backends)]
|
||||||
def get_llm_name(self): return self.llmclient.backends[self.llm_no].default_model
|
def get_llm_name(self):
|
||||||
|
b = self.llmclient.backends[self.llm_no]
|
||||||
|
return f"{type(b).__name__}/{b.default_model}"
|
||||||
|
|
||||||
def abort(self):
|
def abort(self):
|
||||||
print('Abort current task...')
|
print('Abort current task...')
|
||||||
@@ -75,6 +78,9 @@ class GeneraticAgent:
|
|||||||
|
|
||||||
sys_prompt = get_system_prompt()
|
sys_prompt = get_system_prompt()
|
||||||
handler = GenericAgentHandler(None, self.history, './temp')
|
handler = GenericAgentHandler(None, self.history, './temp')
|
||||||
|
if self.handler and self.handler.key_info:
|
||||||
|
handler.key_info = self.handler.key_info
|
||||||
|
handler.key_info += '\n如你确信任务已经改变,请先更新工作记忆去除无用部分\n'
|
||||||
self.handler = handler
|
self.handler = handler
|
||||||
self.llmclient.backend = self.llmclient.backends[self.llm_no]
|
self.llmclient.backend = self.llmclient.backends[self.llm_no]
|
||||||
gen = agent_runner_loop(self.llmclient, sys_prompt, raw_query,
|
gen = agent_runner_loop(self.llmclient, sys_prompt, raw_query,
|
||||||
|
|||||||
@@ -30,3 +30,13 @@
|
|||||||
## 导航避坑
|
## 导航避坑
|
||||||
- `web_scan` 仅读当前页,不会导航。
|
- `web_scan` 仅读当前页,不会导航。
|
||||||
- 切换网站用 `web_execute_js` + `location.href = 'url'`。
|
- 切换网站用 `web_execute_js` + `location.href = 'url'`。
|
||||||
|
|
||||||
|
## Google图片搜索操作
|
||||||
|
- **class名不可靠**:Google的class均为混淆名(如F0uyec),随版本变化,禁止硬编码
|
||||||
|
- 点击图片结果:找搜索结果区内 `[role=button]` 的div,而非外层容器或内部a/img
|
||||||
|
- `web_scan` 会过滤边栏内容,边栏弹出后用JS提取:
|
||||||
|
- 文本:`document.body.innerText`
|
||||||
|
- 大图:遍历所有img,按 `naturalWidth` 最大的那个取src(通常>600px)
|
||||||
|
- "访问"链接:遍历所有`a`找`textContent.includes('访问')`的href
|
||||||
|
- 缩略图base64:结果中`img[src^="data:image"]`可直接提取保存
|
||||||
|
- 下载大图时注意JS返回的src可能被截断,用`return img.src`获取完整URL
|
||||||
|
|||||||
54
sidercall.py
54
sidercall.py
@@ -9,6 +9,7 @@ sider_cookie = mykeys.get("sider_cookie")
|
|||||||
oai_configs = {k: v for k, v in vars(mykey).items() if k.startswith("oai_config") and v}
|
oai_configs = {k: v for k, v in vars(mykey).items() if k.startswith("oai_config") and v}
|
||||||
claude_configs = {k: v for k, v in vars(mykey).items() if k.startswith("claude_config") and v}
|
claude_configs = {k: v for k, v in vars(mykey).items() if k.startswith("claude_config") and v}
|
||||||
google_api_key = mykeys.get("google_api_key")
|
google_api_key = mykeys.get("google_api_key")
|
||||||
|
xai_api_key = mykeys.get("xai_api_key")
|
||||||
|
|
||||||
proxy = mykeys.get("proxy", 'http://127.0.0.1:2082')
|
proxy = mykeys.get("proxy", 'http://127.0.0.1:2082')
|
||||||
proxies = {"http": proxy, "https": proxy} if proxy else None
|
proxies = {"http": proxy, "https": proxy} if proxy else None
|
||||||
@@ -90,13 +91,10 @@ class ClaudeSession:
|
|||||||
return _ask_gen() if stream else ''.join(list(_ask_gen()))
|
return _ask_gen() if stream else ''.join(list(_ask_gen()))
|
||||||
|
|
||||||
class LLMSession:
|
class LLMSession:
|
||||||
def __init__(self, api_key, api_base, model, context_win=16000):
|
def __init__(self, api_key, api_base, model, context_win=12000, proxy=None):
|
||||||
self.api_key = api_key
|
self.api_key = api_key; self.api_base = api_base; self.default_model = model
|
||||||
self.api_base = api_base
|
self.context_win = context_win; self.raw_msgs = []; self.messages = []
|
||||||
self.raw_msgs = []
|
self.proxies = {"http": proxy, "https": proxy} if proxy else None
|
||||||
self.messages = []
|
|
||||||
self.context_win = context_win
|
|
||||||
self.default_model = model
|
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
def raw_ask(self, messages, model=None, temperature=0.5):
|
def raw_ask(self, messages, model=None, temperature=0.5):
|
||||||
@@ -104,8 +102,8 @@ class LLMSession:
|
|||||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", "Accept": "text/event-stream"}
|
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", "Accept": "text/event-stream"}
|
||||||
payload = {"model": model, "messages": messages, "temperature": temperature, "stream": True}
|
payload = {"model": model, "messages": messages, "temperature": temperature, "stream": True}
|
||||||
try:
|
try:
|
||||||
with requests.post(f"{self.api_base}/chat/completions",
|
with requests.post(f"{self.api_base}/chat/completions", headers=headers,
|
||||||
headers=headers, json=payload, stream=True, timeout=(5, 60)) as r:
|
json=payload, stream=True, timeout=(5, 60), proxies=self.proxies) as r:
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
buffer = ''
|
buffer = ''
|
||||||
for line in r.iter_lines():
|
for line in r.iter_lines():
|
||||||
@@ -209,6 +207,44 @@ class GeminiSession:
|
|||||||
return f"[GeminiError] invalid response format: {e}"
|
return f"[GeminiError] invalid response format: {e}"
|
||||||
return iter([full_text]) if stream else full_text
|
return iter([full_text]) if stream else full_text
|
||||||
|
|
||||||
|
class XaiSession:
|
||||||
|
def __init__(self, api_key, proxy="http://127.0.0.1:2082", default_model="grok-4-1-fast-non-reasoning"):
|
||||||
|
import xai_sdk
|
||||||
|
from xai_sdk.chat import user, system
|
||||||
|
self._user, self._system = user, system
|
||||||
|
self.default_model = default_model
|
||||||
|
self._last_response_id = None # 多轮对话链
|
||||||
|
os.environ["XAI_API_KEY"] = api_key
|
||||||
|
if not proxy.startswith("http"): proxy = f"http://{proxy}"
|
||||||
|
os.environ.setdefault("grpc_proxy", proxy)
|
||||||
|
self._client = xai_sdk.Client()
|
||||||
|
def ask(self, prompt, model=None, system_prompt=None, stream=False):
|
||||||
|
"""发送消息,自动串联多轮对话;stream=True返回生成器"""
|
||||||
|
mdl = model or self.default_model
|
||||||
|
try:
|
||||||
|
kw = dict(model=mdl, store_messages=True)
|
||||||
|
if self._last_response_id: kw["previous_response_id"] = self._last_response_id
|
||||||
|
chat = self._client.chat.create(**kw)
|
||||||
|
if system_prompt: chat.append(self._system(system_prompt))
|
||||||
|
chat.append(self._user(prompt))
|
||||||
|
if stream: return self._stream(chat)
|
||||||
|
resp = chat.sample()
|
||||||
|
self._last_response_id = resp.id
|
||||||
|
return resp.content
|
||||||
|
except Exception as e:
|
||||||
|
err = f"[XaiError] {e}"
|
||||||
|
return iter([err]) if stream else err
|
||||||
|
def _stream(self, chat):
|
||||||
|
try:
|
||||||
|
last_resp = None
|
||||||
|
for resp, chunk in chat.stream():
|
||||||
|
last_resp = resp
|
||||||
|
if chunk and chunk.content: yield chunk.content
|
||||||
|
if last_resp and hasattr(last_resp, 'id'): self._last_response_id = last_resp.id
|
||||||
|
except Exception as e:
|
||||||
|
yield f"[XaiError] {e}"
|
||||||
|
def reset(self): self._last_response_id = None
|
||||||
|
|
||||||
class MockFunction:
|
class MockFunction:
|
||||||
def __init__(self, name, arguments): self.name, self.arguments = name, arguments
|
def __init__(self, name, arguments): self.name, self.arguments = name, arguments
|
||||||
|
|
||||||
|
|||||||
1
stapp.py
1
stapp.py
@@ -42,6 +42,7 @@ def render_sidebar():
|
|||||||
if st.button("强行停止任务"):
|
if st.button("强行停止任务"):
|
||||||
agent.abort()
|
agent.abort()
|
||||||
st.toast("已发送停止信号")
|
st.toast("已发送停止信号")
|
||||||
|
st.rerun()
|
||||||
if st.button("重新注入System Prompt"):
|
if st.button("重新注入System Prompt"):
|
||||||
agent.llmclient.last_tools = ''
|
agent.llmclient.last_tools = ''
|
||||||
st.toast("下次将重新注入System Prompt")
|
st.toast("下次将重新注入System Prompt")
|
||||||
|
|||||||
15
tgapp.py
15
tgapp.py
@@ -97,6 +97,14 @@ if __name__ == '__main__':
|
|||||||
threading.Thread(target=agent.run, daemon=True).start()
|
threading.Thread(target=agent.run, daemon=True).start()
|
||||||
proxy = vars(mykey).get('proxy', 'http://127.0.0.1:2082')
|
proxy = vars(mykey).get('proxy', 'http://127.0.0.1:2082')
|
||||||
print('proxy:', proxy)
|
print('proxy:', proxy)
|
||||||
|
|
||||||
|
async def _error_handler(update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
|
print(f"[{time.strftime('%m-%d %H:%M')}] TG error: {context.error}", flush=True)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
print(f"TG bot starting... {time.strftime('%m-%d %H:%M')}")
|
||||||
|
# Recreate request and app objects on each restart to avoid stale connections
|
||||||
request = HTTPXRequest(proxy=proxy, read_timeout=30, write_timeout=30, connect_timeout=30, pool_timeout=30)
|
request = HTTPXRequest(proxy=proxy, read_timeout=30, write_timeout=30, connect_timeout=30, pool_timeout=30)
|
||||||
app = (ApplicationBuilder()
|
app = (ApplicationBuilder()
|
||||||
.token(mykey.tg_bot_token)
|
.token(mykey.tg_bot_token)
|
||||||
@@ -106,14 +114,8 @@ if __name__ == '__main__':
|
|||||||
app.add_handler(CommandHandler("stop", cmd_abort))
|
app.add_handler(CommandHandler("stop", cmd_abort))
|
||||||
app.add_handler(CommandHandler("llm", cmd_llm))
|
app.add_handler(CommandHandler("llm", cmd_llm))
|
||||||
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_msg))
|
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_msg))
|
||||||
|
|
||||||
async def _error_handler(update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
print(f"[{time.strftime('%m-%d %H:%M')}] TG error: {context.error}", flush=True)
|
|
||||||
app.add_error_handler(_error_handler)
|
app.add_error_handler(_error_handler)
|
||||||
|
|
||||||
print(f"TG bot starting... {time.strftime('%m-%d %H:%M')}")
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
app.run_polling(
|
app.run_polling(
|
||||||
drop_pending_updates=True,
|
drop_pending_updates=True,
|
||||||
poll_interval=1.0,
|
poll_interval=1.0,
|
||||||
@@ -122,3 +124,4 @@ if __name__ == '__main__':
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{time.strftime('%m-%d %H:%M')}] polling crashed: {e}", flush=True)
|
print(f"[{time.strftime('%m-%d %H:%M')}] polling crashed: {e}", flush=True)
|
||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
|
asyncio.set_event_loop(asyncio.new_event_loop())
|
||||||
|
|||||||
Reference in New Issue
Block a user