refactor: unify LLM session constructors to accept cfg dict
This commit is contained in:
23
agentmain.py
23
agentmain.py
@@ -44,16 +44,10 @@ class GeneraticAgent:
|
|||||||
for k, cfg in mykeys.items():
|
for k, cfg in mykeys.items():
|
||||||
if not any(x in k for x in ['api', 'config', 'cookie']): continue
|
if not any(x in k for x in ['api', 'config', 'cookie']): continue
|
||||||
try:
|
try:
|
||||||
if 'claude' in k: llm_sessions += [ClaudeSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model'])]
|
if 'claude' in k: llm_sessions += [ClaudeSession(cfg=cfg)]
|
||||||
if 'oai' in k: llm_sessions += [LLMSession(
|
if 'oai' in k: llm_sessions += [LLMSession(cfg=cfg)]
|
||||||
api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model'], proxy=cfg.get('proxy'),
|
if 'xai' in k: llm_sessions += [XaiSession(cfg=cfg)]
|
||||||
api_mode=cfg.get('api_mode', 'chat_completions'),
|
if 'sider' in k: llm_sessions += [SiderLLMSession(cfg={'apikey': cfg, 'model': x}) for x in \
|
||||||
max_retries=cfg.get('max_retries', 2),
|
|
||||||
connect_timeout=cfg.get('connect_timeout', 10),
|
|
||||||
read_timeout=cfg.get('read_timeout', 120),
|
|
||||||
)]
|
|
||||||
if 'xai' in k: llm_sessions += [XaiSession(cfg, mykeys.get('proxy', ''))]
|
|
||||||
if 'sider' in k: llm_sessions += [SiderLLMSession(cfg, default_model=x) for x in \
|
|
||||||
["gemini-3.0-flash", "gpt-5.4"]]
|
["gemini-3.0-flash", "gpt-5.4"]]
|
||||||
except: pass
|
except: pass
|
||||||
if len(llm_sessions) > 0: self.llmclient = ToolClient(llm_sessions, auto_save_tokens=True)
|
if len(llm_sessions) > 0: self.llmclient = ToolClient(llm_sessions, auto_save_tokens=True)
|
||||||
@@ -62,10 +56,8 @@ class GeneraticAgent:
|
|||||||
self.history = []
|
self.history = []
|
||||||
self.task_queue = queue.Queue()
|
self.task_queue = queue.Queue()
|
||||||
self.is_running, self.stop_sig = False, False
|
self.is_running, self.stop_sig = False, False
|
||||||
self.llm_no = 0
|
self.llm_no = 0; self.inc_out = False
|
||||||
self.inc_out = False
|
self.handler = None; self.verbose = True
|
||||||
self.handler = None
|
|
||||||
self.verbose = True
|
|
||||||
|
|
||||||
def next_llm(self, n=-1):
|
def next_llm(self, n=-1):
|
||||||
self.llm_no = ((self.llm_no + 1) if n < 0 else n) % len(self.llmclient.backends)
|
self.llm_no = ((self.llm_no + 1) if n < 0 else n) % len(self.llmclient.backends)
|
||||||
@@ -79,8 +71,7 @@ class GeneraticAgent:
|
|||||||
print('Abort current task...')
|
print('Abort current task...')
|
||||||
if not self.is_running: return
|
if not self.is_running: return
|
||||||
self.stop_sig = True
|
self.stop_sig = True
|
||||||
if self.handler is not None:
|
if self.handler is not None: self.handler.code_stop_signal.append(1)
|
||||||
self.handler.code_stop_signal.append(1)
|
|
||||||
|
|
||||||
def put_task(self, query, source="user", images=None):
|
def put_task(self, query, source="user", images=None):
|
||||||
display_queue = queue.Queue()
|
display_queue = queue.Queue()
|
||||||
|
|||||||
63
llmcore.py
63
llmcore.py
@@ -54,10 +54,10 @@ def build_multimodal_content(prompt_text, image_paths):
|
|||||||
return parts
|
return parts
|
||||||
|
|
||||||
class SiderLLMSession:
|
class SiderLLMSession:
|
||||||
def __init__(self, sider_cookie, default_model="gemini-3.0-flash"):
|
def __init__(self, cfg):
|
||||||
from sider_ai_api import Session # 不使用sider的话没必要安装这个包
|
from sider_ai_api import Session # 不使用sider的话没必要安装这个包
|
||||||
self._core = Session(cookie=sider_cookie, proxies=proxies)
|
self._core = Session(cookie=cfg['apikey'], proxies=proxies)
|
||||||
self.default_model = default_model
|
self.default_model = cfg.get('model', 'gemini-3.0-flash')
|
||||||
def ask(self, prompt, model=None, stream=False):
|
def ask(self, prompt, model=None, stream=False):
|
||||||
if model is None: model = self.default_model
|
if model is None: model = self.default_model
|
||||||
if len(prompt) > 28000:
|
if len(prompt) > 28000:
|
||||||
@@ -68,8 +68,10 @@ class SiderLLMSession:
|
|||||||
return full_text
|
return full_text
|
||||||
|
|
||||||
class ClaudeSession:
|
class ClaudeSession:
|
||||||
def __init__(self, api_key, api_base, model="claude-opus", context_win=12000):
|
def __init__(self, cfg):
|
||||||
self.api_key, self.api_base, self.default_model, self.context_win = api_key, api_base.rstrip('/'), model, context_win
|
self.api_key = cfg['apikey']; self.api_base = cfg['apibase'].rstrip('/')
|
||||||
|
self.default_model = cfg.get('model', 'claude-opus')
|
||||||
|
self.context_win = cfg.get('context_win', 12000)
|
||||||
self.raw_msgs, self.lock = [], threading.Lock()
|
self.raw_msgs, self.lock = [], threading.Lock()
|
||||||
def _trim_messages(self, messages):
|
def _trim_messages(self, messages):
|
||||||
compress_history_tags(messages)
|
compress_history_tags(messages)
|
||||||
@@ -118,28 +120,28 @@ class ClaudeSession:
|
|||||||
return _ask_gen() if stream else ''.join(list(_ask_gen()))
|
return _ask_gen() if stream else ''.join(list(_ask_gen()))
|
||||||
|
|
||||||
class LLMSession:
|
class LLMSession:
|
||||||
def __init__(self, api_key, api_base, model, context_win=16000, proxy=None, api_mode="chat_completions",
|
def __init__(self, cfg):
|
||||||
max_retries=2, connect_timeout=10, read_timeout=120):
|
self.api_key = cfg['apikey']; self.api_base = cfg['apibase'].rstrip('/')
|
||||||
self.api_key = api_key; self.api_base = api_base.rstrip('/'); self.default_model = model
|
self.default_model = cfg['model']
|
||||||
self.context_win = context_win; self.raw_msgs = []; self.messages = []
|
self.context_win = cfg.get('context_win', 16000)
|
||||||
|
self.raw_msgs, self.messages = [], []
|
||||||
|
proxy = cfg.get('proxy')
|
||||||
self.proxies = {"http": proxy, "https": proxy} if proxy else None
|
self.proxies = {"http": proxy, "https": proxy} if proxy else None
|
||||||
|
self.prompt_cache = cfg.get('prompt_cache', False)
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
self.max_retries = max(0, int(max_retries))
|
self.max_retries = max(0, int(cfg.get('max_retries', 2)))
|
||||||
self.connect_timeout = max(1, int(connect_timeout))
|
self.connect_timeout = max(1, int(cfg.get('connect_timeout', 10)))
|
||||||
self.read_timeout = max(5, int(read_timeout))
|
self.read_timeout = max(5, int(cfg.get('read_timeout', 120)))
|
||||||
mode = str(api_mode or "chat_completions").strip().lower().replace('-', '_')
|
mode = str(cfg.get('api_mode', 'chat_completions')).strip().lower().replace('-', '_')
|
||||||
if mode in ["responses", "response"]: self.api_mode = "responses"
|
if mode in ["responses", "response"]: self.api_mode = "responses"
|
||||||
else: self.api_mode = "chat_completions"
|
else: self.api_mode = "chat_completions"
|
||||||
|
|
||||||
def _retry_delay(self, resp, attempt):
|
def _retry_delay(self, resp, attempt):
|
||||||
retry_after = None
|
retry_after = None
|
||||||
try:
|
try:
|
||||||
if resp is not None:
|
if resp is not None: retry_after = (resp.headers or {}).get("retry-after")
|
||||||
retry_after = (resp.headers or {}).get("retry-after")
|
if retry_after is not None: retry_after = float(retry_after)
|
||||||
if retry_after is not None:
|
except: retry_after = None
|
||||||
retry_after = float(retry_after)
|
|
||||||
except:
|
|
||||||
retry_after = None
|
|
||||||
if retry_after is None: retry_after = min(30.0, 1.5 * (2 ** attempt))
|
if retry_after is None: retry_after = min(30.0, 1.5 * (2 ** attempt))
|
||||||
return max(0.5, float(retry_after))
|
return max(0.5, float(retry_after))
|
||||||
|
|
||||||
@@ -312,7 +314,7 @@ class LLMSession:
|
|||||||
content += chunk; yield chunk
|
content += chunk; yield chunk
|
||||||
if not content.startswith("Error:"):
|
if not content.startswith("Error:"):
|
||||||
self.raw_msgs.append({"role": "assistant", "prompt": content, "image": None})
|
self.raw_msgs.append({"role": "assistant", "prompt": content, "image": None})
|
||||||
if total_len > 5000: print(f"[Debug] Whole context length {total_len} {str(msg_lens)}.")
|
if total_len > self.context_win // 2: print(f"[Debug] Whole context length {total_len} {str(msg_lens)}.")
|
||||||
if total_len > self.context_win:
|
if total_len > self.context_win:
|
||||||
yield '[NextWillSummary]'
|
yield '[NextWillSummary]'
|
||||||
threading.Thread(target=self.summary_history, daemon=True).start()
|
threading.Thread(target=self.summary_history, daemon=True).start()
|
||||||
@@ -321,11 +323,12 @@ class LLMSession:
|
|||||||
|
|
||||||
|
|
||||||
class GeminiSession:
|
class GeminiSession:
|
||||||
def __init__(self, api_key=None, default_model="gemini-2.0-flash-001", proxy=proxy):
|
def __init__(self, cfg):
|
||||||
self.api_key = api_key or google_api_key
|
self.api_key = cfg.get('apikey') or google_api_key
|
||||||
if not self.api_key: raise ValueError("google_api_key 未配置或为空,请在 mykey.py 中设置")
|
if not self.api_key: raise ValueError("google_api_key 未配置或为空,请在 mykey.py 中设置")
|
||||||
self.default_model = default_model
|
self.default_model = cfg.get('model', 'gemini-2.0-flash-001')
|
||||||
self.proxies = {"http":proxy, "https":proxy} if proxy else None
|
p = cfg.get('proxy', proxy)
|
||||||
|
self.proxies = {"http":p, "https":p} if p else None
|
||||||
def ask(self, prompt, model=None, stream=False):
|
def ask(self, prompt, model=None, stream=False):
|
||||||
if model is None: model = self.default_model
|
if model is None: model = self.default_model
|
||||||
url = f"https://generativelanguage.googleapis.com/v1/models/{model}:generateContent?key={self.api_key}"
|
url = f"https://generativelanguage.googleapis.com/v1/models/{model}:generateContent?key={self.api_key}"
|
||||||
@@ -349,13 +352,14 @@ class GeminiSession:
|
|||||||
return iter([full_text]) if stream else full_text
|
return iter([full_text]) if stream else full_text
|
||||||
|
|
||||||
class XaiSession:
|
class XaiSession:
|
||||||
def __init__(self, api_key, proxy="http://127.0.0.1:2082", default_model="grok-4-1-fast-non-reasoning"):
|
def __init__(self, cfg):
|
||||||
import xai_sdk
|
import xai_sdk
|
||||||
from xai_sdk.chat import user, system
|
from xai_sdk.chat import user, system
|
||||||
self._user, self._system = user, system
|
self._user, self._system = user, system
|
||||||
self.default_model = default_model
|
self.default_model = cfg.get('model', 'grok-4-1-fast-non-reasoning')
|
||||||
self._last_response_id = None # 多轮对话链
|
self._last_response_id = None # 多轮对话链
|
||||||
os.environ["XAI_API_KEY"] = api_key
|
os.environ["XAI_API_KEY"] = cfg['apikey']
|
||||||
|
proxy = cfg.get('proxy', 'http://127.0.0.1:2082')
|
||||||
if not proxy.startswith("http"): proxy = f"http://{proxy}"
|
if not proxy.startswith("http"): proxy = f"http://{proxy}"
|
||||||
os.environ.setdefault("grpc_proxy", proxy)
|
os.environ.setdefault("grpc_proxy", proxy)
|
||||||
self._client = xai_sdk.Client()
|
self._client = xai_sdk.Client()
|
||||||
@@ -598,10 +602,7 @@ if __name__ == "__main__":
|
|||||||
}
|
}
|
||||||
google_api_key = mykeys.get("google_api_key")
|
google_api_key = mykeys.get("google_api_key")
|
||||||
cfg = oai_configs.get("oai_config")
|
cfg = oai_configs.get("oai_config")
|
||||||
|
llmclient = ToolClient(LLMSession(cfg))
|
||||||
llmclient = ToolClient(GeminiSession(api_key=google_api_key, proxy='127.0.0.1:2082').ask)
|
|
||||||
#llmclient = ToolClient(LLMSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model']).ask)
|
|
||||||
#llmclient = ToolClient(SiderLLMSession().ask)
|
|
||||||
def get_final(gen):
|
def get_final(gen):
|
||||||
try:
|
try:
|
||||||
while True: print('mid:', next(gen))
|
while True: print('mid:', next(gen))
|
||||||
|
|||||||
Reference in New Issue
Block a user