Files
GenericAgent/frontends/continue_cmd.py
weijia fd4c833511 feat(stapp): restore chat history bubbles after /continue N (#138)
When the user runs '/continue N' in stapp, the agent's in-memory context
is restored, but the UI previously showed only a single ' restored' line
— all prior chat bubbles were missing.

This change parses the target session log and reconstructs the
user/assistant message pairs into st.session_state.messages, so reopening
a session feels like the conversation was never interrupted.

* continue_cmd.py: add extract_ui_messages(path)
  - parses model_responses log into [{role, content}, ...]
  - groups multi-turn LLM calls (prompts whose text starts with the
    '### [WORKING MEMORY]' header) into a single assistant bubble,
    inserting the existing '**LLM Running (Turn N) ...**' marker so
    fold_turns() renders them as collapsible segments.
  - two small helpers (_user_text / _assistant_text) keep parsing local.

* stapp.py: in the /continue branch, resolve the target log path BEFORE
  calling handle_frontend_command (which snapshots the current log and
  would otherwise shift list_sessions indices), then replace
  session_state.messages with the reconstructed history on success.
  Falls back to the previous behavior for bare /continue or failure.

Co-authored-by: wjl2023 <wjl2023@users.noreply.github.com>
2026-04-23 16:36:41 +08:00

297 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""`/continue` command: list & restore past model_responses sessions.
Pure functions + one `install(cls)` monkey-patch entry. No side effects at import.
"""
import ast, glob, json, os, re, time
_LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
'temp', 'model_responses')
_LOG_GLOB = os.path.join(_LOG_DIR, 'model_responses_*.txt')
_BLOCK_RE = re.compile(r'^=== (Prompt|Response) ===.*?\n(.*?)(?=^=== (?:Prompt|Response) ===|\Z)',
re.DOTALL | re.MULTILINE)
_SUMMARY_RE = re.compile(r'<summary>\s*(.*?)\s*</summary>', re.DOTALL)
def _rel_time(mtime):
d = int(time.time() - mtime)
if d < 60: return f'{d}秒前'
if d < 3600: return f'{d // 60}分前'
if d < 86400: return f'{d // 3600}小时前'
return f'{d // 86400}天前'
def _pairs(content):
blocks, pairs, pending = _BLOCK_RE.findall(content or ''), [], None
for label, body in blocks:
if label == 'Prompt': pending = body.strip()
elif pending is not None:
pairs.append((pending, body.strip())); pending = None
return pairs
def _first_user(pairs):
for p, _ in pairs:
try: msg = json.loads(p)
except Exception: continue
if not isinstance(msg, dict): continue
for blk in msg.get('content', []) or []:
if isinstance(blk, dict) and blk.get('type') == 'text':
t = (blk.get('text') or '').strip()
if t and '<history>' not in t and not t.startswith('### [WORKING MEMORY]'):
return t
for p, _ in pairs[:1]:
for line in p.splitlines():
s = line.strip()
if s and not s.startswith('###'): return s
return ''
def _last_summary(pairs):
for _, response_body in reversed(pairs):
try:
blocks = ast.literal_eval(response_body)
except Exception:
continue
if not isinstance(blocks, list):
continue
text_parts = []
for block in blocks:
if isinstance(block, dict) and block.get('type') == 'text':
text = block.get('text', '')
if isinstance(text, str) and text:
text_parts.append(text)
match = _SUMMARY_RE.search('\n'.join(text_parts))
if match:
summary = match.group(1).strip()
if summary:
return summary
return ''
def _preview_text(pairs):
return _last_summary(pairs) or _first_user(pairs)
def _parse_native_history(pairs):
history = []
for p, r in pairs:
try: user_msg = json.loads(p)
except Exception: return None
try: blocks = ast.literal_eval(r)
except Exception: return None
if not (isinstance(user_msg, dict) and user_msg.get('role') == 'user'): return None
if not isinstance(blocks, list): return None
history.append(user_msg)
history.append({'role': 'assistant', 'content': blocks})
return history
def list_sessions(exclude_pid=None):
"""Newest-first list of (path, mtime, first_user_text, n_rounds)."""
files = glob.glob(_LOG_GLOB)
if exclude_pid is not None:
tag = f'model_responses_{exclude_pid}.txt'
files = [f for f in files if not f.endswith(tag)]
out = []
for f in files:
try:
with open(f, encoding='utf-8', errors='replace') as fh:
content = fh.read()
except Exception: continue
pairs = _pairs(content)
if not pairs: continue
out.append((f, os.path.getmtime(f), _preview_text(pairs), len(pairs)))
out.sort(key=lambda x: x[1], reverse=True)
return out
_MD_ESCAPE_RE = re.compile(r'([\\`*_\[\]])')
def _escape_md(s): return _MD_ESCAPE_RE.sub(r'\\\1', s)
def _agent_clients(agent):
clients = []
for client in getattr(agent, 'llmclients', []) or []:
if client not in clients:
clients.append(client)
current = getattr(agent, 'llmclient', None)
if current is not None and current not in clients:
clients.insert(0, current)
return clients
def _replace_backend_history(agent, history):
backend = getattr(getattr(agent, 'llmclient', None), 'backend', None)
if backend is not None and hasattr(backend, 'history'):
backend.history = list(history or [])
def _current_log_path(pid=None):
pid = os.getpid() if pid is None else pid
return os.path.join(_LOG_DIR, f'model_responses_{pid}.txt')
def _snapshot_current_log(pid=None):
"""Persist current PID log as a standalone recoverable snapshot, then clear it."""
path = _current_log_path(pid)
if not os.path.isfile(path):
return None
try:
with open(path, encoding='utf-8', errors='replace') as fh:
content = fh.read()
except Exception:
return None
if not _pairs(content):
return None
os.makedirs(_LOG_DIR, exist_ok=True)
pid = os.getpid() if pid is None else pid
stamp = time.strftime('%Y%m%d_%H%M%S')
snapshot = os.path.join(_LOG_DIR, f'model_responses_snapshot_{pid}_{stamp}_{time.time_ns() % 1_000_000_000:09d}.txt')
with open(snapshot, 'w', encoding='utf-8', errors='replace') as fh:
fh.write(content)
with open(path, 'w', encoding='utf-8', errors='replace'):
pass
return snapshot
def reset_conversation(agent, message='🆕 已开启新对话,当前上下文已清空'):
"""Abort current work and clear all known frontend-visible conversation state."""
try:
agent.abort()
except Exception:
pass
_snapshot_current_log()
if hasattr(agent, 'history'):
agent.history = []
for client in _agent_clients(agent):
backend = getattr(client, 'backend', None)
if backend is not None and hasattr(backend, 'history'):
backend.history = []
if hasattr(client, 'last_tools'):
client.last_tools = ''
if hasattr(agent, 'handler'):
agent.handler = None
return message
def format_list(sessions, limit=20):
if not sessions: return '❌ 没有可恢复的历史会话'
lines = ['**可恢复会话**(输入 `/continue N` 恢复第 N 个):', '']
for i, (_, mtime, first, n) in enumerate(sessions[:limit], 1):
preview = _escape_md((first or '(无法预览)').replace('\n', ' ')[:60])
lines.append(f'{i}. `{_rel_time(mtime)}` · **{n} 轮** · {preview}')
return '\n'.join(lines)
def restore(agent, path):
"""Restore session at path. Returns (msg, is_full)."""
try:
with open(path, encoding='utf-8', errors='replace') as fh:
content = fh.read()
except Exception as e: return f'❌ 读取失败: {e}', False
pairs = _pairs(content)
if not pairs: return f'{os.path.basename(path)} 为空或格式不符', False
history = _parse_native_history(pairs)
name = os.path.basename(path)
if history is not None:
agent.abort()
_replace_backend_history(agent, history)
return f'✅ 已恢复 {len(pairs)} 轮完整对话({name}\n(已写入 backend.history可直接继续)', True
from chatapp_common import _restore_native_history, _restore_text_pairs
summary = _restore_text_pairs(content) or _restore_native_history(content)
if not summary: return f'{name} 无法解析(非 native 且无摘要可提取)', False
agent.abort()
agent.history.extend(summary)
n = sum(1 for l in summary if l.startswith('[USER]: '))
return f'⚠️ 非 native 格式,已降级恢复 {n} 轮摘要({name}\n(请输入新问题继续)', False
def handle(agent, query, display_queue):
"""Dispatch /continue or /continue N. Returns None if consumed else original query."""
s = (query or '').strip()
if s == '/continue':
display_queue.put({'done': format_list(list_sessions(exclude_pid=os.getpid())), 'source': 'system'})
return None
m = re.match(r'/continue\s+(\d+)\s*$', s)
if m:
sessions = list_sessions(exclude_pid=os.getpid())
idx = int(m.group(1)) - 1
if not (0 <= idx < len(sessions)):
display_queue.put({'done': f'❌ 索引越界(有效范围 1-{len(sessions)}', 'source': 'system'})
return None
reset_conversation(agent, message=None)
msg, _ = restore(agent, sessions[idx][0])
display_queue.put({'done': msg, 'source': 'system'})
return None
return query
def _user_text(prompt_body):
"""User-typed text from a prompt JSON; '' if this is an agent auto-continuation."""
try: msg = json.loads(prompt_body)
except Exception: return ''
if not isinstance(msg, dict): return ''
for blk in msg.get('content', []) or []:
if isinstance(blk, dict) and blk.get('type') == 'text':
t = (blk.get('text') or '').strip()
if t and not t.startswith('### [WORKING MEMORY]'): return t
return ''
def _assistant_text(response_body):
"""Joined text from a response blocks repr; '' on parse failure."""
try: blocks = ast.literal_eval(response_body)
except Exception: return ''
if not isinstance(blocks, list): return ''
return '\n'.join(b['text'] for b in blocks
if isinstance(b, dict) and b.get('type') == 'text'
and isinstance(b.get('text'), str) and b['text'].strip())
_TURN_MARK = '**LLM Running (Turn {}) ...**\n\n'
def extract_ui_messages(path):
"""Parse a model_responses log into [{role, content}, ...] for UI replay.
Auto-continuation turns are folded into one assistant bubble with Turn markers,
matching live chat rendering via fold_turns().
"""
try:
with open(path, encoding='utf-8', errors='replace') as f: content = f.read()
except Exception: return []
rounds = [] # [(user_text, [turn_text, ...]), ...]
for prompt, response in _pairs(content):
user = _user_text(prompt)
if user or not rounds: rounds.append((user, []))
rounds[-1][1].append(_assistant_text(response))
out = []
for user, turns in rounds:
if not user or not any(turns): continue
body = '\n\n'.join(t if i == 0 else _TURN_MARK.format(i + 1) + t
for i, t in enumerate(turns))
out += [{'role': 'user', 'content': user},
{'role': 'assistant', 'content': body}]
return out
def handle_frontend_command(agent, query, exclude_pid=None):
"""Frontend-friendly /continue entry that returns text directly."""
s = (query or '').strip()
exclude_pid = os.getpid() if exclude_pid is None else exclude_pid
if s == '/continue':
return format_list(list_sessions(exclude_pid=exclude_pid))
m = re.match(r'/continue\s+(\d+)\s*$', s)
if not m:
return '用法: /continue 或 /continue N'
sessions = list_sessions(exclude_pid=exclude_pid)
idx = int(m.group(1)) - 1
if not (0 <= idx < len(sessions)):
return f'❌ 索引越界(有效范围 1-{len(sessions)}'
reset_conversation(agent, message=None)
msg, _ = restore(agent, sessions[idx][0])
return msg
def install(cls):
"""Wrap cls._handle_slash_cmd so /continue is handled before original dispatch."""
orig = cls._handle_slash_cmd
if getattr(orig, '_continue_patched', False): return
def patched(self, raw_query, display_queue):
if (raw_query or '').startswith('/continue'):
r = handle(self, raw_query, display_queue)
if r is None: return None
return orig(self, raw_query, display_queue)
patched._continue_patched = True
cls._handle_slash_cmd = patched