feat: add ClaudeSession for native Claude Messages API support
This commit is contained in:
@@ -4,7 +4,7 @@ if sys.stdout is None: sys.stdout = open(os.devnull, "w")
|
|||||||
if sys.stderr is None: sys.stderr = open(os.devnull, "w")
|
if sys.stderr is None: sys.stderr = open(os.devnull, "w")
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
from sidercall import SiderLLMSession, LLMSession, ToolClient
|
from sidercall import SiderLLMSession, LLMSession, ToolClient, ClaudeSession
|
||||||
from agent_loop import agent_runner_loop, StepOutcome, BaseHandler
|
from agent_loop import agent_runner_loop, StepOutcome, BaseHandler
|
||||||
from ga import GenericAgentHandler, smart_format, get_global_memory, format_error
|
from ga import GenericAgentHandler, smart_format, get_global_memory, format_error
|
||||||
|
|
||||||
@@ -28,12 +28,14 @@ def get_system_prompt():
|
|||||||
class GeneraticAgent:
|
class GeneraticAgent:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if not os.path.exists('temp'): os.makedirs('temp')
|
if not os.path.exists('temp'): os.makedirs('temp')
|
||||||
from sidercall import sider_cookie, oai_configs
|
from sidercall import sider_cookie, oai_configs, claude_configs
|
||||||
llm_sessions = []
|
llm_sessions = []
|
||||||
if sider_cookie: llm_sessions += [SiderLLMSession(default_model=x) for x in \
|
if sider_cookie: llm_sessions += [SiderLLMSession(default_model=x) for x in \
|
||||||
["gemini-3.0-flash", "claude-haiku-4.5", "kimi-k2"]]
|
["gemini-3.0-flash", "claude-haiku-4.5", "kimi-k2"]]
|
||||||
for cfg in oai_configs.values():
|
for cfg in oai_configs.values():
|
||||||
llm_sessions += [LLMSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model'])]
|
llm_sessions += [LLMSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model'])]
|
||||||
|
for cfg in claude_configs.values():
|
||||||
|
llm_sessions += [ClaudeSession(api_key=cfg['apikey'], api_base=cfg['apibase'], model=cfg['model'])]
|
||||||
if len(llm_sessions) > 0:
|
if len(llm_sessions) > 0:
|
||||||
llmclient = ToolClient(llm_sessions, auto_save_tokens=True)
|
llmclient = ToolClient(llm_sessions, auto_save_tokens=True)
|
||||||
self.llmclient = llmclient
|
self.llmclient = llmclient
|
||||||
|
|||||||
@@ -15,3 +15,10 @@ oai_config2 = {
|
|||||||
'apibase':"http://133.145.139.147:3001/v1",
|
'apibase':"http://133.145.139.147:3001/v1",
|
||||||
'model':"claude-opus-4-6-20260206"
|
'model':"claude-opus-4-6-20260206"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
claude_config = {
|
||||||
|
'apikey':'klURcj...',
|
||||||
|
'apibase':"http://233.145.139.147:3001/",
|
||||||
|
'model':"claude-opus"
|
||||||
|
}
|
||||||
|
|||||||
56
sidercall.py
56
sidercall.py
@@ -1,13 +1,13 @@
|
|||||||
import os, json, re, time, requests, sys, threading
|
import os, json, re, time, requests, sys, threading, urllib3
|
||||||
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
try: import mykey
|
try: import mykey
|
||||||
except: raise Exception('[ERROR] mykey.py not found, please copy mykey_template.py to mykey.py and fill your LLM backend.')
|
except: raise Exception('[ERROR] mykey.py not found, please copy mykey_template.py to mykey.py and fill your LLM backend.')
|
||||||
|
|
||||||
mykeys = vars(mykey)
|
mykeys = vars(mykey)
|
||||||
sider_cookie = mykeys.get("sider_cookie")
|
sider_cookie = mykeys.get("sider_cookie")
|
||||||
oai_configs = {
|
oai_configs = {k: v for k, v in vars(mykey).items() if k.startswith("oai_config") and v}
|
||||||
k: v for k, v in vars(mykey).items() if k.startswith("oai_config") and v
|
claude_configs = {k: v for k, v in vars(mykey).items() if k.startswith("claude_config") and v}
|
||||||
}
|
|
||||||
google_api_key = mykeys.get("google_api_key")
|
google_api_key = mykeys.get("google_api_key")
|
||||||
|
|
||||||
proxy = mykeys.get("proxy", 'http://127.0.0.1:2082')
|
proxy = mykeys.get("proxy", 'http://127.0.0.1:2082')
|
||||||
@@ -55,6 +55,54 @@ class GeminiSession:
|
|||||||
return f"[GeminiError] invalid response format: {e}"
|
return f"[GeminiError] invalid response format: {e}"
|
||||||
return iter([full_text]) if stream else full_text
|
return iter([full_text]) if stream else full_text
|
||||||
|
|
||||||
|
class ClaudeSession:
|
||||||
|
def __init__(self, api_key, api_base, model="claude-opus", context_win=32000):
|
||||||
|
self.api_key, self.api_base, self.default_model, self.context_win = api_key, api_base.rstrip('/'), model, context_win
|
||||||
|
self.raw_msgs, self.lock = [], threading.Lock()
|
||||||
|
def _trim_messages(self, messages):
|
||||||
|
total = sum(len(m['prompt'])//4 for m in messages)
|
||||||
|
if total <= self.context_win: return messages
|
||||||
|
trimmed = []
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if sum(len(m['prompt'])//4 for m in trimmed) + len(msg['prompt'])//4 <= self.context_win * 0.9:
|
||||||
|
trimmed.insert(0, msg)
|
||||||
|
else: break
|
||||||
|
return trimmed if trimmed else messages[-2:]
|
||||||
|
def raw_ask(self, messages, model=None, temperature=0.5, max_tokens=4096):
|
||||||
|
model = model or self.default_model
|
||||||
|
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}
|
||||||
|
payload = {"model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "stream": True}
|
||||||
|
try:
|
||||||
|
with requests.post(f"{self.api_base}/v1/messages", headers=headers, json=payload, stream=True, timeout=(5,60), verify=False) as r:
|
||||||
|
r.raise_for_status()
|
||||||
|
for line in r.iter_lines():
|
||||||
|
if not line: continue
|
||||||
|
line = line.decode("utf-8") if isinstance(line, bytes) else line
|
||||||
|
if not line.startswith("data:"): continue
|
||||||
|
data = line[5:].lstrip()
|
||||||
|
if data == "[DONE]": break
|
||||||
|
try:
|
||||||
|
obj = json.loads(data)
|
||||||
|
if obj.get("type") == "content_block_delta" and obj.get("delta", {}).get("type") == "text_delta":
|
||||||
|
text = obj["delta"].get("text", "")
|
||||||
|
if text: yield text
|
||||||
|
except: pass
|
||||||
|
except Exception as e: yield f"Error: {str(e)}"
|
||||||
|
def make_messages(self, raw_list):
|
||||||
|
trimmed = self._trim_messages(raw_list)
|
||||||
|
return [{"role": m['role'], "content": m['prompt']} for m in trimmed]
|
||||||
|
def ask(self, prompt, model=None, stream=False):
|
||||||
|
def _ask_gen():
|
||||||
|
content = ''
|
||||||
|
with self.lock:
|
||||||
|
self.raw_msgs.append({"role": "user", "prompt": prompt})
|
||||||
|
messages = self.make_messages(self.raw_msgs)
|
||||||
|
for chunk in self.raw_ask(messages, model):
|
||||||
|
content += chunk; yield chunk
|
||||||
|
if not content.startswith("Error:"):
|
||||||
|
self.raw_msgs.append({"role": "assistant", "prompt": content})
|
||||||
|
return _ask_gen() if stream else ''.join(list(_ask_gen()))
|
||||||
|
|
||||||
class LLMSession:
|
class LLMSession:
|
||||||
def __init__(self, api_key, api_base, model, context_win=16000):
|
def __init__(self, api_key, api_base, model, context_win=16000):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
|||||||
Reference in New Issue
Block a user