feat: 新增文档源和技能管理相关功能

1. 新增文档源管理模块(documentSource)
   - 控制器:documentSourceController.py
   - DAO层:documentSourceDao.py
   - 模型:documentSourceModel.py
   - 服务层:documentSourceService.py

2. 新增技能管理模块(skill)
   - 控制器:skillController.py
   - DAO层:skillDao.py
   - 模型:skillModel.py
   - 服务层:skillService.py

3. 新增AI服务(aiService.py)

4. 新增配置文件
   - AI配置:config/ai_config.py
   - 技能配置:config/skills/test-case-generator/

5. 新增SQL脚本
   - 文档权限:add_document_permissions.sql
   - 模块状态字段:add_module_status_field.sql
   - 文档源表:create_document_source_table.sql
   - 技能规则:skills_rules_pgsql.sql
This commit is contained in:
qiaoxinjiu
2026-05-18 10:23:07 +08:00
parent 65524de6fc
commit 420b9e37fa
38 changed files with 9613 additions and 0 deletions

View File

@@ -0,0 +1,266 @@
# encoding: UTF-8
import os
import re
import uuid
from datetime import datetime
from flask import current_app, g
from .baseCrudController import BaseCrudController
from ..model.documentSourceModel import DocumentSource
from ..model.productModel import Product
from ..model.projectModel import Project
from ..service.documentSourceService import DocumentSourceService
class DocumentSourceController(BaseCrudController):
UPLOAD_FOLDER = 'uploads'
ALLOWED_EXTENSIONS = {'pdf'}
def allowed_file(self, filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in self.ALLOWED_EXTENSIONS
def document_list(self):
items, total = DocumentSourceService.list(self.session, self.req_data)
return {'list': self.serialize_list(items, ['is_delete']), 'total': total}
def document_detail(self):
document_id = self._get(self.req_data, 'documentId', 'id')
if not document_id:
return {}, 'documentId 为必传参数'
item = DocumentSourceService.get_by_id(self.session, document_id)
if not item:
return {}, '未查询到对应文档!'
return self.serialize(item, ['is_delete']), ''
def document_create(self):
product_id = self._get(self.req_data, 'productId', 'product_id')
project_id = self._get(self.req_data, 'projectId', 'project_id')
source = self._get(self.req_data, 'source')
if not product_id or not project_id or not source:
return 0, 'productId、projectId、source 为必传参数'
data = {
'product_id': product_id,
'project_id': project_id,
'source': source,
'type': self._get(self.req_data, 'type', default=1),
'content': self._get(self.req_data, 'content', default=''),
'created_by': self._get(self.req_data, 'createdBy', 'created_by')
}
return DocumentSourceService.create(self.session, data)
def document_update(self):
document_id = self._get(self.req_data, 'documentId', 'id')
if not document_id:
return 0, 'documentId 为必传参数'
data = {}
fields = ['type', 'source', 'content', 'ai_model']
for field in fields:
value = self._get(self.req_data, field)
if value is not None:
data[field] = value
return DocumentSourceService.update(self.session, document_id, data)
def document_delete(self):
document_id = self._get(self.req_data, 'documentId', 'id')
if not document_id:
return 0, 'documentId 为必传参数'
result, msg = DocumentSourceService.delete(self.session, document_id)
if msg:
return 0, msg
err = self.session.done(close=False)
if err:
return 0, f'删除失败!{err}'
return result, ''
def document_refresh(self):
document_id = self._get(self.req_data, 'documentId', 'id')
if not document_id:
return False, 'documentId 为必传参数'
return DocumentSourceService.refresh_content(self.session, document_id)
def document_generate_cases(self):
# 支持单个文档ID或多个文档ID
document_id = self._get(self.req_data, 'documentId', 'id')
document_ids = self._get(self.req_data, 'documentIds', 'document_ids', default=[])
# 如果传了单个ID转换为列表
if document_id:
document_ids = [document_id]
if not document_ids or not isinstance(document_ids, list) or len(document_ids) == 0:
return [], 'documentId 或 documentIds 为必传参数'
project_id = self._get(self.req_data, 'projectId', 'project_id')
user_id = getattr(g, 'current_user_id', None) or self._get(self.req_data, 'userId', 'user_id')
if not project_id:
return [], 'projectId 为必传参数'
if not user_id:
return [], '未获取到当前登录用户'
template = {
'project_id': int(project_id),
'priority': int(self._get(self.req_data, 'priority', default=2)),
'case_type': int(self._get(self.req_data, 'caseType', 'case_type', default=1)),
'tags': self._get(self.req_data, 'tags', default=['AI生成']),
'skill_ids': self._get(self.req_data, 'skillIds', 'skill_ids', default=[]),
'rule_ids': self._get(self.req_data, 'ruleIds', 'rule_ids', default=[])
}
if isinstance(template['tags'], str):
template['tags'] = template['tags'].split(',')
# 批量生成测试用例(合并多个文档内容)
all_cases, failed_docs = DocumentSourceService.generate_cases_batch(
self.session, document_ids, template
)
if failed_docs:
return {'cases': [], 'total': 0, 'failed': failed_docs}, ''
# 直接导入到用例表,自动创建不存在的模块
success_count, msg = DocumentSourceService.import_cases(
self.session,
document_ids[0], # 使用第一个文档ID作为关联
all_cases,
user_id,
auto_create_module=True # 自动创建模块
)
if msg:
return {'cases': all_cases, 'total': len(all_cases), 'failed': [{'error': msg}]}, ''
# 提交事务
self.session.commit()
return {
'cases': all_cases,
'total': len(all_cases),
'importedCount': success_count,
'failed': []
}, ''
def document_match_modules(self):
document_id = self._get(self.req_data, 'documentId', 'id')
cases = self._get(self.req_data, 'cases', default=[])
if not document_id:
return [], 'documentId 为必传参数'
document = DocumentSourceService.get_by_id(self.session, document_id)
if not document:
return [], '文档不存在'
return DocumentSourceService.match_modules(self.session, document.project_id, cases), ''
def document_import_cases(self):
document_id = self._get(self.req_data, 'documentId', 'id')
cases = self._get(self.req_data, 'cases', default=[])
user_id = self._get(self.req_data, 'userId', 'user_id')
if not document_id:
return 0, 'documentId 为必传参数'
if not isinstance(cases, list):
return 0, 'cases 必须为数组'
return DocumentSourceService.import_cases(self.session, document_id, cases, user_id)
def document_batch_create_modules(self):
project_id = self._get(self.req_data, 'projectId', 'project_id')
module_names = self._get(self.req_data, 'moduleNames', 'module_names', default=[])
if not project_id:
return [], 'projectId 为必传参数'
if not isinstance(module_names, list):
return [], 'moduleNames 必须为数组'
modules = DocumentSourceService.batch_create_modules(self.session, project_id, module_names)
return self.serialize_list(modules, ['is_delete']), ''
def document_upload(self):
if 'file' not in self.req_data.files:
return None, '未找到上传文件'
file = self.req_data.files['file']
if file.filename == '':
return None, '文件名不能为空'
if not self.allowed_file(file.filename):
return None, '不支持的文件格式仅支持pdf'
# 文件上传使用 form 表单获取参数
product_id = self.req_data.form.get('productId')
project_id = self.req_data.form.get('projectId')
created_by = self.req_data.form.get('createdBy')
if not product_id or not project_id:
return None, 'productId、projectId 为必传参数'
# 获取产品和项目名称
product = self.session.query(Product).filter(Product.id == int(product_id), Product.is_delete == 0).first()
if not product:
return None, '产品不存在'
project = self.session.query(Project).filter(Project.id == int(project_id), Project.is_delete == 0).first()
if not project:
return None, '项目不存在'
try:
# 创建文件夹结构uploads/{产品名称}/{项目名称}
base_upload_path = os.path.join(os.getcwd(), self.UPLOAD_FOLDER)
product_folder = os.path.join(base_upload_path, product.name)
project_folder = os.path.join(product_folder, project.name)
os.makedirs(project_folder, exist_ok=True)
# 获取原始文件扩展名
ext = file.filename.rsplit('.', 1)[1].lower()
# 生成安全的文件名(保留原始文件名的主要部分,替换特殊字符)
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
# 从原始文件名中提取主名称(不包含扩展名)
original_name = file.filename.rsplit('.', 1)[0]
# 替换特殊字符为下划线,但保留中文字符
safe_name = re.sub(r'[^\w\u4e00-\u9fa5-]', '_', original_name)
# 限制文件名长度,避免过长
safe_name = safe_name[:50] if len(safe_name) > 50 else safe_name
# 组合文件名
new_filename = f'{timestamp}-{safe_name}-{uuid.uuid4().hex[:8]}.{ext}'
# 保存文件
file_path = os.path.join(project_folder, new_filename)
file.save(file_path)
# 计算相对路径用于数据库存储
relative_path = os.path.join(self.UPLOAD_FOLDER, product.name, project.name, new_filename)
# 转换为统一的路径格式
relative_path = relative_path.replace('\\', '/')
# 创建文档源记录
data = {
'product_id': product_id,
'project_id': project_id,
'source': relative_path,
'type': 1,
'content': '',
'created_by': created_by
}
document_id, msg = DocumentSourceService.create(self.session, data)
if msg:
return None, msg
# 提交事务
self.session.commit()
return {'documentId': document_id, 'filePath': relative_path}, ''
except Exception as e:
self.session.rollback()
return None, f'文件上传失败:{str(e)}'

View File

@@ -0,0 +1,46 @@
# encoding: UTF-8
from flask import g
from .baseCrudController import BaseCrudController
from ..service.skillService import SkillService
class SkillController(BaseCrudController):
def skill_create(self):
return SkillService.create_skill(self.session, self.req_data, getattr(g, 'current_user_id', None))
def skill_update(self):
return SkillService.update_skill(self.session, self.req_data)
def skill_delete(self):
return SkillService.delete_skill(self.session, self.req_data)
def skill_detail(self):
skill_id = self._get(self.req_data, 'skillId', 'id')
if not skill_id:
return {}, 'skillId 为必传参数'
return SkillService.skill_detail(self.session, skill_id)
def skill_list(self):
return SkillService.skill_list(self.session, self.req_data)
def skill_rule_list(self):
return SkillService.skill_rule_list(self.session, self.req_data)
def business_rule_create(self):
return SkillService.create_business_rule(self.session, self.req_data, getattr(g, 'current_user_id', None))
def business_rule_update(self):
return SkillService.update_business_rule(self.session, self.req_data)
def business_rule_delete(self):
return SkillService.delete_business_rule(self.session, self.req_data)
def business_rule_detail(self):
rule_id = self._get(self.req_data, 'ruleId', 'id')
if not rule_id:
return {}, 'ruleId 为必传参数'
return SkillService.business_rule_detail(self.session, rule_id)
def business_rule_list(self):
return SkillService.business_rule_list(self.session, self.req_data)

View File

@@ -0,0 +1,75 @@
# encoding: UTF-8
from sqlalchemy import func
from ..model.documentSourceModel import DocumentSource
class DocumentSourceDao:
@staticmethod
def create(session, document_source):
session.add(document_source)
session.flush()
return document_source.id
@staticmethod
def get_by_id(session, document_id):
return session.query(DocumentSource).filter(
DocumentSource.id == document_id,
DocumentSource.is_delete == 0
).first()
@staticmethod
def get_by_source(session, source):
return session.query(DocumentSource).filter(
DocumentSource.source == source,
DocumentSource.is_delete == 0
).first()
@staticmethod
def list_by_filters(session, filters, page_no=1, page_size=20, order_by=None):
query = session.query(DocumentSource).filter(*filters)
if order_by is not None:
query = query.order_by(order_by)
total = query.count()
items = query.offset((page_no - 1) * page_size).limit(page_size).all()
return items, total
@staticmethod
def update_by_id(session, document_id, update_info):
result = session.query(DocumentSource).filter(
DocumentSource.id == document_id,
DocumentSource.is_delete == 0
).update(update_info)
session.flush()
return result
@staticmethod
def delete_by_id(session, document_id):
return session.query(DocumentSource).filter(
DocumentSource.id == document_id,
DocumentSource.is_delete == 0
).update({'is_delete': 1})
@staticmethod
def get_latest_version(session, product_id, project_id, source):
return session.query(DocumentSource).filter(
DocumentSource.product_id == product_id,
DocumentSource.project_id == project_id,
DocumentSource.source == source,
DocumentSource.is_delete == 0
).order_by(DocumentSource.version.desc()).first()
@staticmethod
def get_max_version(session, product_id, project_id, source):
result = session.query(func.max(DocumentSource.version)).filter(
DocumentSource.product_id == product_id,
DocumentSource.project_id == project_id,
DocumentSource.source == source,
DocumentSource.is_delete == 0
).scalar()
return result if result else 0

163
app/api/dao/skillDao.py Normal file
View File

@@ -0,0 +1,163 @@
# encoding: UTF-8
from sqlalchemy import or_
from logger import logger
from ..model.caseModel import Module
from ..model.productModel import Product
from ..model.projectModel import Project
from ..model.skillModel import TestSkill, TestBusinessRule, TestAiGenerationContext
class SkillDao(object):
@staticmethod
def create(session, model_cls, add_info):
obj = model_cls(**add_info)
session.add(obj)
err = session.done(close=False)
if err:
logger.warning(f'{model_cls.__name__}新增失败!{err}')
return 0, f'新增失败!{err}'
return obj.id, ''
@staticmethod
def update_by_id(session, model_cls, obj_id, update_info):
update_res = session.query(model_cls).filter(model_cls.id == int(obj_id), model_cls.is_delete == 0).update(update_info)
err = session.done(close=False)
if err:
logger.error(f'{model_cls.__name__}更新失败id: {obj_id}, err: {err}')
return 0, f'更新失败!{err}'
if not update_res:
return 0, '未查询到对应记录!'
return int(obj_id), ''
@staticmethod
def get_by_id(session, model_cls, obj_id):
return session.query(model_cls).filter(model_cls.id == int(obj_id), model_cls.is_delete == 0).first()
@staticmethod
def get_skill_by_project_code(session, project_id, code):
return session.query(TestSkill).filter(
TestSkill.project_id == int(project_id),
TestSkill.code == code,
TestSkill.is_delete == 0
).first()
@staticmethod
def get_business_rule_by_project_code(session, project_id, rule_code):
return session.query(TestBusinessRule).filter(
TestBusinessRule.project_id == int(project_id),
TestBusinessRule.rule_code == rule_code,
TestBusinessRule.is_delete == 0
).first()
@staticmethod
def list_skill(session, filters, page=1, limit=20, keyword=None, tag=None):
query = session.query(TestSkill).filter(TestSkill.is_delete == 0, *filters)
if keyword:
like_keyword = f'%{keyword}%'
query = query.filter(or_(
TestSkill.name.like(like_keyword),
TestSkill.code.like(like_keyword),
TestSkill.description.like(like_keyword),
TestSkill.trigger_condition.like(like_keyword)
))
if tag:
query = query.filter(TestSkill.tags.contains([tag]))
total = query.count()
items = query.order_by(TestSkill.created_time.desc()).offset((int(page) - 1) * int(limit)).limit(int(limit)).all()
return items, total
@staticmethod
def list_business_rule(session, filters, page=1, limit=20, keyword=None, tag=None):
query = session.query(TestBusinessRule).filter(TestBusinessRule.is_delete == 0, *filters)
if keyword:
like_keyword = f'%{keyword}%'
query = query.filter(or_(
TestBusinessRule.name.like(like_keyword),
TestBusinessRule.rule_code.like(like_keyword),
TestBusinessRule.rule_content.like(like_keyword),
TestBusinessRule.applicable_scene.like(like_keyword)
))
if tag:
query = query.filter(TestBusinessRule.tags.contains([tag]))
total = query.count()
items = query.order_by(TestBusinessRule.created_time.desc()).offset((int(page) - 1) * int(limit)).limit(int(limit)).all()
return items, total
@staticmethod
def delete_by_id(session, model_cls, obj_id):
return SkillDao.update_by_id(session, model_cls, obj_id, {'is_delete': 1})
@staticmethod
def get_project_by_product(session, product_id, project_id):
return session.query(Project).filter(
Project.id == int(project_id),
Project.product_id == int(product_id),
Project.is_delete == 0
).first()
@staticmethod
def list_skills_by_project(session, project_id, status=None):
query = session.query(TestSkill).filter(
TestSkill.project_id == int(project_id),
TestSkill.is_delete == 0
)
if status not in (None, ''):
query = query.filter(TestSkill.status == int(status))
return query.order_by(TestSkill.created_time.desc()).all()
@staticmethod
def list_business_rules_by_project(session, project_id, status=None):
query = session.query(TestBusinessRule).filter(
TestBusinessRule.project_id == int(project_id),
TestBusinessRule.is_delete == 0
)
if status not in (None, ''):
query = query.filter(TestBusinessRule.status == int(status))
return query.order_by(TestBusinessRule.created_time.desc()).all()
@staticmethod
def list_skills_by_ids(session, project_id, skill_ids):
if not skill_ids:
return []
return session.query(TestSkill).filter(
TestSkill.project_id == int(project_id),
TestSkill.id.in_([int(skill_id) for skill_id in skill_ids]),
TestSkill.is_delete == 0
).all()
@staticmethod
def list_business_rules_by_ids(session, project_id, rule_ids):
if not rule_ids:
return []
return session.query(TestBusinessRule).filter(
TestBusinessRule.project_id == int(project_id),
TestBusinessRule.id.in_([int(rule_id) for rule_id in rule_ids]),
TestBusinessRule.is_delete == 0
).all()
@staticmethod
def get_skill_path_context(session, project_id, module_id=None):
project = session.query(Project).filter(Project.id == int(project_id), Project.is_delete == 0).first()
product = None
module = None
if project and project.product_id:
product = session.query(Product).filter(Product.id == int(project.product_id), Product.is_delete == 0).first()
if module_id:
module = session.query(Module).filter(Module.id == int(module_id), Module.is_delete == 0).first()
return {
'product_name': product.name if product else '未关联产品',
'project_name': project.name if project else f'项目{project_id}',
'module_name': module.name if module else '项目通用'
}
@staticmethod
def batch_create_generation_context(session, rows):
if not rows:
return 0, ''
session.add_all([TestAiGenerationContext(**row) for row in rows])
err = session.done(close=False)
if err:
logger.warning(f'TestAiGenerationContext批量新增失败{err}')
return 0, f'批量新增失败!{err}'
return len(rows), ''

View File

@@ -0,0 +1,26 @@
# encoding: UTF-8
from sqlalchemy import BigInteger, Column, Integer, SmallInteger, String, TIMESTAMP, Text, text
from sqlalchemy.ext.declarative import declarative_base
from common.sqlSession import to_dict
Base = declarative_base()
Base.to_dict = to_dict
class DocumentSource(Base):
__tablename__ = 'document_source'
id = Column(BigInteger, primary_key=True, autoincrement=True, comment='主键ID')
product_id = Column(BigInteger, nullable=False, comment='产品ID')
project_id = Column(BigInteger, nullable=False, comment='项目ID')
type = Column(SmallInteger, default=1, comment='类型1-PDF文件2-飞书链接')
source = Column(String(512), nullable=False, comment='文件路径或飞书链接')
content = Column(Text, comment='解析后的文本内容(缓存)')
version = Column(Integer, default=1, comment='版本号')
status = Column(SmallInteger, default=0, comment='状态0-待解析1-已解析2-已生成用例')
ai_model = Column(String(64), comment='使用的AI模型')
created_by = Column(BigInteger, comment='创建人ID')
is_delete = Column(Integer, default=0, comment='0未删除1已删除')
created_time = Column(TIMESTAMP, server_default=text('CURRENT_TIMESTAMP'), comment='创建时间')
updated_time = Column(TIMESTAMP, server_default=text('CURRENT_TIMESTAMP'), server_onupdate=text('CURRENT_TIMESTAMP'), comment='更新时间')

View File

@@ -0,0 +1,67 @@
from sqlalchemy import BigInteger, Column, Integer, SmallInteger, String, TIMESTAMP, Text, text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.declarative import declarative_base
from common.sqlSession import to_dict
Base = declarative_base()
Base.to_dict = to_dict
class TestSkill(Base):
__tablename__ = 'test_skill'
id = Column(BigInteger, primary_key=True, autoincrement=True, comment='id')
project_id = Column(BigInteger, nullable=False, comment='项目id')
module_id = Column(BigInteger, comment='模块id空表示项目级通用')
name = Column(String(128), nullable=False, comment='Skill名称')
code = Column(String(64), nullable=False, comment='Skill编码项目内唯一')
description = Column(Text, comment='Skill描述')
trigger_condition = Column(Text, nullable=False, comment='触发条件')
reasoning_path = Column(Text, comment='推理路径')
output_spec = Column(Text, comment='输出规范')
skill_file_path = Column(String(512), comment='Skill文件路径指向config/skills下生成的SKILL.md')
skill_type = Column(SmallInteger, nullable=False, default=1, comment='类型1通用测试策略 2历史缺陷模式 3边界场景 4接口测试 5UI测试 6性能测试 7安全测试 8数据一致性 9并发幂等 99其他')
risk_level = Column(SmallInteger, nullable=False, default=2, comment='风险等级0高 1中高 2中 3低')
tags = Column(JSONB, nullable=False, server_default=text("'[]'::jsonb"), comment='标签数组')
status = Column(SmallInteger, nullable=False, default=1, comment='状态1启用 2停用 3草稿')
owner_id = Column(BigInteger, comment='负责人用户id')
created_by = Column(BigInteger, comment='创建人用户id')
usage_count = Column(Integer, nullable=False, default=0, comment='使用次数')
is_delete = Column(Integer, nullable=False, default=0, comment='0未删除 1已删除')
created_time = Column(TIMESTAMP, server_default=text('CURRENT_TIMESTAMP'), nullable=True, comment='创建时间')
updated_time = Column(TIMESTAMP, server_default=text('CURRENT_TIMESTAMP'), server_onupdate=text('CURRENT_TIMESTAMP'), nullable=True, comment='修改时间')
class TestBusinessRule(Base):
__tablename__ = 'test_business_rule'
id = Column(BigInteger, primary_key=True, autoincrement=True, comment='id')
project_id = Column(BigInteger, nullable=False, comment='项目id')
module_id = Column(BigInteger, comment='模块id空表示项目级通用')
name = Column(String(128), nullable=False, comment='业务规则名称')
rule_code = Column(String(64), comment='业务规则编码,项目内唯一')
rule_content = Column(Text, nullable=False, comment='业务规则内容')
applicable_scene = Column(Text, comment='适用场景')
example = Column(Text, comment='示例')
rule_file_path = Column(String(512), comment='业务规则文件路径指向config/rules下生成的RULE.md')
priority = Column(SmallInteger, nullable=False, default=2, comment='优先级0高 1中高 2中 3低')
tags = Column(JSONB, nullable=False, server_default=text("'[]'::jsonb"), comment='标签数组')
status = Column(SmallInteger, nullable=False, default=1, comment='状态1启用 2停用 3草稿')
owner_id = Column(BigInteger, comment='负责人用户id')
created_by = Column(BigInteger, comment='创建人用户id')
usage_count = Column(Integer, nullable=False, default=0, comment='使用次数')
is_delete = Column(Integer, nullable=False, default=0, comment='0未删除 1已删除')
created_time = Column(TIMESTAMP, server_default=text('CURRENT_TIMESTAMP'), nullable=True, comment='创建时间')
updated_time = Column(TIMESTAMP, server_default=text('CURRENT_TIMESTAMP'), server_onupdate=text('CURRENT_TIMESTAMP'), nullable=True, comment='修改时间')
class TestAiGenerationContext(Base):
__tablename__ = 'test_ai_generation_context'
id = Column(BigInteger, primary_key=True, autoincrement=True, comment='id')
generation_id = Column(BigInteger, comment='AI生成任务id兼容现有生成任务')
project_id = Column(BigInteger, nullable=False, comment='项目id')
module_id = Column(BigInteger, comment='模块id')
source_type = Column(SmallInteger, nullable=False, comment='来源类型1 Skill 2业务规则')
source_id = Column(BigInteger, nullable=False, comment='来源id')
source_name = Column(String(128), comment='来源名称快照')
match_score = Column(Integer, nullable=False, default=0, comment='匹配分数')
created_time = Column(TIMESTAMP, server_default=text('CURRENT_TIMESTAMP'), nullable=True, comment='创建时间')

View File

@@ -0,0 +1,536 @@
# encoding: UTF-8
"""
AI服务类 - 用于调用大模型生成测试用例、测试 Skill 和业务规则
"""
import json
import re
import time
import traceback
from pathlib import Path
from flask import current_app
class AIService:
"""AI服务类"""
@staticmethod
def generate_test_cases(document_content, template=None):
try:
from openai import OpenAI
from config.ai_config import AIConfig
import httpx
api_key = AIConfig.get_api_key()
api_base = AIConfig.get_api_base()
model = AIConfig.get_model()
provider = AIConfig.MODEL_PROVIDER
key_source = AIConfig.get_api_key_source()
if not api_key or api_key == '请替换为你的Meteor API Key':
return [], '未配置API密钥请在.env中配置METEOR_API_KEY'
is_plan_key = provider == 'custom' and api_key.startswith('plan-')
request_base = AIService._normalize_plan_api_base(api_base) if is_plan_key else AIService._normalize_api_base(api_base)
current_app.logger.info(f'AI配置: provider={provider}, base={request_base}, model={model}, key_source={key_source}, key_prefix={api_key[:8]}, plan_key={is_plan_key}')
timeout = httpx.Timeout(connect=AIConfig.CONNECT_TIMEOUT, read=AIConfig.READ_TIMEOUT, write=AIConfig.READ_TIMEOUT, pool=AIConfig.CONNECT_TIMEOUT)
skill_content = AIService._load_skill_content()
chunks = AIService._split_document_content(document_content)
all_cases = []
for chunk_index, chunk in enumerate(chunks, 1):
prompt = AIService._build_prompt(chunk['content'], template, skill_content, chunk_index, len(chunks), chunk['title'])
result = AIService._request_model(OpenAI, AIConfig, api_key, request_base, model, is_plan_key, prompt, timeout, httpx)
try:
parsed_result = json.loads(AIService._extract_json_text(result))
all_cases.extend(AIService._normalize_cases(parsed_result, template, chunk['title']))
except json.JSONDecodeError:
return [], f'{chunk_index}段解析结果失败: {result[:200]}'
return AIService._deduplicate_cases(all_cases), ''
except Exception as e:
current_app.logger.error(f'AI生成测试用例失败: {str(e)}')
current_app.logger.error(traceback.format_exc())
return [], f'AI生成失败: {str(e)}'
@staticmethod
def _request_model(OpenAI, AIConfig, api_key, request_base, model, is_plan_key, prompt, timeout, httpx):
max_retries = AIConfig.MAX_RETRIES
retry_delay = AIConfig.RETRY_DELAY
for attempt in range(max_retries):
try:
if is_plan_key:
return AIService._create_plan_message(api_key, request_base, model, prompt, timeout)
client = OpenAI(api_key=api_key, base_url=request_base, http_client=httpx.Client(timeout=timeout, trust_env=False))
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "你是一个专业的测试知识资产生成助手。必须最终只输出可解析JSON。"},
{"role": "user", "content": prompt}
],
max_tokens=AIConfig.OPENAI_MAX_TOKENS,
temperature=AIConfig.OPENAI_TEMPERATURE
)
return response.choices[0].message.content
except Exception as e:
if attempt < max_retries - 1:
current_app.logger.warning(f'AI请求第{attempt + 1}次失败,{retry_delay}秒后重试: {str(e)}')
time.sleep(retry_delay * (2 ** attempt))
else:
raise
@staticmethod
def _normalize_api_base(api_base):
if not api_base:
return 'https://api.routin.ai/v1'
return api_base.rstrip('/').replace('/chat/completions', '')
@staticmethod
def _normalize_plan_api_base(api_base):
if not api_base:
return 'https://api.routin.ai/plan/v1'
normalized = api_base.rstrip('/').replace('/chat/completions', '')
if '/plan/v1' in normalized:
return normalized
return normalized.replace('/v1', '/plan/v1')
@staticmethod
def _create_plan_message(api_key, api_base, model, prompt, timeout):
import httpx
response = httpx.post(
f'{api_base}/messages',
headers={'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'},
json={'model': model, 'messages': [{'role': 'user', 'content': prompt}], 'max_tokens': 4096, 'temperature': 0.7},
timeout=timeout,
trust_env=False
)
response.raise_for_status()
return AIService._extract_message_text(response.json())
@staticmethod
def _extract_message_text(data):
if isinstance(data, dict):
content = data.get('content')
if isinstance(content, list):
texts = [part['text'] for part in content if isinstance(part, dict) and part.get('text')]
if texts:
return ''.join(texts)
if isinstance(content, str):
return content
return json.dumps(data, ensure_ascii=False)
@staticmethod
def _extract_json_text(result):
text = result.strip()
fence_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', text)
if fence_match:
text = fence_match.group(1).strip()
if text.startswith('{') or text.startswith('['):
return text
json_match = re.search(r'(\{[\s\S]*\}|\[[\s\S]*\])', text)
if json_match:
return json_match.group(1).strip()
return text
@staticmethod
def generate_skill_content(req_data):
return AIService._generate_asset_content(
req_data=req_data,
prompt=AIService._build_skill_create_prompt(req_data),
markdown_key='skill_md',
normalizer=AIService._normalize_skill_markdown,
error_prefix='AI生成 Skill 内容'
)
@staticmethod
def generate_business_rule_content(req_data):
return AIService._generate_asset_content(
req_data=req_data,
prompt=AIService._build_business_rule_create_prompt(req_data),
markdown_key='rule_md',
normalizer=AIService._normalize_rule_markdown,
error_prefix='AI生成业务规则内容'
)
@staticmethod
def _generate_asset_content(req_data, prompt, markdown_key, normalizer, error_prefix):
try:
from openai import OpenAI
from config.ai_config import AIConfig
import httpx
api_key = AIConfig.get_api_key()
api_base = AIConfig.get_api_base()
model = AIConfig.get_model()
provider = AIConfig.MODEL_PROVIDER
if not api_key or api_key == '请替换为你的Meteor API Key':
return {}, '未配置API密钥请在.env中配置METEOR_API_KEY'
is_plan_key = provider == 'custom' and api_key.startswith('plan-')
request_base = AIService._normalize_plan_api_base(api_base) if is_plan_key else AIService._normalize_api_base(api_base)
timeout = httpx.Timeout(connect=AIConfig.CONNECT_TIMEOUT, read=AIConfig.READ_TIMEOUT, write=AIConfig.READ_TIMEOUT, pool=AIConfig.CONNECT_TIMEOUT)
result = AIService._request_model(OpenAI, AIConfig, api_key, request_base, model, is_plan_key, prompt, timeout, httpx)
parsed_result = json.loads(AIService._extract_json_text(result))
if not isinstance(parsed_result, dict):
return {}, f'{error_prefix}格式错误'
md = parsed_result.get(markdown_key) or parsed_result.get(markdown_key.replace('_', ''))
if not md or not isinstance(md, str):
return {}, f'{error_prefix}缺少 {markdown_key}'
parsed_result[markdown_key] = normalizer(md, req_data)
return parsed_result, ''
except json.JSONDecodeError:
return {}, f'{error_prefix}不是合法 JSON'
except Exception as e:
current_app.logger.error(f'{error_prefix}失败: {str(e)}')
current_app.logger.error(traceback.format_exc())
return {}, f'{error_prefix}失败: {str(e)}'
@staticmethod
def _normalize_skill_markdown(skill_md, req_data):
return AIService._normalize_markdown(skill_md, req_data, 'generated-skill')
@staticmethod
def _normalize_rule_markdown(rule_md, req_data):
return AIService._normalize_markdown(rule_md, req_data, 'generated-rule')
@staticmethod
def _normalize_markdown(markdown, req_data, fallback_name):
content = markdown.strip()
content = re.sub(r'^```(?:markdown|md)?\s*', '', content)
content = re.sub(r'\s*```$', '', content).strip()
if content.startswith('---'):
return content
raw_name = str(req_data.get('name') or fallback_name).strip()
frontmatter_name = re.sub(r'[^a-zA-Z0-9_-]+', '-', raw_name.lower()).strip('-') or fallback_name
description = str(req_data.get('description') or raw_name).strip()
return f'---\nname: {frontmatter_name}\ndescription: {description}\n---\n\n{content}'
@staticmethod
def get_default_case_generation_trigger_condition():
return '当用户基于 PRD、需求文档、用户故事、功能说明、接口说明、UI 交互说明或业务规则生成、补充、优化、评审测试用例时触发。'
@staticmethod
def get_default_case_generation_output_spec():
return '''输出必须兼容当前 AI 生成用例入库结构:最终只输出 JSON 对象,不输出 Markdown、解释文本或代码块。JSON 对象结构为 {"cases": [{"title": "用例名称/测试点名称", "module_name": "父模块/子模块/叶子模块", "precondition": "前置条件", "steps": "步骤1\\n步骤2", "expected_result": "预期结果1\\n预期结果2", "priority": 2, "case_type": 1, "tags": ["AI生成"]}]}。每条用例 title 需要细化到具体场景steps 和 expected_result 每一行带数字编号,信息不足时标记“待确认”,不能编造需求。'''
@staticmethod
def _load_skill_creator_content():
skill_path = Path(__file__).resolve().parents[3] / 'config' / 'skills' / 'skill-creator' / 'SKILL.md'
if not skill_path.exists():
raise FileNotFoundError(f'Skill创建规则不存在: {skill_path}')
return skill_path.read_text(encoding='utf-8')
@staticmethod
def _load_skill_content():
skill_path = Path(__file__).resolve().parents[3] / 'config' / 'skills' / 'test-case-generator' / 'SKILL.md'
if not skill_path.exists():
raise FileNotFoundError(f'测试用例生成技能不存在: {skill_path}')
return skill_path.read_text(encoding='utf-8')
@staticmethod
def _build_skill_create_prompt(req_data):
skill_creator_content = AIService._load_skill_creator_content()
default_trigger_condition = AIService.get_default_case_generation_trigger_condition()
default_output_spec = AIService.get_default_case_generation_output_spec()
return f'''
你现在要严格按照下面 skill-creator 的 SKILL.md 规范,为测试平台创建一个新的 Skill 文件。
<skill-creator-skill-md>
{skill_creator_content}
</skill-creator-skill-md>
<new-skill-input>
Skill 名称:{req_data.get('name') or ''}
用户补充描述:{req_data.get('description') or ''}
标签:{req_data.get('tags') or []}
Skill 类型枚举值:{req_data.get('skillType') or req_data.get('skill_type') or 1}
风险等级枚举值:{req_data.get('riskLevel') or req_data.get('risk_level') or 2}
</new-skill-input>
<platform-contract>
这个 Skill 的目标是增强当前平台“AI 根据 PRD/需求生成测试用例”的能力。
触发条件固定理解为:{default_trigger_condition}
输出规范固定理解为:{default_output_spec}
</platform-contract>
请只输出 JSON 对象:
{{
"description": "适合列表展示的 Skill 简介80字以内",
"reasoning_path": "面向测试用例生成的推理路径摘要,简洁步骤描述",
"tags": ["标签1", "标签2"],
"skill_type": 1,
"risk_level": 2,
"skill_md": "完整的 SKILL.md 文件内容,包含 YAML frontmatter 和 Markdown body"
}}
约束skill_md 必须包含 YAML frontmatter至少包含 name 和 descriptionbody 必须是面向测试用例生成的 Markdown 指令;不要复制 skill-creator 原文;不要输出代码块或额外说明。
'''.strip()
@staticmethod
def _build_business_rule_create_prompt(req_data):
input_rule_content = req_data.get('ruleContent') or req_data.get('rule_content') or req_data.get('description') or ''
return f'''
请为测试平台创建一条“业务规则”知识资产,用于增强 AI 根据 PRD/需求生成测试用例时对确定性业务约束、校验条件、状态流转、边界条件和异常处理的理解。
<business-rule-input>
规则名称:{req_data.get('name') or ''}
用户输入的规则原文:{input_rule_content}
用户补充描述:{req_data.get('description') or ''}
标签:{req_data.get('tags') or []}
优先级枚举值:{req_data.get('priority') or 2}
</business-rule-input>
硬性约束:
1. 不要随机生成、替换或改变“用户输入的规则原文”的业务含义。
2. 返回 JSON 中的 rule_content 必须逐字等于“用户输入的规则原文”。
3. 你只能基于用户输入补充 applicable_scene、example、tags、priority并生成用于测试用例生成的 RULE.md。
4. RULE.md 的“## Rule”章节必须逐字包含“用户输入的规则原文”不能改写成另一条规则。
请只输出 JSON 对象:
{{
"rule_content": "逐字返回用户输入的规则原文",
"applicable_scene": "该规则适用的业务场景",
"example": "输入/场景/预期的示例",
"tags": ["标签1", "标签2"],
"priority": 2,
"rule_md": "完整的 RULE.md 文件内容,包含 YAML frontmatter 和 Markdown body"
}}
RULE.md 要求:必须包含 YAML frontmatter至少包含 name 和 descriptionbody 建议包含规则说明、适用场景、测试关注点、正反例、生成用例时的约束内容必须面向测试用例生成priority 只能是 0、1、2、3tags 最多 8 个;不要输出代码块或额外说明。
'''.strip()
@staticmethod
def _split_document_content(document_content, max_chars=8000):
content = (document_content or '').strip()
if not content:
return []
sections = AIService._split_by_headings(content)
chunks = []
current_parts = []
current_len = 0
current_title = '文档内容'
for section in sections:
section_text = section['content'].strip()
if not section_text:
continue
if len(section_text) > max_chars:
if current_parts:
chunks.append({'title': current_title, 'content': '\n\n'.join(current_parts)})
current_parts = []
current_len = 0
chunks.extend(AIService._split_large_section(section['title'], section_text, max_chars))
continue
if current_parts and current_len + len(section_text) > max_chars:
chunks.append({'title': current_title, 'content': '\n\n'.join(current_parts)})
current_parts = []
current_len = 0
if not current_parts:
current_title = section['title']
current_parts.append(section_text)
current_len += len(section_text)
if current_parts:
chunks.append({'title': current_title, 'content': '\n\n'.join(current_parts)})
return chunks or [{'title': '文档内容', 'content': content}]
@staticmethod
def _split_by_headings(content):
heading_pattern = re.compile(r'(?m)^(#{1,6}\s+.+|第[一二三四五六七八九十百千万\d]+[章节部分篇].*|\d+(?:\.\d+)*[、.]\s*.+)$')
matches = list(heading_pattern.finditer(content))
if not matches:
return [{'title': '文档内容', 'content': content}]
sections = []
if matches[0].start() > 0:
sections.append({'title': '文档开头', 'content': content[:matches[0].start()].strip()})
for index, match in enumerate(matches):
start = match.start()
end = matches[index + 1].start() if index + 1 < len(matches) else len(content)
title = match.group(0).strip().lstrip('#').strip()
sections.append({'title': title[:80] or '文档内容', 'content': content[start:end].strip()})
return sections
@staticmethod
def _split_large_section(title, section_text, max_chars):
paragraphs = re.split(r'\n\s*\n', section_text)
chunks = []
current_parts = []
current_len = 0
part_index = 1
for paragraph in paragraphs:
paragraph = paragraph.strip()
if not paragraph:
continue
while len(paragraph) > max_chars:
if current_parts:
chunks.append({'title': f'{title}(第{part_index}部分)', 'content': '\n\n'.join(current_parts)})
part_index += 1
current_parts = []
current_len = 0
chunks.append({'title': f'{title}(第{part_index}部分)', 'content': paragraph[:max_chars]})
part_index += 1
paragraph = paragraph[max_chars:]
if current_parts and current_len + len(paragraph) > max_chars:
chunks.append({'title': f'{title}(第{part_index}部分)', 'content': '\n\n'.join(current_parts)})
part_index += 1
current_parts = []
current_len = 0
current_parts.append(paragraph)
current_len += len(paragraph)
if current_parts:
chunks.append({'title': f'{title}(第{part_index}部分)', 'content': '\n\n'.join(current_parts)})
return chunks
@staticmethod
def _deduplicate_cases(cases):
seen = {}
deduplicated = []
for case in cases:
key = f"{case.get('module_name', '')}::{case.get('title', '')}".strip().lower()
if not key or key in seen:
continue
seen[key] = True
deduplicated.append(case)
return deduplicated
@staticmethod
def _normalize_cases(parsed_result, template=None, chunk_title=''):
template = template or {}
raw_cases = AIService._collect_case_items(parsed_result)
normalized = []
for index, item in enumerate(raw_cases, 1):
if not isinstance(item, dict):
continue
tags = item.get('tags') or item.get('标签') or template.get('tags', ['AI生成'])
if isinstance(tags, str):
tags = [tag.strip() for tag in re.split(r'[,]', tags) if tag.strip()]
normalized.append({
'selected': item.get('selected', True),
'module_name': AIService._normalize_module_name(item.get('module_name') or item.get('所属模块') or item.get('module') or '未分类'),
'title': item.get('title') or item.get('用例名称') or item.get('case_name') or item.get('name') or f'AI生成用例{index}',
'precondition': item.get('precondition') or item.get('前置条件') or '',
'steps': AIService._number_lines(item.get('steps') or item.get('步骤描述') or item.get('操作步骤') or ''),
'expected_result': AIService._number_lines(item.get('expected_result') or item.get('expected_results') or item.get('预期结果') or item.get('期望结果') or ''),
'priority': AIService._normalize_priority(item.get('priority') or item.get('用例等级'), template.get('priority', 2)),
'case_type': AIService._normalize_case_type(item.get('case_type') or item.get('类型') or item.get('标签'), template.get('case_type', 1)),
'tags': tags or ['AI生成']
})
return normalized
@staticmethod
def _collect_case_items(value):
if isinstance(value, list):
items = []
for item in value:
items.extend(AIService._collect_case_items(item))
return items
if not isinstance(value, dict):
return []
case_keys = {'title', '用例名称', 'case_name', 'name', 'steps', '步骤描述', '操作步骤', 'expected_result', '预期结果', '期望结果'}
if any(key in value for key in case_keys):
return [value]
items = []
for nested_value in value.values():
items.extend(AIService._collect_case_items(nested_value))
return items
@staticmethod
def _normalize_module_name(module_name):
parts = [part.strip() for part in re.split(r'[/\\>|]', str(module_name or '')) if part.strip()]
return '/'.join(parts[:3]) if parts else '未分类'
@staticmethod
def _number_lines(value):
if isinstance(value, list):
lines = [str(item).strip() for item in value if str(item).strip()]
else:
lines = [line.strip() for line in re.split(r'\n+', str(value or '')) if line.strip()]
normalized = []
for index, line in enumerate(lines, 1):
cleaned_line = re.sub(r'^(?:步骤|预期结果)?\s*\d+\s*[.、.]\s*', '', line).strip()
normalized.append(f'{index}. {cleaned_line}')
return '\n'.join(normalized)
@staticmethod
def _normalize_priority(value, default=2):
if isinstance(value, int):
return value
return {'P0': 0, 'P1': 1, 'P2': 2, 'P3': 3, 'P4': 3, 'P5': 3}.get(str(value).upper(), default)
@staticmethod
def _normalize_case_type(value, default=1):
if isinstance(value, int):
return value
text = str(value or '')
if '性能' in text:
return 2
if '安全' in text:
return 3
if '接口' in text or 'API' in text.upper():
return 4
return default
@staticmethod
def _build_generation_context(template):
template = template or {}
skill_contexts = template.get('skill_contexts') or []
rule_contexts = template.get('rule_contexts') or []
if not skill_contexts and not rule_contexts:
return ''
parts = ['<selected-generation-context>']
if skill_contexts:
parts.append('请在生成测试用例时结合以下用户指定 Skill')
for item in skill_contexts:
parts.append(f'''<selected-skill id="{item.get('id')}" name="{item.get('name')}">
{item.get('content') or ''}
</selected-skill>''')
if rule_contexts:
parts.append('请在生成测试用例时严格覆盖以下用户指定业务规则:')
for item in rule_contexts:
parts.append(f'''<selected-rule id="{item.get('id')}" name="{item.get('name')}">
{item.get('content') or ''}
</selected-rule>''')
parts.append('</selected-generation-context>')
return '\n\n'.join(parts)
@staticmethod
def _build_prompt(document_content, template=None, skill_content='', chunk_index=1, total_chunks=1, chunk_title='文档内容'):
template = template or {'priority': 2, 'case_type': 1, 'tags': ['AI生成']}
generation_context = AIService._build_generation_context(template)
return f'''
请使用下面的 test-case-generator skill 对需求文档分段进行深度测试用例设计。最终只输出 JSON。
<test-case-generator-skill>
{skill_content}
</test-case-generator-skill>
{generation_context}
<document-chunk-info>
当前分段:{chunk_index}/{total_chunks}
分段标题:{chunk_title}
</document-chunk-info>
<requirement-document-chunk>
{document_content}
</requirement-document-chunk>
平台入库配置:
- 默认优先级(priority): {template['priority']}
- 默认用例类型(case_type): {template['case_type']}
- 默认标签(tags): {template['tags']}
输出 JSON 结构:
{{"cases":[{{"title":"用例名称/测试点名称","module_name":"父模块/子模块/叶子模块","precondition":"前置条件","steps":"步骤1\\n步骤2","expected_result":"预期结果1\\n预期结果2","priority":2,"case_type":1,"tags":["AI生成"]}}]}}
'''.strip()
@staticmethod
def parse_pdf_and_generate_cases(pdf_path, template=None):
try:
from PyPDF2 import PdfReader
reader = PdfReader(pdf_path)
content = ''
for page in reader.pages:
page_content = page.extract_text()
if page_content:
content += page_content + '\n'
if not content.strip():
return [], 'PDF文件内容为空'
return AIService.generate_test_cases(content, template)
except Exception as e:
current_app.logger.error(f'解析PDF并生成用例失败: {str(e)}')
return [], f'解析PDF失败: {str(e)}'

