refactor: unify LLM session constructors to accept cfg dict
This commit is contained in:
63
llmcore.py
63
llmcore.py
@@ -54,10 +54,10 @@ def build_multimodal_content(prompt_text, image_paths):
|
||||
return parts
|
||||
|
||||
class SiderLLMSession:
|
||||
def __init__(self, sider_cookie, default_model="gemini-3.0-flash"):
|
||||
def __init__(self, cfg):
|
||||
from sider_ai_api import Session # 不使用sider的话没必要安装这个包
|
||||
self._core = Session(cookie=sider_cookie, proxies=proxies)
|
||||
self.default_model = default_model
|
||||
self._core = Session(cookie=cfg['apikey'], proxies=proxies)
|
||||
self.default_model = cfg.get('model', 'gemini-3.0-flash')
|
||||
def ask(self, prompt, model=None, stream=False):
|
||||
if model is None: model = self.default_model
|
||||
if len(prompt) > 28000:
|
||||
@@ -68,8 +68,10 @@ class SiderLLMSession:
|
||||
return full_text
|
||||
|
||||
class ClaudeSession:
|
||||
def __init__(self, api_key, api_base, model="claude-opus", context_win=12000):
|
||||
self.api_key, self.api_base, self.default_model, self.context_win = api_key, api_base.rstrip('/'), model, context_win
|
||||
def __init__(self, cfg):
|
||||
self.api_key = cfg['apikey']; self.api_base = cfg['apibase'].rstrip('/')
|
||||
self.default_model = cfg.get('model', 'claude-opus')
|
||||
self.context_win = cfg.get('context_win', 12000)
|
||||
self.raw_msgs, self.lock = [], threading.Lock()
|
||||
def _trim_messages(self, messages):
|
||||
compress_history_tags(messages)
|
||||
@@ -118,28 +120,28 @@ class ClaudeSession:
|
||||
return _ask_gen() if stream else ''.join(list(_ask_gen()))
|
||||
|
||||
class LLMSession:
|
||||
def __init__(self, api_key, api_base, model, context_win=16000, proxy=None, api_mode="chat_completions",
|
||||
max_retries=2, connect_timeout=10, read_timeout=120):
|
||||
self.api_key = api_key; self.api_base = api_base.rstrip('/'); self.default_model = model
|
||||
self.context_win = context_win; self.raw_msgs = []; self.messages = []
|
||||
def __init__(self, cfg):
|
||||
self.api_key = cfg['apikey']; self.api_base = cfg['apibase'].rstrip('/')
|
||||
self.default_model = cfg['model']
|
||||
self.context_win = cfg.get('context_win', 16000)
|
||||
self.raw_msgs, self.messages = [], []
|
||||
proxy = cfg.get('proxy')
|
||||
self.proxies = {"http": proxy, "https": proxy} if proxy else None
|
||||
self.prompt_cache = cfg.get('prompt_cache', False)
|
||||
self.lock = threading.Lock()
|
||||
self.max_retries = max(0, int(max_retries))
|
||||
self.connect_timeout = max(1, int(connect_timeout))
|
||||
self.read_timeout = max(5, int(read_timeout))
|
||||
mode = str(api_mode or "chat_completions").strip().lower().replace('-', '_')
|
||||
self.max_retries = max(0, int(cfg.get('max_retries', 2)))
|
||||
self.connect_timeout = max(1, int(cfg.get('connect_timeout', 10)))
|
||||
self.read_timeout = max(5, int(cfg.get('read_timeout', 120)))
|
||||
mode = str(cfg.get('api_mode', 'chat_completions')).strip().lower().replace('-', '_')
|
||||
if mode in ["responses", "response"]: self.api_mode = "responses"
|
||||
else: self.api_mode = "chat_completions"
|
||||
|
||||
def _retry_delay(self, resp, attempt):
|
||||
retry_after = None
|
||||
try:
|
||||
if resp is not None:
|
||||
retry_after = (resp.headers or {}).get("retry-after")
|
||||
if retry_after is not None:
|
||||
retry_after = float(retry_after)
|
||||
except:
|
||||
retry_after = None
|
||||
if resp is not None: retry_after = (resp.headers or {}).get("retry-after")
|
||||
if retry_after is not None: retry_after = float(retry_after)
|
||||
except: retry_after = None
|
||||
if retry_after is None: retry_after = min(30.0, 1.5 * (2 ** attempt))
|
||||
return max(0.5, float(retry_after))
|
||||
|
||||
@@ -312,7 +314,7 @@ class LLMSession:
|
||||
content += chunk; yield chunk
|
||||
if not content.startswith("Error:"):
|
||||
self.raw_msgs.append({"role": "assistant", "prompt": content, "image": None})
|
||||
if total_len > 5000: print(f"[Debug] Whole context length {total_len} {str(msg_lens)}.")
|
||||
if total_len > self.context_win // 2: print(f"[Debug] Whole context length {total_len} {str(msg_lens)}.")
|
||||
if total_len > self.context_win:
|
||||
yield '[NextWillSummary]'
|
||||
threading.Thread(target=self.summary_history, daemon=True).start()
|
||||
@@ -321,11 +323,12 @@ class LLMSession:
|
||||
|
||||
|
||||
class GeminiSession:
|
||||
def __init__(self, api_key=None, default_model="gemini-2.0-flash-001", proxy=proxy):
|
||||
self.api_key = api_key or google_api_key
|
||||
def __init__(self, cfg):
|
||||
self.api_key = cfg.get('apikey') or google_api_key
|
||||
if not self.api_key: raise ValueError("google_api_key 未配置或为空,请在 mykey.py 中设置")
|
||||
self.default_model = default_model
|
||||
self.proxies = {"http":proxy, "https":proxy} if proxy else None
|
||||
self.default_model = cfg.get('model', 'gemini-2.0-flash-001')
|
||||
p = cfg.get('proxy', proxy)
|
||||
self.proxies = {"http":p, "https":p} if p else None
|
||||
def ask(self, prompt, model=None, stream=False):
|
||||
if model is None: model = self.default_model
|
||||
url = f"https://generativelanguage.googleapis.com/v1/models/{model}:generateContent?key={self.api_key}"
|
||||
@@ -349,13 +352,14 @@ class GeminiSession:
|
||||
return iter([full_text]) if stream else full_text
|
||||
|
||||
class XaiSession:
|
||||
def __init__(self, api_key, proxy="http://127.0.0.1:2082", default_model="grok-4-1-fast-non-reasoning"):
|
||||
def __init__(self, cfg):
|
||||
import xai_sdk
|
||||
from xai_sdk.chat import user, system
|
||||
self._user, self._system = user, system
|
||||
self.default_model = default_model
|
||||
self.default_model = cfg.get('model', 'grok-4-1-fast-non-reasoning')
|
||||
self._last_response_id = None # 多轮对话链
|
||||
os.environ["XAI_API_KEY"] = api_key
|
||||
os.environ["XAI_API_KEY"] = cfg['apikey']
|
||||
proxy = cfg.get('proxy', 'http://127.0.0.1:2082')
|
||||
if not proxy.startswith("http"): proxy = f"http://{proxy}"
|
||||
os.environ.setdefault("grpc_proxy", proxy)
|
||||
self._client = xai_sdk.Client()
|
||||
@@ -598,10 +602,7 @@ if __name__ == "__main__":
|
||||
}
|
||||
google_api_key = mykeys.get("google_api_key")
|
||||
cfg = oai_configs.get("oai_config")
|
||||
|
||||
llmclient = ToolClient(GeminiSession(api_key=google_api_key, proxy='127.0.0.1:2082').ask)
|
||||
#llmclient = ToolClient(LLMSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model']).ask)
|
||||
#llmclient = ToolClient(SiderLLMSession().ask)
|
||||
llmclient = ToolClient(LLMSession(cfg))
|
||||
def get_final(gen):
|
||||
try:
|
||||
while True: print('mid:', next(gen))
|
||||
|
||||
Reference in New Issue
Block a user