refactor: unify LLM session constructors to accept cfg dict

This commit is contained in:
Liang Jiaqing
2026-03-17 16:52:12 +08:00
parent d896355bad
commit 00038cd1ae
2 changed files with 39 additions and 47 deletions

View File

@@ -44,16 +44,10 @@ class GeneraticAgent:
for k, cfg in mykeys.items(): for k, cfg in mykeys.items():
if not any(x in k for x in ['api', 'config', 'cookie']): continue if not any(x in k for x in ['api', 'config', 'cookie']): continue
try: try:
if 'claude' in k: llm_sessions += [ClaudeSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model'])] if 'claude' in k: llm_sessions += [ClaudeSession(cfg=cfg)]
if 'oai' in k: llm_sessions += [LLMSession( if 'oai' in k: llm_sessions += [LLMSession(cfg=cfg)]
api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model'], proxy=cfg.get('proxy'), if 'xai' in k: llm_sessions += [XaiSession(cfg=cfg)]
api_mode=cfg.get('api_mode', 'chat_completions'), if 'sider' in k: llm_sessions += [SiderLLMSession(cfg={'apikey': cfg, 'model': x}) for x in \
max_retries=cfg.get('max_retries', 2),
connect_timeout=cfg.get('connect_timeout', 10),
read_timeout=cfg.get('read_timeout', 120),
)]
if 'xai' in k: llm_sessions += [XaiSession(cfg, mykeys.get('proxy', ''))]
if 'sider' in k: llm_sessions += [SiderLLMSession(cfg, default_model=x) for x in \
["gemini-3.0-flash", "gpt-5.4"]] ["gemini-3.0-flash", "gpt-5.4"]]
except: pass except: pass
if len(llm_sessions) > 0: self.llmclient = ToolClient(llm_sessions, auto_save_tokens=True) if len(llm_sessions) > 0: self.llmclient = ToolClient(llm_sessions, auto_save_tokens=True)
@@ -62,10 +56,8 @@ class GeneraticAgent:
self.history = [] self.history = []
self.task_queue = queue.Queue() self.task_queue = queue.Queue()
self.is_running, self.stop_sig = False, False self.is_running, self.stop_sig = False, False
self.llm_no = 0 self.llm_no = 0; self.inc_out = False
self.inc_out = False self.handler = None; self.verbose = True
self.handler = None
self.verbose = True
def next_llm(self, n=-1): def next_llm(self, n=-1):
self.llm_no = ((self.llm_no + 1) if n < 0 else n) % len(self.llmclient.backends) self.llm_no = ((self.llm_no + 1) if n < 0 else n) % len(self.llmclient.backends)
@@ -79,8 +71,7 @@ class GeneraticAgent:
print('Abort current task...') print('Abort current task...')
if not self.is_running: return if not self.is_running: return
self.stop_sig = True self.stop_sig = True
if self.handler is not None: if self.handler is not None: self.handler.code_stop_signal.append(1)
self.handler.code_stop_signal.append(1)
def put_task(self, query, source="user", images=None): def put_task(self, query, source="user", images=None):
display_queue = queue.Queue() display_queue = queue.Queue()

View File

@@ -54,10 +54,10 @@ def build_multimodal_content(prompt_text, image_paths):
return parts return parts
class SiderLLMSession: class SiderLLMSession:
def __init__(self, sider_cookie, default_model="gemini-3.0-flash"): def __init__(self, cfg):
from sider_ai_api import Session # 不使用sider的话没必要安装这个包 from sider_ai_api import Session # 不使用sider的话没必要安装这个包
self._core = Session(cookie=sider_cookie, proxies=proxies) self._core = Session(cookie=cfg['apikey'], proxies=proxies)
self.default_model = default_model self.default_model = cfg.get('model', 'gemini-3.0-flash')
def ask(self, prompt, model=None, stream=False): def ask(self, prompt, model=None, stream=False):
if model is None: model = self.default_model if model is None: model = self.default_model
if len(prompt) > 28000: if len(prompt) > 28000:
@@ -68,8 +68,10 @@ class SiderLLMSession:
return full_text return full_text
class ClaudeSession: class ClaudeSession:
def __init__(self, api_key, api_base, model="claude-opus", context_win=12000): def __init__(self, cfg):
self.api_key, self.api_base, self.default_model, self.context_win = api_key, api_base.rstrip('/'), model, context_win self.api_key = cfg['apikey']; self.api_base = cfg['apibase'].rstrip('/')
self.default_model = cfg.get('model', 'claude-opus')
self.context_win = cfg.get('context_win', 12000)
self.raw_msgs, self.lock = [], threading.Lock() self.raw_msgs, self.lock = [], threading.Lock()
def _trim_messages(self, messages): def _trim_messages(self, messages):
compress_history_tags(messages) compress_history_tags(messages)
@@ -118,28 +120,28 @@ class ClaudeSession:
return _ask_gen() if stream else ''.join(list(_ask_gen())) return _ask_gen() if stream else ''.join(list(_ask_gen()))
class LLMSession: class LLMSession:
def __init__(self, api_key, api_base, model, context_win=16000, proxy=None, api_mode="chat_completions", def __init__(self, cfg):
max_retries=2, connect_timeout=10, read_timeout=120): self.api_key = cfg['apikey']; self.api_base = cfg['apibase'].rstrip('/')
self.api_key = api_key; self.api_base = api_base.rstrip('/'); self.default_model = model self.default_model = cfg['model']
self.context_win = context_win; self.raw_msgs = []; self.messages = [] self.context_win = cfg.get('context_win', 16000)
self.raw_msgs, self.messages = [], []
proxy = cfg.get('proxy')
self.proxies = {"http": proxy, "https": proxy} if proxy else None self.proxies = {"http": proxy, "https": proxy} if proxy else None
self.prompt_cache = cfg.get('prompt_cache', False)
self.lock = threading.Lock() self.lock = threading.Lock()
self.max_retries = max(0, int(max_retries)) self.max_retries = max(0, int(cfg.get('max_retries', 2)))
self.connect_timeout = max(1, int(connect_timeout)) self.connect_timeout = max(1, int(cfg.get('connect_timeout', 10)))
self.read_timeout = max(5, int(read_timeout)) self.read_timeout = max(5, int(cfg.get('read_timeout', 120)))
mode = str(api_mode or "chat_completions").strip().lower().replace('-', '_') mode = str(cfg.get('api_mode', 'chat_completions')).strip().lower().replace('-', '_')
if mode in ["responses", "response"]: self.api_mode = "responses" if mode in ["responses", "response"]: self.api_mode = "responses"
else: self.api_mode = "chat_completions" else: self.api_mode = "chat_completions"
def _retry_delay(self, resp, attempt): def _retry_delay(self, resp, attempt):
retry_after = None retry_after = None
try: try:
if resp is not None: if resp is not None: retry_after = (resp.headers or {}).get("retry-after")
retry_after = (resp.headers or {}).get("retry-after") if retry_after is not None: retry_after = float(retry_after)
if retry_after is not None: except: retry_after = None
retry_after = float(retry_after)
except:
retry_after = None
if retry_after is None: retry_after = min(30.0, 1.5 * (2 ** attempt)) if retry_after is None: retry_after = min(30.0, 1.5 * (2 ** attempt))
return max(0.5, float(retry_after)) return max(0.5, float(retry_after))
@@ -312,7 +314,7 @@ class LLMSession:
content += chunk; yield chunk content += chunk; yield chunk
if not content.startswith("Error:"): if not content.startswith("Error:"):
self.raw_msgs.append({"role": "assistant", "prompt": content, "image": None}) self.raw_msgs.append({"role": "assistant", "prompt": content, "image": None})
if total_len > 5000: print(f"[Debug] Whole context length {total_len} {str(msg_lens)}.") if total_len > self.context_win // 2: print(f"[Debug] Whole context length {total_len} {str(msg_lens)}.")
if total_len > self.context_win: if total_len > self.context_win:
yield '[NextWillSummary]' yield '[NextWillSummary]'
threading.Thread(target=self.summary_history, daemon=True).start() threading.Thread(target=self.summary_history, daemon=True).start()
@@ -321,11 +323,12 @@ class LLMSession:
class GeminiSession: class GeminiSession:
def __init__(self, api_key=None, default_model="gemini-2.0-flash-001", proxy=proxy): def __init__(self, cfg):
self.api_key = api_key or google_api_key self.api_key = cfg.get('apikey') or google_api_key
if not self.api_key: raise ValueError("google_api_key 未配置或为空,请在 mykey.py 中设置") if not self.api_key: raise ValueError("google_api_key 未配置或为空,请在 mykey.py 中设置")
self.default_model = default_model self.default_model = cfg.get('model', 'gemini-2.0-flash-001')
self.proxies = {"http":proxy, "https":proxy} if proxy else None p = cfg.get('proxy', proxy)
self.proxies = {"http":p, "https":p} if p else None
def ask(self, prompt, model=None, stream=False): def ask(self, prompt, model=None, stream=False):
if model is None: model = self.default_model if model is None: model = self.default_model
url = f"https://generativelanguage.googleapis.com/v1/models/{model}:generateContent?key={self.api_key}" url = f"https://generativelanguage.googleapis.com/v1/models/{model}:generateContent?key={self.api_key}"
@@ -349,13 +352,14 @@ class GeminiSession:
return iter([full_text]) if stream else full_text return iter([full_text]) if stream else full_text
class XaiSession: class XaiSession:
def __init__(self, api_key, proxy="http://127.0.0.1:2082", default_model="grok-4-1-fast-non-reasoning"): def __init__(self, cfg):
import xai_sdk import xai_sdk
from xai_sdk.chat import user, system from xai_sdk.chat import user, system
self._user, self._system = user, system self._user, self._system = user, system
self.default_model = default_model self.default_model = cfg.get('model', 'grok-4-1-fast-non-reasoning')
self._last_response_id = None # 多轮对话链 self._last_response_id = None # 多轮对话链
os.environ["XAI_API_KEY"] = api_key os.environ["XAI_API_KEY"] = cfg['apikey']
proxy = cfg.get('proxy', 'http://127.0.0.1:2082')
if not proxy.startswith("http"): proxy = f"http://{proxy}" if not proxy.startswith("http"): proxy = f"http://{proxy}"
os.environ.setdefault("grpc_proxy", proxy) os.environ.setdefault("grpc_proxy", proxy)
self._client = xai_sdk.Client() self._client = xai_sdk.Client()
@@ -598,10 +602,7 @@ if __name__ == "__main__":
} }
google_api_key = mykeys.get("google_api_key") google_api_key = mykeys.get("google_api_key")
cfg = oai_configs.get("oai_config") cfg = oai_configs.get("oai_config")
llmclient = ToolClient(LLMSession(cfg))
llmclient = ToolClient(GeminiSession(api_key=google_api_key, proxy='127.0.0.1:2082').ask)
#llmclient = ToolClient(LLMSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model']).ask)
#llmclient = ToolClient(SiderLLMSession().ask)
def get_final(gen): def get_final(gen):
try: try:
while True: print('mid:', next(gen)) while True: print('mid:', next(gen))