Refactor LLM backend structure and optimize config/launch
This commit is contained in:
12
agentmain.py
12
agentmain.py
@@ -34,7 +34,7 @@ class GeneraticAgent:
|
|||||||
["gemini-3.0-flash", "claude-haiku-4.5", "kimi-k2"]]
|
["gemini-3.0-flash", "claude-haiku-4.5", "kimi-k2"]]
|
||||||
if oai_apikey: llm_sessions += [LLMSession(api_key=oai_apikey, api_base=oai_apibase)]
|
if oai_apikey: llm_sessions += [LLMSession(api_key=oai_apikey, api_base=oai_apibase)]
|
||||||
if len(llm_sessions) > 0:
|
if len(llm_sessions) > 0:
|
||||||
llmclient = ToolClient([x.ask for x in llm_sessions], auto_save_tokens=True)
|
llmclient = ToolClient(llm_sessions, auto_save_tokens=True)
|
||||||
self.llmclient = llmclient
|
self.llmclient = llmclient
|
||||||
else:
|
else:
|
||||||
self.llmclient = None
|
self.llmclient = None
|
||||||
@@ -50,18 +50,20 @@ class GeneraticAgent:
|
|||||||
self.handler = None
|
self.handler = None
|
||||||
|
|
||||||
def next_llm(self):
|
def next_llm(self):
|
||||||
self.llm_no = (self.llm_no + 1) % len(self.llmclient.raw_apis)
|
self.llm_no = (self.llm_no + 1) % len(self.llmclient.backends)
|
||||||
self.llmclient.last_tools = ''
|
self.llmclient.last_tools = ''
|
||||||
|
|
||||||
def abort(self):
|
def abort(self):
|
||||||
print('About to abort current task...')
|
print('Abort current task...')
|
||||||
if not self.is_running: return
|
if not self.is_running: return
|
||||||
self.stop_sig = True
|
self.stop_sig = True
|
||||||
if self.handler is not None:
|
if self.handler is not None:
|
||||||
self.handler.code_stop_signal.append(1)
|
self.handler.code_stop_signal.append(1)
|
||||||
|
|
||||||
def put_task(self, query, source="user"):
|
def put_task(self, query, source="user"):
|
||||||
self.display_queue.queue.clear()
|
while self.display_queue.qsize() > 0:
|
||||||
|
try: self.display_queue.get_nowait()
|
||||||
|
except queue.Empty: break
|
||||||
self.task_queue.put({"query": query, "source": source})
|
self.task_queue.put({"query": query, "source": source})
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
@@ -78,7 +80,7 @@ class GeneraticAgent:
|
|||||||
sys_prompt = get_system_prompt()
|
sys_prompt = get_system_prompt()
|
||||||
handler = GenericAgentHandler(None, self.history, './temp')
|
handler = GenericAgentHandler(None, self.history, './temp')
|
||||||
self.handler = handler
|
self.handler = handler
|
||||||
self.llmclient.raw_api = self.llmclient.raw_apis[self.llm_no]
|
self.llmclient.backend = self.llmclient.backends[self.llm_no]
|
||||||
gen = agent_runner_loop(self.llmclient, sys_prompt,
|
gen = agent_runner_loop(self.llmclient, sys_prompt,
|
||||||
raw_query, handler, TOOLS_SCHEMA, max_turns=25)
|
raw_query, handler, TOOLS_SCHEMA, max_turns=25)
|
||||||
|
|
||||||
|
|||||||
@@ -19,11 +19,11 @@ def get_screen_width():
|
|||||||
# 如果不是 Windows 或者出错了,返回一个兜底值 (比如 1920)
|
# 如果不是 Windows 或者出错了,返回一个兜底值 (比如 1920)
|
||||||
return 1920
|
return 1920
|
||||||
|
|
||||||
def start_streamlit():
|
def start_streamlit(port):
|
||||||
global proc
|
global proc
|
||||||
cmd = [
|
cmd = [
|
||||||
sys.executable, "-m", "streamlit", "run", "stapp.py",
|
sys.executable, "-m", "streamlit", "run", "stapp.py",
|
||||||
"--server.port", "8501",
|
"--server.port", str(port),
|
||||||
"--server.headless", "true",
|
"--server.headless", "true",
|
||||||
"--theme.base", "dark" #以此默认开启暗黑模式,更有极客感
|
"--theme.base", "dark" #以此默认开启暗黑模式,更有极客感
|
||||||
]
|
]
|
||||||
@@ -31,7 +31,8 @@ def start_streamlit():
|
|||||||
atexit.register(proc.kill)
|
atexit.register(proc.kill)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
t = threading.Thread(target=start_streamlit, daemon=True)
|
port = sys.argv[1] if len(sys.argv) > 1 else "8501"
|
||||||
|
t = threading.Thread(target=start_streamlit, args=(port,), daemon=True)
|
||||||
t.start()
|
t.start()
|
||||||
if os.name == 'nt':
|
if os.name == 'nt':
|
||||||
screen_width = get_screen_width()
|
screen_width = get_screen_width()
|
||||||
@@ -40,7 +41,7 @@ if __name__ == '__main__':
|
|||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
webview.create_window(
|
webview.create_window(
|
||||||
title='GenericAgent',
|
title='GenericAgent',
|
||||||
url='http://localhost:8501',
|
url=f'http://localhost:{port}',
|
||||||
width=WINDOW_WIDTH,
|
width=WINDOW_WIDTH,
|
||||||
height=WINDOW_HEIGHT,
|
height=WINDOW_HEIGHT,
|
||||||
x=x_pos, y=TOP_PADDING,
|
x=x_pos, y=TOP_PADDING,
|
||||||
|
|||||||
89
sidercall.py
89
sidercall.py
@@ -1,14 +1,22 @@
|
|||||||
import os, json, re, time, requests, sys, threading
|
import os, json, re, time, requests, sys, threading
|
||||||
|
|
||||||
try: from mykey import sider_cookie
|
try: import mykey
|
||||||
except ImportError: sider_cookie = ""
|
except: raise Exception('[ERROR] mykey.py not found, please copy mykey_template.py to mykey.py and fill your LLM backend.')
|
||||||
try: from mykey import oai_apikey, oai_apibase, oai_model
|
|
||||||
except ImportError: oai_apikey = oai_apibase = oai_model = ""
|
def get_config(name, default=""): return getattr(mykey, name, default)
|
||||||
|
|
||||||
|
sider_cookie = get_config("sider_cookie")
|
||||||
|
oai_apikey = get_config("oai_apikey")
|
||||||
|
oai_apibase = get_config("oai_apibase")
|
||||||
|
oai_model = get_config("oai_model")
|
||||||
|
google_api_key = get_config("google_api_key")
|
||||||
|
proxy = get_config("proxy", 'http://127.0.0.1:2082')
|
||||||
|
proxies = {"http": proxy, "https": proxy} if proxy else None
|
||||||
|
|
||||||
class SiderLLMSession:
|
class SiderLLMSession:
|
||||||
def __init__(self, default_model="gemini-3.0-flash"):
|
def __init__(self, default_model="gemini-3.0-flash"):
|
||||||
from sider_ai_api import Session
|
from sider_ai_api import Session
|
||||||
self._core = Session(cookie=sider_cookie, proxies={'https':'127.0.0.1:2082'})
|
self._core = Session(cookie=sider_cookie, proxies=proxies)
|
||||||
self.default_model = default_model
|
self.default_model = default_model
|
||||||
def ask(self, prompt, model=None, stream=False):
|
def ask(self, prompt, model=None, stream=False):
|
||||||
if model is None: model = self.default_model
|
if model is None: model = self.default_model
|
||||||
@@ -19,6 +27,34 @@ class SiderLLMSession:
|
|||||||
if stream: return iter([full_text]) # gen有奇怪的空回复或死循环行为,sider足够快
|
if stream: return iter([full_text]) # gen有奇怪的空回复或死循环行为,sider足够快
|
||||||
return full_text
|
return full_text
|
||||||
|
|
||||||
|
class GeminiSession:
|
||||||
|
def __init__(self, api_key=None, default_model="gemini-2.0-flash-001", proxy=proxy):
|
||||||
|
self.api_key = api_key or google_api_key
|
||||||
|
if not self.api_key: raise ValueError("google_api_key 未配置或为空,请在 mykey.py 中设置")
|
||||||
|
self.default_model = default_model
|
||||||
|
self.proxies = {"http":proxy, "https":proxy} if proxy else None
|
||||||
|
def ask(self, prompt, model=None, stream=False):
|
||||||
|
if model is None: model = self.default_model
|
||||||
|
url = f"https://generativelanguage.googleapis.com/v1/models/{model}:generateContent?key={self.api_key}"
|
||||||
|
headers = {"Content-Type":"application/json"}
|
||||||
|
data = {"contents":[{"role":"user","parts":[{"text":prompt}]}]}
|
||||||
|
try:
|
||||||
|
kw = {"headers":headers, "json":data, "timeout":60, 'proxies': self.proxies}
|
||||||
|
r = requests.post(url, **kw)
|
||||||
|
except Exception as e:
|
||||||
|
return f"[GeminiError] request failed: {e}"
|
||||||
|
if r.status_code != 200:
|
||||||
|
body = r.text[:500].replace("\n"," ")
|
||||||
|
return f"[GeminiError] HTTP {r.status_code}: {body}"
|
||||||
|
try:
|
||||||
|
obj = r.json(); cands = obj.get("candidates") or []
|
||||||
|
if not cands: return "[GeminiError] empty candidates"
|
||||||
|
parts = (cands[0].get("content") or {}).get("parts") or []
|
||||||
|
full_text = "".join(p.get("text","") for p in parts)
|
||||||
|
except Exception as e:
|
||||||
|
return f"[GeminiError] invalid response format: {e}"
|
||||||
|
return iter([full_text]) if stream else full_text
|
||||||
|
|
||||||
class LLMSession:
|
class LLMSession:
|
||||||
def __init__(self, api_key=oai_apikey, api_base=oai_apibase, model=oai_model, context_win=12000):
|
def __init__(self, api_key=oai_apikey, api_base=oai_apibase, model=oai_model, context_win=12000):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
@@ -26,11 +62,11 @@ class LLMSession:
|
|||||||
self.raw_msgs = []
|
self.raw_msgs = []
|
||||||
self.messages = []
|
self.messages = []
|
||||||
self.context_win = context_win
|
self.context_win = context_win
|
||||||
self.model = model
|
self.default_model = model
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
def raw_ask(self, messages, model=None, temperature=0.5):
|
def raw_ask(self, messages, model=None, temperature=0.5):
|
||||||
if model is None: model = self.model
|
if model is None: model = self.default_model
|
||||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", "Accept": "text/event-stream"}
|
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", "Accept": "text/event-stream"}
|
||||||
payload = {"model": model, "messages": messages, "temperature": temperature, "stream": True}
|
payload = {"model": model, "messages": messages, "temperature": temperature, "stream": True}
|
||||||
try:
|
try:
|
||||||
@@ -69,7 +105,7 @@ class LLMSession:
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
def summary_history(self, model=None):
|
def summary_history(self, model=None):
|
||||||
if model is None: model = self.model
|
if model is None: model = self.default_model
|
||||||
with self.lock:
|
with self.lock:
|
||||||
keep = 0; tok = 0
|
keep = 0; tok = 0
|
||||||
for m in reversed(self.raw_msgs):
|
for m in reversed(self.raw_msgs):
|
||||||
@@ -90,7 +126,7 @@ class LLMSession:
|
|||||||
else: self.raw_msgs = old + self.raw_msgs # 不做了,下次再做
|
else: self.raw_msgs = old + self.raw_msgs # 不做了,下次再做
|
||||||
|
|
||||||
def ask(self, prompt, model=None, image_base64=None, stream=False):
|
def ask(self, prompt, model=None, image_base64=None, stream=False):
|
||||||
if model is None: model = self.model
|
if model is None: model = self.default_model
|
||||||
def _ask_gen():
|
def _ask_gen():
|
||||||
content = ''
|
content = ''
|
||||||
with self.lock:
|
with self.lock:
|
||||||
@@ -132,10 +168,10 @@ class MockResponse:
|
|||||||
return f"<MockResponse thinking={bool(self.thinking)}, content='{self.content}', tools={bool(self.tool_calls)}>"
|
return f"<MockResponse thinking={bool(self.thinking)}, content='{self.content}', tools={bool(self.tool_calls)}>"
|
||||||
|
|
||||||
class ToolClient:
|
class ToolClient:
|
||||||
def __init__(self, raw_api_func, auto_save_tokens=False):
|
def __init__(self, backends, auto_save_tokens=False):
|
||||||
if isinstance(raw_api_func, list): self.raw_apis = raw_api_func
|
if isinstance(backends, list): self.backends = backends
|
||||||
else: self.raw_apis = [raw_api_func]
|
else: self.backends = [backends]
|
||||||
self.raw_api = self.raw_apis[0]
|
self.backend = self.backends[0]
|
||||||
self.auto_save_tokens = auto_save_tokens
|
self.auto_save_tokens = auto_save_tokens
|
||||||
self.last_tools = ''
|
self.last_tools = ''
|
||||||
self.total_cd_tokens = 0
|
self.total_cd_tokens = 0
|
||||||
@@ -145,7 +181,7 @@ class ToolClient:
|
|||||||
print("Full prompt length:", len(full_prompt), 'chars')
|
print("Full prompt length:", len(full_prompt), 'chars')
|
||||||
with open('model_responses.txt', 'a', encoding='utf-8', errors="replace") as f:
|
with open('model_responses.txt', 'a', encoding='utf-8', errors="replace") as f:
|
||||||
f.write(f"=== Prompt ===\n{full_prompt}\n")
|
f.write(f"=== Prompt ===\n{full_prompt}\n")
|
||||||
gen = self.raw_api(full_prompt, stream=True)
|
gen = self.backend.ask(full_prompt, stream=True)
|
||||||
raw_text = ''; summarytag = '[NextWillSummary]'
|
raw_text = ''; summarytag = '[NextWillSummary]'
|
||||||
for chunk in gen:
|
for chunk in gen:
|
||||||
raw_text += chunk;
|
raw_text += chunk;
|
||||||
@@ -243,19 +279,26 @@ def tryparse(json_str):
|
|||||||
return json.loads(json_str[:-1])
|
return json.loads(json_str[:-1])
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
import sys, os
|
||||||
try: from mykey import sider_cookie
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
except ImportError: sider_cookie = ""
|
try:
|
||||||
try: from mykey import oai_apikey, oai_apibase, oai_model
|
import mykey
|
||||||
except ImportError: oai_apikey = oai_apibase = oai_model = ""
|
except ImportError:
|
||||||
|
class MockMyKey: pass
|
||||||
|
mykey = MockMyKey()
|
||||||
|
|
||||||
|
sider_cookie = get_config("sider_cookie")
|
||||||
|
oai_apikey = get_config("oai_apikey")
|
||||||
|
oai_apibase = get_config("oai_apibase")
|
||||||
|
oai_model = get_config("oai_model")
|
||||||
|
google_api_key = get_config("google_api_key")
|
||||||
|
|
||||||
llmclient = ToolClient(LLMSession(api_key=oai_apikey, api_base=oai_apibase, model=oai_model).ask)
|
llmclient = ToolClient(GeminiSession(api_key=google_api_key, proxy='127.0.0.1:2082').ask)
|
||||||
print(llmclient.raw_api("Hello, world!", stream=False))
|
#llmclient = ToolClient(LLMSession(api_key=oai_apikey, api_base=oai_apibase, model=oai_model).ask)
|
||||||
#llmclient = ToolClient(SiderLLMSession().ask)
|
#llmclient = ToolClient(SiderLLMSession().ask)
|
||||||
def get_final(gen):
|
def get_final(gen):
|
||||||
try:
|
try:
|
||||||
while True:
|
while True: print('mid:', next(gen))
|
||||||
print('mid:', next(gen))
|
|
||||||
except StopIteration as e:
|
except StopIteration as e:
|
||||||
return e.value
|
return e.value
|
||||||
|
|
||||||
|
|||||||
5
stapp.py
5
stapp.py
@@ -27,13 +27,12 @@ if "idle_buf" not in st.session_state: st.session_state.idle_buf = ""
|
|||||||
if "messages" not in st.session_state: st.session_state.messages = []
|
if "messages" not in st.session_state: st.session_state.messages = []
|
||||||
|
|
||||||
for msg in st.session_state.messages:
|
for msg in st.session_state.messages:
|
||||||
with st.chat_message(msg["role"]):
|
with st.chat_message(msg["role"]): st.markdown(msg["content"])
|
||||||
st.markdown(msg["content"])
|
|
||||||
|
|
||||||
@st.fragment
|
@st.fragment
|
||||||
def render_llm_switcher():
|
def render_llm_switcher():
|
||||||
current_idx = agent.llm_no
|
current_idx = agent.llm_no
|
||||||
st.caption(f"LLM Core: {current_idx}")
|
st.caption(f"LLM Core: {current_idx}: {agent.llmclient.backends[current_idx].default_model}", help="点击切换备用链路")
|
||||||
if st.button("切换备用链路"):
|
if st.button("切换备用链路"):
|
||||||
agent.next_llm()
|
agent.next_llm()
|
||||||
st.rerun(scope="fragment")
|
st.rerun(scope="fragment")
|
||||||
|
|||||||
Reference in New Issue
Block a user