# -*- coding:utf-8 -*- """ Author: qiaoxinjiu Create Data: 2020/11/6 17:30 """ import time import re from retrying import retry from base_framework.public_tools import log from base_framework.public_tools.db_dbutils_init import get_pg_connection # 假设有PostgreSQL连接池 from base_framework.public_tools.read_config import get_current_config, get_current_env from base_framework.public_tools.huohua_dbs import HuoHuaDBS obj_log = log.get_logger() class PgSqlHelper: """ PostgreSQL数据库操作 """ def __init__(self): self.db = get_pg_connection() # PostgreSQL连接池 self.current_business = get_current_config(section='run_evn_name', key='current_business') self.current_evn = get_current_env() self.qa_db_to_sim_instance_name = {} self.sim_dbs = HuoHuaDBS() # 封装执行命令 def execute(self, sql, param=None, auto_close=False, choose_db=None): """ | 功能说明: | 执行具体的sql语句 | | 输入参数: | sql | 待执行的sql语句 | | | param=None | sql语句中where后跟的参数,也可直接写在sql语句中 | | | auto_close=True | 是否自动关闭数据库连接,默认:自动关闭 | | 返回参数: | conn,cursor,count | 连接,游标,行数 | """ if self.current_evn.lower() == "sim": raise Exception("SIM环境请直接使用huohua_dbs.py中的函数") try: cursor, conn = self.db.getconn(choose_db=choose_db) # 从连接池获取连接 except Exception as e: # 打印详细的连接信息 try: from base_framework.public_tools.read_config import ReadConfig from base_framework.base_config.current_pth import config_file_path rc = ReadConfig(config_file_path) db_host = rc.get_value(sections='PostgreSQL', options='db_test_host') db_port = rc.get_value(sections='PostgreSQL', options='db_test_port') db_name = rc.get_value(sections='PostgreSQL', options='db_test_dbname') db_user = rc.get_value(sections='PostgreSQL', options='db_test_user') db_password = rc.get_value(sections='PostgreSQL', options='db_test_password') error_info = """ PostgreSQL连接失败! 连接配置信息: 主机(Host): {} 端口(Port): {} 数据库名(Database): {} 用户名(User): {} 密码(Password): {} (已隐藏) 选择数据库(ChooseDB): {} 错误详情: {} """.format( db_host, db_port, db_name, db_user, '*' * len(db_password) if db_password else 'None', choose_db or 'default', str(e) ) print(error_info) obj_log.error(error_info) except Exception as config_error: error_info = "PostgreSQL连接失败: {} (无法读取配置信息: {})".format(str(e), str(config_error)) print(error_info) obj_log.error(error_info) raise ValueError("PostgreSQL连接失败: {}".format(str(e))) try: # count : 为改变的数据条数 if param: # PostgreSQL使用 %s 作为占位符,与MySQL相同 count = cursor.execute(sql, param) else: count = cursor.execute(sql) conn.commit() if auto_close: self.close(cursor, conn) except Exception as e: obj_log.error("PostgreSQL执行SQL失败: {}, SQL: {}".format(str(e), sql)) raise ValueError("数据库操作失败,SQL语句:{}, 错误: {}".format(sql, str(e))) return cursor, conn, count # 释放连接 @staticmethod def close(cursor, conn): cursor.close() conn.close() # 查询所有 def select_all(self, sql, param=None, choose_db=None, show_log=True): """ | 功能说明: | 查询数据库 | 并返回所有结果 | """ if self.current_evn.lower() == "sim": return self.sim_dbs.dbs_select(sql_content=sql) if show_log: if param is not None: obj_log.info('SQL语句:{}|{}'.format(sql, param)) else: obj_log.info('SQL语句:{}'.format(sql)) cursor, conn = None, None try: cursor, conn, count = self.execute(sql, param, choose_db=choose_db) res = cursor.fetchall() if show_log: obj_log.info('数据库查询结果:{}'.format(res)) return res except Exception as e: if cursor and conn: self.close(cursor, conn) raise RuntimeError(e.args) finally: if cursor and conn and not show_log: self.close(cursor, conn) def select_all_as_list(self, sql, choose_db=None): """ | 功能说明: | 查询数据库 | 功能同select_all,仅返回结果为list | """ if self.current_evn.lower() == "sim": return self.sim_dbs.dbs_select(sql_content=sql, r_type='list') cursor, conn = None, None try: cursor, conn, count = self.execute(sql, choose_db=choose_db) res = cursor.fetchall() has_data = False for item in res: for key in item: if item[key]: has_data = True break if not has_data: return [] res_list = [] for index in range(len(res)): if len(res[index]) > 1: res_row = [] for key in res[index]: res_row.append(res[index][key]) res_list.append(res_row) else: for key in res[index]: res_list.append(res[index][key]) return res_list except Exception as e: if cursor and conn: self.close(cursor, conn) raise RuntimeError(e.args) finally: if cursor and conn: self.close(cursor, conn) # 查询单条 def select_one(self, sql, param=None, choose_db=None): """ | 功能说明: | 查询数据库,并返第一行 | """ if self.current_evn.lower() == "sim": if 'limit' not in sql.lower(): sql = sql.split(';')[0] + ' limit 1;' sim_data = self.sim_dbs.dbs_select(sql_content=sql) if sim_data: return sim_data[0] else: return {} cursor, conn = None, None try: cursor, conn, count = self.execute(sql, param, choose_db=choose_db) res = cursor.fetchone() return res except Exception as e: if cursor and conn: self.close(cursor, conn) raise RuntimeError(e.args) finally: if cursor and conn: self.close(cursor, conn) # 增加单条数据 def insert_one(self, sql, param=None, choose_db=None, log_level='info'): """ | 功能说明: | insert一行数据 | """ if self.current_evn.lower() == "sim": return self.sim_dbs.dbs_execute_sql(sql_content=sql) if log_level.lower() == 'info': if param is not None: obj_log.info('SQL语句:{}|{}'.format(sql, param)) else: obj_log.info('SQL语句:{}'.format(sql)) cursor, conn = None, None try: cursor, conn, count = self.execute(sql, param, choose_db=choose_db) conn.commit() if log_level.lower() == 'info': obj_log.info('插入数据库条数:{}'.format(count)) return count except Exception as e: if conn: conn.rollback() raise RuntimeError(e.args) finally: if cursor and conn: self.close(cursor, conn) # 插入后返回插入ID(PostgreSQL方式) def insert_one_extension(self, sql, param=None, choose_db=None): """ | 功能说明: | insert一行数据,并返回插入行的ID | """ return_dict = dict() if param is not None: obj_log.info('SQL语句:{}|{}'.format(sql, param)) else: obj_log.info('SQL语句:{}'.format(sql)) # PostgreSQL需要在INSERT语句后添加RETURNING子句来获取ID # 如果SQL中已有RETURNING,则直接使用 if 'RETURNING' not in sql.upper(): # 尝试提取表名和主键列名(简化处理,实际需根据业务调整) table_match = re.search(r'INSERT INTO\s+(\w+\.)?(\w+)', sql, re.IGNORECASE) if table_match: sql = sql.rstrip(';') + ' RETURNING id;' cursor, conn = None, None try: cursor, conn = self.db.getconn(choose_db=choose_db) if param: cursor.execute(sql, param) else: cursor.execute(sql) # 获取返回的ID inserted_id = cursor.fetchone()[0] if 'RETURNING' in sql.upper() else None count = cursor.rowcount conn.commit() obj_log.info('插入数据库条数:{}'.format(count)) return_dict['insert_count'] = count return_dict['insert_id'] = inserted_id return return_dict except Exception as e: if conn: conn.rollback() raise RuntimeError(e.args) finally: if cursor and conn: self.close(cursor, conn) # 插入多条数据 def insert_many(self, sql, param=None, choose_db=None): """ | 功能说明: | insert多行数据 | """ if param is not None: obj_log.info('SQL语句:{}|{}'.format(sql, param)) else: obj_log.info('SQL语句:{}'.format(sql)) cursor, conn = None, None try: cursor, conn = self.db.getconn(choose_db=choose_db) count = cursor.executemany(sql, eval(str(param))) conn.commit() obj_log.info('插入数据库条数:{}'.format(count)) return count except Exception as e: if conn: conn.rollback() raise RuntimeError(e.args) finally: if cursor and conn: self.close(cursor, conn) # 插入多条并返回ID def insert_many_extension(self, sql, param=None, choose_db=None): """ | 功能说明: | insert多行数据,并返回ID列表 | """ return_dict = dict() if param is not None: obj_log.info('SQL语句:{}|{}'.format(sql, param)) else: obj_log.info('SQL语句:{}'.format(sql)) # PostgreSQL需要在INSERT语句后添加RETURNING子句 if 'RETURNING' not in sql.upper(): sql = sql.rstrip(';') + ' RETURNING id;' cursor, conn = None, None try: cursor, conn = self.db.getconn(choose_db=choose_db) if param: cursor.executemany(sql, eval(str(param))) else: cursor.execute(sql) # 获取所有返回的ID inserted_ids = [row[0] for row in cursor.fetchall()] if 'RETURNING' in sql.upper() else [] count = cursor.rowcount conn.commit() obj_log.info('插入数据库条数:{}'.format(count)) return_dict['insert_count'] = count return_dict['insert_ids'] = inserted_ids if inserted_ids: return_dict['insert_id'] = inserted_ids[0] # 第一条数据的ID return return_dict except Exception as e: if conn: conn.rollback() raise RuntimeError(e.args) finally: if cursor and conn: self.close(cursor, conn) # 删除 def delete(self, sql, param=None, choose_db=None): """ | 功能说明: | 删除数据库记录 | """ if self.current_evn.lower() == "sim": return self.sim_dbs.dbs_execute_sql(sql_content=sql) if param is not None: obj_log.info('SQL语句:{}|{}'.format(sql, param)) else: obj_log.info('SQL语句:{}'.format(sql)) cursor, conn = None, None try: cursor, conn, count = self.execute(sql, param, choose_db=choose_db) obj_log.info('删除数据库条数:{}'.format(count)) return count except Exception as e: if cursor and conn: self.close(cursor, conn) raise RuntimeError(e.args) finally: if cursor and conn: self.close(cursor, conn) # 更新 def update(self, sql, param=None, choose_db=None): """ | 功能说明: | 更新数据库记录 | """ if self.current_evn.lower() == "sim": return self.sim_dbs.dbs_execute_sql(sql_content=sql) if param is not None: obj_log.info('SQL语句:{}|{}'.format(sql, param)) else: obj_log.info('SQL语句:{}'.format(sql)) cursor, conn = None, None try: cursor, conn, count = self.execute(sql, param, choose_db=choose_db) conn.commit() obj_log.info('更新数据库条数:{}'.format(count)) return count except Exception as e: if conn: conn.rollback() raise RuntimeError(e.args) finally: if cursor and conn: self.close(cursor, conn) # 以下方法保持原样,仅需确保内部调用的方法正确 def check_result_exist(self, *select_statement, retry_count=3): start = 0 flag = 0 while start <= retry_count: for sql in select_statement: res = self.select_all(sql=sql) if not res: if start == retry_count: obj_log.info('the result is not exist,but expect at least one') raise RuntimeError(u'No results, when exec sql: %s' % sql) else: obj_log.info('the result is not exist,retry {}'.format(start + 1)) start += 1 time.sleep(1) else: obj_log.info('find [{0}] results when exec :{1}'.format(len(res), sql)) flag += 1 break if flag > 0: break def check_result_not_exist(self, *select_statement): for sql in select_statement: res = self.select_all(sql=sql) if res: obj_log.info('find [{0}] results, but expect 0'.format(len(res))) raise RuntimeError(u'the result exist, when exec sql: %s' % sql) else: obj_log.info('the result is not exist, this step pass...') def row_count(self, selectStatement, param=None): cursor, conn = None, None try: cursor, conn, count = self.execute(selectStatement, param) return count except Exception as e: print("error_msg:", e.args) return 0 finally: if cursor and conn: self.close(cursor, conn) # 数据库断言方法(适配PostgreSQL) def kw_check_if_exists_in_database(self, selectStatement, choose_db=None): obj_log.info('Executing : Check If Exists In Database | %s ' % selectStatement) if not self.select_one(selectStatement, choose_db=choose_db): raise AssertionError("Expected to have have at least one row from '%s' " "but got 0 rows." % selectStatement) else: return True def kw_check_if_not_exists_in_database(self, selectStatement, choose_db=None): obj_log.info('Executing : Check If Not Exists In Database | %s ' % selectStatement) queryResults = self.select_one(selectStatement, choose_db=choose_db) if queryResults: raise AssertionError("Expected to have have no rows from '%s' " "but got some rows : %s." % (selectStatement, queryResults)) else: return True def kw_row_count_is_0(self, selectStatement): obj_log.info('Executing : Row Count Is 0 | %s ' % selectStatement) num_rows = self.row_count(selectStatement) if num_rows > 0: raise AssertionError("Expected zero rows to be returned from '%s' " "but got rows back. Number of rows returned was %s" % (selectStatement, num_rows)) else: return True def kw_row_count_is_equal_to_x(self, selectStatement, numRows): obj_log.info('Executing : Row Count Is Equal To X | %s | %s ' % (selectStatement, numRows)) num_rows = self.row_count(selectStatement) if num_rows != int(str(numRows)): raise AssertionError("Expected same number of rows to be returned from '%s' " "than the returned rows of %s" % (selectStatement, num_rows)) else: return True def kw_row_count_is_greater_than_x(self, selectStatement, numRows): obj_log.info('Executing : Row Count Is Greater Than X | %s | %s ' % (selectStatement, numRows)) num_rows = self.row_count(selectStatement) if num_rows <= int(numRows): raise AssertionError("Expected more rows to be returned from '%s' " "than the returned rows of %s" % (selectStatement, num_rows)) else: return True def kw_row_count_is_less_than_x(self, selectStatement, numRows): obj_log.info('Executing : Row Count Is Less Than X | %s | %s ' % (selectStatement, numRows)) num_rows = self.row_count(selectStatement) if num_rows >= int(numRows): raise AssertionError("Expected less rows to be returned from '%s' " "than the returned rows of %s" % (selectStatement, num_rows)) else: return True def kw_table_must_exist(self, tableName): """ PostgreSQL检查表是否存在 """ obj_log.info('Executing : Table Must Exist | %s ' % tableName) # PostgreSQL检查表是否存在的查询 selectStatement = """ SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = %s ); """ try: result = self.select_one(selectStatement, (tableName,)) if not result or not result[0]: # 结果为False或None raise AssertionError("Table '%s' does not exist in the database" % tableName) return True except Exception as e: raise AssertionError("Error checking table existence: %s" % str(e)) # 生成器方式查询 def select_all_as_generator(self, sql, param=None, choose_db=None): if param is not None: obj_log.info('SQL语句:{}|{}'.format(sql, param)) else: obj_log.info('SQL语句:{}'.format(sql)) cursor, conn = None, None try: cursor, conn, counts = self.execute(sql, param, choose_db=choose_db) for count in range(counts): item = cursor.fetchone() obj_log.info(item) yield item except Exception as e: raise RuntimeError(e.args) finally: if cursor and conn: self.close(cursor, conn) def get_qa_to_sim_dbs(self): get_sql = "select qa_db,sim_instance from sparkatp.qa_db_mapping_sim_dbs where status=1 and is_delete=0" res_info = self.sim_dbs.dbs_query_by_db_name("qadb-slave", "sparkatp", get_sql, limit_num=1000) for item in res_info: self.qa_db_to_sim_instance_name.update({item[0]: item[1]}) return self.qa_db_to_sim_instance_name def get_instance_name(self, sql_str): db_name = sql_str.split(".")[0].split(" ")[-1].replace("`", "").replace(" ", "") return self.qa_db_to_sim_instance_name.get(db_name) if __name__ == '__main__': # 测试示例 db = PgSqlHelper() # 测试查询 sql = "SELECT * FROM account.account LIMIT 5;" res = db.select_one(sql=sql) print(res) # 测试插入(带RETURNING) insert_sql = """ INSERT INTO account.account (username, email, created_at) VALUES (%s, %s, NOW()) RETURNING id; """ insert_params = ('test_user', 'test@example.com') insert_result = db.insert_one_extension(insert_sql, insert_params) print(f"插入结果: {insert_result}")