555 lines
21 KiB
Python
555 lines
21 KiB
Python
# -*- coding:utf-8 -*-
|
||
"""
|
||
Author: qiaoxinjiu
|
||
Create Data: 2020/11/6 17:30
|
||
"""
|
||
import time
|
||
import re
|
||
from retrying import retry
|
||
from base_framework.public_tools import log
|
||
from base_framework.public_tools.db_dbutils_init import get_pg_connection # 假设有PostgreSQL连接池
|
||
from base_framework.public_tools.read_config import get_current_config, get_current_env
|
||
from base_framework.public_tools.huohua_dbs import HuoHuaDBS
|
||
|
||
obj_log = log.get_logger()
|
||
|
||
|
||
class PgSqlHelper:
|
||
"""
|
||
PostgreSQL数据库操作
|
||
"""
|
||
|
||
def __init__(self):
|
||
self.db = get_pg_connection() # PostgreSQL连接池
|
||
self.current_business = get_current_config(section='run_evn_name', key='current_business')
|
||
self.current_evn = get_current_env()
|
||
self.qa_db_to_sim_instance_name = {}
|
||
self.sim_dbs = HuoHuaDBS()
|
||
|
||
# 封装执行命令
|
||
def execute(self, sql, param=None, auto_close=False, choose_db=None):
|
||
"""
|
||
| 功能说明: | 执行具体的sql语句 |
|
||
| 输入参数: | sql | 待执行的sql语句 |
|
||
| | param=None | sql语句中where后跟的参数,也可直接写在sql语句中 |
|
||
| | auto_close=True | 是否自动关闭数据库连接,默认:自动关闭 |
|
||
| 返回参数: | conn,cursor,count | 连接,游标,行数 |
|
||
"""
|
||
if self.current_evn.lower() == "sim":
|
||
raise Exception("SIM环境请直接使用huohua_dbs.py中的函数")
|
||
|
||
try:
|
||
cursor, conn = self.db.getconn(choose_db=choose_db) # 从连接池获取连接
|
||
except Exception as e:
|
||
# 打印详细的连接信息
|
||
try:
|
||
from base_framework.public_tools.read_config import ReadConfig
|
||
from base_framework.base_config.current_pth import config_file_path
|
||
rc = ReadConfig(config_file_path)
|
||
db_host = rc.get_value(sections='PostgreSQL', options='db_test_host')
|
||
db_port = rc.get_value(sections='PostgreSQL', options='db_test_port')
|
||
db_name = rc.get_value(sections='PostgreSQL', options='db_test_dbname')
|
||
db_user = rc.get_value(sections='PostgreSQL', options='db_test_user')
|
||
db_password = rc.get_value(sections='PostgreSQL', options='db_test_password')
|
||
|
||
error_info = """
|
||
PostgreSQL连接失败!
|
||
连接配置信息:
|
||
主机(Host): {}
|
||
端口(Port): {}
|
||
数据库名(Database): {}
|
||
用户名(User): {}
|
||
密码(Password): {} (已隐藏)
|
||
选择数据库(ChooseDB): {}
|
||
错误详情: {}
|
||
""".format(
|
||
db_host, db_port, db_name, db_user,
|
||
'*' * len(db_password) if db_password else 'None',
|
||
choose_db or 'default',
|
||
str(e)
|
||
)
|
||
print(error_info)
|
||
obj_log.error(error_info)
|
||
except Exception as config_error:
|
||
error_info = "PostgreSQL连接失败: {} (无法读取配置信息: {})".format(str(e), str(config_error))
|
||
print(error_info)
|
||
obj_log.error(error_info)
|
||
raise ValueError("PostgreSQL连接失败: {}".format(str(e)))
|
||
|
||
try:
|
||
# count : 为改变的数据条数
|
||
if param:
|
||
# PostgreSQL使用 %s 作为占位符,与MySQL相同
|
||
count = cursor.execute(sql, param)
|
||
else:
|
||
count = cursor.execute(sql)
|
||
conn.commit()
|
||
if auto_close:
|
||
self.close(cursor, conn)
|
||
except Exception as e:
|
||
obj_log.error("PostgreSQL执行SQL失败: {}, SQL: {}".format(str(e), sql))
|
||
raise ValueError("数据库操作失败,SQL语句:{}, 错误: {}".format(sql, str(e)))
|
||
return cursor, conn, count
|
||
|
||
# 释放连接
|
||
@staticmethod
|
||
def close(cursor, conn):
|
||
cursor.close()
|
||
conn.close()
|
||
|
||
# 查询所有
|
||
def select_all(self, sql, param=None, choose_db=None, show_log=True):
|
||
"""
|
||
| 功能说明: | 查询数据库 | 并返回所有结果 |
|
||
"""
|
||
if self.current_evn.lower() == "sim":
|
||
return self.sim_dbs.dbs_select(sql_content=sql)
|
||
|
||
if show_log:
|
||
if param is not None:
|
||
obj_log.info('SQL语句:{}|{}'.format(sql, param))
|
||
else:
|
||
obj_log.info('SQL语句:{}'.format(sql))
|
||
|
||
cursor, conn = None, None
|
||
try:
|
||
cursor, conn, count = self.execute(sql, param, choose_db=choose_db)
|
||
res = cursor.fetchall()
|
||
if show_log:
|
||
obj_log.info('数据库查询结果:{}'.format(res))
|
||
return res
|
||
except Exception as e:
|
||
if cursor and conn:
|
||
self.close(cursor, conn)
|
||
raise RuntimeError(e.args)
|
||
finally:
|
||
if cursor and conn and not show_log:
|
||
self.close(cursor, conn)
|
||
|
||
def select_all_as_list(self, sql, choose_db=None):
|
||
"""
|
||
| 功能说明: | 查询数据库 | 功能同select_all,仅返回结果为list |
|
||
"""
|
||
if self.current_evn.lower() == "sim":
|
||
return self.sim_dbs.dbs_select(sql_content=sql, r_type='list')
|
||
|
||
cursor, conn = None, None
|
||
try:
|
||
cursor, conn, count = self.execute(sql, choose_db=choose_db)
|
||
res = cursor.fetchall()
|
||
has_data = False
|
||
for item in res:
|
||
for key in item:
|
||
if item[key]:
|
||
has_data = True
|
||
break
|
||
if not has_data:
|
||
return []
|
||
res_list = []
|
||
for index in range(len(res)):
|
||
if len(res[index]) > 1:
|
||
res_row = []
|
||
for key in res[index]:
|
||
res_row.append(res[index][key])
|
||
res_list.append(res_row)
|
||
else:
|
||
for key in res[index]:
|
||
res_list.append(res[index][key])
|
||
return res_list
|
||
except Exception as e:
|
||
if cursor and conn:
|
||
self.close(cursor, conn)
|
||
raise RuntimeError(e.args)
|
||
finally:
|
||
if cursor and conn:
|
||
self.close(cursor, conn)
|
||
|
||
# 查询单条
|
||
def select_one(self, sql, param=None, choose_db=None):
|
||
"""
|
||
| 功能说明: | 查询数据库,并返第一行 |
|
||
"""
|
||
if self.current_evn.lower() == "sim":
|
||
if 'limit' not in sql.lower():
|
||
sql = sql.split(';')[0] + ' limit 1;'
|
||
sim_data = self.sim_dbs.dbs_select(sql_content=sql)
|
||
if sim_data:
|
||
return sim_data[0]
|
||
else:
|
||
return {}
|
||
|
||
cursor, conn = None, None
|
||
try:
|
||
cursor, conn, count = self.execute(sql, param, choose_db=choose_db)
|
||
res = cursor.fetchone()
|
||
return res
|
||
except Exception as e:
|
||
if cursor and conn:
|
||
self.close(cursor, conn)
|
||
raise RuntimeError(e.args)
|
||
finally:
|
||
if cursor and conn:
|
||
self.close(cursor, conn)
|
||
|
||
# 增加单条数据
|
||
def insert_one(self, sql, param=None, choose_db=None, log_level='info'):
|
||
"""
|
||
| 功能说明: | insert一行数据 |
|
||
"""
|
||
if self.current_evn.lower() == "sim":
|
||
return self.sim_dbs.dbs_execute_sql(sql_content=sql)
|
||
|
||
if log_level.lower() == 'info':
|
||
if param is not None:
|
||
obj_log.info('SQL语句:{}|{}'.format(sql, param))
|
||
else:
|
||
obj_log.info('SQL语句:{}'.format(sql))
|
||
|
||
cursor, conn = None, None
|
||
try:
|
||
cursor, conn, count = self.execute(sql, param, choose_db=choose_db)
|
||
conn.commit()
|
||
if log_level.lower() == 'info':
|
||
obj_log.info('插入数据库条数:{}'.format(count))
|
||
return count
|
||
except Exception as e:
|
||
if conn:
|
||
conn.rollback()
|
||
raise RuntimeError(e.args)
|
||
finally:
|
||
if cursor and conn:
|
||
self.close(cursor, conn)
|
||
|
||
# 插入后返回插入ID(PostgreSQL方式)
|
||
def insert_one_extension(self, sql, param=None, choose_db=None):
|
||
"""
|
||
| 功能说明: | insert一行数据,并返回插入行的ID |
|
||
"""
|
||
return_dict = dict()
|
||
if param is not None:
|
||
obj_log.info('SQL语句:{}|{}'.format(sql, param))
|
||
else:
|
||
obj_log.info('SQL语句:{}'.format(sql))
|
||
|
||
# PostgreSQL需要在INSERT语句后添加RETURNING子句来获取ID
|
||
# 如果SQL中已有RETURNING,则直接使用
|
||
if 'RETURNING' not in sql.upper():
|
||
# 尝试提取表名和主键列名(简化处理,实际需根据业务调整)
|
||
table_match = re.search(r'INSERT INTO\s+(\w+\.)?(\w+)', sql, re.IGNORECASE)
|
||
if table_match:
|
||
sql = sql.rstrip(';') + ' RETURNING id;'
|
||
|
||
cursor, conn = None, None
|
||
try:
|
||
cursor, conn = self.db.getconn(choose_db=choose_db)
|
||
if param:
|
||
cursor.execute(sql, param)
|
||
else:
|
||
cursor.execute(sql)
|
||
|
||
# 获取返回的ID
|
||
inserted_id = cursor.fetchone()[0] if 'RETURNING' in sql.upper() else None
|
||
count = cursor.rowcount
|
||
|
||
conn.commit()
|
||
obj_log.info('插入数据库条数:{}'.format(count))
|
||
return_dict['insert_count'] = count
|
||
return_dict['insert_id'] = inserted_id
|
||
return return_dict
|
||
except Exception as e:
|
||
if conn:
|
||
conn.rollback()
|
||
raise RuntimeError(e.args)
|
||
finally:
|
||
if cursor and conn:
|
||
self.close(cursor, conn)
|
||
|
||
# 插入多条数据
|
||
def insert_many(self, sql, param=None, choose_db=None):
|
||
"""
|
||
| 功能说明: | insert多行数据 |
|
||
"""
|
||
if param is not None:
|
||
obj_log.info('SQL语句:{}|{}'.format(sql, param))
|
||
else:
|
||
obj_log.info('SQL语句:{}'.format(sql))
|
||
|
||
cursor, conn = None, None
|
||
try:
|
||
cursor, conn = self.db.getconn(choose_db=choose_db)
|
||
count = cursor.executemany(sql, eval(str(param)))
|
||
conn.commit()
|
||
obj_log.info('插入数据库条数:{}'.format(count))
|
||
return count
|
||
except Exception as e:
|
||
if conn:
|
||
conn.rollback()
|
||
raise RuntimeError(e.args)
|
||
finally:
|
||
if cursor and conn:
|
||
self.close(cursor, conn)
|
||
|
||
# 插入多条并返回ID
|
||
def insert_many_extension(self, sql, param=None, choose_db=None):
|
||
"""
|
||
| 功能说明: | insert多行数据,并返回ID列表 |
|
||
"""
|
||
return_dict = dict()
|
||
if param is not None:
|
||
obj_log.info('SQL语句:{}|{}'.format(sql, param))
|
||
else:
|
||
obj_log.info('SQL语句:{}'.format(sql))
|
||
|
||
# PostgreSQL需要在INSERT语句后添加RETURNING子句
|
||
if 'RETURNING' not in sql.upper():
|
||
sql = sql.rstrip(';') + ' RETURNING id;'
|
||
|
||
cursor, conn = None, None
|
||
try:
|
||
cursor, conn = self.db.getconn(choose_db=choose_db)
|
||
if param:
|
||
cursor.executemany(sql, eval(str(param)))
|
||
else:
|
||
cursor.execute(sql)
|
||
|
||
# 获取所有返回的ID
|
||
inserted_ids = [row[0] for row in cursor.fetchall()] if 'RETURNING' in sql.upper() else []
|
||
count = cursor.rowcount
|
||
|
||
conn.commit()
|
||
obj_log.info('插入数据库条数:{}'.format(count))
|
||
return_dict['insert_count'] = count
|
||
return_dict['insert_ids'] = inserted_ids
|
||
if inserted_ids:
|
||
return_dict['insert_id'] = inserted_ids[0] # 第一条数据的ID
|
||
return return_dict
|
||
except Exception as e:
|
||
if conn:
|
||
conn.rollback()
|
||
raise RuntimeError(e.args)
|
||
finally:
|
||
if cursor and conn:
|
||
self.close(cursor, conn)
|
||
|
||
# 删除
|
||
def delete(self, sql, param=None, choose_db=None):
|
||
"""
|
||
| 功能说明: | 删除数据库记录 |
|
||
"""
|
||
if self.current_evn.lower() == "sim":
|
||
return self.sim_dbs.dbs_execute_sql(sql_content=sql)
|
||
|
||
if param is not None:
|
||
obj_log.info('SQL语句:{}|{}'.format(sql, param))
|
||
else:
|
||
obj_log.info('SQL语句:{}'.format(sql))
|
||
|
||
cursor, conn = None, None
|
||
try:
|
||
cursor, conn, count = self.execute(sql, param, choose_db=choose_db)
|
||
obj_log.info('删除数据库条数:{}'.format(count))
|
||
return count
|
||
except Exception as e:
|
||
if cursor and conn:
|
||
self.close(cursor, conn)
|
||
raise RuntimeError(e.args)
|
||
finally:
|
||
if cursor and conn:
|
||
self.close(cursor, conn)
|
||
|
||
# 更新
|
||
def update(self, sql, param=None, choose_db=None):
|
||
"""
|
||
| 功能说明: | 更新数据库记录 |
|
||
"""
|
||
if self.current_evn.lower() == "sim":
|
||
return self.sim_dbs.dbs_execute_sql(sql_content=sql)
|
||
|
||
if param is not None:
|
||
obj_log.info('SQL语句:{}|{}'.format(sql, param))
|
||
else:
|
||
obj_log.info('SQL语句:{}'.format(sql))
|
||
|
||
cursor, conn = None, None
|
||
try:
|
||
cursor, conn, count = self.execute(sql, param, choose_db=choose_db)
|
||
conn.commit()
|
||
obj_log.info('更新数据库条数:{}'.format(count))
|
||
return count
|
||
except Exception as e:
|
||
if conn:
|
||
conn.rollback()
|
||
raise RuntimeError(e.args)
|
||
finally:
|
||
if cursor and conn:
|
||
self.close(cursor, conn)
|
||
|
||
# 以下方法保持原样,仅需确保内部调用的方法正确
|
||
def check_result_exist(self, *select_statement, retry_count=3):
|
||
start = 0
|
||
flag = 0
|
||
while start <= retry_count:
|
||
for sql in select_statement:
|
||
res = self.select_all(sql=sql)
|
||
if not res:
|
||
if start == retry_count:
|
||
obj_log.info('the result is not exist,but expect at least one')
|
||
raise RuntimeError(u'No results, when exec sql: %s' % sql)
|
||
else:
|
||
obj_log.info('the result is not exist,retry {}'.format(start + 1))
|
||
start += 1
|
||
time.sleep(1)
|
||
else:
|
||
obj_log.info('find [{0}] results when exec :{1}'.format(len(res), sql))
|
||
flag += 1
|
||
break
|
||
if flag > 0:
|
||
break
|
||
|
||
def check_result_not_exist(self, *select_statement):
|
||
for sql in select_statement:
|
||
res = self.select_all(sql=sql)
|
||
if res:
|
||
obj_log.info('find [{0}] results, but expect 0'.format(len(res)))
|
||
raise RuntimeError(u'the result exist, when exec sql: %s' % sql)
|
||
else:
|
||
obj_log.info('the result is not exist, this step pass...')
|
||
|
||
def row_count(self, selectStatement, param=None):
|
||
cursor, conn = None, None
|
||
try:
|
||
cursor, conn, count = self.execute(selectStatement, param)
|
||
return count
|
||
except Exception as e:
|
||
print("error_msg:", e.args)
|
||
return 0
|
||
finally:
|
||
if cursor and conn:
|
||
self.close(cursor, conn)
|
||
|
||
# 数据库断言方法(适配PostgreSQL)
|
||
def kw_check_if_exists_in_database(self, selectStatement, choose_db=None):
|
||
obj_log.info('Executing : Check If Exists In Database | %s ' % selectStatement)
|
||
if not self.select_one(selectStatement, choose_db=choose_db):
|
||
raise AssertionError("Expected to have have at least one row from '%s' "
|
||
"but got 0 rows." % selectStatement)
|
||
else:
|
||
return True
|
||
|
||
def kw_check_if_not_exists_in_database(self, selectStatement, choose_db=None):
|
||
obj_log.info('Executing : Check If Not Exists In Database | %s ' % selectStatement)
|
||
queryResults = self.select_one(selectStatement, choose_db=choose_db)
|
||
if queryResults:
|
||
raise AssertionError("Expected to have have no rows from '%s' "
|
||
"but got some rows : %s." % (selectStatement, queryResults))
|
||
else:
|
||
return True
|
||
|
||
def kw_row_count_is_0(self, selectStatement):
|
||
obj_log.info('Executing : Row Count Is 0 | %s ' % selectStatement)
|
||
num_rows = self.row_count(selectStatement)
|
||
if num_rows > 0:
|
||
raise AssertionError("Expected zero rows to be returned from '%s' "
|
||
"but got rows back. Number of rows returned was %s" % (selectStatement, num_rows))
|
||
else:
|
||
return True
|
||
|
||
def kw_row_count_is_equal_to_x(self, selectStatement, numRows):
|
||
obj_log.info('Executing : Row Count Is Equal To X | %s | %s ' % (selectStatement, numRows))
|
||
num_rows = self.row_count(selectStatement)
|
||
if num_rows != int(str(numRows)):
|
||
raise AssertionError("Expected same number of rows to be returned from '%s' "
|
||
"than the returned rows of %s" % (selectStatement, num_rows))
|
||
else:
|
||
return True
|
||
|
||
def kw_row_count_is_greater_than_x(self, selectStatement, numRows):
|
||
obj_log.info('Executing : Row Count Is Greater Than X | %s | %s ' % (selectStatement, numRows))
|
||
num_rows = self.row_count(selectStatement)
|
||
if num_rows <= int(numRows):
|
||
raise AssertionError("Expected more rows to be returned from '%s' "
|
||
"than the returned rows of %s" % (selectStatement, num_rows))
|
||
else:
|
||
return True
|
||
|
||
def kw_row_count_is_less_than_x(self, selectStatement, numRows):
|
||
obj_log.info('Executing : Row Count Is Less Than X | %s | %s ' % (selectStatement, numRows))
|
||
num_rows = self.row_count(selectStatement)
|
||
if num_rows >= int(numRows):
|
||
raise AssertionError("Expected less rows to be returned from '%s' "
|
||
"than the returned rows of %s" % (selectStatement, num_rows))
|
||
else:
|
||
return True
|
||
|
||
def kw_table_must_exist(self, tableName):
|
||
"""
|
||
PostgreSQL检查表是否存在
|
||
"""
|
||
obj_log.info('Executing : Table Must Exist | %s ' % tableName)
|
||
# PostgreSQL检查表是否存在的查询
|
||
selectStatement = """
|
||
SELECT EXISTS (
|
||
SELECT 1
|
||
FROM information_schema.tables
|
||
WHERE table_schema = 'public'
|
||
AND table_name = %s
|
||
);
|
||
"""
|
||
try:
|
||
result = self.select_one(selectStatement, (tableName,))
|
||
if not result or not result[0]: # 结果为False或None
|
||
raise AssertionError("Table '%s' does not exist in the database" % tableName)
|
||
return True
|
||
except Exception as e:
|
||
raise AssertionError("Error checking table existence: %s" % str(e))
|
||
|
||
# 生成器方式查询
|
||
def select_all_as_generator(self, sql, param=None, choose_db=None):
|
||
if param is not None:
|
||
obj_log.info('SQL语句:{}|{}'.format(sql, param))
|
||
else:
|
||
obj_log.info('SQL语句:{}'.format(sql))
|
||
|
||
cursor, conn = None, None
|
||
try:
|
||
cursor, conn, counts = self.execute(sql, param, choose_db=choose_db)
|
||
for count in range(counts):
|
||
item = cursor.fetchone()
|
||
obj_log.info(item)
|
||
yield item
|
||
except Exception as e:
|
||
raise RuntimeError(e.args)
|
||
finally:
|
||
if cursor and conn:
|
||
self.close(cursor, conn)
|
||
|
||
def get_qa_to_sim_dbs(self):
|
||
get_sql = "select qa_db,sim_instance from sparkatp.qa_db_mapping_sim_dbs where status=1 and is_delete=0"
|
||
res_info = self.sim_dbs.dbs_query_by_db_name("qadb-slave", "sparkatp", get_sql, limit_num=1000)
|
||
for item in res_info:
|
||
self.qa_db_to_sim_instance_name.update({item[0]: item[1]})
|
||
return self.qa_db_to_sim_instance_name
|
||
|
||
def get_instance_name(self, sql_str):
|
||
db_name = sql_str.split(".")[0].split(" ")[-1].replace("`", "").replace(" ", "")
|
||
return self.qa_db_to_sim_instance_name.get(db_name)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 测试示例
|
||
db = PgSqlHelper()
|
||
|
||
# 测试查询
|
||
sql = "SELECT * FROM account.account LIMIT 5;"
|
||
res = db.select_one(sql=sql)
|
||
print(res)
|
||
|
||
# 测试插入(带RETURNING)
|
||
insert_sql = """
|
||
INSERT INTO account.account (username, email, created_at)
|
||
VALUES (%s, %s, NOW())
|
||
RETURNING id;
|
||
"""
|
||
insert_params = ('test_user', 'test@example.com')
|
||
insert_result = db.insert_one_extension(insert_sql, insert_params)
|
||
print(f"插入结果: {insert_result}") |