View File

@@ -0,0 +1,507 @@
# encoding: UTF-8
import os
import re
from ..model.documentSourceModel import DocumentSource
from ..model.caseModel import TestCase, Module
from ..dao.documentSourceDao import DocumentSourceDao
from ..dao.caseDao import CaseDao
from ..dao.skillDao import SkillDao
from .aiService import AIService
class DocumentSourceService:
DOCUMENT_TYPE_PDF = 1
DOCUMENT_TYPE_FEISHU = 2
DOCUMENT_STATUS_PENDING = 0
DOCUMENT_STATUS_PARSED = 1
DOCUMENT_STATUS_GENERATED = 2
@staticmethod
def create(session, data):
product_id = data.get('productId') or data.get('product_id')
project_id = data.get('projectId') or data.get('project_id')
document_type = data.get('type', 1)
source = data.get('source')
content = data.get('content', '')
created_by = data.get('createdBy') or data.get('created_by')
if not product_id or not project_id or not source:
return 0, 'productId、projectId、source 为必传参数'
max_version = DocumentSourceDao.get_max_version(session, product_id, project_id, source)
document_source = DocumentSource(
product_id=product_id,
project_id=project_id,
type=document_type,
source=source,
content=content,
version=max_version + 1,
status=DocumentSourceService.DOCUMENT_STATUS_PENDING,
created_by=created_by,
is_delete=0
)
if document_type == DocumentSourceService.DOCUMENT_TYPE_FEISHU:
content = DocumentSourceService._fetch_feishu_content(source)
if content:
document_source.content = content
document_source.status = DocumentSourceService.DOCUMENT_STATUS_PARSED
doc_id = DocumentSourceDao.create(session, document_source)
return doc_id, ''
@staticmethod
def _fetch_feishu_content(url):
try:
import requests
from bs4 import BeautifulSoup
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
}
response = requests.get(url, headers=headers, timeout=30)
if response.status_code == 200:
soup = BeautifulSoup(response.content, 'html.parser')
return soup.get_text(strip=True)[:10000]
return None
except Exception:
return None
@staticmethod
def parse_pdf_content(pdf_path):
try:
from PyPDF2 import PdfReader
reader = PdfReader(pdf_path)
content = ''
for page in reader.pages:
text = page.extract_text()
if text:
content += text
return content
except Exception:
return None
@staticmethod
def get_by_id(session, document_id):
return DocumentSourceDao.get_by_id(session, document_id)
@staticmethod
def list(session, req_data):
filters = [DocumentSource.is_delete == 0]
product_id = req_data.get('productId') or req_data.get('product_id')
if product_id:
filters.append(DocumentSource.product_id == product_id)
project_id = req_data.get('projectId') or req_data.get('project_id')
if project_id:
filters.append(DocumentSource.project_id == project_id)
doc_type = req_data.get('type')
if doc_type is not None:
filters.append(DocumentSource.type == doc_type)
status = req_data.get('status')
if status is not None:
filters.append(DocumentSource.status == status)
keyword = req_data.get('keyword')
if keyword:
filters.append(DocumentSource.source.like(f'%{keyword}%'))
page_no = int(req_data.get('pageNo', req_data.get('page', 1)))
page_size = int(req_data.get('pageSize', req_data.get('size', 20)))
items, total = DocumentSourceDao.list_by_filters(
session, filters, page_no, page_size, DocumentSource.created_time.desc()
)
return items, total
@staticmethod
def update(session, document_id, data):
update_info = {}
fields = ['type', 'source', 'content', 'ai_model']
for field in fields:
if field in data:
update_info[field] = data[field]
if update_info:
return DocumentSourceDao.update_by_id(session, document_id, update_info)
return 1
@staticmethod
def delete(session, document_id):
import os
from flask import current_app
# 先查询文档信息
document = DocumentSourceDao.get_by_id(session, document_id)
if not document:
return 0, '文档不存在'
# 如果是PDF文件类型删除对应的文件
if document.type == DocumentSourceService.DOCUMENT_TYPE_PDF and document.source:
# source字段存储的是相对路径uploads/zhyy/v2.0/xxx.pdf
file_path = os.path.join(os.getcwd(), document.source)
try:
if os.path.exists(file_path):
os.remove(file_path)
current_app.logger.info(f'已删除文件: {file_path}')
except Exception as e:
current_app.logger.error(f'删除文件失败: {file_path}, 错误: {str(e)}')
# 软删除数据库记录
result = DocumentSourceDao.delete_by_id(session, document_id)
return result, ''
@staticmethod
def refresh_content(session, document_id):
document = DocumentSourceDao.get_by_id(session, document_id)
if not document:
return False, '文档不存在'
if document.type == DocumentSourceService.DOCUMENT_TYPE_FEISHU:
content = DocumentSourceService._fetch_feishu_content(document.source)
if content:
DocumentSourceDao.update_by_id(session, document_id, {
'content': content,
'status': DocumentSourceService.DOCUMENT_STATUS_PARSED
})
return True, ''
return False, '获取飞书内容失败'
return False, '仅支持刷新飞书链接内容'
@staticmethod
def generate_cases(session, document_id, template=None):
document = DocumentSourceDao.get_by_id(session, document_id)
if not document:
return [], '文档不存在'
# 如果是PDF类型且内容为空先解析PDF
if document.type == DocumentSourceService.DOCUMENT_TYPE_PDF and not document.content:
# 解析PDF内容
pdf_path = os.path.join(os.getcwd(), document.source)
if not os.path.exists(pdf_path):
return [], 'PDF文件不存在'
# 使用AI服务解析PDF并生成用例
cases, msg = AIService.parse_pdf_and_generate_cases(pdf_path, template)
if msg:
return [], msg
# 更新文档内容和状态
DocumentSourceDao.update_by_id(session, document_id, {
'content': DocumentSourceService._extract_content_from_pdf(pdf_path),
'status': DocumentSourceService.DOCUMENT_STATUS_GENERATED
})
return cases, ''
if not document.content:
return [], '文档内容为空'
# 使用AI服务生成测试用例
cases, msg = AIService.generate_test_cases(document.content, template)
if msg:
return [], msg
# 更新文档状态为已生成用例
DocumentSourceDao.update_by_id(session, document_id, {
'status': DocumentSourceService.DOCUMENT_STATUS_GENERATED
})
return cases, ''
@staticmethod
def _extract_content_from_pdf(pdf_path):
"""提取PDF内容"""
try:
from PyPDF2 import PdfReader
reader = PdfReader(pdf_path)
content = ''
for page in reader.pages:
page_content = page.extract_text()
if page_content:
content += page_content + '\n'
return content
except Exception:
return ''
@staticmethod
def generate_cases_batch(session, document_ids, template=None):
"""
批量生成测试用例,支持多个文档
:param session: 数据库会话
:param document_ids: 文档ID列表
:param template: 用例模板配置
:return: 所有测试用例列表,失败文档列表
"""
all_cases = []
failed_docs = []
combined_content = []
template = template or {}
for doc_id in document_ids:
document = DocumentSourceDao.get_by_id(session, doc_id)
if not document:
failed_docs.append({'documentId': doc_id, 'error': '文档不存在'})
continue
content = document.content
# 如果是PDF类型且内容为空先解析PDF
if document.type == DocumentSourceService.DOCUMENT_TYPE_PDF and not content:
pdf_path = os.path.join(os.getcwd(), document.source)
if not os.path.exists(pdf_path):
failed_docs.append({'documentId': doc_id, 'error': 'PDF文件不存在'})
continue
# 提取PDF内容
content = DocumentSourceService._extract_content_from_pdf(pdf_path)
if not content:
failed_docs.append({'documentId': doc_id, 'error': 'PDF内容为空'})
continue
# 更新文档内容
DocumentSourceDao.update_by_id(session, doc_id, {
'content': content,
'status': DocumentSourceService.DOCUMENT_STATUS_PARSED
})
if not content:
failed_docs.append({'documentId': doc_id, 'error': '文档内容为空'})
continue
# 添加文档标识
combined_content.append(f"【文档ID: {doc_id}\n{content}\n")
if not combined_content:
return [], failed_docs
# 合并所有文档内容
merged_content = "\n---\n".join(combined_content)
context_template, context_err = DocumentSourceService._attach_generation_context(session, template)
if context_err:
return [], [{'documentId': 'all', 'error': context_err}]
# 使用AI服务生成测试用例基于合并后的内容
cases, msg = AIService.generate_test_cases(merged_content, context_template)
if msg:
return [], [{'documentId': 'all', 'error': msg}]
# 更新所有文档状态为已生成用例
for doc_id in document_ids:
if doc_id not in [f['documentId'] for f in failed_docs]:
DocumentSourceDao.update_by_id(session, doc_id, {
'status': DocumentSourceService.DOCUMENT_STATUS_GENERATED
})
return cases, failed_docs
@staticmethod
def _attach_generation_context(session, template):
template = dict(template or {})
skill_ids = template.get('skill_ids') or []
rule_ids = template.get('rule_ids') or []
if not skill_ids and not rule_ids:
return template, ''
project_id = template.get('project_id')
if not project_id:
return template, 'projectId 为必传参数'
try:
skill_ids = [int(item) for item in skill_ids]
rule_ids = [int(item) for item in rule_ids]
except (TypeError, ValueError):
return template, 'skillIds、ruleIds 必须是数字数组'
skills = SkillDao.list_skills_by_ids(session, project_id, skill_ids)
rules = SkillDao.list_business_rules_by_ids(session, project_id, rule_ids)
if len(skills) != len(set(skill_ids)):
return template, '存在未查询到的 Skill 或 Skill 不属于当前项目'
if len(rules) != len(set(rule_ids)):
return template, '存在未查询到的业务规则或业务规则不属于当前项目'
skill_contexts, err_msg = DocumentSourceService._load_asset_contexts(skills, 'skill_file_path', 'Skill')
if err_msg:
return template, err_msg
rule_contexts, err_msg = DocumentSourceService._load_asset_contexts(rules, 'rule_file_path', '业务规则')
if err_msg:
return template, err_msg
template['skill_contexts'] = skill_contexts
template['rule_contexts'] = rule_contexts
return template, ''
@staticmethod
def _load_asset_contexts(items, path_field, source_label):
contexts = []
workspace_root = os.getcwd()
for item in items:
file_path = getattr(item, path_field, None)
if not file_path:
return [], f'{source_label}{getattr(item, "name", "")}」未配置文件路径'
if not os.path.isabs(file_path):
file_path = os.path.join(workspace_root, file_path)
normalized_path = os.path.abspath(file_path)
if not os.path.exists(normalized_path):
return [], f'{source_label}{getattr(item, "name", "")}」文件不存在'
try:
with open(normalized_path, 'r', encoding='utf-8') as file_obj:
content = file_obj.read()
except Exception as e:
return [], f'{source_label}{getattr(item, "name", "")}」文件读取失败:{str(e)}'
contexts.append({
'id': item.id,
'name': item.name,
'path': normalized_path,
'content': content
})
return contexts, ''
@staticmethod
def match_modules(session, project_id, cases):
for case in cases:
module_name = case.get('module_name')
case['module_id'] = DocumentSourceService._find_module_by_path(session, project_id, module_name) if module_name else None
return cases
@staticmethod
def import_cases(session, document_id, cases, user_id, auto_create_module=False):
document = DocumentSourceDao.get_by_id(session, document_id)
if not document:
return 0, '文档不存在'
success_count = 0
for case_data in cases:
if not case_data.get('selected', True):
continue
module_id = case_data.get('module_id')
module_name = case_data.get('module_name', '未分类')
if not module_id:
if auto_create_module:
module_id = DocumentSourceService._get_or_create_module_path(session, document.project_id, module_name)
else:
module_id = DocumentSourceService._find_module_by_path(session, document.project_id, module_name)
if not module_id:
continue
case_info = {
'project_id': document.project_id,
'module_id': module_id,
'case_key': CaseDao.next_case_key(session, document.project_id, module_id, document.product_id),
'title': case_data.get('title', ''),
'preconditions': case_data.get('precondition', ''),
'steps': case_data.get('steps', ''),
'expected_results': case_data.get('expected_result', ''),
'priority': case_data.get('priority', 2),
'case_type': case_data.get('case_type', 1),
'tags': case_data.get('tags', []),
'is_ai_generated': 1,
'status': 0,
'is_delete': 0,
'created_by': user_id
}
case_id, err_msg = CaseDao.create(session, TestCase, case_info)
if err_msg:
return success_count, err_msg
success_count += 1
DocumentSourceDao.update_by_id(session, document_id, {
'status': DocumentSourceService.DOCUMENT_STATUS_GENERATED
})
return success_count, ''
@staticmethod
def batch_create_modules(session, project_id, module_names):
created_modules = []
for name in module_names:
module = DocumentSourceService._get_or_create_module_path(session, project_id, name, return_model=True)
if module:
created_modules.append(module)
session.flush()
return created_modules
@staticmethod
def _find_module_by_path(session, project_id, module_name):
parts = DocumentSourceService._parse_module_path(module_name)
parent_id = 0
module_id = None
for name in parts:
module = session.query(Module).filter(
Module.project_id == project_id,
Module.parent_id == parent_id,
Module.name == name,
Module.is_delete == 0
).first()
if not module:
return None
module_id = module.id
parent_id = module.id
return module_id
@staticmethod
def _get_or_create_module_path(session, project_id, module_name, return_model=False):
parts = DocumentSourceService._parse_module_path(module_name)
parent_id = 0
current_module = None
for name in parts:
current_module = session.query(Module).filter(
Module.project_id == project_id,
Module.parent_id == parent_id,
Module.name == name,
Module.is_delete == 0
).first()
if not current_module:
current_module = Module(
project_id=project_id,
parent_id=parent_id,
name=name,
sort_order=DocumentSourceService._next_module_sort_order(session, project_id, parent_id),
path=DocumentSourceService._build_module_path(session, parent_id, name),
is_delete=0,
status=0
)
session.add(current_module)
session.flush()
parent_id = current_module.id
return current_module if return_model else current_module.id
@staticmethod
def _parse_module_path(module_name):
module_name = str(module_name or '').strip() or '未分类'
parts = [part.strip() for part in re.split(r'[/\\>|]', module_name) if part.strip()]
return (parts or ['未分类'])[:3]
@staticmethod
def _next_module_sort_order(session, project_id, parent_id):
last_module = session.query(Module).filter(
Module.project_id == project_id,
Module.parent_id == parent_id,
Module.is_delete == 0
).order_by(Module.sort_order.desc()).first()
return (last_module.sort_order if last_module and last_module.sort_order is not None else 0) + 1
@staticmethod
def _build_module_path(session, parent_id, name):
if not parent_id:
return name
parent = session.query(Module).filter(Module.id == parent_id, Module.is_delete == 0).first()
if parent and parent.path:
return f'{parent.path}/{name}'
if parent:
return f'{parent.name}/{name}'
return name

