refactor: unify LLM session constructors to accept cfg dict

This commit is contained in:
Liang Jiaqing
2026-03-17 16:52:12 +08:00
parent d896355bad
commit 00038cd1ae
2 changed files with 39 additions and 47 deletions

View File

@@ -54,10 +54,10 @@ def build_multimodal_content(prompt_text, image_paths):
return parts
class SiderLLMSession:
def __init__(self, sider_cookie, default_model="gemini-3.0-flash"):
def __init__(self, cfg):
from sider_ai_api import Session # 不使用sider的话没必要安装这个包
self._core = Session(cookie=sider_cookie, proxies=proxies)
self.default_model = default_model
self._core = Session(cookie=cfg['apikey'], proxies=proxies)
self.default_model = cfg.get('model', 'gemini-3.0-flash')
def ask(self, prompt, model=None, stream=False):
if model is None: model = self.default_model
if len(prompt) > 28000:
@@ -68,8 +68,10 @@ class SiderLLMSession:
return full_text
class ClaudeSession:
def __init__(self, api_key, api_base, model="claude-opus", context_win=12000):
self.api_key, self.api_base, self.default_model, self.context_win = api_key, api_base.rstrip('/'), model, context_win
def __init__(self, cfg):
self.api_key = cfg['apikey']; self.api_base = cfg['apibase'].rstrip('/')
self.default_model = cfg.get('model', 'claude-opus')
self.context_win = cfg.get('context_win', 12000)
self.raw_msgs, self.lock = [], threading.Lock()
def _trim_messages(self, messages):
compress_history_tags(messages)
@@ -118,28 +120,28 @@ class ClaudeSession:
return _ask_gen() if stream else ''.join(list(_ask_gen()))
class LLMSession:
def __init__(self, api_key, api_base, model, context_win=16000, proxy=None, api_mode="chat_completions",
max_retries=2, connect_timeout=10, read_timeout=120):
self.api_key = api_key; self.api_base = api_base.rstrip('/'); self.default_model = model
self.context_win = context_win; self.raw_msgs = []; self.messages = []
def __init__(self, cfg):
self.api_key = cfg['apikey']; self.api_base = cfg['apibase'].rstrip('/')
self.default_model = cfg['model']
self.context_win = cfg.get('context_win', 16000)
self.raw_msgs, self.messages = [], []
proxy = cfg.get('proxy')
self.proxies = {"http": proxy, "https": proxy} if proxy else None
self.prompt_cache = cfg.get('prompt_cache', False)
self.lock = threading.Lock()
self.max_retries = max(0, int(max_retries))
self.connect_timeout = max(1, int(connect_timeout))
self.read_timeout = max(5, int(read_timeout))
mode = str(api_mode or "chat_completions").strip().lower().replace('-', '_')
self.max_retries = max(0, int(cfg.get('max_retries', 2)))
self.connect_timeout = max(1, int(cfg.get('connect_timeout', 10)))
self.read_timeout = max(5, int(cfg.get('read_timeout', 120)))
mode = str(cfg.get('api_mode', 'chat_completions')).strip().lower().replace('-', '_')
if mode in ["responses", "response"]: self.api_mode = "responses"
else: self.api_mode = "chat_completions"
def _retry_delay(self, resp, attempt):
retry_after = None
try:
if resp is not None:
retry_after = (resp.headers or {}).get("retry-after")
if retry_after is not None:
retry_after = float(retry_after)
except:
retry_after = None
if resp is not None: retry_after = (resp.headers or {}).get("retry-after")
if retry_after is not None: retry_after = float(retry_after)
except: retry_after = None
if retry_after is None: retry_after = min(30.0, 1.5 * (2 ** attempt))
return max(0.5, float(retry_after))
@@ -312,7 +314,7 @@ class LLMSession:
content += chunk; yield chunk
if not content.startswith("Error:"):
self.raw_msgs.append({"role": "assistant", "prompt": content, "image": None})
if total_len > 5000: print(f"[Debug] Whole context length {total_len} {str(msg_lens)}.")
if total_len > self.context_win // 2: print(f"[Debug] Whole context length {total_len} {str(msg_lens)}.")
if total_len > self.context_win:
yield '[NextWillSummary]'
threading.Thread(target=self.summary_history, daemon=True).start()
@@ -321,11 +323,12 @@ class LLMSession:
class GeminiSession:
def __init__(self, api_key=None, default_model="gemini-2.0-flash-001", proxy=proxy):
self.api_key = api_key or google_api_key
def __init__(self, cfg):
self.api_key = cfg.get('apikey') or google_api_key
if not self.api_key: raise ValueError("google_api_key 未配置或为空,请在 mykey.py 中设置")
self.default_model = default_model
self.proxies = {"http":proxy, "https":proxy} if proxy else None
self.default_model = cfg.get('model', 'gemini-2.0-flash-001')
p = cfg.get('proxy', proxy)
self.proxies = {"http":p, "https":p} if p else None
def ask(self, prompt, model=None, stream=False):
if model is None: model = self.default_model
url = f"https://generativelanguage.googleapis.com/v1/models/{model}:generateContent?key={self.api_key}"
@@ -349,13 +352,14 @@ class GeminiSession:
return iter([full_text]) if stream else full_text
class XaiSession:
def __init__(self, api_key, proxy="http://127.0.0.1:2082", default_model="grok-4-1-fast-non-reasoning"):
def __init__(self, cfg):
import xai_sdk
from xai_sdk.chat import user, system
self._user, self._system = user, system
self.default_model = default_model
self.default_model = cfg.get('model', 'grok-4-1-fast-non-reasoning')
self._last_response_id = None # 多轮对话链
os.environ["XAI_API_KEY"] = api_key
os.environ["XAI_API_KEY"] = cfg['apikey']
proxy = cfg.get('proxy', 'http://127.0.0.1:2082')
if not proxy.startswith("http"): proxy = f"http://{proxy}"
os.environ.setdefault("grpc_proxy", proxy)
self._client = xai_sdk.Client()
@@ -598,10 +602,7 @@ if __name__ == "__main__":
}
google_api_key = mykeys.get("google_api_key")
cfg = oai_configs.get("oai_config")
llmclient = ToolClient(GeminiSession(api_key=google_api_key, proxy='127.0.0.1:2082').ask)
#llmclient = ToolClient(LLMSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model']).ask)
#llmclient = ToolClient(SiderLLMSession().ask)
llmclient = ToolClient(LLMSession(cfg))
def get_final(gen):
try:
while True: print('mid:', next(gen))