118 lines
4.9 KiB
Python
118 lines
4.9 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
# @Author : qiaoxinjiu
|
|
# @Time : 2021/07/13
|
|
# @File : websocket_api.py
|
|
|
|
import time
|
|
import json
|
|
from functools import wraps
|
|
import requests
|
|
from websocket import create_connection
|
|
from base_framework.public_tools.log import get_logger
|
|
from base_framework.public_tools.sqlhelper import MySqLHelper
|
|
from base_framework.public_tools.read_config import InitConfig
|
|
|
|
|
|
def check_conn(func):
|
|
@wraps(func)
|
|
def wrap_check_conn(self, *args, **kwargs):
|
|
self.ws = self.get_ws_conn()
|
|
self.logger.info('%s connected websocket server' % func.__name__)
|
|
return func(self, *args, **kwargs)
|
|
|
|
return wrap_check_conn
|
|
|
|
|
|
class WebSocketAPI(InitConfig):
|
|
def __init__(self, user_name=None, env_name=None, timeout=30):
|
|
# super().__init__(run_user_name=user_name, current_evn=env_name)
|
|
InitConfig.__init__(self, run_user_name=user_name, current_evn=env_name)
|
|
self.sso_url = self.all_cfg[self.current_evn]['sso_url']
|
|
self.code_url = self.all_cfg[self.current_evn]['code_url']
|
|
self.token_url = self.all_cfg[self.current_evn]['token_url']
|
|
self.redirect_url = self.all_cfg[self.current_evn]['teach_opt_url']
|
|
self.auth_code = self.all_cfg['Authorization']['sparkle-manage']
|
|
self.logger = get_logger()
|
|
self.db_conn = MySqLHelper()
|
|
|
|
def get_session(self):
|
|
session = requests.Session()
|
|
session.headers.update({'Authorization': self.auth_code})
|
|
if not hasattr(self, 'access_token'):
|
|
self.access_token = self._get_access_token()
|
|
session.headers.update({'accesstoken': self.access_token})
|
|
return session
|
|
|
|
def _get_access_token(self):
|
|
token_session = requests.Session()
|
|
post_data = {'showUsername': self.show_username, 'username': self.username, 'password': self.password}
|
|
resp1 = token_session.post(self.sso_url, data=post_data, allow_redirects=False)
|
|
assert resp1.status_code == 302, 'incorrect response code %s' % resp1.status_code
|
|
get_data = {'client_id': 'tm-manage', 'response_type': 'code',
|
|
'redirect_uri': self.redirect_url}
|
|
resp2 = token_session.get(self.code_url, params=get_data, allow_redirects=False)
|
|
assert resp2.status_code == 302, 'incorrect response code %s' % resp2.status_code
|
|
tmp = resp2.headers
|
|
code = tmp.get('Location').split('=')[1]
|
|
post_data2 = {'grant_type': 'authorization_code', 'code': code, 'redirect_uri': self.redirect_url}
|
|
token_session.headers.update({'Authorization': self.auth_code})
|
|
resp3 = token_session.post(self.token_url, data=post_data2)
|
|
token = json.loads(resp3.text).get('access_token')
|
|
return token
|
|
|
|
def get_ws_conn(self, timeout=60):
|
|
self.logger.info('websocket connecting...')
|
|
self.access_token = self._get_access_token()
|
|
uri = self.all_cfg[self.current_evn]['websocket_server']
|
|
if 'accesstoken' in uri:
|
|
self.uri = uri
|
|
else:
|
|
self.uri = uri + '/?accesstoken=%s' % self.access_token if uri.endswith(
|
|
'/') else uri + '/?accesstoken=%s' % self.access_token
|
|
ws_conn = create_connection(self.uri, timeout=timeout)
|
|
if ws_conn.getstatus() == 101:
|
|
return ws_conn
|
|
else:
|
|
raise Exception('connect websocket server failed!')
|
|
|
|
def _close_ws(self):
|
|
if hasattr(self, 'ws'):
|
|
self.ws.close()
|
|
self.logger.info('connection closed success!')
|
|
|
|
@check_conn
|
|
def send_msg_and_get_response(self, send_data: [dict, str], timeout=3):
|
|
if isinstance(send_data, dict):
|
|
data = json.dumps(send_data)
|
|
self.ws.send(data)
|
|
self.logger.info(">>> heart beat!")
|
|
self.logger.info(">>> send message:%s" % data)
|
|
return self.get_response(ws_conn=self.ws, tm_stamp=send_data.get("timestamp"), timeout=timeout)
|
|
elif isinstance(send_data, str):
|
|
data = send_data
|
|
self.ws.send(data)
|
|
self.logger.info(">>> heart beat!")
|
|
self.logger.info(">>> send message:%s" % data)
|
|
return self.get_response(ws_conn=self.ws, tm_stamp=json.loads(send_data).get("timestamp"), timeout=timeout)
|
|
else:
|
|
raise Exception('send_data can only support dict or str type')
|
|
|
|
def get_response(self, ws_conn, tm_stamp=None, timeout=3):
|
|
t0 = time.time()
|
|
while time.time() - t0 <= int(timeout):
|
|
time.sleep(0.2)
|
|
out_put = ws_conn.recv()
|
|
if tm_stamp:
|
|
if str(tm_stamp) in out_put:
|
|
self.logger.info(">>> received message:%s" % out_put)
|
|
return json.loads(out_put)
|
|
elif out_put:
|
|
return json.loads(out_put)
|
|
else:
|
|
continue
|
|
else:
|
|
self._close_ws()
|
|
raise Exception('>>> received no response, please check ws connect!')
|