feat: publish ocr_utils, vision_sop, ui_detect to public repo

This commit is contained in:
Liang Jiaqing
2026-04-15 00:00:14 +08:00
parent d0a9b5de08
commit 97560a4fba
4 changed files with 289 additions and 0 deletions

135
memory/ui_detect.py Normal file
View File

@@ -0,0 +1,135 @@
#!/usr/bin/env python3
"""
极简UI元素检测脚本 - 基于OmniParser的YOLO模型
依赖: ultralytics, rapidocr-onnxruntime, pillow, numpy
"""
import sys
from pathlib import Path
from ultralytics import YOLO
from PIL import Image, ImageDraw
import numpy as np
DEFAULT_MODEL = str(Path(__file__).resolve().parent.parent / 'temp' / 'weights' / 'icon_detect' / 'model.pt')
# 可选使用rapidocr做OCR
try:
from rapidocr_onnxruntime import RapidOCR
ocr_engine = RapidOCR()
HAS_OCR = True
except ImportError:
HAS_OCR = False
print("警告: rapidocr未安装跳过OCR功能")
def detect_ui_elements(image_path, model_path=None, conf_threshold=0.25):
model_path = model_path or DEFAULT_MODEL
"""检测UI元素并返回边界框"""
# 加载模型
model = YOLO(model_path)
# 推理
results = model(image_path, conf=conf_threshold, verbose=False)
# 提取检测结果
detections = []
for result in results:
boxes = result.boxes
for box in boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
conf = float(box.conf[0])
cls = int(box.cls[0])
detections.append({
'bbox': [int(x1), int(y1), int(x2), int(y2)],
'confidence': conf,
'class': cls
})
return detections
def ocr_text(image_path):
"""OCR识别文本"""
if not HAS_OCR:
return []
result, _ = ocr_engine(image_path)
if not result:
return []
texts = []
for item in result:
bbox, text, conf = item
texts.append({
'text': text,
'bbox': bbox,
'confidence': conf
})
return texts
def visualize(image_path, detections, ocr_results=None, output_path=None):
"""可视化检测结果"""
img = Image.open(image_path)
draw = ImageDraw.Draw(img)
# 画UI元素框红色
for det in detections:
x1, y1, x2, y2 = det['bbox']
draw.rectangle([x1, y1, x2, y2], outline='red', width=2)
draw.text((x1, y1-10), f"{det['confidence']:.2f}", fill='red')
# 画OCR文本框蓝色
if ocr_results:
for ocr in ocr_results:
bbox = ocr['bbox']
points = [(bbox[i][0], bbox[i][1]) for i in range(4)]
draw.polygon(points, outline='blue')
draw.text((points[0][0], points[0][1]-10), ocr['text'][:10], fill='blue')
if output_path:
img.save(output_path)
return img
def main():
if len(sys.argv) < 2:
print("用法: python ui_detect.py <图片路径> <模型路径> [输出路径]")
print("示例: python ui_detect.py screenshot.png weights/icon_detect/model.pt output.png")
sys.exit(1)
image_path = sys.argv[1]
model_path = sys.argv[2] if len(sys.argv) > 2 else DEFAULT_MODEL
output_path = sys.argv[3] if len(sys.argv) > 3 else "output.png"
print(f"检测图片: {image_path}")
print(f"使用模型: {model_path}")
# UI元素检测
print("\n[1/2] YOLO检测UI元素...")
detections = detect_ui_elements(image_path, model_path)
print(f"检测到 {len(detections)} 个UI元素")
for i, det in enumerate(detections, 1):
print(f" {i}. bbox={det['bbox']}, conf={det['confidence']:.3f}")
# OCR文本识别
ocr_results = None
if HAS_OCR:
print("\n[2/2] OCR识别文本...")
ocr_results = ocr_text(image_path)
print(f"识别到 {len(ocr_results)} 个文本区域")
for i, ocr in enumerate(ocr_results, 1):
print(f" {i}. text='{ocr['text']}', conf={ocr['confidence']:.3f}")
# 可视化
print(f"\n保存结果到: {output_path}")
visualize(image_path, detections, ocr_results, output_path)
# 输出JSON格式结果
import json
result = {
'ui_elements': detections,
'ocr_texts': ocr_results or []
}
json_path = output_path.replace('.png', '.json')
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"JSON结果: {json_path}")
if __name__ == "__main__":
main()