feat: add continue/new support to chat frontends
This commit is contained in:
@@ -2,10 +2,12 @@
|
||||
Pure functions + one `install(cls)` monkey-patch entry. No side effects at import.
|
||||
"""
|
||||
import ast, glob, json, os, re, time
|
||||
_LOG_GLOB = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
'temp', 'model_responses', 'model_responses_*.txt')
|
||||
_LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
'temp', 'model_responses')
|
||||
_LOG_GLOB = os.path.join(_LOG_DIR, 'model_responses_*.txt')
|
||||
_BLOCK_RE = re.compile(r'^=== (Prompt|Response) ===.*?\n(.*?)(?=^=== (?:Prompt|Response) ===|\Z)',
|
||||
re.DOTALL | re.MULTILINE)
|
||||
_SUMMARY_RE = re.compile(r'<summary>\s*(.*?)\s*</summary>', re.DOTALL)
|
||||
|
||||
def _rel_time(mtime):
|
||||
d = int(time.time() - mtime)
|
||||
@@ -38,6 +40,32 @@ def _first_user(pairs):
|
||||
if s and not s.startswith('###'): return s
|
||||
return ''
|
||||
|
||||
|
||||
def _last_summary(pairs):
|
||||
for _, response_body in reversed(pairs):
|
||||
try:
|
||||
blocks = ast.literal_eval(response_body)
|
||||
except Exception:
|
||||
continue
|
||||
if not isinstance(blocks, list):
|
||||
continue
|
||||
text_parts = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
text = block.get('text', '')
|
||||
if isinstance(text, str) and text:
|
||||
text_parts.append(text)
|
||||
match = _SUMMARY_RE.search('\n'.join(text_parts))
|
||||
if match:
|
||||
summary = match.group(1).strip()
|
||||
if summary:
|
||||
return summary
|
||||
return ''
|
||||
|
||||
|
||||
def _preview_text(pairs):
|
||||
return _last_summary(pairs) or _first_user(pairs)
|
||||
|
||||
def _parse_native_history(pairs):
|
||||
history = []
|
||||
for p, r in pairs:
|
||||
@@ -60,16 +88,82 @@ def list_sessions(exclude_pid=None):
|
||||
out = []
|
||||
for f in files:
|
||||
try:
|
||||
content = open(f, encoding='utf-8', errors='replace').read()
|
||||
with open(f, encoding='utf-8', errors='replace') as fh:
|
||||
content = fh.read()
|
||||
except Exception: continue
|
||||
pairs = _pairs(content)
|
||||
if not pairs: continue
|
||||
out.append((f, os.path.getmtime(f), _first_user(pairs), len(pairs)))
|
||||
out.append((f, os.path.getmtime(f), _preview_text(pairs), len(pairs)))
|
||||
out.sort(key=lambda x: x[1], reverse=True)
|
||||
return out
|
||||
_MD_ESCAPE_RE = re.compile(r'([\\`*_\[\]])')
|
||||
def _escape_md(s): return _MD_ESCAPE_RE.sub(r'\\\1', s)
|
||||
|
||||
|
||||
def _agent_clients(agent):
|
||||
clients = []
|
||||
for client in getattr(agent, 'llmclients', []) or []:
|
||||
if client not in clients:
|
||||
clients.append(client)
|
||||
current = getattr(agent, 'llmclient', None)
|
||||
if current is not None and current not in clients:
|
||||
clients.insert(0, current)
|
||||
return clients
|
||||
|
||||
|
||||
def _replace_backend_history(agent, history):
|
||||
backend = getattr(getattr(agent, 'llmclient', None), 'backend', None)
|
||||
if backend is not None and hasattr(backend, 'history'):
|
||||
backend.history = list(history or [])
|
||||
|
||||
|
||||
def _current_log_path(pid=None):
|
||||
pid = os.getpid() if pid is None else pid
|
||||
return os.path.join(_LOG_DIR, f'model_responses_{pid}.txt')
|
||||
|
||||
|
||||
def _snapshot_current_log(pid=None):
|
||||
"""Persist current PID log as a standalone recoverable snapshot, then clear it."""
|
||||
path = _current_log_path(pid)
|
||||
if not os.path.isfile(path):
|
||||
return None
|
||||
try:
|
||||
with open(path, encoding='utf-8', errors='replace') as fh:
|
||||
content = fh.read()
|
||||
except Exception:
|
||||
return None
|
||||
if not _pairs(content):
|
||||
return None
|
||||
os.makedirs(_LOG_DIR, exist_ok=True)
|
||||
pid = os.getpid() if pid is None else pid
|
||||
stamp = time.strftime('%Y%m%d_%H%M%S')
|
||||
snapshot = os.path.join(_LOG_DIR, f'model_responses_snapshot_{pid}_{stamp}_{time.time_ns() % 1_000_000_000:09d}.txt')
|
||||
with open(snapshot, 'w', encoding='utf-8', errors='replace') as fh:
|
||||
fh.write(content)
|
||||
with open(path, 'w', encoding='utf-8', errors='replace'):
|
||||
pass
|
||||
return snapshot
|
||||
|
||||
|
||||
def reset_conversation(agent, message='🆕 已开启新对话,当前上下文已清空'):
|
||||
"""Abort current work and clear all known frontend-visible conversation state."""
|
||||
try:
|
||||
agent.abort()
|
||||
except Exception:
|
||||
pass
|
||||
_snapshot_current_log()
|
||||
if hasattr(agent, 'history'):
|
||||
agent.history = []
|
||||
for client in _agent_clients(agent):
|
||||
backend = getattr(client, 'backend', None)
|
||||
if backend is not None and hasattr(backend, 'history'):
|
||||
backend.history = []
|
||||
if hasattr(client, 'last_tools'):
|
||||
client.last_tools = ''
|
||||
if hasattr(agent, 'handler'):
|
||||
agent.handler = None
|
||||
return message
|
||||
|
||||
def format_list(sessions, limit=20):
|
||||
if not sessions: return '❌ 没有可恢复的历史会话'
|
||||
lines = ['**可恢复会话**(输入 `/continue N` 恢复第 N 个):', '']
|
||||
@@ -80,7 +174,9 @@ def format_list(sessions, limit=20):
|
||||
|
||||
def restore(agent, path):
|
||||
"""Restore session at path. Returns (msg, is_full)."""
|
||||
try: content = open(path, encoding='utf-8', errors='replace').read()
|
||||
try:
|
||||
with open(path, encoding='utf-8', errors='replace') as fh:
|
||||
content = fh.read()
|
||||
except Exception as e: return f'❌ 读取失败: {e}', False
|
||||
pairs = _pairs(content)
|
||||
if not pairs: return f'❌ {os.path.basename(path)} 为空或格式不符', False
|
||||
@@ -88,7 +184,7 @@ def restore(agent, path):
|
||||
name = os.path.basename(path)
|
||||
if history is not None:
|
||||
agent.abort()
|
||||
agent.llmclient.backend.history = history
|
||||
_replace_backend_history(agent, history)
|
||||
return f'✅ 已恢复 {len(pairs)} 轮完整对话({name})\n(已写入 backend.history,可直接继续)', True
|
||||
from chatapp_common import _restore_native_history, _restore_text_pairs
|
||||
summary = _restore_text_pairs(content) or _restore_native_history(content)
|
||||
@@ -111,10 +207,31 @@ def handle(agent, query, display_queue):
|
||||
if not (0 <= idx < len(sessions)):
|
||||
display_queue.put({'done': f'❌ 索引越界(有效范围 1-{len(sessions)})', 'source': 'system'})
|
||||
return None
|
||||
reset_conversation(agent, message=None)
|
||||
msg, _ = restore(agent, sessions[idx][0])
|
||||
display_queue.put({'done': msg, 'source': 'system'})
|
||||
return None
|
||||
return query
|
||||
|
||||
|
||||
def handle_frontend_command(agent, query, exclude_pid=None):
|
||||
"""Frontend-friendly /continue entry that returns text directly."""
|
||||
s = (query or '').strip()
|
||||
exclude_pid = os.getpid() if exclude_pid is None else exclude_pid
|
||||
if s == '/continue':
|
||||
return format_list(list_sessions(exclude_pid=exclude_pid))
|
||||
m = re.match(r'/continue\s+(\d+)\s*$', s)
|
||||
if not m:
|
||||
return '用法: /continue 或 /continue N'
|
||||
sessions = list_sessions(exclude_pid=exclude_pid)
|
||||
idx = int(m.group(1)) - 1
|
||||
if not (0 <= idx < len(sessions)):
|
||||
return f'❌ 索引越界(有效范围 1-{len(sessions)})'
|
||||
reset_conversation(agent, message=None)
|
||||
msg, _ = restore(agent, sessions[idx][0])
|
||||
return msg
|
||||
|
||||
|
||||
def install(cls):
|
||||
"""Wrap cls._handle_slash_cmd so /continue is handled before original dispatch."""
|
||||
orig = cls._handle_slash_cmd
|
||||
|
||||
Reference in New Issue
Block a user