diff --git a/zhyy/test_case/run_tests.py b/zhyy/test_case/run_tests.py index 8b3bdd2..af54c52 100644 --- a/zhyy/test_case/run_tests.py +++ b/zhyy/test_case/run_tests.py @@ -64,34 +64,44 @@ def is_importable(file_path): except Exception: return False -def find_test_files(directory, include_all=False): - """递归查找所有测试文件""" +def contains_test_class(file_path): + """检查文件是否包含测试类(以 Test 开头的类或包含 allure 装饰器)""" + try: + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + # 检查是否包含 test_ 开头的函数或 Test 开头的类 + if 'def test_' in content or 'class Test' in content: + return True + # 检查是否包含 allure 装饰器 + if '@allure.feature' in content or '@allure.story' in content: + return True + return False + except Exception: + return False + + +def find_test_files(directory): + """递归查找所有测试文件(自动识别包含测试类的文件)""" test_files = [] for root, dirs, files in os.walk(directory): for file in files: if file.endswith('.py') and not file.startswith('__'): file_path = os.path.join(root, file) - if include_all: - # 包含所有 .py 文件(用于特殊目录如 SZPurchase) + # 如果文件以 test_ 开头,直接加入 + if file.startswith('test_'): test_files.append(file_path) else: - # 只查找以 test_ 开头的 Python 文件(符合 pytest 命名约定) - if file.startswith('test_'): + # 否则检查文件内容是否包含测试类 + if contains_test_class(file_path): test_files.append(file_path) return test_files def get_all_test_files(): - """获取所有测试文件(包含标准 test_*.py 和 SZPurchase 目录下的所有 .py 文件)""" + """获取所有测试文件(自动发现所有包含测试类的文件)""" test_files = find_test_files(case_dir) - - szpurchase_dir = os.path.join(case_dir, '接口', 'SZPurchase') - if os.path.exists(szpurchase_dir): - szpurchase_files = find_test_files(szpurchase_dir, include_all=True) - test_files.extend(szpurchase_files) - print(f"添加 SZPurchase 目录下的 {len(szpurchase_files)} 个文件") - + print(f"共发现 {len(test_files)} 个测试文件") return test_files @@ -150,9 +160,6 @@ def run_tests(target, test_type='all', **kwargs): return 1 print(f"按目录运行: {target}") test_files = find_test_files(full_path) - if not test_files: - print("未找到 test_*.py 文件,尝试查找所有 .py 文件...") - test_files = find_test_files(full_path, include_all=True) if not test_files: print("错误: 未找到测试文件") return 1