增加了对火山引擎平台的模型适配,避免路径错误
This commit is contained in:
@@ -37,7 +37,7 @@ class SiderLLMSession:
|
|||||||
return full_text
|
return full_text
|
||||||
|
|
||||||
class ClaudeSession:
|
class ClaudeSession:
|
||||||
def __init__(self, api_key, api_base, model="claude-opus", context_win=10000):
|
def __init__(self, api_key, api_base, model="claude-opus", context_win=9000):
|
||||||
self.api_key, self.api_base, self.default_model, self.context_win = api_key, api_base.rstrip('/'), model, context_win
|
self.api_key, self.api_base, self.default_model, self.context_win = api_key, api_base.rstrip('/'), model, context_win
|
||||||
self.raw_msgs, self.lock = [], threading.Lock()
|
self.raw_msgs, self.lock = [], threading.Lock()
|
||||||
def _trim_messages(self, messages):
|
def _trim_messages(self, messages):
|
||||||
@@ -51,7 +51,7 @@ class ClaudeSession:
|
|||||||
else: break
|
else: break
|
||||||
if current > self.context_win * 3.6: print(f'[DEBUG] {len(result)} contexts, whole length {current//4} tokens.')
|
if current > self.context_win * 3.6: print(f'[DEBUG] {len(result)} contexts, whole length {current//4} tokens.')
|
||||||
return result[::-1] or messages[-2:]
|
return result[::-1] or messages[-2:]
|
||||||
def raw_ask(self, messages, model=None, temperature=0.5, max_tokens=6144):
|
def raw_ask(self, messages, model=None, temperature=0.5, max_tokens=4096):
|
||||||
model = model or self.default_model
|
model = model or self.default_model
|
||||||
headers = {"x-api-key": self.api_key, "Content-Type": "application/json", "anthropic-version": "2023-06-01"}
|
headers = {"x-api-key": self.api_key, "Content-Type": "application/json", "anthropic-version": "2023-06-01"}
|
||||||
payload = {"model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "stream": True}
|
payload = {"model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "stream": True}
|
||||||
@@ -100,6 +100,9 @@ class LLMSession:
|
|||||||
else: self.api_mode = "chat_completions"
|
else: self.api_mode = "chat_completions"
|
||||||
|
|
||||||
def _endpoint(self, path):
|
def _endpoint(self, path):
|
||||||
|
# 处理火山引擎API,它已经包含完整路径
|
||||||
|
if 'ark.cn-beijing.volces.com' in self.api_base or 'ark.cn-shanghai.volces.com' in self.api_base:
|
||||||
|
return f"{self.api_base}/{path.lstrip('/')}"
|
||||||
if self.api_base.endswith('/v1'): return f"{self.api_base}/{path.lstrip('/')}"
|
if self.api_base.endswith('/v1'): return f"{self.api_base}/{path.lstrip('/')}"
|
||||||
if self.api_base.endswith('$'): return f"{self.api_base.rstrip('$')}/{path.lstrip('/')}"
|
if self.api_base.endswith('$'): return f"{self.api_base.rstrip('$')}/{path.lstrip('/')}"
|
||||||
return f"{self.api_base}/v1/{path.lstrip('/')}"
|
return f"{self.api_base}/v1/{path.lstrip('/')}"
|
||||||
|
|||||||
Reference in New Issue
Block a user