Files
smart-management-auto-test/base_framework/public_tools/pgsqlhelper.py
qiaoxinjiu 6994b185a3 addproject
2026-01-22 19:10:37 +08:00

555 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding:utf-8 -*-
"""
Author: qiaoxinjiu
Create Data: 2020/11/6 17:30
"""
import time
import re
from retrying import retry
from base_framework.public_tools import log
from base_framework.public_tools.db_dbutils_init import get_pg_connection # 假设有PostgreSQL连接池
from base_framework.public_tools.read_config import get_current_config, get_current_env
from base_framework.public_tools.huohua_dbs import HuoHuaDBS
obj_log = log.get_logger()
class PgSqlHelper:
"""
PostgreSQL数据库操作
"""
def __init__(self):
self.db = get_pg_connection() # PostgreSQL连接池
self.current_business = get_current_config(section='run_evn_name', key='current_business')
self.current_evn = get_current_env()
self.qa_db_to_sim_instance_name = {}
self.sim_dbs = HuoHuaDBS()
# 封装执行命令
def execute(self, sql, param=None, auto_close=False, choose_db=None):
"""
| 功能说明: | 执行具体的sql语句 |
| 输入参数: | sql | 待执行的sql语句 |
| | param=None | sql语句中where后跟的参数也可直接写在sql语句中 |
| | auto_close=True | 是否自动关闭数据库连接,默认:自动关闭 |
| 返回参数: | conncursorcount | 连接,游标,行数 |
"""
if self.current_evn.lower() == "sim":
raise Exception("SIM环境请直接使用huohua_dbs.py中的函数")
try:
cursor, conn = self.db.getconn(choose_db=choose_db) # 从连接池获取连接
except Exception as e:
# 打印详细的连接信息
try:
from base_framework.public_tools.read_config import ReadConfig
from base_framework.base_config.current_pth import config_file_path
rc = ReadConfig(config_file_path)
db_host = rc.get_value(sections='PostgreSQL', options='db_test_host')
db_port = rc.get_value(sections='PostgreSQL', options='db_test_port')
db_name = rc.get_value(sections='PostgreSQL', options='db_test_dbname')
db_user = rc.get_value(sections='PostgreSQL', options='db_test_user')
db_password = rc.get_value(sections='PostgreSQL', options='db_test_password')
error_info = """
PostgreSQL连接失败
连接配置信息:
主机(Host): {}
端口(Port): {}
数据库名(Database): {}
用户名(User): {}
密码(Password): {} (已隐藏)
选择数据库(ChooseDB): {}
错误详情: {}
""".format(
db_host, db_port, db_name, db_user,
'*' * len(db_password) if db_password else 'None',
choose_db or 'default',
str(e)
)
print(error_info)
obj_log.error(error_info)
except Exception as config_error:
error_info = "PostgreSQL连接失败: {} (无法读取配置信息: {})".format(str(e), str(config_error))
print(error_info)
obj_log.error(error_info)
raise ValueError("PostgreSQL连接失败: {}".format(str(e)))
try:
# count : 为改变的数据条数
if param:
# PostgreSQL使用 %s 作为占位符与MySQL相同
count = cursor.execute(sql, param)
else:
count = cursor.execute(sql)
conn.commit()
if auto_close:
self.close(cursor, conn)
except Exception as e:
obj_log.error("PostgreSQL执行SQL失败: {}, SQL: {}".format(str(e), sql))
raise ValueError("数据库操作失败SQL语句{}, 错误: {}".format(sql, str(e)))
return cursor, conn, count
# 释放连接
@staticmethod
def close(cursor, conn):
cursor.close()
conn.close()
# 查询所有
def select_all(self, sql, param=None, choose_db=None, show_log=True):
"""
| 功能说明: | 查询数据库 | 并返回所有结果 |
"""
if self.current_evn.lower() == "sim":
return self.sim_dbs.dbs_select(sql_content=sql)
if show_log:
if param is not None:
obj_log.info('SQL语句{}|{}'.format(sql, param))
else:
obj_log.info('SQL语句{}'.format(sql))
cursor, conn = None, None
try:
cursor, conn, count = self.execute(sql, param, choose_db=choose_db)
res = cursor.fetchall()
if show_log:
obj_log.info('数据库查询结果:{}'.format(res))
return res
except Exception as e:
if cursor and conn:
self.close(cursor, conn)
raise RuntimeError(e.args)
finally:
if cursor and conn and not show_log:
self.close(cursor, conn)
def select_all_as_list(self, sql, choose_db=None):
"""
| 功能说明: | 查询数据库 | 功能同select_all,仅返回结果为list |
"""
if self.current_evn.lower() == "sim":
return self.sim_dbs.dbs_select(sql_content=sql, r_type='list')
cursor, conn = None, None
try:
cursor, conn, count = self.execute(sql, choose_db=choose_db)
res = cursor.fetchall()
has_data = False
for item in res:
for key in item:
if item[key]:
has_data = True
break
if not has_data:
return []
res_list = []
for index in range(len(res)):
if len(res[index]) > 1:
res_row = []
for key in res[index]:
res_row.append(res[index][key])
res_list.append(res_row)
else:
for key in res[index]:
res_list.append(res[index][key])
return res_list
except Exception as e:
if cursor and conn:
self.close(cursor, conn)
raise RuntimeError(e.args)
finally:
if cursor and conn:
self.close(cursor, conn)
# 查询单条
def select_one(self, sql, param=None, choose_db=None):
"""
| 功能说明: | 查询数据库,并返第一行 |
"""
if self.current_evn.lower() == "sim":
if 'limit' not in sql.lower():
sql = sql.split(';')[0] + ' limit 1;'
sim_data = self.sim_dbs.dbs_select(sql_content=sql)
if sim_data:
return sim_data[0]
else:
return {}
cursor, conn = None, None
try:
cursor, conn, count = self.execute(sql, param, choose_db=choose_db)
res = cursor.fetchone()
return res
except Exception as e:
if cursor and conn:
self.close(cursor, conn)
raise RuntimeError(e.args)
finally:
if cursor and conn:
self.close(cursor, conn)
# 增加单条数据
def insert_one(self, sql, param=None, choose_db=None, log_level='info'):
"""
| 功能说明: | insert一行数据 |
"""
if self.current_evn.lower() == "sim":
return self.sim_dbs.dbs_execute_sql(sql_content=sql)
if log_level.lower() == 'info':
if param is not None:
obj_log.info('SQL语句{}|{}'.format(sql, param))
else:
obj_log.info('SQL语句{}'.format(sql))
cursor, conn = None, None
try:
cursor, conn, count = self.execute(sql, param, choose_db=choose_db)
conn.commit()
if log_level.lower() == 'info':
obj_log.info('插入数据库条数:{}'.format(count))
return count
except Exception as e:
if conn:
conn.rollback()
raise RuntimeError(e.args)
finally:
if cursor and conn:
self.close(cursor, conn)
# 插入后返回插入IDPostgreSQL方式
def insert_one_extension(self, sql, param=None, choose_db=None):
"""
| 功能说明: | insert一行数据并返回插入行的ID |
"""
return_dict = dict()
if param is not None:
obj_log.info('SQL语句{}|{}'.format(sql, param))
else:
obj_log.info('SQL语句{}'.format(sql))
# PostgreSQL需要在INSERT语句后添加RETURNING子句来获取ID
# 如果SQL中已有RETURNING则直接使用
if 'RETURNING' not in sql.upper():
# 尝试提取表名和主键列名(简化处理,实际需根据业务调整)
table_match = re.search(r'INSERT INTO\s+(\w+\.)?(\w+)', sql, re.IGNORECASE)
if table_match:
sql = sql.rstrip(';') + ' RETURNING id;'
cursor, conn = None, None
try:
cursor, conn = self.db.getconn(choose_db=choose_db)
if param:
cursor.execute(sql, param)
else:
cursor.execute(sql)
# 获取返回的ID
inserted_id = cursor.fetchone()[0] if 'RETURNING' in sql.upper() else None
count = cursor.rowcount
conn.commit()
obj_log.info('插入数据库条数:{}'.format(count))
return_dict['insert_count'] = count
return_dict['insert_id'] = inserted_id
return return_dict
except Exception as e:
if conn:
conn.rollback()
raise RuntimeError(e.args)
finally:
if cursor and conn:
self.close(cursor, conn)
# 插入多条数据
def insert_many(self, sql, param=None, choose_db=None):
"""
| 功能说明: | insert多行数据 |
"""
if param is not None:
obj_log.info('SQL语句{}|{}'.format(sql, param))
else:
obj_log.info('SQL语句{}'.format(sql))
cursor, conn = None, None
try:
cursor, conn = self.db.getconn(choose_db=choose_db)
count = cursor.executemany(sql, eval(str(param)))
conn.commit()
obj_log.info('插入数据库条数:{}'.format(count))
return count
except Exception as e:
if conn:
conn.rollback()
raise RuntimeError(e.args)
finally:
if cursor and conn:
self.close(cursor, conn)
# 插入多条并返回ID
def insert_many_extension(self, sql, param=None, choose_db=None):
"""
| 功能说明: | insert多行数据并返回ID列表 |
"""
return_dict = dict()
if param is not None:
obj_log.info('SQL语句{}|{}'.format(sql, param))
else:
obj_log.info('SQL语句{}'.format(sql))
# PostgreSQL需要在INSERT语句后添加RETURNING子句
if 'RETURNING' not in sql.upper():
sql = sql.rstrip(';') + ' RETURNING id;'
cursor, conn = None, None
try:
cursor, conn = self.db.getconn(choose_db=choose_db)
if param:
cursor.executemany(sql, eval(str(param)))
else:
cursor.execute(sql)
# 获取所有返回的ID
inserted_ids = [row[0] for row in cursor.fetchall()] if 'RETURNING' in sql.upper() else []
count = cursor.rowcount
conn.commit()
obj_log.info('插入数据库条数:{}'.format(count))
return_dict['insert_count'] = count
return_dict['insert_ids'] = inserted_ids
if inserted_ids:
return_dict['insert_id'] = inserted_ids[0] # 第一条数据的ID
return return_dict
except Exception as e:
if conn:
conn.rollback()
raise RuntimeError(e.args)
finally:
if cursor and conn:
self.close(cursor, conn)
# 删除
def delete(self, sql, param=None, choose_db=None):
"""
| 功能说明: | 删除数据库记录 |
"""
if self.current_evn.lower() == "sim":
return self.sim_dbs.dbs_execute_sql(sql_content=sql)
if param is not None:
obj_log.info('SQL语句{}|{}'.format(sql, param))
else:
obj_log.info('SQL语句{}'.format(sql))
cursor, conn = None, None
try:
cursor, conn, count = self.execute(sql, param, choose_db=choose_db)
obj_log.info('删除数据库条数:{}'.format(count))
return count
except Exception as e:
if cursor and conn:
self.close(cursor, conn)
raise RuntimeError(e.args)
finally:
if cursor and conn:
self.close(cursor, conn)
# 更新
def update(self, sql, param=None, choose_db=None):
"""
| 功能说明: | 更新数据库记录 |
"""
if self.current_evn.lower() == "sim":
return self.sim_dbs.dbs_execute_sql(sql_content=sql)
if param is not None:
obj_log.info('SQL语句{}|{}'.format(sql, param))
else:
obj_log.info('SQL语句{}'.format(sql))
cursor, conn = None, None
try:
cursor, conn, count = self.execute(sql, param, choose_db=choose_db)
conn.commit()
obj_log.info('更新数据库条数:{}'.format(count))
return count
except Exception as e:
if conn:
conn.rollback()
raise RuntimeError(e.args)
finally:
if cursor and conn:
self.close(cursor, conn)
# 以下方法保持原样,仅需确保内部调用的方法正确
def check_result_exist(self, *select_statement, retry_count=3):
start = 0
flag = 0
while start <= retry_count:
for sql in select_statement:
res = self.select_all(sql=sql)
if not res:
if start == retry_count:
obj_log.info('the result is not exist,but expect at least one')
raise RuntimeError(u'No results, when exec sql: %s' % sql)
else:
obj_log.info('the result is not exist,retry {}'.format(start + 1))
start += 1
time.sleep(1)
else:
obj_log.info('find [{0}] results when exec {1}'.format(len(res), sql))
flag += 1
break
if flag > 0:
break
def check_result_not_exist(self, *select_statement):
for sql in select_statement:
res = self.select_all(sql=sql)
if res:
obj_log.info('find [{0}] results, but expect 0'.format(len(res)))
raise RuntimeError(u'the result exist, when exec sql: %s' % sql)
else:
obj_log.info('the result is not exist, this step pass...')
def row_count(self, selectStatement, param=None):
cursor, conn = None, None
try:
cursor, conn, count = self.execute(selectStatement, param)
return count
except Exception as e:
print("error_msg:", e.args)
return 0
finally:
if cursor and conn:
self.close(cursor, conn)
# 数据库断言方法适配PostgreSQL
def kw_check_if_exists_in_database(self, selectStatement, choose_db=None):
obj_log.info('Executing : Check If Exists In Database | %s ' % selectStatement)
if not self.select_one(selectStatement, choose_db=choose_db):
raise AssertionError("Expected to have have at least one row from '%s' "
"but got 0 rows." % selectStatement)
else:
return True
def kw_check_if_not_exists_in_database(self, selectStatement, choose_db=None):
obj_log.info('Executing : Check If Not Exists In Database | %s ' % selectStatement)
queryResults = self.select_one(selectStatement, choose_db=choose_db)
if queryResults:
raise AssertionError("Expected to have have no rows from '%s' "
"but got some rows : %s." % (selectStatement, queryResults))
else:
return True
def kw_row_count_is_0(self, selectStatement):
obj_log.info('Executing : Row Count Is 0 | %s ' % selectStatement)
num_rows = self.row_count(selectStatement)
if num_rows > 0:
raise AssertionError("Expected zero rows to be returned from '%s' "
"but got rows back. Number of rows returned was %s" % (selectStatement, num_rows))
else:
return True
def kw_row_count_is_equal_to_x(self, selectStatement, numRows):
obj_log.info('Executing : Row Count Is Equal To X | %s | %s ' % (selectStatement, numRows))
num_rows = self.row_count(selectStatement)
if num_rows != int(str(numRows)):
raise AssertionError("Expected same number of rows to be returned from '%s' "
"than the returned rows of %s" % (selectStatement, num_rows))
else:
return True
def kw_row_count_is_greater_than_x(self, selectStatement, numRows):
obj_log.info('Executing : Row Count Is Greater Than X | %s | %s ' % (selectStatement, numRows))
num_rows = self.row_count(selectStatement)
if num_rows <= int(numRows):
raise AssertionError("Expected more rows to be returned from '%s' "
"than the returned rows of %s" % (selectStatement, num_rows))
else:
return True
def kw_row_count_is_less_than_x(self, selectStatement, numRows):
obj_log.info('Executing : Row Count Is Less Than X | %s | %s ' % (selectStatement, numRows))
num_rows = self.row_count(selectStatement)
if num_rows >= int(numRows):
raise AssertionError("Expected less rows to be returned from '%s' "
"than the returned rows of %s" % (selectStatement, num_rows))
else:
return True
def kw_table_must_exist(self, tableName):
"""
PostgreSQL检查表是否存在
"""
obj_log.info('Executing : Table Must Exist | %s ' % tableName)
# PostgreSQL检查表是否存在的查询
selectStatement = """
SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s
);
"""
try:
result = self.select_one(selectStatement, (tableName,))
if not result or not result[0]: # 结果为False或None
raise AssertionError("Table '%s' does not exist in the database" % tableName)
return True
except Exception as e:
raise AssertionError("Error checking table existence: %s" % str(e))
# 生成器方式查询
def select_all_as_generator(self, sql, param=None, choose_db=None):
if param is not None:
obj_log.info('SQL语句{}|{}'.format(sql, param))
else:
obj_log.info('SQL语句{}'.format(sql))
cursor, conn = None, None
try:
cursor, conn, counts = self.execute(sql, param, choose_db=choose_db)
for count in range(counts):
item = cursor.fetchone()
obj_log.info(item)
yield item
except Exception as e:
raise RuntimeError(e.args)
finally:
if cursor and conn:
self.close(cursor, conn)
def get_qa_to_sim_dbs(self):
get_sql = "select qa_db,sim_instance from sparkatp.qa_db_mapping_sim_dbs where status=1 and is_delete=0"
res_info = self.sim_dbs.dbs_query_by_db_name("qadb-slave", "sparkatp", get_sql, limit_num=1000)
for item in res_info:
self.qa_db_to_sim_instance_name.update({item[0]: item[1]})
return self.qa_db_to_sim_instance_name
def get_instance_name(self, sql_str):
db_name = sql_str.split(".")[0].split(" ")[-1].replace("`", "").replace(" ", "")
return self.qa_db_to_sim_instance_name.get(db_name)
if __name__ == '__main__':
# 测试示例
db = PgSqlHelper()
# 测试查询
sql = "SELECT * FROM account.account LIMIT 5;"
res = db.select_one(sql=sql)
print(res)
# 测试插入带RETURNING
insert_sql = """
INSERT INTO account.account (username, email, created_at)
VALUES (%s, %s, NOW())
RETURNING id;
"""
insert_params = ('test_user', 'test@example.com')
insert_result = db.insert_one_extension(insert_sql, insert_params)
print(f"插入结果: {insert_result}")