1822 lines
66 KiB
Python
1822 lines
66 KiB
Python
import ctypes
|
||
import sys
|
||
import os
|
||
import subprocess
|
||
import resource
|
||
import threading
|
||
import time
|
||
import argparse
|
||
import json
|
||
import uuid
|
||
from flask import Flask, request, jsonify, Response, stream_with_context
|
||
from rkllm import RKLLM, get_global_state, get_global_text, set_global_state, set_global_text
|
||
import re
|
||
from werkzeug.exceptions import BadRequest
|
||
from typing import List, Dict, Any, Optional
|
||
from enum import Enum
|
||
|
||
app = Flask(__name__)
|
||
|
||
# Set the dynamic library path
|
||
rkllm_lib = ctypes.CDLL('lib/librkllmrt.so')
|
||
# Define the structures from the library
|
||
RKLLM_Handle_t = ctypes.c_void_p
|
||
userdata = ctypes.c_void_p(None)
|
||
|
||
LLMCallState = ctypes.c_int
|
||
LLMCallState.RKLLM_RUN_NORMAL = 0
|
||
LLMCallState.RKLLM_RUN_WAITING = 1
|
||
LLMCallState.RKLLM_RUN_FINISH = 2
|
||
LLMCallState.RKLLM_RUN_ERROR = 3
|
||
|
||
RKLLMInputType = ctypes.c_int
|
||
RKLLMInputType.RKLLM_INPUT_PROMPT = 0
|
||
RKLLMInputType.RKLLM_INPUT_TOKEN = 1
|
||
RKLLMInputType.RKLLM_INPUT_EMBED = 2
|
||
RKLLMInputType.RKLLM_INPUT_MULTIMODAL = 3
|
||
|
||
RKLLMInferMode = ctypes.c_int
|
||
RKLLMInferMode.RKLLM_INFER_GENERATE = 0
|
||
RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
|
||
RKLLMInferMode.RKLLM_INFER_GET_LOGITS = 2
|
||
|
||
class RKLLMExtendParam(ctypes.Structure):
|
||
_fields_ = [
|
||
("base_domain_id", ctypes.c_int32),
|
||
("embed_flash", ctypes.c_int8),
|
||
("enabled_cpus_num", ctypes.c_int8),
|
||
("enabled_cpus_mask", ctypes.c_uint32),
|
||
("n_batch", ctypes.c_uint8),
|
||
("use_cross_attn", ctypes.c_int8),
|
||
("reserved", ctypes.c_uint8 * 104)
|
||
]
|
||
|
||
class RKLLMParam(ctypes.Structure):
|
||
_fields_ = [
|
||
("model_path", ctypes.c_char_p),
|
||
("max_context_len", ctypes.c_int32),
|
||
("max_new_tokens", ctypes.c_int32),
|
||
("top_k", ctypes.c_int32),
|
||
("n_keep", ctypes.c_int32),
|
||
("top_p", ctypes.c_float),
|
||
("temperature", ctypes.c_float),
|
||
("repeat_penalty", ctypes.c_float),
|
||
("frequency_penalty", ctypes.c_float),
|
||
("presence_penalty", ctypes.c_float),
|
||
("mirostat", ctypes.c_int32),
|
||
("mirostat_tau", ctypes.c_float),
|
||
("mirostat_eta", ctypes.c_float),
|
||
("skip_special_token", ctypes.c_bool),
|
||
("is_async", ctypes.c_bool),
|
||
("img_start", ctypes.c_char_p),
|
||
("img_end", ctypes.c_char_p),
|
||
("img_content", ctypes.c_char_p),
|
||
("extend_param", RKLLMExtendParam),
|
||
]
|
||
|
||
class RKLLMLoraAdapter(ctypes.Structure):
|
||
_fields_ = [
|
||
("lora_adapter_path", ctypes.c_char_p),
|
||
("lora_adapter_name", ctypes.c_char_p),
|
||
("scale", ctypes.c_float)
|
||
]
|
||
|
||
class RKLLMEmbedInput(ctypes.Structure):
|
||
_fields_ = [
|
||
("embed", ctypes.POINTER(ctypes.c_float)),
|
||
("n_tokens", ctypes.c_size_t)
|
||
]
|
||
|
||
class RKLLMTokenInput(ctypes.Structure):
|
||
_fields_ = [
|
||
("input_ids", ctypes.POINTER(ctypes.c_int32)),
|
||
("n_tokens", ctypes.c_size_t)
|
||
]
|
||
class RKLLMMultiModalInput(ctypes.Structure):
|
||
_fields_ = [
|
||
("prompt", ctypes.c_char_p),
|
||
("image_embed", ctypes.POINTER(ctypes.c_float)),
|
||
("n_image_tokens", ctypes.c_size_t),
|
||
("n_image", ctypes.c_size_t),
|
||
("image_width", ctypes.c_size_t),
|
||
("image_height", ctypes.c_size_t)
|
||
]
|
||
|
||
class RKLLMInputUnion(ctypes.Union):
|
||
_fields_ = [
|
||
("prompt_input", ctypes.c_char_p),
|
||
("embed_input", RKLLMEmbedInput),
|
||
("token_input", RKLLMTokenInput),
|
||
("multimodal_input", RKLLMMultiModalInput)
|
||
]
|
||
|
||
class RKLLMInput(ctypes.Structure):
|
||
_fields_ = [
|
||
("role", ctypes.c_char_p),
|
||
("enable_thinking", ctypes.c_bool),
|
||
("input_type", RKLLMInputType),
|
||
("input_data", RKLLMInputUnion)
|
||
]
|
||
|
||
class RKLLMLoraParam(ctypes.Structure):
|
||
_fields_ = [
|
||
("lora_adapter_name", ctypes.c_char_p)
|
||
]
|
||
|
||
class RKLLMPromptCacheParam(ctypes.Structure):
|
||
_fields_ = [
|
||
("save_prompt_cache", ctypes.c_int),
|
||
("prompt_cache_path", ctypes.c_char_p)
|
||
]
|
||
|
||
class RKLLMInferParam(ctypes.Structure):
|
||
_fields_ = [
|
||
("mode", RKLLMInferMode),
|
||
("lora_params", ctypes.POINTER(RKLLMLoraParam)),
|
||
("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
|
||
("keep_history", ctypes.c_int)
|
||
]
|
||
|
||
class RKLLMResultLastHiddenLayer(ctypes.Structure):
|
||
_fields_ = [
|
||
("hidden_states", ctypes.POINTER(ctypes.c_float)),
|
||
("embd_size", ctypes.c_int),
|
||
("num_tokens", ctypes.c_int)
|
||
]
|
||
|
||
class RKLLMResultLogits(ctypes.Structure):
|
||
_fields_ = [
|
||
("logits", ctypes.POINTER(ctypes.c_float)),
|
||
("vocab_size", ctypes.c_int),
|
||
("num_tokens", ctypes.c_int)
|
||
]
|
||
|
||
class RKLLMPerfStat(ctypes.Structure):
|
||
_fields_ = [
|
||
("prefill_time_ms", ctypes.c_float),
|
||
("prefill_tokens", ctypes.c_int),
|
||
("generate_time_ms", ctypes.c_float),
|
||
("generate_tokens", ctypes.c_int),
|
||
("memory_usage_mb", ctypes.c_float)
|
||
]
|
||
|
||
class RKLLMResult(ctypes.Structure):
|
||
_fields_ = [
|
||
("text", ctypes.c_char_p),
|
||
("token_id", ctypes.c_int),
|
||
("last_hidden_layer", RKLLMResultLastHiddenLayer),
|
||
("logits", RKLLMResultLogits),
|
||
("perf", RKLLMPerfStat)
|
||
]
|
||
|
||
# Create a lock to control multi-user access to the server.
|
||
lock = threading.Lock()
|
||
is_blocking = False
|
||
|
||
# Define global variables
|
||
system_prompt = ''
|
||
global_text = []
|
||
global_state = -1
|
||
split_byte_data = bytes(b"")
|
||
recevied_messages = []
|
||
|
||
SAFE_CHAR_LIMIT = 8000
|
||
MAX_OUTPUT_CHARS = 2048
|
||
MAX_GENERATION_TIME = 60
|
||
MAX_CONTEXT_TOKENS = 14000 # 预留 2000 给响应
|
||
|
||
# ==================== Token 管理类 ====================
|
||
class ConversationManager:
|
||
"""对话管理器,自动管理 token 限制"""
|
||
|
||
def __init__(self, max_context_tokens=14000, reserve_tokens=2000):
|
||
self.max_context_tokens = max_context_tokens
|
||
self.reserve_tokens = reserve_tokens
|
||
self.available_tokens = max_context_tokens - reserve_tokens
|
||
|
||
def count_tokens(self, text):
|
||
"""计算 token 数(粗略估算)"""
|
||
if not text:
|
||
return 0
|
||
# 中文约 1.5 字符/token,英文约 4 字符/token
|
||
chinese_chars = len([c for c in text if '\u4e00' <= c <= '\u9fff'])
|
||
other_chars = len(text) - chinese_chars
|
||
return int(chinese_chars / 1.5 + other_chars / 4)
|
||
|
||
def truncate_text(self, text, max_tokens):
|
||
"""截断文本到指定 token 数"""
|
||
if not text:
|
||
return text
|
||
|
||
max_chars = int(max_tokens * 1.5)
|
||
if len(text) <= max_chars:
|
||
return text
|
||
|
||
# 保留开头和结尾
|
||
half = max_chars // 2
|
||
return text[:half] + "\n...[内容已截断]...\n" + text[-half:]
|
||
|
||
def build_prompt(self, messages, system_prompt, safe_get_content_func):
|
||
"""构建 prompt,自动处理 token 限制"""
|
||
|
||
# 计算系统提示的 token
|
||
system_text = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
|
||
system_tokens = self.count_tokens(system_text)
|
||
|
||
available = self.available_tokens - system_tokens
|
||
|
||
# 从后往前构建消息
|
||
prompt_parts = []
|
||
current_tokens = 0
|
||
|
||
# 添加结束标记的 token 开销
|
||
end_token = "<|im_start|>assistant\n"
|
||
current_tokens += self.count_tokens(end_token)
|
||
|
||
# 从后往前处理消息
|
||
for msg in reversed(messages):
|
||
role = msg.get('role')
|
||
if role == 'system':
|
||
continue
|
||
|
||
content = safe_get_content_func(msg)
|
||
|
||
# 构建消息格式
|
||
if role == 'assistant':
|
||
tool_calls = msg.get('tool_calls', [])
|
||
if tool_calls:
|
||
msg_text = f"<|im_start|>assistant\nTool Call: {tool_calls[0]['function']['name']}\nArguments: {tool_calls[0]['function']['arguments']}<|im_end|>\n"
|
||
else:
|
||
msg_text = f"<|im_start|>assistant\n{content}<|im_end|>\n"
|
||
elif role == 'tool':
|
||
msg_text = f"<|im_start|>user\nTool Result: {content}<|im_end|>\n"
|
||
else: # user
|
||
msg_text = f"<|im_start|>user\n{content}<|im_end|>\n"
|
||
|
||
msg_tokens = self.count_tokens(msg_text)
|
||
|
||
if current_tokens + msg_tokens <= available:
|
||
prompt_parts.insert(0, msg_text)
|
||
current_tokens += msg_tokens
|
||
else:
|
||
# 如果这条消息太大,尝试截断
|
||
if len(prompt_parts) == 0:
|
||
# 至少保留最后一条消息的部分内容
|
||
truncated_content = self.truncate_text(content, available - current_tokens)
|
||
msg_text = f"<|im_start|>user\n{truncated_content}<|im_end|>\n"
|
||
prompt_parts.insert(0, msg_text)
|
||
break
|
||
|
||
# 构建完整 prompt
|
||
full_prompt = system_text + "".join(prompt_parts) + "<|im_start|>assistant\n"
|
||
|
||
total_tokens = self.count_tokens(full_prompt)
|
||
print(f"Prompt built: {total_tokens}/{self.max_context_tokens} tokens, {len(prompt_parts)} messages")
|
||
|
||
if total_tokens > self.max_context_tokens:
|
||
print(f"⚠️ Warning: Prompt still too long: {total_tokens} tokens")
|
||
# 强制截断
|
||
max_chars = int(self.max_context_tokens * 1.5)
|
||
if len(full_prompt) > max_chars:
|
||
full_prompt = full_prompt[:max_chars] + "\n...[对话已截断]..."
|
||
|
||
return full_prompt
|
||
|
||
class SkillStatus(Enum):
|
||
"""技能状态枚举"""
|
||
READY = "ready"
|
||
DISABLED = "disabled"
|
||
NEEDS_SETUP = "needs setup"
|
||
UNKNOWN = "unknown"
|
||
|
||
class OpenClawSkillParser:
|
||
"""解析 openclaw skills list 命令输出"""
|
||
|
||
def __init__(self, openclaw_cmd: str = "openclaw"):
|
||
"""
|
||
初始化解析器
|
||
|
||
Args:
|
||
openclaw_cmd: openclaw命令路径
|
||
"""
|
||
self.openclaw_cmd = openclaw_cmd
|
||
|
||
def get_skills_from_cli(self, include_disabled: bool = False) -> List[Dict[str, Any]]:
|
||
"""
|
||
通过执行 openclaw skills list 命令获取技能列表
|
||
|
||
Args:
|
||
include_disabled: 是否包含禁用的技能
|
||
|
||
Returns:
|
||
技能列表,每个技能包含完整信息
|
||
"""
|
||
try:
|
||
# 执行命令获取表格输出
|
||
result = subprocess.run(
|
||
[self.openclaw_cmd, "skills", "list"],
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=30
|
||
)
|
||
|
||
if result.returncode != 0:
|
||
print(f"命令执行失败: {result.stderr}")
|
||
return []
|
||
|
||
# 解析表格输出
|
||
skills = self._parse_table_output(result.stdout)
|
||
|
||
# 过滤技能
|
||
if not include_disabled:
|
||
skills = [s for s in skills if s['status'] != SkillStatus.DISABLED]
|
||
|
||
return skills
|
||
|
||
except subprocess.TimeoutExpired:
|
||
print("命令执行超时")
|
||
return []
|
||
except FileNotFoundError:
|
||
print(f"未找到 openclaw 命令: {self.openclaw_cmd}")
|
||
return []
|
||
except Exception as e:
|
||
print(f"获取技能列表失败: {e}")
|
||
return []
|
||
|
||
def _parse_table_output(self, output: str) -> List[Dict[str, Any]]:
|
||
"""
|
||
解析表格格式的输出
|
||
|
||
输入格式示例:
|
||
┌──────────────┬───────────────────────┬─────────────────────────────────────────────────────────────────┬────────────────────┐
|
||
│ Status │ Skill │ Description │ Source │
|
||
├──────────────┼───────────────────────┼─────────────────────────────────────────────────────────────────┼────────────────────┤
|
||
│ ✓ ready │ 📦 clawhub │ Use the ClawHub CLI to search, install... │ openclaw-bundled │
|
||
│ ⏸ disabled │ 🔐 1password │ Set up and use 1Password CLI (op)... │ openclaw-bundled │
|
||
│ △ needs setup│ 🫧 bluebubbles │ Use when you need to send or manage iMessages... │ openclaw-bundled │
|
||
└──────────────┴───────────────────────┴─────────────────────────────────────────────────────────────────┴────────────────────┘
|
||
"""
|
||
skills = []
|
||
lines = output.strip().split('\n')
|
||
|
||
# 查找表格数据行
|
||
in_table = False
|
||
for line in lines:
|
||
# 跳过表格边框和分隔线
|
||
if '┌' in line or '├' in line or '└' in line:
|
||
continue
|
||
|
||
# 检查是否是数据行(包含 │ 分隔符)
|
||
if '│' in line:
|
||
# 解析表格行
|
||
parts = self._parse_table_row(line)
|
||
if parts and len(parts) >= 4:
|
||
# 跳过表头
|
||
if 'Status' in parts[0] and 'Skill' in parts[1]:
|
||
continue
|
||
|
||
# 提取各列数据
|
||
status_raw = parts[0].strip()
|
||
skill_raw = parts[1].strip()
|
||
description_raw = parts[2].strip()
|
||
source_raw = parts[3].strip() if len(parts) > 3 else ''
|
||
|
||
# 解析技能信息
|
||
skill_info = self._parse_skill_info(status_raw, skill_raw, description_raw, source_raw)
|
||
if skill_info:
|
||
skills.append(skill_info)
|
||
|
||
return skills
|
||
|
||
def _parse_table_row(self, line: str) -> List[str]:
|
||
"""
|
||
解析表格的一行数据
|
||
|
||
Args:
|
||
line: 包含 │ 分隔符的行
|
||
|
||
Returns:
|
||
各列数据的列表
|
||
"""
|
||
# 分割行,保留列内容
|
||
parts = line.split('│')
|
||
# 去除首尾空元素,并清理每列内容
|
||
cleaned = []
|
||
for part in parts:
|
||
part = part.strip()
|
||
if part: # 只保留非空内容
|
||
cleaned.append(part)
|
||
return cleaned
|
||
|
||
def _parse_skill_info(self, status_raw: str, skill_raw: str,
|
||
description_raw: str, source_raw: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
解析单个技能的详细信息
|
||
|
||
Args:
|
||
status_raw: 状态列原始文本,如 "✓ ready" 或 "⏸ disabled" 或 "△ needs setup"
|
||
skill_raw: 技能名列原始文本,如 "📦 clawhub" 或 "🔐 1password"
|
||
description_raw: 描述列文本
|
||
source_raw: 来源列文本
|
||
|
||
Returns:
|
||
解析后的技能信息字典
|
||
"""
|
||
# 解析状态
|
||
status = self._parse_status(status_raw)
|
||
if not status:
|
||
return None
|
||
|
||
# 解析技能名称(移除emoji和多余空格)
|
||
skill_name = self._parse_skill_name(skill_raw)
|
||
if not skill_name:
|
||
return None
|
||
|
||
# 解析技能图标(如果有)
|
||
skill_icon = self._parse_skill_icon(skill_raw)
|
||
|
||
# 解析来源
|
||
source = source_raw.strip() if source_raw else 'unknown'
|
||
|
||
return {
|
||
'name': skill_name,
|
||
'display_name': skill_raw.strip(),
|
||
'icon': skill_icon,
|
||
'status': status,
|
||
'status_raw': status_raw,
|
||
'description': description_raw.strip(),
|
||
'source': source,
|
||
'is_ready': status == SkillStatus.READY,
|
||
'is_disabled': status == SkillStatus.DISABLED,
|
||
'needs_setup': status == SkillStatus.NEEDS_SETUP
|
||
}
|
||
|
||
def _parse_status(self, status_raw: str) -> Optional[SkillStatus]:
|
||
"""解析状态列"""
|
||
status_lower = status_raw.lower()
|
||
|
||
if 'ready' in status_lower or '✓' in status_raw:
|
||
return SkillStatus.READY
|
||
elif 'disabled' in status_lower or '⏸' in status_raw:
|
||
return SkillStatus.DISABLED
|
||
elif 'needs setup' in status_lower or '△' in status_raw:
|
||
return SkillStatus.NEEDS_SETUP
|
||
else:
|
||
return SkillStatus.UNKNOWN
|
||
|
||
def _remove_emoji(self, text):
|
||
"""
|
||
移除字符串中的所有emoji表情
|
||
"""
|
||
# 匹配所有emoji的正则表达式
|
||
emoji_pattern = re.compile(
|
||
pattern="["
|
||
u"\U0001F600-\U0001F64F" # 表情
|
||
u"\U0001F300-\U0001F5FF" # 符号
|
||
u"\U0001F680-\U0001F6FF" # 交通和地图
|
||
u"\U0001F700-\U0001F77F" # 炼金术
|
||
u"\U0001F780-\U0001F7FF" # 几何图形
|
||
u"\U0001F800-\U0001F8FF" # 箭头
|
||
u"\U0001F900-\U0001F9FF" # 补充符号
|
||
u"\U0001FA00-\U0001FA6F" # 象棋
|
||
u"\U0001FA70-\U0001FAFF" # 扩展符号
|
||
u"\U00002702-\U000027B0" # 杂项符号
|
||
u"\U000024C2-\U0001F251"
|
||
"]+",
|
||
flags=re.UNICODE
|
||
)
|
||
return emoji_pattern.sub('', text)
|
||
|
||
def _parse_skill_name(self, skill_raw: str) -> str:
|
||
"""解析技能名称,移除emoji"""
|
||
# 移除emoji(简单处理:移除常见emoji字符)
|
||
# 注意:这是一个简化版本,实际可能需要更复杂的emoji检测
|
||
# 移除emoji和特殊符号(如🔐、📦、🫧等)
|
||
name = self._remove_emoji(skill_raw)
|
||
# 清理多余空格
|
||
name = ' '.join(name.split())
|
||
return name.lower().strip()
|
||
|
||
def _parse_skill_icon(self, skill_raw: str) -> str:
|
||
"""解析技能图标(emoji)"""
|
||
# 提取开头的emoji(如果有)
|
||
import re
|
||
# 匹配emoji的正则表达式(简化版)
|
||
emoji_pattern = re.compile(r'^[\U0001F300-\U0001F9FF]|^[\u2600-\u26FF]|^[\u2700-\u27BF]')
|
||
match = emoji_pattern.match(skill_raw)
|
||
if match:
|
||
return match.group(0)
|
||
return ''
|
||
|
||
def get_ready_skills(self) -> List[Dict[str, Any]]:
|
||
"""获取就绪的技能(状态为 ready)"""
|
||
skills = self.get_skills_from_cli(include_disabled=False)
|
||
return [s for s in skills if s['is_ready']]
|
||
|
||
def get_enabled_skills(self) -> List[Dict[str, Any]]:
|
||
"""获取启用的技能(ready 和 needs_setup 但未 disabled)"""
|
||
skills = self.get_skills_from_cli(include_disabled=False)
|
||
return skills
|
||
|
||
def get_skills_by_source(self, source: str) -> List[Dict[str, Any]]:
|
||
"""按来源筛选技能"""
|
||
skills = self.get_skills_from_cli(include_disabled=True)
|
||
return [s for s in skills if s['source'] == source]
|
||
|
||
def print_skills_summary(self):
|
||
"""打印技能摘要"""
|
||
skills = self.get_skills_from_cli(include_disabled=True)
|
||
|
||
ready = [s for s in skills if s['is_ready']]
|
||
needs_setup = [s for s in skills if s['needs_setup']]
|
||
disabled = [s for s in skills if s['is_disabled']]
|
||
|
||
print(f"技能统计:")
|
||
print(f" ✓ 就绪: {len(ready)}")
|
||
print(f" △ 需要设置: {len(needs_setup)}")
|
||
print(f" ⏸ 禁用: {len(disabled)}")
|
||
print(f" 总计: {len(skills)}")
|
||
|
||
if ready:
|
||
print(f"\n就绪技能:")
|
||
for skill in ready:
|
||
print(f"{skill['name']}: {skill['description'][:60]}...")
|
||
|
||
# ==================== 消息提取函数 ====================
|
||
def extract_text_content(content):
|
||
"""从各种格式的消息内容中提取文本"""
|
||
if content is None:
|
||
return ""
|
||
if isinstance(content, str):
|
||
return content.strip()
|
||
if isinstance(content, list):
|
||
text_parts = []
|
||
for item in content:
|
||
if isinstance(item, dict):
|
||
if item.get('type') == 'text':
|
||
text_parts.append(item.get('text', ''))
|
||
else:
|
||
text_parts.append(str(item))
|
||
else:
|
||
text_parts.append(str(item))
|
||
return "\n".join(text_parts).strip()
|
||
return str(content)
|
||
|
||
def clean_openclaw_message(text):
|
||
"""清理 OpenClaw 消息中的元数据"""
|
||
if not isinstance(text, str):
|
||
return text
|
||
|
||
lines = text.split('\n')
|
||
cleaned_lines = []
|
||
in_json_block = False
|
||
|
||
for line in lines:
|
||
if "Sender (untrusted metadata)" in line:
|
||
continue
|
||
if line.strip().startswith('```json'):
|
||
in_json_block = True
|
||
continue
|
||
if line.strip() == '```' and in_json_block:
|
||
in_json_block = False
|
||
continue
|
||
if in_json_block:
|
||
continue
|
||
if re.match(r'\[.*?UTC\]', line.strip()):
|
||
continue
|
||
if line.strip():
|
||
cleaned_lines.append(line.strip())
|
||
|
||
if not cleaned_lines:
|
||
return lines[-1].strip() if lines else text
|
||
|
||
cleaned = " ".join(cleaned_lines)
|
||
|
||
# 提取查询内容
|
||
match = re.search(r'使用\w+查询(.+)', cleaned)
|
||
if match:
|
||
return match.group(1).strip()
|
||
|
||
return cleaned
|
||
|
||
def safe_get_content(message):
|
||
"""安全获取消息内容"""
|
||
content = message.get('content')
|
||
role = message.get('role')
|
||
|
||
text_content = extract_text_content(content)
|
||
|
||
if role == 'user':
|
||
return clean_openclaw_message(text_content)
|
||
return text_content
|
||
|
||
|
||
# Define the callback function
|
||
def callback_impl(result, userdata, state):
|
||
global global_text, global_state
|
||
if state == LLMCallState.RKLLM_RUN_FINISH:
|
||
global_state = state
|
||
print("\n")
|
||
sys.stdout.flush()
|
||
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
||
global_state = state
|
||
print("run error")
|
||
sys.stdout.flush()
|
||
elif state == LLMCallState.RKLLM_RUN_NORMAL:
|
||
global_state = state
|
||
if result and result.contents and result.contents.text:
|
||
global_text += result.contents.text.decode('utf-8')
|
||
return 0
|
||
|
||
callback_type = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.POINTER(RKLLMResult), ctypes.c_void_p, ctypes.c_int)
|
||
callback = callback_type(callback_impl)
|
||
|
||
# Define the RKLLM class
|
||
class RKLLM(object):
|
||
def __init__(self, model_path, lora_model_path = None, prompt_cache_path = None, platform = "rk3588"):
|
||
rkllm_param = RKLLMParam()
|
||
rkllm_param.model_path = bytes(model_path, 'utf-8')
|
||
rkllm_param.max_context_len = 16000
|
||
rkllm_param.max_new_tokens = 2048
|
||
rkllm_param.skip_special_token = True
|
||
rkllm_param.n_keep = -1
|
||
rkllm_param.top_k = 50
|
||
rkllm_param.top_p = 0.9
|
||
rkllm_param.temperature = 0.7
|
||
rkllm_param.repeat_penalty = 1.1
|
||
rkllm_param.frequency_penalty = 0.0
|
||
rkllm_param.presence_penalty = 0.0
|
||
rkllm_param.mirostat = 0
|
||
rkllm_param.mirostat_tau = 5.0
|
||
rkllm_param.mirostat_eta = 0.1
|
||
rkllm_param.is_async = False
|
||
rkllm_param.img_start = "".encode('utf-8')
|
||
rkllm_param.img_end = "".encode('utf-8')
|
||
rkllm_param.img_content = "".encode('utf-8')
|
||
|
||
rkllm_param.extend_param.base_domain_id = 0
|
||
rkllm_param.extend_param.embed_flash = 1
|
||
rkllm_param.extend_param.n_batch = 1
|
||
rkllm_param.extend_param.use_cross_attn = 0
|
||
rkllm_param.extend_param.enabled_cpus_num = 4
|
||
if platform.lower() in ["rk3576", "rk3588"]:
|
||
rkllm_param.extend_param.enabled_cpus_mask = (1 << 4)|(1 << 5)|(1 << 6)|(1 << 7)
|
||
else:
|
||
rkllm_param.extend_param.enabled_cpus_mask = (1 << 0)|(1 << 1)|(1 << 2)|(1 << 3)
|
||
|
||
self.handle = RKLLM_Handle_t()
|
||
self.rkllm_init = rkllm_lib.rkllm_init
|
||
self.rkllm_init.argtypes = [ctypes.POINTER(RKLLM_Handle_t), ctypes.POINTER(RKLLMParam), callback_type]
|
||
self.rkllm_init.restype = ctypes.c_int
|
||
ret = self.rkllm_init(ctypes.byref(self.handle), ctypes.byref(rkllm_param), callback)
|
||
if ret != 0:
|
||
print("\nrkllm init failed\n")
|
||
exit(0)
|
||
else:
|
||
print("\nrkllm init success!\n")
|
||
|
||
self.rkllm_run = rkllm_lib.rkllm_run
|
||
self.rkllm_run.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p]
|
||
self.rkllm_run.restype = ctypes.c_int
|
||
|
||
self.set_function_tools_ = rkllm_lib.rkllm_set_function_tools
|
||
self.set_function_tools_.argtypes = [RKLLM_Handle_t, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p]
|
||
self.set_function_tools_.restype = ctypes.c_int
|
||
|
||
self.rkllm_destroy = rkllm_lib.rkllm_destroy
|
||
self.rkllm_destroy.argtypes = [RKLLM_Handle_t]
|
||
self.rkllm_destroy.restype = ctypes.c_int
|
||
|
||
self.rkllm_abort = rkllm_lib.rkllm_abort
|
||
|
||
# Initialize inference params
|
||
self.rkllm_infer_params = RKLLMInferParam()
|
||
ctypes.memset(ctypes.byref(self.rkllm_infer_params), 0, ctypes.sizeof(RKLLMInferParam))
|
||
self.rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
|
||
self.rkllm_infer_params.keep_history = 0
|
||
|
||
self.tools = None
|
||
|
||
def set_function_tools(self, system_prompt, tools, tool_response_str):
|
||
if self.tools is None or not self.tools == tools:
|
||
self.tools = tools
|
||
self.set_function_tools_(self.handle, ctypes.c_char_p(system_prompt.encode('utf-8')),
|
||
ctypes.c_char_p(tools.encode('utf-8')),
|
||
ctypes.c_char_p(tool_response_str.encode('utf-8')))
|
||
|
||
def run(self, *param):
|
||
role, enable_thinking, prompt = param
|
||
rkllm_input = RKLLMInput()
|
||
rkllm_input.role = role.encode('utf-8') if role is not None else "user".encode('utf-8')
|
||
rkllm_input.enable_thinking = ctypes.c_bool(enable_thinking if enable_thinking is not None else False)
|
||
rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
|
||
rkllm_input.input_data.prompt_input = ctypes.c_char_p(prompt.encode('utf-8'))
|
||
self.rkllm_run(self.handle, ctypes.byref(rkllm_input), ctypes.byref(self.rkllm_infer_params), None)
|
||
return
|
||
|
||
def abort(self):
|
||
return self.rkllm_abort(self.handle)
|
||
|
||
def release(self):
|
||
self.rkllm_destroy(self.handle)
|
||
|
||
def create_openai_stream_chunk(content, model_name, completion_id, finish_reason=None):
|
||
chunk = {
|
||
"id": completion_id,
|
||
"object": "chat.completion.chunk",
|
||
"created": int(time.time()),
|
||
"model": model_name,
|
||
"choices": [{"index": 0, "delta": {}, "finish_reason": finish_reason}]
|
||
}
|
||
if content is not None:
|
||
chunk["choices"][0]["delta"] = {"content": content}
|
||
return chunk
|
||
|
||
from collections import deque
|
||
import threading
|
||
import time
|
||
import uuid
|
||
import json
|
||
from flask import request, jsonify, Response
|
||
from werkzeug.exceptions import BadRequest
|
||
|
||
# 全局变量(确保在模块顶部定义)
|
||
global_text = deque() # 改为 deque
|
||
global_state = -1
|
||
is_blocking = False
|
||
lock = threading.Lock()
|
||
|
||
# 安全限制(可根据 RK3588 实测调整)
|
||
SAFE_CHAR_LIMIT = 3000 # 输入总字符上限
|
||
MAX_OUTPUT_CHARS = 30000 # 输出总字符上限(约 1200 tokens)
|
||
MAX_GENERATION_TIME = 1200 # 最大生成时间(秒)
|
||
|
||
# ==================== 消息提取辅助函数 ====================
|
||
|
||
def extract_text_content(content):
|
||
"""
|
||
从各种格式的消息内容中提取文本
|
||
支持:字符串、列表、字典等多种格式
|
||
"""
|
||
if content is None:
|
||
return ""
|
||
|
||
# 如果是字符串,直接返回
|
||
if isinstance(content, str):
|
||
return content.strip()
|
||
|
||
# 如果是列表,提取所有文本内容
|
||
if isinstance(content, list):
|
||
text_parts = []
|
||
for item in content:
|
||
if isinstance(item, dict):
|
||
# 处理标准的 OpenAI 多模态格式
|
||
if item.get('type') == 'text':
|
||
text = item.get('text', '')
|
||
text_parts.append(text)
|
||
elif item.get('type') == 'image_url':
|
||
text_parts.append("[Image]")
|
||
else:
|
||
# 尝试获取任何可能的文本字段
|
||
for field in ['text', 'content', 'message', 'description']:
|
||
if field in item:
|
||
text_parts.append(str(item[field]))
|
||
break
|
||
else:
|
||
# 如果没有找到文本字段,将整个字典转为字符串
|
||
text_parts.append(str(item))
|
||
else:
|
||
text_parts.append(str(item))
|
||
|
||
# 用换行符连接所有文本
|
||
return "\n".join(text_parts).strip()
|
||
|
||
# 其他类型,转字符串
|
||
return str(content).strip()
|
||
|
||
def clean_openclaw_message(text):
|
||
"""清理 OpenClaw 消息中的元数据"""
|
||
if not isinstance(text, str):
|
||
return text
|
||
|
||
lines = text.split('\n')
|
||
cleaned_lines = []
|
||
in_json_block = False
|
||
|
||
for line in lines:
|
||
if "Sender (untrusted metadata)" in line:
|
||
continue
|
||
if line.strip().startswith('```json'):
|
||
in_json_block = True
|
||
continue
|
||
if line.strip() == '```' and in_json_block:
|
||
in_json_block = False
|
||
continue
|
||
if in_json_block:
|
||
continue
|
||
if re.match(r'\[.*?UTC\]', line.strip()):
|
||
continue
|
||
if line.strip():
|
||
cleaned_lines.append(line.strip())
|
||
|
||
if not cleaned_lines:
|
||
return lines[-1].strip() if lines else text
|
||
|
||
cleaned = " ".join(cleaned_lines)
|
||
|
||
# 提取查询内容
|
||
match = re.search(r'使用\w+查询(.+)', cleaned)
|
||
if match:
|
||
return match.group(1).strip()
|
||
|
||
return cleaned
|
||
|
||
def parse_react_to_tool_call(text, available_tools, skill_manager, messages=None):
|
||
"""解析 ReAct 格式,支持技能调用和工具名称映射,增强 JSON 容错,检测循环"""
|
||
|
||
print(f"\n=== Parsing tool call ===")
|
||
print(f"Input text: {text[:300]}")
|
||
|
||
# ========== 1. 首先尝试纯 JSON 格式 ==========
|
||
skip_react = False
|
||
thought = ""
|
||
tool_name = ""
|
||
tool_input_str = ""
|
||
|
||
try:
|
||
# 清理可能的 markdown 代码块
|
||
clean_text = text.strip()
|
||
if clean_text.startswith('```json'):
|
||
clean_text = clean_text[7:]
|
||
if clean_text.startswith('```'):
|
||
clean_text = clean_text[3:]
|
||
if clean_text.endswith('```'):
|
||
clean_text = clean_text[:-3]
|
||
clean_text = clean_text.strip()
|
||
|
||
json_data = json.loads(clean_text)
|
||
print(f" Found pure JSON: {json_data}")
|
||
|
||
# ========== 支持多种 JSON 格式 ==========
|
||
|
||
# 格式1: {"action": "web_search", "action_input": {...}}
|
||
if 'action' in json_data and 'action_input' in json_data:
|
||
tool_name = json_data['action']
|
||
tool_input_str = json.dumps(json_data['action_input'])
|
||
thought = "Using tool based on JSON format"
|
||
skip_react = True
|
||
print(f" Parsed as action/action_input format: {tool_name}")
|
||
|
||
# 格式2: {"tool": "web_search", "input": {...}}
|
||
elif 'tool' in json_data and 'input' in json_data:
|
||
tool_name = json_data['tool']
|
||
tool_input_str = json.dumps(json_data['input'])
|
||
thought = "Using tool based on JSON format"
|
||
skip_react = True
|
||
print(f" Parsed as tool/input format: {tool_name}")
|
||
|
||
# 格式3: {"action": "web_search", "query": "..."} (直接参数)
|
||
elif 'action' in json_data:
|
||
tool_name = json_data['action']
|
||
# 提取除了 action 之外的所有字段作为参数
|
||
action_input = {k: v for k, v in json_data.items() if k != 'action'}
|
||
tool_input_str = json.dumps(action_input)
|
||
thought = "Using tool based on JSON format"
|
||
skip_react = True
|
||
print(f" Parsed as action with params format: {tool_name}")
|
||
print(f" Extracted params: {action_input}")
|
||
|
||
# 格式4: {"name": "web_search", "parameters": {...}}
|
||
elif 'name' in json_data and 'parameters' in json_data:
|
||
tool_name = json_data['name']
|
||
tool_input_str = json.dumps(json_data['parameters'])
|
||
thought = "Using tool based on JSON format"
|
||
skip_react = True
|
||
print(f" Parsed as name/parameters format: {tool_name}")
|
||
|
||
# 格式5: {"skill": "web_search", "input": {...}}
|
||
elif 'skill' in json_data and 'input' in json_data:
|
||
tool_name = 'invoke_skill'
|
||
tool_input_str = json.dumps(json_data)
|
||
thought = "Using skill based on JSON format"
|
||
skip_react = True
|
||
print(f" Parsed as skill format: {json_data.get('skill')}")
|
||
|
||
# 格式6: {"action": "cron", "method": "add", "arguments": {...}}
|
||
elif 'action' in json_data and 'method' in json_data and 'arguments' in json_data:
|
||
tool_name = json_data['action']
|
||
# 将 method 和 arguments 合并
|
||
tool_input = {
|
||
"method": json_data['method'],
|
||
**json_data['arguments']
|
||
}
|
||
# 也保留其他顶层字段
|
||
for key in ['schedule', 'payload', 'sessionTarget', 'enabled']:
|
||
if key in json_data:
|
||
tool_input[key] = json_data[key]
|
||
tool_input_str = json.dumps(tool_input)
|
||
thought = "Using tool based on JSON format"
|
||
skip_react = True
|
||
print(f" Parsed as action/method/arguments format: {tool_name}")
|
||
print(f" Extracted input: {tool_input}")
|
||
|
||
# 格式7: {"action": "cron", "method": "add", "schedule": {...}, ...}
|
||
elif 'action' in json_data and 'method' in json_data:
|
||
tool_name = json_data['action']
|
||
# 提取所有参数(除了 action)
|
||
tool_input = {k: v for k, v in json_data.items() if k != 'action'}
|
||
tool_input_str = json.dumps(tool_input)
|
||
thought = "Using tool based on JSON format"
|
||
skip_react = True
|
||
print(f" Parsed as action/method format: {tool_name}")
|
||
print(f" Extracted input: {tool_input}")
|
||
|
||
# 格式8: {"action": "cron", "arguments": {...}} (只有 arguments)
|
||
elif 'action' in json_data and 'arguments' in json_data:
|
||
tool_name = json_data['action']
|
||
tool_input_str = json.dumps(json_data['arguments'])
|
||
thought = "Using tool based on JSON format"
|
||
skip_react = True
|
||
print(f" Parsed as action/arguments format: {tool_name}")
|
||
|
||
# 格式9: {"action": "cron", "schedule": "...", "command": "..."} (旧格式,需要转换)
|
||
elif 'action' in json_data and ('schedule' in json_data or 'command' in json_data):
|
||
tool_name = json_data['action']
|
||
# 转换为 OpenClaw 期望的格式
|
||
tool_input = {
|
||
"method": "add"
|
||
}
|
||
|
||
# 转换 schedule
|
||
if 'schedule' in json_data:
|
||
tool_input["schedule"] = {
|
||
"kind": "cron",
|
||
"expression": json_data['schedule']
|
||
}
|
||
else:
|
||
tool_input["schedule"] = {
|
||
"kind": "cron",
|
||
"expression": "0 14 * * 1-5"
|
||
}
|
||
|
||
# 转换 command 或 text
|
||
if 'command' in json_data:
|
||
tool_input["payload"] = {
|
||
"kind": "systemEvent",
|
||
"text": json_data['command']
|
||
}
|
||
else:
|
||
tool_input["payload"] = {
|
||
"kind": "systemEvent",
|
||
"text": "Reminder"
|
||
}
|
||
|
||
tool_input["sessionTarget"] = "main"
|
||
tool_input["enabled"] = True
|
||
|
||
tool_input_str = json.dumps(tool_input)
|
||
thought = "Using tool based on JSON format (converted to OpenClaw format)"
|
||
skip_react = True
|
||
print(f" Parsed and converted cron format: {tool_name}")
|
||
print(f" Converted input: {tool_input}")
|
||
|
||
# 格式10: {"web_search": {"query": "..."}} (单键格式)
|
||
elif len(json_data) == 1:
|
||
possible_tool = list(json_data.keys())[0]
|
||
if possible_tool in available_tools:
|
||
tool_name = possible_tool
|
||
tool_input_str = json.dumps(json_data[possible_tool])
|
||
thought = "Using tool based on JSON format"
|
||
skip_react = True
|
||
print(f" Parsed as single-key tool format: {tool_name}")
|
||
else:
|
||
# 不是工具调用格式,继续尝试 ReAct
|
||
print(" JSON is not a tool call format, trying ReAct...")
|
||
skip_react = False
|
||
|
||
except json.JSONDecodeError as e:
|
||
print(f" Not pure JSON: {e}")
|
||
# 格式11 Tool Call: web_search\n Arguments: {"query":"...."}
|
||
skip_react = False
|
||
if 'Tool Call' in text:
|
||
pattern = r'.*?Tool Call: (.*?)\nArguments: (.*)'
|
||
match = re.search(pattern, text)
|
||
|
||
if match:
|
||
skip_react = True
|
||
tool_call = match.group(1)
|
||
arguments = json.loads(match.group(2))
|
||
tool_name = tool_call
|
||
print(f"Tool Call: {tool_call}")
|
||
print(f"Arguments: {arguments}")
|
||
tool_input_str = json.dumps(arguments)
|
||
# skip_react = False
|
||
except Exception as e:
|
||
print(f" JSON parsing error: {e}")
|
||
skip_react = False
|
||
|
||
|
||
# ========== 2. 如果没有找到 JSON 格式,尝试 ReAct 格式 ==========
|
||
if not skip_react:
|
||
# ReAct 格式的正则表达式
|
||
patterns = [
|
||
# 标准格式
|
||
r'Thought:\s*(.*?)\nAction:\s*(\w+)\s*\nAction Input:\s*(\{.*\})',
|
||
# 没有 Thought
|
||
r'Action:\s*(\w+)\s*\nAction Input:\s*(\{.*\})',
|
||
# 同一行
|
||
r'Action:\s*(\w+)\s+Action Input:\s*(\{.*\})',
|
||
# 带冒号的 Action
|
||
r'Action:\s*(\w+).*?\nAction Input:\s*(\{.*\})',
|
||
]
|
||
|
||
for pattern in patterns:
|
||
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
|
||
if match:
|
||
if len(match.groups()) == 3:
|
||
thought = match.group(1).strip()
|
||
tool_name = match.group(2).strip()
|
||
tool_input_str = match.group(3).strip()
|
||
else:
|
||
tool_name = match.group(1).strip()
|
||
tool_input_str = match.group(2).strip()
|
||
print(f" Found ReAct format: Action={tool_name}")
|
||
break
|
||
|
||
if not tool_name:
|
||
print(" No tool call found in output")
|
||
return None
|
||
|
||
print(f" Found Action: {tool_name}")
|
||
print(f" Raw Action Input: {tool_input_str[:200]}")
|
||
|
||
# ========== 3. 清理和修复 JSON 字符串 ==========
|
||
def clean_json_string(json_str):
|
||
"""清理和修复 JSON 字符串"""
|
||
if not json_str:
|
||
return json_str
|
||
|
||
# 移除所有换行符和多余空格
|
||
json_str = ' '.join(json_str.split())
|
||
|
||
# 确保以 { 开头
|
||
json_str = json_str.strip()
|
||
if not json_str.startswith('{'):
|
||
start_idx = json_str.find('{')
|
||
if start_idx != -1:
|
||
json_str = json_str[start_idx:]
|
||
else:
|
||
return json_str
|
||
|
||
# 确保以 } 结尾
|
||
if not json_str.endswith('}'):
|
||
end_idx = json_str.rfind('}')
|
||
if end_idx != -1:
|
||
json_str = json_str[:end_idx + 1]
|
||
|
||
return json_str
|
||
|
||
# 清理 JSON 字符串
|
||
cleaned_input = clean_json_string(tool_input_str)
|
||
print(f" Cleaned Action Input: {cleaned_input[:200]}")
|
||
|
||
# ========== 4. 修复 JSON 格式 ==========
|
||
def fix_json_format(json_str):
|
||
"""修复 JSON 格式问题"""
|
||
if not json_str:
|
||
return json_str
|
||
|
||
# 修复缺少逗号的情况
|
||
json_str = re.sub(r'"\s+"', '", "', json_str)
|
||
|
||
# 修复对象内缺少逗号
|
||
json_str = re.sub(r'}\s+{', '},{', json_str)
|
||
|
||
# 修复键值对之间缺少逗号
|
||
json_str = re.sub(r'("[^"]+"\s*:\s*[^,}]+)\s+(")', r'\1,\2', json_str)
|
||
|
||
# 修复数组内缺少逗号
|
||
json_str = re.sub(r']\s+\[', '],[', json_str)
|
||
|
||
# 修复末尾多余逗号
|
||
json_str = re.sub(r',\s*}', '}', json_str)
|
||
json_str = re.sub(r',\s*]', ']', json_str)
|
||
|
||
# 修复单引号
|
||
in_string = False
|
||
result = []
|
||
for i, ch in enumerate(json_str):
|
||
if ch == '"' and (i == 0 or json_str[i-1] != '\\'):
|
||
in_string = not in_string
|
||
if not in_string and ch == "'":
|
||
result.append('"')
|
||
else:
|
||
result.append(ch)
|
||
json_str = ''.join(result)
|
||
|
||
# 修复布尔值
|
||
json_str = json_str.replace('True', 'true').replace('False', 'false')
|
||
|
||
# 修复 None
|
||
json_str = json_str.replace('None', 'null')
|
||
|
||
return json_str
|
||
|
||
# ========== 5. 工具名称映射 ==========
|
||
mapping = {
|
||
"add": "cron",
|
||
"create_task": "cron",
|
||
"schedule": "cron",
|
||
"remind": "cron",
|
||
"set_reminder": "cron",
|
||
}
|
||
|
||
# 应用映射
|
||
if tool_name in mapping and mapping[tool_name] in available_tools:
|
||
original_name = tool_name
|
||
tool_name = mapping[tool_name]
|
||
print(f" Tool name mapped: {original_name} -> {tool_name}")
|
||
|
||
# ========== 6. 解析 JSON ==========
|
||
tool_arguments = None
|
||
|
||
# 方法1:直接解析清理后的 JSON
|
||
try:
|
||
tool_arguments = json.loads(cleaned_input)
|
||
print(f" ✅ JSON parsed successfully (method 1)")
|
||
except json.JSONDecodeError as e1:
|
||
print(f" Method 1 failed: {e1}")
|
||
|
||
# 方法2:修复格式后解析
|
||
try:
|
||
fixed_json = fix_json_format(cleaned_input)
|
||
tool_arguments = json.loads(fixed_json)
|
||
print(f" ✅ JSON parsed after format fix (method 2)")
|
||
except json.JSONDecodeError as e2:
|
||
print(f" Method 2 failed: {e2}")
|
||
|
||
# 方法3:使用正则提取键值对
|
||
try:
|
||
pattern = r'"(\w+)"\s*:\s*"([^"]*)"'
|
||
matches = re.findall(pattern, cleaned_input)
|
||
if matches:
|
||
tool_arguments = {}
|
||
for key, value in matches:
|
||
tool_arguments[key] = value
|
||
print(f" ✅ JSON parsed via regex (method 3)")
|
||
else:
|
||
raise Exception("No key-value pairs found")
|
||
except Exception as e3:
|
||
print(f" Method 3 failed: {e3}")
|
||
|
||
# 方法4:尝试 ast.literal_eval
|
||
try:
|
||
import ast
|
||
py_str = cleaned_input.replace('null', 'None').replace('true', 'True').replace('false', 'False')
|
||
tool_arguments = ast.literal_eval(py_str)
|
||
print(f" ✅ JSON parsed via ast.literal_eval (method 4)")
|
||
except Exception as e4:
|
||
print(f" Method 4 failed: {e4}")
|
||
print(f" Raw input: {cleaned_input[:500]}")
|
||
return None
|
||
|
||
if tool_arguments is None:
|
||
return None
|
||
|
||
# ========== 7. 检测循环调用 ==========
|
||
def detect_loop(tool_name, tool_arguments, messages):
|
||
"""检测是否在重复调用同一个工具/技能"""
|
||
if not messages:
|
||
return False
|
||
|
||
# 构建当前调用的标识
|
||
current_call = {
|
||
'name': tool_name,
|
||
'args': json.dumps(tool_arguments, sort_keys=True) if tool_arguments else ''
|
||
}
|
||
|
||
# 检查最近的消息中是否有相同的调用
|
||
recent_calls = []
|
||
for msg in reversed(messages[-6:]):
|
||
if msg.get('role') == 'assistant' and msg.get('tool_calls'):
|
||
for tc in msg.get('tool_calls', []):
|
||
tc_name = tc.get('function', {}).get('name')
|
||
if tc_name == 'invoke_skill':
|
||
try:
|
||
tc_args = json.loads(tc.get('function', {}).get('arguments', '{}'))
|
||
skill_name = tc_args.get('skill')
|
||
if skill_name:
|
||
recent_calls.append({
|
||
'name': 'invoke_skill',
|
||
'skill': skill_name
|
||
})
|
||
except:
|
||
pass
|
||
else:
|
||
recent_calls.append({
|
||
'name': tc_name,
|
||
'args': tc.get('function', {}).get('arguments', '')
|
||
})
|
||
|
||
# 对于 invoke_skill,检查技能名是否重复
|
||
if tool_name == 'invoke_skill':
|
||
current_skill = tool_arguments.get('skill')
|
||
if current_skill:
|
||
same_skill_count = 0
|
||
for call in recent_calls[-3:]:
|
||
if call.get('name') == 'invoke_skill' and call.get('skill') == current_skill:
|
||
same_skill_count += 1
|
||
|
||
if same_skill_count >= 2:
|
||
print(f" ⚠️ Loop detected: Skill '{current_skill}' called {same_skill_count + 1} times")
|
||
return True
|
||
|
||
# 对于普通工具
|
||
else:
|
||
same_tool_count = 0
|
||
for call in recent_calls[-3:]:
|
||
if call.get('name') == tool_name and call.get('args') == current_call['args']:
|
||
same_tool_count += 1
|
||
|
||
if same_tool_count >= 2:
|
||
print(f" ⚠️ Loop detected: Tool '{tool_name}' called {same_tool_count + 1} times")
|
||
return True
|
||
|
||
return False
|
||
|
||
# 检查循环
|
||
if messages and detect_loop(tool_name, tool_arguments, messages):
|
||
print(" ❌ Preventing loop: Returning None to force text response")
|
||
return None
|
||
|
||
# ========== 8. 检查是否是技能调用 ==========
|
||
|
||
installed_skills = []
|
||
for skill in skill_manager:
|
||
installed_skills.append(skill['name'])
|
||
|
||
print(f" Installed skills: {installed_skills}")
|
||
|
||
# 判断是否为技能调用
|
||
is_skill_call = False
|
||
skill_name = None
|
||
skill_params = None
|
||
|
||
if tool_name == 'invoke_skill':
|
||
is_skill_call = True
|
||
skill_name = tool_arguments.get('skill')
|
||
skill_params = tool_arguments.get('input', {})
|
||
elif tool_name in installed_skills:
|
||
is_skill_call = True
|
||
skill_name = tool_name
|
||
skill_params = tool_arguments
|
||
elif 'skill' in tool_arguments:
|
||
is_skill_call = True
|
||
skill_name = tool_arguments.get('skill')
|
||
skill_params = tool_arguments.get('input', {})
|
||
|
||
if is_skill_call and skill_name:
|
||
# 验证技能是否已安装
|
||
if skill_name not in installed_skills:
|
||
print(f" ⚠️ Warning: Skill '{skill_name}' not in installed skills list")
|
||
print(f" Available skills: {installed_skills}")
|
||
|
||
print(f"\n✅ Skill call detected:")
|
||
print(f" Skill: {skill_name}")
|
||
print(f" Params: {skill_params}")
|
||
|
||
return {
|
||
"tool_calls": [{
|
||
"id": f"call_{uuid.uuid4().hex}",
|
||
"type": "function",
|
||
"function": {
|
||
"name": "invoke_skill",
|
||
"arguments": json.dumps({
|
||
"skill": skill_name,
|
||
"input": skill_params
|
||
}, ensure_ascii=False)
|
||
}
|
||
}],
|
||
"content": thought
|
||
}
|
||
|
||
# ========== 9. 处理内置工具 ==========
|
||
|
||
|
||
if tool_name in available_tools:
|
||
print(f"\n✅ Tool call detected:")
|
||
print(f" Tool: {tool_name}")
|
||
print(f" Args: {tool_arguments}")
|
||
|
||
return {
|
||
"tool_calls": [{
|
||
"id": f"call_{uuid.uuid4().hex}",
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool_name,
|
||
"arguments": json.dumps(tool_arguments, ensure_ascii=False)
|
||
}
|
||
}],
|
||
"content": thought
|
||
}
|
||
|
||
print(f"\n⚠️ Unknown tool/skill: {tool_name}")
|
||
return None
|
||
|
||
def detect_tool_call_in_stream(text):
|
||
"""
|
||
检测流式文本中是否包含工具调用(用于提前判断)
|
||
"""
|
||
if "Action:" in text and "Action Input:" in text:
|
||
return True
|
||
return False
|
||
|
||
|
||
# ==================== 响应构建函数 ====================
|
||
|
||
def build_normal_response(completion_id, model, content, finish_reason="stop"):
|
||
"""
|
||
构建普通文本响应
|
||
"""
|
||
return {
|
||
"id": completion_id,
|
||
"object": "chat.completion",
|
||
"created": int(time.time()),
|
||
"model": model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": content
|
||
},
|
||
"finish_reason": finish_reason
|
||
}]
|
||
}
|
||
|
||
def build_tool_call_response(completion_id, model, tool_calls, content=None):
|
||
"""
|
||
构建工具调用响应
|
||
"""
|
||
message = {
|
||
"role": "assistant",
|
||
"content": content
|
||
}
|
||
|
||
if tool_calls:
|
||
message["tool_calls"] = tool_calls
|
||
|
||
return {
|
||
"id": completion_id,
|
||
"object": "chat.completion",
|
||
"created": int(time.time()),
|
||
"model": model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"message": message,
|
||
"finish_reason": "tool_calls" if tool_calls else "stop"
|
||
}]
|
||
}
|
||
|
||
def build_stream_chunk(completion_id, content=None, role=None, finish_reason=None):
|
||
"""
|
||
构建流式响应块
|
||
"""
|
||
chunk = {
|
||
"id": completion_id,
|
||
"object": "chat.completion.chunk",
|
||
"created": int(time.time()),
|
||
"model": "rkllm-model",
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {},
|
||
"finish_reason": finish_reason
|
||
}]
|
||
}
|
||
|
||
if role:
|
||
chunk["choices"][0]["delta"]["role"] = role
|
||
|
||
if content:
|
||
chunk["choices"][0]["delta"]["content"] = content
|
||
|
||
return chunk
|
||
|
||
# ==================== 系统提示构建函数 ====================
|
||
def build_system_prompt(tools, skill_manager):
|
||
"""根据工具列表和已安装的 Skills 构建系统提示"""
|
||
|
||
tool_names = []
|
||
tool_descriptions = []
|
||
|
||
for tool in tools:
|
||
if tool.get('type') == 'function':
|
||
func = tool['function']
|
||
name = func['name']
|
||
tool_names.append(name)
|
||
desc = func.get('description', '')
|
||
# 限制描述长度
|
||
if len(desc) > 200:
|
||
desc = desc[:200] + "..."
|
||
tool_descriptions.append(f"- {name}: {desc}")
|
||
|
||
|
||
|
||
# 获取已安装的 Skills(使用文件系统检查)
|
||
|
||
# 构建技能说明
|
||
|
||
skill_list = []
|
||
for skill in skill_manager:
|
||
skill_list.append(skill['name'])
|
||
|
||
description = skill['description']
|
||
if len(description) > 200:
|
||
description = description[:200] + '...'
|
||
|
||
tool_descriptions.append(f"- {skill['name']}: {description}")
|
||
|
||
skills_instruction = f"""
|
||
**How to use installed skills:**
|
||
You have access to the following installed skills:
|
||
{chr(10).join(skill_list)}
|
||
|
||
To use a skill, you MUST use the 'invoke_skill' tool with this exact format:
|
||
Action: invoke_skill
|
||
Action Input: {{"skill": "skill_name", "input": {{"action": "action_name", "params": {{...}}}}}}
|
||
|
||
For example:
|
||
- To use Agent Browser: Action: invoke_skill, Input: {{"skill": "agent-browser", "input": {{"action": "navigate", "url": "https://example.com"}}}}
|
||
- To use akshare-stock: Action: invoke_skill, Input: {{"skill": "akshare-stock", "input": {{"action": "get_price", "symbol": "000001"}}}}
|
||
"""
|
||
else:
|
||
skills_instruction = """
|
||
**Installed skills:** None found. Skills will be detected automatically when installed.
|
||
"""
|
||
|
||
# 添加 cron 使用说明
|
||
cron_instruction = ""
|
||
if 'cron' in tool_names:
|
||
cron_instruction = """
|
||
**For scheduling reminders/tasks:**
|
||
- Use the 'cron' tool to create scheduled tasks
|
||
- Schedule format options:
|
||
* Every X milliseconds: {"kind": "every", "everyMs": 86400000}
|
||
* Cron expression: {"kind": "cron", "expression": "30 8 * * *"}
|
||
- Example for daily reminder at 8:30 AM:
|
||
Action: cron
|
||
Action Input: {"schedule": {"kind": "cron", "expression": "30 8 * * *"}, "payload": {"kind": "systemEvent", "text": "Your reminder message"}, "sessionTarget": "main", "enabled": true}
|
||
"""
|
||
|
||
tool_instruction = """
|
||
**CRITICAL RULES:**
|
||
"When you need to use a tool, you MUST output ONLY the tool call in JSON format.\n"
|
||
"Do NOT add any explanatory text before or after the JSON.\n"
|
||
"Do NOT tell the user how to call tools. Just call them yourself.\n\n"
|
||
"Example of correct output:\n"
|
||
'{"action": "web_search", "query": "上海天气"}\n\n'
|
||
"Example of INCORRECT output (do NOT do this):\n"
|
||
"I will now call the web_search tool...\n"
|
||
'{"action": "web_search", "query": "上海天气"}\n'
|
||
"That's how you do it...\n\n"
|
||
"If you have the information and can answer directly, output your answer in plain text.\n"
|
||
"But for any real-time information, ALWAYS output ONLY the JSON tool call."
|
||
"""
|
||
|
||
system_content = (
|
||
"You are an AI assistant that can use tools to help users.\n\n"
|
||
f"**AVAILABLE TOOLS:**\n{chr(10).join(tool_descriptions)}\n\n"
|
||
f"{tool_instruction}\n"
|
||
f"{skills_instruction}\n"
|
||
"**Response Format:**\n"
|
||
"Thought: [your reasoning about what to do]\n"
|
||
"Action: [tool_name]\n"
|
||
"Action Input: {\"param\": \"value\"}\n\n"
|
||
"After receiving tool results, provide your final answer.\n\n"
|
||
|
||
)
|
||
# "Remember: For any query requiring current information, ALWAYS use web_search first!"
|
||
# 限制系统提示长度
|
||
if len(system_content) > 4000:
|
||
system_content = system_content[:4000] + "\n...(更多工具信息已省略)..."
|
||
|
||
return system_content
|
||
|
||
# ==================== 对话历史构建函数 ====================
|
||
def build_conversation_prompt(messages, system_prompt):
|
||
"""
|
||
构建完整的对话 prompt(使用 Qwen 格式)
|
||
"""
|
||
input_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
|
||
|
||
for msg in messages:
|
||
role = msg.get('role')
|
||
|
||
if role == 'system':
|
||
continue
|
||
|
||
content = safe_get_content(msg)
|
||
|
||
if role == 'assistant':
|
||
tool_calls = msg.get('tool_calls', [])
|
||
if tool_calls:
|
||
# 模型的工具调用请求
|
||
for tc in tool_calls:
|
||
input_prompt += f"<|im_start|>assistant\nTool Call: {tc['function']['name']}\nArguments: {tc['function']['arguments']}<|im_end|>\n"
|
||
else:
|
||
input_prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
|
||
|
||
elif role == 'tool':
|
||
# 【关键】工具执行结果 - 需要明确告诉模型这是结果
|
||
tool_call_id = msg.get('tool_call_id', 'unknown')
|
||
# 添加明确的指令让模型基于结果回答
|
||
input_prompt += f"<|im_start|>user\nTool Result ({tool_call_id}):\n{content}\n\nBased on this tool result, please provide a helpful answer to the user. Do NOT call the same tool again unless necessary.<|im_end|>\n"
|
||
print(f" Added tool result to prompt: {content[:100]}...")
|
||
|
||
else: # user
|
||
input_prompt += f"<|im_start|>user\n{content}<|im_end|>\n"
|
||
|
||
# 添加 assistant 开始标记
|
||
input_prompt += "<|im_start|>assistant\n"
|
||
|
||
return input_prompt
|
||
|
||
# ==================== 错误响应函数 ====================
|
||
|
||
def error_response(message, code=400):
|
||
"""
|
||
构建错误响应
|
||
"""
|
||
return jsonify({
|
||
'error': {
|
||
'message': message,
|
||
'code': code
|
||
}
|
||
}), code
|
||
|
||
# 假设这些是你的全局配置
|
||
# SAFE_CHAR_LIMIT = 8000 # 针对 RK3588 16k context 的安全字符上限
|
||
# MAX_OUTPUT_CHARS = 100000
|
||
# MAX_GENERATION_TIME = 6000
|
||
|
||
# ==================== 主路由 ====================
|
||
@app.route('/v1/chat/completions', methods=['POST'])
|
||
@app.route('/chat/completions', methods=['POST'])
|
||
def chat_completions():
|
||
global global_text, is_blocking, global_state
|
||
|
||
if is_blocking:
|
||
return jsonify({'error': {'message': 'System Busy', 'code': 503}}), 503
|
||
|
||
data = request.json
|
||
if not data:
|
||
return jsonify({'error': 'Invalid JSON data'}), 400
|
||
|
||
messages = data.get('messages', [])
|
||
model = data.get('model', 'rkllm-model')
|
||
stream = data.get('stream', False)
|
||
tools = data.get('tools', [])
|
||
|
||
# 初始化管理器
|
||
conv_manager = ConversationManager(max_context_tokens=MAX_CONTEXT_TOKENS, reserve_tokens=2000)
|
||
parser = OpenClawSkillParser()
|
||
|
||
|
||
# 打印请求信息
|
||
print(f"\n{'='*80}")
|
||
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] New request")
|
||
print(f"Stream: {stream}, Tools: {len(tools)}")
|
||
|
||
# 获取可用工具名
|
||
available_tools = []
|
||
for tool in tools:
|
||
if tool.get('type') == 'function':
|
||
available_tools.append(tool['function']['name'])
|
||
print(f"Available tools: {available_tools}")
|
||
|
||
# 获取已安装的 Skills
|
||
ready_skills = parser.get_ready_skills()
|
||
installed_skills = []
|
||
for skill in ready_skills:
|
||
installed_skills.append(skill['name'])
|
||
|
||
print(installed_skills)
|
||
if installed_skills:
|
||
print(f"Installed skills: {installed_skills}")
|
||
|
||
# ========== 构建系统提示 ==========
|
||
system_prompt = build_system_prompt(tools, ready_skills)
|
||
|
||
# ========== 构建输入 Prompt(带 token 管理) ==========
|
||
input_prompt = conv_manager.build_prompt(messages, system_prompt, safe_get_content)
|
||
|
||
print(f"\n--- Input Prompt (first 300 chars) ---")
|
||
print(input_prompt[:300])
|
||
print(f"Total prompt length: {len(input_prompt)} chars")
|
||
|
||
# ========== 执行推理 ==========
|
||
lock.acquire()
|
||
is_blocking = True
|
||
|
||
try:
|
||
# 清理之前的状态
|
||
try:
|
||
rkllm_model.abort()
|
||
except:
|
||
pass
|
||
global_text.clear()
|
||
global_state = -1
|
||
|
||
# 启动推理线程
|
||
model_thread = threading.Thread(
|
||
target=rkllm_model.run,
|
||
args=(None, False, input_prompt)
|
||
)
|
||
model_thread.daemon = True
|
||
model_thread.start()
|
||
|
||
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||
|
||
# ========== 收集响应 ==========
|
||
accumulated_text = ""
|
||
start_time = time.time()
|
||
|
||
print(f"\n--- Model output ---")
|
||
|
||
while model_thread.is_alive() or len(global_text) > 0:
|
||
if time.time() - start_time > MAX_GENERATION_TIME:
|
||
print(f"\nTimeout after {MAX_GENERATION_TIME}s")
|
||
try:
|
||
rkllm_model.abort()
|
||
except:
|
||
pass
|
||
break
|
||
|
||
if global_text:
|
||
chunk = global_text.popleft()
|
||
accumulated_text += chunk
|
||
print(chunk, end='', flush=True)
|
||
else:
|
||
time.sleep(0.05)
|
||
|
||
print(f"\n\n--- Generation complete ---")
|
||
print(f"Output length: {len(accumulated_text)} chars")
|
||
|
||
|
||
parser = OpenClawSkillParser()
|
||
|
||
# 获取就绪技能
|
||
ready_skills = parser.get_ready_skills()
|
||
skill_list = []
|
||
for skill in ready_skills:
|
||
skill_list.append(skill['name'])
|
||
|
||
# ========== 解析响应 ==========
|
||
parsed = parse_react_to_tool_call(accumulated_text, available_tools, ready_skills)
|
||
|
||
# 检查是否模型输出了说明文字而不是工具调用
|
||
if "以下是修正后的正确调用方式" in accumulated_text or "正确的工具调用应包含" in accumulated_text:
|
||
print(" ⚠️ Model output description instead of tool call. Forcing tool usage...")
|
||
|
||
# 添加强制指令,重新生成
|
||
force_prompt = input_prompt + "\n\nIMPORTANT: Do NOT explain how to call tools. Output ONLY the JSON tool call itself.\n"
|
||
|
||
# 重新生成
|
||
global_text.clear()
|
||
model_thread = threading.Thread(target=rkllm_model.run, args=(None, False, force_prompt))
|
||
model_thread.start()
|
||
|
||
# ========== 构建响应 ==========
|
||
if stream:
|
||
# 流式响应
|
||
def generate():
|
||
# 发送角色标识
|
||
yield f"data: {json.dumps({'id': completion_id, 'choices': [{'delta': {'role': 'assistant'}, 'index': 0}]})}\n\n"
|
||
|
||
if parsed and parsed["tool_calls"]:
|
||
# 有工具调用,发送工具调用信息
|
||
tool_call = parsed["tool_calls"][0]
|
||
yield f"data: {json.dumps({'id': completion_id, 'choices': [{'delta': {'tool_calls': [tool_call]}, 'index': 0}]})}\n\n"
|
||
else:
|
||
# 普通文本,逐字符发送
|
||
for char in accumulated_text:
|
||
yield f"data: {json.dumps({'id': completion_id, 'choices': [{'delta': {'content': char}, 'index': 0}]})}\n\n"
|
||
time.sleep(0.01)
|
||
|
||
# 发送结束标记
|
||
yield "data: [DONE]\n\n"
|
||
|
||
return Response(stream_with_context(generate()), mimetype='text/event-stream')
|
||
|
||
else:
|
||
# 非流式响应
|
||
if parsed and parsed["tool_calls"]:
|
||
response_data = {
|
||
"id": completion_id,
|
||
"object": "chat.completion",
|
||
"created": int(time.time()),
|
||
"model": model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": parsed["content"] if parsed["content"] else None,
|
||
"tool_calls": parsed["tool_calls"]
|
||
},
|
||
"finish_reason": "tool_calls"
|
||
}]
|
||
}
|
||
print(f"\n--- Returning tool call response ---")
|
||
print(f"Tool: {parsed['tool_calls'][0]['function']['name']}")
|
||
else:
|
||
final_content = accumulated_text.strip()
|
||
if not final_content:
|
||
final_content = "抱歉,我暂时无法回答这个问题。"
|
||
response_data = {
|
||
"id": completion_id,
|
||
"object": "chat.completion",
|
||
"created": int(time.time()),
|
||
"model": model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": final_content
|
||
},
|
||
"finish_reason": "stop"
|
||
}]
|
||
}
|
||
print(f"\n--- Returning normal response ---")
|
||
print(f"Content: {final_content[:200]}...")
|
||
|
||
return jsonify(response_data)
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
print(f"\n[ERROR] Exception in chat_completions:")
|
||
traceback.print_exc()
|
||
return jsonify({'error': str(e)}), 500
|
||
finally:
|
||
time.sleep(0.1)
|
||
is_blocking = False
|
||
lock.release()
|
||
|
||
@app.route('/health', methods=['GET'])
|
||
def health_check():
|
||
"""健康检查端点"""
|
||
return jsonify({"status": "ok", "model": "rkllm-qwen3-8b"}), 200
|
||
|
||
@app.route('/v1/models', methods=['GET'])
|
||
def list_models():
|
||
"""让 OpenClaw 认为这个模型支持超大上下文"""
|
||
models = [
|
||
{
|
||
"id": "rkllm-model",
|
||
"object": "model",
|
||
"created": int(time.time()),
|
||
"owned_by": "rockchip",
|
||
"context_window": 32768, # 欺骗:告诉 OpenClaw 你支持 32k
|
||
"permission": []
|
||
}
|
||
]
|
||
return jsonify({"object": "list", "data": models})
|
||
|
||
@app.route('/rkllm_chat', methods=['POST'])
|
||
def receive_message():
|
||
return chat_completions()
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument('--rkllm_model_path', type=str, required=True)
|
||
parser.add_argument('--target_platform', type=str, required=True)
|
||
parser.add_argument('--lora_model_path', type=str)
|
||
parser.add_argument('--prompt_cache_path', type=str)
|
||
parser.add_argument('--port', type=int, default=8080)
|
||
parser.add_argument('--host', type=str, default='0.0.0.0')
|
||
args = parser.parse_args()
|
||
|
||
if not os.path.exists(args.rkllm_model_path):
|
||
print("Error: Model path not found")
|
||
sys.exit(1)
|
||
|
||
# Fix frequency
|
||
command = f"sudo bash fix_freq_{args.target_platform}.sh"
|
||
subprocess.run(command, shell=True)
|
||
|
||
# Set resource limit
|
||
resource.setrlimit(resource.RLIMIT_NOFILE, (102400, 102400))
|
||
|
||
# Initialize RKLLM model
|
||
print("========= Initializing RKLLM ==========")
|
||
rkllm_model = RKLLM(args.rkllm_model_path, args.lora_model_path, args.prompt_cache_path, args.target_platform)
|
||
print("=======================================")
|
||
|
||
print(f"Server starting on {args.host}:{args.port}")
|
||
print("WARNING: Only sending the latest user message!")
|
||
app.run(host=args.host, port=args.port, threaded=True, debug=False)
|
||
|
||
print("Releasing RKLLM model...")
|
||
try:
|
||
rkllm_model.release()
|
||
except:
|
||
pass
|