# encoding: UTF-8 from sqlalchemy import func from ..model.caseModel import CaseReview, CaseSnapshot, Module, TestCase from logger import logger class CaseDao(object): """用例域通用 DAO,复用模块、用例、快照、评审的基础操作。""" @staticmethod def create(session, model_cls, add_info): """创建记录并提交事务。""" obj = model_cls(**add_info) session.add(obj) err = session.done(close=False) if err: logger.warning(f'{model_cls.__name__}新增失败!{err}') return 0, f'新增失败!{err}' return obj.id, '' @staticmethod def update_by_id(session, model_cls, obj_id, update_info, soft_delete=True): filters = [model_cls.id == int(obj_id)] if soft_delete and hasattr(model_cls, 'is_delete'): filters.append(model_cls.is_delete == 0) update_res = session.query(model_cls).filter(*filters).update(update_info) err = session.done(close=False) if err: logger.error(f'{model_cls.__name__}更新失败!id: {obj_id}, err: {err}') return 0, f'更新失败!{err}' if not update_res: return 0, '未查询到对应记录!' return int(obj_id), '' @staticmethod def get_by_id(session, model_cls, obj_id, soft_delete=True): filters = [model_cls.id == int(obj_id)] if soft_delete and hasattr(model_cls, 'is_delete'): filters.append(model_cls.is_delete == 0) return session.query(model_cls).filter(*filters).first() @staticmethod def list_by_filters(session, model_cls, filter_list, page=1, limit=20, order_column=None): query = session.query(model_cls).filter(*filter_list) if hasattr(model_cls, 'is_delete'): query = query.filter(model_cls.is_delete == 0) total = query.count() if order_column is not None: query = query.order_by(order_column.desc()) rets = query.offset((int(page) - 1) * int(limit)).limit(int(limit)).all() return rets, total @staticmethod def delete_by_id(session, model_cls, obj_id): return CaseDao.update_by_id(session, model_cls, obj_id, {'is_delete': 1}) @staticmethod def next_case_key(session, project_id): count_num = session.query(func.count(TestCase.id)).filter(TestCase.project_id == int(project_id)).scalar() or 0 return 'TC-{:03d}'.format(count_num + 1) @staticmethod def next_snapshot_version(session, case_id): """生成用例快照版本号。""" max_version = session.query(func.max(CaseSnapshot.version)).filter(CaseSnapshot.case_id == int(case_id)).scalar() or 0 return int(max_version) + 1 @staticmethod def module_model(): return Module @staticmethod def case_model(): return TestCase @staticmethod def snapshot_model(): return CaseSnapshot @staticmethod def review_model(): return CaseReview @staticmethod def get_module_name_map(session, module_ids): if not module_ids: return {} module_items = session.query(Module).filter(Module.id.in_(module_ids), Module.is_delete == 0).all() return {module.id: module.name for module in module_items}