View File

@@ -0,0 +1,571 @@
# encoding: UTF-8
import re
import shutil
from datetime import datetime
from pathlib import Path
from ..dao.skillDao import SkillDao
from ..model.skillModel import TestSkill, TestBusinessRule
from .aiService import AIService
class SkillService(object):
VALID_SKILL_TYPES = {1, 2, 3, 4, 5, 6, 7, 8, 9, 99}
VALID_STATUS = {1, 2, 3}
VALID_LEVELS = {0, 1, 2, 3}
@staticmethod
def _get(req_data, *keys, default=None):
for key in keys:
value = req_data.get(key)
if value not in (None, ''):
return value
return default
@staticmethod
def _ensure_list(value, field_name):
if value in (None, ''):
return [], ''
if not isinstance(value, list):
return [], f'{field_name} 必须是数组'
return value, ''
@staticmethod
def _normalize_generated_tags(value, fallback):
if isinstance(value, list):
tags = [str(item).strip() for item in value if str(item).strip()]
elif isinstance(value, str):
tags = [item.strip() for item in re.split(r'[,,、\s]+', value) if item.strip()]
else:
tags = []
return tags[:8] or fallback
@staticmethod
def _generate_unique_code(session, project_id, name, prefix, exists_checker):
name_text = str(name or '').strip().upper()
letters = re.sub(r'[^A-Z0-9]+', '_', name_text).strip('_')
code_prefix = (letters[:24] if letters else prefix) or prefix
time_part = datetime.now().strftime('%Y%m%d%H%M%S%f')[:20]
code = f'{code_prefix}_{time_part}'[:64]
if not exists_checker(session, project_id, code):
return code
for index in range(1, 100):
candidate = f'{code_prefix}_{time_part}_{index}'[:64]
if not exists_checker(session, project_id, candidate):
return candidate
return f'{prefix}_{time_part}'[:64]
@staticmethod
def _generate_skill_code(session, project_id, name):
return SkillService._generate_unique_code(session, project_id, name, 'SKILL', SkillDao.get_skill_by_project_code)
@staticmethod
def _generate_rule_code(session, project_id, name):
return SkillService._generate_unique_code(session, project_id, name, 'RULE', SkillDao.get_business_rule_by_project_code)
@staticmethod
def _safe_path_name(value, fallback):
value = str(value or '').strip() or fallback
value = re.sub(r'[\\/:*?"<>|\r\n\t]+', '_', value)
value = re.sub(r'\s+', ' ', value).strip(' .')
return (value or fallback)[:80]
@staticmethod
def _build_rule_file_content(rule_info):
tags = rule_info.get('tags') or []
tags_text = ', '.join([str(tag) for tag in tags])
frontmatter_name = re.sub(r'[^a-zA-Z0-9_-]+', '-', str(rule_info.get('name') or 'generated-rule').lower()).strip('-') or 'generated-rule'
description = rule_info.get('rule_content') or rule_info.get('description') or rule_info.get('name') or ''
return f'''---
name: {frontmatter_name}
description: {description}
---
# {rule_info.get('name')}
## Rule
{rule_info.get('rule_content') or ''}
## Applicable scene
{rule_info.get('applicable_scene') or ''}
## Example
{rule_info.get('example') or ''}
## Test design constraints
- Generate cases that verify this rule is satisfied in normal flows.
- Generate negative and boundary cases when the rule describes validation, limits, state changes, permissions, or data constraints.
- Mark missing prerequisites as “待确认” instead of inventing behavior.
## Metadata
- Code: {rule_info.get('rule_code') or ''}
- Product: {rule_info.get('product_name') or ''}
- Project: {rule_info.get('project_name') or ''}
- Module: {rule_info.get('module_name') or ''}
- Priority: {rule_info.get('priority')}
- Tags: {tags_text}
'''
@staticmethod
def _build_skill_file_content(skill_info):
skill_md = skill_info.get('skill_md') or skill_info.get('skillMd')
if isinstance(skill_md, str) and skill_md.strip():
return skill_md.strip() + '\n'
tags = skill_info.get('tags') or []
tags_text = ', '.join([str(tag) for tag in tags])
frontmatter_name = re.sub(r'[^a-zA-Z0-9_-]+', '-', str(skill_info.get('name') or 'generated-skill').lower()).strip('-') or 'generated-skill'
description = skill_info.get('description') or skill_info.get('name') or ''
return f'''---
name: {frontmatter_name}
description: {description}
---
# {skill_info.get('name')}
Use this skill when PRD, requirement, user story, interface specification, UI interaction, or business rule content needs to be transformed into high-quality test cases. This skill helps the model apply project-specific testing experience when designing functional, interface, boundary, exception, and regression cases.
## When to use this skill
{skill_info.get('trigger_condition') or ''}
## Analysis workflow
{skill_info.get('reasoning_path') or ''}
## Test design guidance
- Identify the core business flow, state changes, inputs, outputs, permissions, and data dependencies.
- Cover normal paths, boundary values, invalid inputs, exception handling, idempotency, concurrency, and regression risks when applicable.
- Mark missing or ambiguous requirements as “待确认” rather than inventing behavior.
## Output format
{skill_info.get('output_spec') or ''}
## Metadata
- Code: {skill_info.get('code') or ''}
- Product: {skill_info.get('product_name') or ''}
- Project: {skill_info.get('project_name') or ''}
- Module: {skill_info.get('module_name') or ''}
- Skill Type: {skill_info.get('skill_type')}
- Risk Level: {skill_info.get('risk_level')}
- Tags: {tags_text}
'''
@staticmethod
def _create_asset_file(session, project_id, module_id, asset_info, root_folder, folder_name, file_name, content_builder):
context = SkillDao.get_skill_path_context(session, project_id, module_id)
product_name = SkillService._safe_path_name(context.get('product_name'), '未关联产品')
project_name = SkillService._safe_path_name(context.get('project_name'), f'项目{project_id}')
module_name = SkillService._safe_path_name(context.get('module_name'), '项目通用')
asset_name = SkillService._safe_path_name(folder_name, '未命名')
base_dir = Path(__file__).resolve().parents[3] / 'config' / root_folder
asset_dir = base_dir / product_name / project_name / module_name / asset_name
if asset_dir.exists():
suffix = datetime.now().strftime('%Y%m%d%H%M%S%f')[:20]
asset_dir = asset_dir.with_name(f'{asset_dir.name}_{suffix}')
asset_dir.mkdir(parents=True, exist_ok=False)
asset_path = asset_dir / file_name
file_info = dict(asset_info)
file_info.update({
'product_name': context.get('product_name'),
'project_name': context.get('project_name'),
'module_name': context.get('module_name')
})
asset_path.write_text(content_builder(file_info), encoding='utf-8')
return str(asset_path), str(asset_dir)
@staticmethod
def _create_skill_file(session, project_id, module_id, skill_info):
return SkillService._create_asset_file(
session, project_id, module_id, skill_info, 'skills', skill_info.get('name'), 'SKILL.md', SkillService._build_skill_file_content
)
@staticmethod
def _create_rule_file(session, project_id, module_id, rule_info):
return SkillService._create_asset_file(
session, project_id, module_id, rule_info, 'rules', rule_info.get('name'), 'RULE.md', SkillService._build_rule_file_content
)
@staticmethod
def _remove_asset_file_path(asset_file_path, root_folder):
if not asset_file_path:
return
asset_path = Path(asset_file_path)
base_dir = Path(__file__).resolve().parents[3] / 'config' / root_folder
try:
resolved_asset_path = asset_path.resolve()
resolved_base_dir = base_dir.resolve()
if resolved_base_dir not in resolved_asset_path.parents:
return
asset_dir = resolved_asset_path.parent
if asset_dir.exists() and asset_dir.name not in {root_folder, 'config'}:
shutil.rmtree(asset_dir)
except FileNotFoundError:
return
@staticmethod
def _remove_skill_file_path(skill_file_path):
SkillService._remove_asset_file_path(skill_file_path, 'skills')
@staticmethod
def _remove_rule_file_path(rule_file_path):
SkillService._remove_asset_file_path(rule_file_path, 'rules')
@staticmethod
def create_skill(session, req_data, user_id=None):
project_id = SkillService._get(req_data, 'projectId', 'project_id')
name = SkillService._get(req_data, 'name')
if not project_id or not name:
return 0, 'projectId、name 为必传参数'
input_tags, err_msg = SkillService._ensure_list(SkillService._get(req_data, 'tags', default=[]), 'tags')
if err_msg:
return 0, err_msg
generated_info, err_msg = AIService.generate_skill_content(req_data)
if err_msg:
return 0, err_msg
generated_skill_type = generated_info.get('skill_type') or generated_info.get('skillType')
generated_risk_level = generated_info.get('risk_level') or generated_info.get('riskLevel')
skill_type = int(generated_skill_type if generated_skill_type is not None else SkillService._get(req_data, 'skillType', 'skill_type', default=1))
risk_level = int(generated_risk_level if generated_risk_level is not None else SkillService._get(req_data, 'riskLevel', 'risk_level', default=2))
status = int(SkillService._get(req_data, 'status', default=1))
if skill_type not in SkillService.VALID_SKILL_TYPES:
skill_type = 1
if risk_level not in SkillService.VALID_LEVELS:
risk_level = 2
if status not in SkillService.VALID_STATUS:
return 0, 'status 不合法'
generated_tags = SkillService._normalize_generated_tags(generated_info.get('tags'), input_tags)
module_id_value = SkillService._get(req_data, 'moduleId', 'module_id')
module_id = int(module_id_value) if module_id_value else None
add_info = {
'project_id': int(project_id),
'module_id': module_id,
'name': name,
'code': SkillService._generate_skill_code(session, project_id, name),
'description': generated_info.get('description') or SkillService._get(req_data, 'description') or name,
'trigger_condition': AIService.get_default_case_generation_trigger_condition(),
'reasoning_path': generated_info.get('reasoning_path') or generated_info.get('reasoningPath'),
'output_spec': AIService.get_default_case_generation_output_spec(),
'skill_type': skill_type,
'risk_level': risk_level,
'tags': generated_tags,
'status': status,
'owner_id': int(user_id) if user_id else None,
'created_by': user_id,
'usage_count': 0,
'is_delete': 0
}
skill_file_info = dict(add_info)
skill_file_info['skill_md'] = generated_info.get('skill_md')
skill_file_path, skill_dir = SkillService._create_skill_file(session, int(project_id), module_id, skill_file_info)
add_info['skill_file_path'] = skill_file_path
obj_id, create_err = SkillDao.create(session, TestSkill, add_info)
if create_err:
shutil.rmtree(skill_dir, ignore_errors=True)
return 0, create_err
return obj_id, ''
@staticmethod
def update_skill(session, req_data):
skill_id = SkillService._get(req_data, 'skillId', 'id')
if not skill_id:
return 0, 'skillId 为必传参数'
item = SkillDao.get_by_id(session, TestSkill, skill_id)
if not item:
return 0, '未查询到对应 Skill'
update_info = {}
mapping = [
('name', 'name'), ('description', 'description'),
('triggerCondition', 'trigger_condition'), ('trigger_condition', 'trigger_condition'),
('reasoningPath', 'reasoning_path'), ('reasoning_path', 'reasoning_path'),
('outputSpec', 'output_spec'), ('output_spec', 'output_spec')
]
for req_key, column_key in mapping:
value = SkillService._get(req_data, req_key)
if value is not None:
update_info[column_key] = value
module_id = SkillService._get(req_data, 'moduleId', 'module_id')
if module_id is not None:
update_info['module_id'] = int(module_id) if module_id != '' else None
owner_id = SkillService._get(req_data, 'ownerId', 'owner_id')
if owner_id is not None:
update_info['owner_id'] = int(owner_id) if owner_id != '' else None
tags = SkillService._get(req_data, 'tags')
if tags is not None:
tags, err_msg = SkillService._ensure_list(tags, 'tags')
if err_msg:
return 0, err_msg
update_info['tags'] = tags
for req_key, column_key, valid_set in [
('skillType', 'skill_type', SkillService.VALID_SKILL_TYPES),
('skill_type', 'skill_type', SkillService.VALID_SKILL_TYPES),
('riskLevel', 'risk_level', SkillService.VALID_LEVELS),
('risk_level', 'risk_level', SkillService.VALID_LEVELS),
('status', 'status', SkillService.VALID_STATUS)
]:
value = SkillService._get(req_data, req_key)
if value is not None:
value = int(value)
if value not in valid_set:
return 0, f'{req_key} 不合法'
update_info[column_key] = value
if not update_info:
return int(skill_id), ''
merged_info = item.to_dict()
merged_info.update(update_info)
new_skill_file_path = None
new_skill_dir = None
try:
new_skill_file_path, new_skill_dir = SkillService._create_skill_file(
session,
int(merged_info.get('project_id')),
merged_info.get('module_id'),
merged_info
)
update_info['skill_file_path'] = new_skill_file_path
except Exception as e:
return 0, f'Skill 文件创建失败:{str(e)}'
obj_id, err_msg = SkillDao.update_by_id(session, TestSkill, skill_id, update_info)
if err_msg:
if new_skill_dir:
shutil.rmtree(new_skill_dir, ignore_errors=True)
return obj_id, err_msg
SkillService._remove_skill_file_path(item.skill_file_path)
return obj_id, ''
@staticmethod
def delete_skill(session, req_data):
skill_id = SkillService._get(req_data, 'skillId', 'id')
if not skill_id:
return 0, 'skillId 为必传参数'
item = SkillDao.get_by_id(session, TestSkill, skill_id)
if not item:
return 0, '未查询到对应 Skill'
obj_id, err_msg = SkillDao.delete_by_id(session, TestSkill, skill_id)
if err_msg:
return obj_id, err_msg
SkillService._remove_skill_file_path(item.skill_file_path)
return obj_id, ''
@staticmethod
def skill_detail(session, skill_id):
item = SkillDao.get_by_id(session, TestSkill, skill_id)
if not item:
return {}, '未查询到对应 Skill'
return item.to_dict(), ''
@staticmethod
def skill_list(session, req_data):
filters = []
project_id = SkillService._get(req_data, 'projectId', 'project_id')
module_id = SkillService._get(req_data, 'moduleId', 'module_id')
status = SkillService._get(req_data, 'status')
skill_type = SkillService._get(req_data, 'skillType', 'skill_type')
risk_level = SkillService._get(req_data, 'riskLevel', 'risk_level')
if project_id:
filters.append(TestSkill.project_id == int(project_id))
if module_id not in (None, ''):
filters.append(TestSkill.module_id == int(module_id))
if status not in (None, ''):
filters.append(TestSkill.status == int(status))
if skill_type not in (None, ''):
filters.append(TestSkill.skill_type == int(skill_type))
if risk_level not in (None, ''):
filters.append(TestSkill.risk_level == int(risk_level))
items, total = SkillDao.list_skill(
session, filters,
SkillService._get(req_data, 'pageNo', 'page', default=1),
SkillService._get(req_data, 'pageSize', 'size', default=20),
SkillService._get(req_data, 'keyword'),
SkillService._get(req_data, 'tag')
)
return {'list': [item.to_dict() for item in items], 'total': total}
@staticmethod
def create_business_rule(session, req_data, user_id=None):
project_id = SkillService._get(req_data, 'projectId', 'project_id')
name = SkillService._get(req_data, 'name')
if not project_id or not name:
return 0, 'projectId、name 为必传参数'
input_tags, err_msg = SkillService._ensure_list(SkillService._get(req_data, 'tags', default=[]), 'tags')
if err_msg:
return 0, err_msg
generated_info, err_msg = AIService.generate_business_rule_content(req_data)
if err_msg:
return 0, err_msg
input_priority = SkillService._get(req_data, 'priority')
priority_value = input_priority if input_priority is not None else generated_info.get('priority')
priority = int(priority_value if priority_value is not None else 2)
status = int(SkillService._get(req_data, 'status', default=1))
if priority not in SkillService.VALID_LEVELS:
priority = 2
if status not in SkillService.VALID_STATUS:
return 0, 'status 不合法'
generated_tags = SkillService._normalize_generated_tags(generated_info.get('tags'), input_tags)
module_id_value = SkillService._get(req_data, 'moduleId', 'module_id')
module_id = int(module_id_value) if module_id_value else None
input_rule_content = SkillService._get(req_data, 'ruleContent', 'rule_content') or SkillService._get(req_data, 'description') or name
add_info = {
'project_id': int(project_id),
'module_id': module_id,
'name': name,
'rule_code': SkillService._generate_rule_code(session, project_id, name),
'rule_content': input_rule_content,
'applicable_scene': SkillService._get(req_data, 'applicableScene', 'applicable_scene') or generated_info.get('applicable_scene') or generated_info.get('applicableScene'),
'example': SkillService._get(req_data, 'example') or generated_info.get('example'),
'priority': priority,
'tags': input_tags or generated_tags,
'status': status,
'owner_id': int(user_id) if user_id else None,
'created_by': user_id,
'usage_count': 0,
'is_delete': 0
}
rule_file_info = dict(add_info)
rule_file_path, rule_dir = SkillService._create_rule_file(session, int(project_id), module_id, rule_file_info)
add_info['rule_file_path'] = rule_file_path
obj_id, create_err = SkillDao.create(session, TestBusinessRule, add_info)
if create_err:
shutil.rmtree(rule_dir, ignore_errors=True)
return 0, create_err
return obj_id, ''
@staticmethod
def update_business_rule(session, req_data):
rule_id = SkillService._get(req_data, 'ruleId', 'id')
if not rule_id:
return 0, 'ruleId 为必传参数'
item = SkillDao.get_by_id(session, TestBusinessRule, rule_id)
if not item:
return 0, '未查询到对应业务规则'
update_info = {}
mapping = [
('name', 'name'), ('ruleContent', 'rule_content'), ('rule_content', 'rule_content'),
('applicableScene', 'applicable_scene'), ('applicable_scene', 'applicable_scene'),
('example', 'example')
]
for req_key, column_key in mapping:
value = SkillService._get(req_data, req_key)
if value is not None:
update_info[column_key] = value
module_id = SkillService._get(req_data, 'moduleId', 'module_id')
if module_id is not None:
update_info['module_id'] = int(module_id) if module_id != '' else None
owner_id = SkillService._get(req_data, 'ownerId', 'owner_id')
if owner_id is not None:
update_info['owner_id'] = int(owner_id) if owner_id != '' else None
tags = SkillService._get(req_data, 'tags')
if tags is not None:
tags, err_msg = SkillService._ensure_list(tags, 'tags')
if err_msg:
return 0, err_msg
update_info['tags'] = tags
priority = SkillService._get(req_data, 'priority')
if priority is not None:
priority = int(priority)
if priority not in SkillService.VALID_LEVELS:
return 0, 'priority 不合法'
update_info['priority'] = priority
status = SkillService._get(req_data, 'status')
if status is not None:
status = int(status)
if status not in SkillService.VALID_STATUS:
return 0, 'status 不合法'
update_info['status'] = status
if not update_info:
return int(rule_id), ''
merged_info = item.to_dict()
merged_info.update(update_info)
new_rule_file_path = None
new_rule_dir = None
try:
new_rule_file_path, new_rule_dir = SkillService._create_rule_file(
session,
int(merged_info.get('project_id')),
merged_info.get('module_id'),
merged_info
)
update_info['rule_file_path'] = new_rule_file_path
except Exception as e:
return 0, f'业务规则文件创建失败:{str(e)}'
obj_id, err_msg = SkillDao.update_by_id(session, TestBusinessRule, rule_id, update_info)
if err_msg:
if new_rule_dir:
shutil.rmtree(new_rule_dir, ignore_errors=True)
return obj_id, err_msg
SkillService._remove_rule_file_path(item.rule_file_path)
return obj_id, ''
@staticmethod
def delete_business_rule(session, req_data):
rule_id = SkillService._get(req_data, 'ruleId', 'id')
if not rule_id:
return 0, 'ruleId 为必传参数'
item = SkillDao.get_by_id(session, TestBusinessRule, rule_id)
if not item:
return 0, '未查询到对应业务规则'
obj_id, err_msg = SkillDao.delete_by_id(session, TestBusinessRule, rule_id)
if err_msg:
return obj_id, err_msg
SkillService._remove_rule_file_path(item.rule_file_path)
return obj_id, ''
@staticmethod
def business_rule_detail(session, rule_id):
item = SkillDao.get_by_id(session, TestBusinessRule, rule_id)
if not item:
return {}, '未查询到对应业务规则'
return item.to_dict(), ''
@staticmethod
def skill_rule_list(session, req_data):
product_id = SkillService._get(req_data, 'productId', 'product_id')
project_id = SkillService._get(req_data, 'projectId', 'project_id')
status = SkillService._get(req_data, 'status')
if not product_id or not project_id:
return {}, 'productId、projectId 为必传参数'
project = SkillDao.get_project_by_product(session, product_id, project_id)
if not project:
return {}, '未查询到对应产品下的项目'
skills = SkillDao.list_skills_by_project(session, project_id, status)
rules = SkillDao.list_business_rules_by_project(session, project_id, status)
return {
'productId': int(product_id),
'projectId': int(project_id),
'skills': [item.to_dict() for item in skills],
'rules': [item.to_dict() for item in rules],
'skillTotal': len(skills),
'ruleTotal': len(rules)
}, ''
@staticmethod
def business_rule_list(session, req_data):
filters = []
project_id = SkillService._get(req_data, 'projectId', 'project_id')
module_id = SkillService._get(req_data, 'moduleId', 'module_id')
status = SkillService._get(req_data, 'status')
priority = SkillService._get(req_data, 'priority')
if project_id:
filters.append(TestBusinessRule.project_id == int(project_id))
if module_id not in (None, ''):
filters.append(TestBusinessRule.module_id == int(module_id))
if status not in (None, ''):
filters.append(TestBusinessRule.status == int(status))
if priority not in (None, ''):
filters.append(TestBusinessRule.priority == int(priority))
items, total = SkillDao.list_business_rule(
session, filters,
SkillService._get(req_data, 'pageNo', 'page', default=1),
SkillService._get(req_data, 'pageSize', 'size', default=20),
SkillService._get(req_data, 'keyword'),
SkillService._get(req_data, 'tag')
)
return {'list': [item.to_dict() for item in items], 'total': total}