代码拉取完成,页面将自动刷新
#!/usr/bin/env python
# -*- coding:UTF-8
# Copyright (c) 2023 CSUDATA.COM and/or its affiliates. All rights reserved.
# CLup is licensed under AGPLv3.
# See the GNU AFFERO GENERAL PUBLIC LICENSE v3 for more details.
# You can use this software according to the terms and conditions of the AGPLv3.
#
# THIS SOFTWARE IS PROVIDED BY CSUDATA.COM "AS IS" AND ANY EXPRESS OR IMPLIED
# WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, OR NON-INFRINGEMENT, ARE
# DISCLAIMED. IN NO EVENT SHALL CSUDATA.COM BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
@Author: tangcheng
@description: rpc调用的模块
"""
import contextlib
import errno
import fcntl
import logging
import os
import pickle
import queue
import select
import socket
import struct
import sys
import threading
import time
import traceback
from typing import Any, cast
import cs_low_trans
logger = logging.getLogger('csurpc')
# 列出服务端的各个服务名
CMD_FUNC_LIST = 100
# 调用服务端的某个服务
CMD_CALL_FUNC = 200
DEBUG_LOG_MAX_LEN = 8192
# 线程池中的线程
class _WorkThread(threading.Thread):
"""
线程池中的工作线程
"""
def __init__(self, thread_pool, thread_name):
"""
:param thread_pool: 线程池对象
"""
threading.Thread.__init__(self, name=thread_name)
self.thread_pool = thread_pool
# is_exit: 是一个函数,在每次循环时,调用此函数,如果此函数返回为趁,则此线程结束
self.is_exit = thread_pool.is_exit
self.setDaemon(True)
self.work_queue = thread_pool.work_queue
self.start()
def run(self):
while not self.is_exit():
try:
work_func, args, kwargs = self.work_queue.get(timeout=0.1)
except queue.Empty: # 任务队列为空,继续循环
continue
try:
# 执行任务
self.thread_pool.inc_busy()
work_func(*args, **kwargs)
self.thread_pool.dec_busy()
except Exception:
self.thread_pool.dec_busy()
logger.error(f"RPC ERROR: {traceback.format_exc()}.")
class _ThreadPool:
"""
线程池对象
"""
def __init__(self, name, is_exit_func, pool_size=10):
"""
:param pool_size: 线程池启动的线程数
"""
self.name = name
self.mutex = threading.Lock()
self.is_exit = is_exit_func
self.busy_count = 0 # 繁忙的线程数
self.work_queue = queue.Queue()
self.threads = []
self.__create_thread(pool_size)
def inc_busy(self):
self.mutex.acquire()
self.busy_count += 1
self.mutex.release()
def dec_busy(self):
self.mutex.acquire()
self.busy_count -= 1
self.mutex.release()
def get_busy_threads_count(self):
self.mutex.acquire()
busy_count = self.busy_count
self.mutex.release()
return busy_count
def __create_thread(self, num_of_threads):
for i in range(num_of_threads):
thread = _WorkThread(self, "%s-%d" % (self.name, i))
self.threads.append(thread)
def wait_for_complete(self):
"""
等待所有线程完成。
"""
for thread in self.threads:
# 等待线程结束
if thread.isAlive(): # 判断线程是否还存活来决定是否调用join
thread.join()
def add_job(self, work_func, *args, **kwargs):
"""
把工作任务加到任务队列中
"""
self.work_queue.put((work_func, args, kwargs))
def _get_member_func(obj):
func_list = []
attr_name_list = dir(obj)
for attr_name in attr_name_list:
if attr_name[0] == '_':
continue
attr = getattr(obj, attr_name)
if callable(attr):
func_list.append(attr_name)
return func_list
def _handler_connect(sock, srv_obj):
"""
:param sock: 新连接的socket句柄
:param srv_obj: 服务类的一个实例
:return: 无返回值
"""
global DEBUG_LOG_MAX_LEN
try:
err, _msg = cs_low_trans.auth_connect(sock, srv_obj.password, srv_obj.timeout)
if err:
return # 验证失败或socket错误,直接返回
except Exception:
traceback.print_exc()
return
while True:
try:
err, _msg, cmd, data = cs_low_trans.recv_cmd(sock, srv_obj.timeout)
if err:
break
if cmd == CMD_FUNC_LIST: # 客户端请求handler中有哪些函数可以调用
ret_data = pickle.dumps(srv_obj.srv_func_list)
err, _msg = cs_low_trans.reply_cmd(sock, 0, ret_data, srv_obj.timeout)
if err:
break
elif cmd == CMD_CALL_FUNC:
try:
func_name, func_args, func_kwargs = pickle.loads(data)
if logger.level <= logging.DEBUG:
str_args = repr(func_args)
if len(str_args) > DEBUG_LOG_MAX_LEN:
str_args = str_args[:DEBUG_LOG_MAX_LEN] + " ... "
str_kwargs = repr(func_kwargs)
if len(str_kwargs) > DEBUG_LOG_MAX_LEN:
str_kwargs = str_kwargs[:DEBUG_LOG_MAX_LEN] + " ... "
rpc_info = f"RECV RPC: {func_name} :\nargs={str_args}"
if func_kwargs:
rpc_info += f"\nkwargs={str_kwargs}"
logger.debug(rpc_info)
except Exception as e:
err, _msg = cs_low_trans.reply_cmd(
sock, 1, f"decode func args failed: {str(e)}".encode('utf-8'), srv_obj.timeout)
if err:
break
continue
if func_name not in srv_obj.srv_func_list:
err, _msg = cs_low_trans.reply_cmd(
sock, 1, ("Function(%s) does not exist" % func_name).encode('utf-8'), srv_obj.timeout)
if err:
break
continue
call_func = getattr(srv_obj.handler, func_name)
try:
ret = call_func(*func_args, **func_kwargs)
ret_data = pickle.dumps(ret)
ret_code = 0
if logger.level <= logging.DEBUG:
str_ret = repr(ret)
if len(str_ret) > DEBUG_LOG_MAX_LEN:
str_ret = str_ret[:DEBUG_LOG_MAX_LEN]
rpc_info = f"RETURN RECV RPC: {func_name} :\nreturn={str_ret}"
logger.debug(rpc_info)
except Exception as e:
if logger.level <= logging.DEBUG:
logger.debug(f"call {func_name} failed:\n{traceback.format_exc()}")
exc_type, _, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
err_msg = '%s: %s in %s:%d' % (exc_type.__name__, str(e), fname, exc_tb.tb_lineno)
ret_code = 1
ret_data = err_msg.encode()
err, _msg = cs_low_trans.reply_cmd(sock, ret_code, ret_data, srv_obj.timeout)
if err:
break
except Exception:
traceback.print_exc()
break
with contextlib.suppress(Exception):
sock.close()
def parse_connect_url(conn_url):
"""
解析连接字符串
:param conn_url:
:return: (protocol, ip, port)
"""
cells = conn_url.split('://')
if len(cells) != 2:
raise Exception(f"Invalid connect url: {conn_url}.")
protocol = cells[0].lower()
if protocol != 'tcp': # 注意目前只支持tcp协议
raise Exception("Unsupported protocol: %s" % cells[0])
ipport = cells[1]
cells = ipport.split(':')
if len(cells) != 2:
raise Exception("Invalid connect url: %s" % conn_url)
try:
ip = socket.gethostbyname(cells[0])
except socket.gaierror as wrong:
raise Exception(f"Invalid server or ip addr in url: {conn_url}.") from wrong
if not cells[1].isdigit():
raise Exception("Invalid port in connect url:%s" % cells[1])
port = int(cells[1])
return protocol, ip, port
class Server:
"""
使用方法:
s = Server(MyHandle):
s.bind("tcp://0.0.0.0:4342")
s.run()
"""
def __init__(self, name, handler, is_exit_func, password='cstechRpc', thread_count=30, timeout=300, debug=0):
self.name = name
self.handler = handler
self.thread_count = thread_count
self.timeout = timeout
self.ss = None
self.is_exit = is_exit_func
self.thread_pool = None
self.srv_func_list = _get_member_func(handler)
self.password = password
self.debug = debug
def get_busy_threads_count(self):
"""
获得正在执行任务的线程数
:return: int, 返回正在执行任务的线程数
"""
if not self.thread_pool:
return 0
return self.thread_pool.get_busy_threads_count()
def bind(self, conn_url):
"""
:param conn_url: 连接url,格式为: 'protocol://ip:port',目前protocol只支持tcp。
:return: 无返回值
使用例子: s.bind('tcp://0.0.0.0:4342')
"""
protocol, ip, port = parse_connect_url(conn_url)
if protocol != 'tcp':
raise Exception("Unsupported protocol:%s" % protocol)
self.ss = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
fcntl.fcntl(self.ss.fileno(), fcntl.F_SETFD, fcntl.FD_CLOEXEC)
self.ss.settimeout(self.timeout)
self.ss.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.ss.bind((ip, port))
def run(self):
"""
运行服务
:return: 无返回值
"""
self.thread_pool = _ThreadPool(self.name, self.is_exit, self.thread_count)
self.ss.listen(10)
while not self.is_exit():
# if self.debug:
# print "Busy threads count: %s" % self.get_busy_threads_count()
try:
readable, _writeable, _exceptional = select.select([self.ss], [], [], 1)
if not readable:
# logger.debug("select timeout.")
continue
except select.error as e:
if e.args[0] == errno.EINTR: # 这是收到信号打断了select函数
continue
raise Exception("call select() failed: %s" % repr(e)) from e
(client, _address) = self.ss.accept()
# if self.debug:
# print("Accept connection from %s" % str(address))
linger = struct.pack('ii', 1, 1)
r = client.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger)
if r:
raise Exception("Set socket SO_LINGER failed!")
client.settimeout(self.timeout)
self.thread_pool.add_job(_handler_connect, client, self)
self.ss.close()
class CsuTimeoutError(Exception):
"""
定义一个超时错误,当在异步调用时,获得结果超时会抛出此错误
"""
def __init__(self, msg):
Exception.__init__(self, msg)
class RunError(Exception):
"""
定义一个超时错误,当在异步调用时,获得结果超时会抛出此错误
"""
def __init__(self, msg):
Exception.__init__(self, msg)
class AsyncError(Exception):
"""
当在同一个socket连接中正在执行一个异步请求还没有结束,又发起了一个请求时,就会抛出此错误
"""
def __init__(self, msg):
Exception.__init__(self, msg)
class _Trans:
"""
传输对象,把socket对象封装进来
"""
ip: str = ''
port: int = 0
def __init__(self, sock, call_timeout, ip, port, msg_callback=None):
self.sock = sock
self.call_timeout = call_timeout
self.ip = ip
self.port = port
self.pending = False # pending==True时,表示在异步模式下正在处理异步任务,这时不能再发起异步请求
self.msg_callback = msg_callback
# 客户端异步调用时使用的线程
class _AsyncCallTask(threading.Thread):
"""
异步调用的任务类
"""
def __init__(self, trans, func_name, *args, **kwargs):
"""
:param trans: 传输对象
:param func_name: 要调用的远程函数名称
"""
threading.Thread.__init__(self)
self.mutex = threading.Lock()
self.result_queue = queue.Queue(maxsize=1)
self.trans = trans
self.func_name = func_name
self.args = args
self.kwargs = kwargs
self.setDaemon(True)
if trans.pending:
raise AsyncError("Previous async task is running, please get result of previous async task, then retry!")
trans.pending = True
self.start()
def run(self):
data = pickle.dumps([self.func_name, self.args, self.kwargs])
err, msg, ret_code, ret_data = cs_low_trans.send_cmd(self.trans.sock, CMD_CALL_FUNC, data, self.trans.call_timeout)
if err:
self.result_queue.put((err, msg))
return
if ret_code:
self.result_queue.put((ret_code, ret_data))
return
func_ret = pickle.loads(ret_data)
self.result_queue.put((0, func_ret))
def get(self, timeout=None):
"""
获得调用的结果
:param timeout: 如果不设置,则一直等到用户端返回
:return: result, result是远程函数的返回值
"""
if not self.trans.pending:
raise AsyncError("Async task not running!")
try:
err, result = self.result_queue.get(timeout=timeout)
self.trans.pending = False
if err:
raise RunError(result)
return result
except queue.Empty as wrong:
raise CsuTimeoutError(f"Timeout: unable to obtain results within {timeout} seconds") from wrong
def call_remote_func(trans, func_name, *args, **kwargs) -> Any:
"""
调用远程的函数
:param trans:
:param func_name: 远程的函数名
:param args: 传给远程函数的列表形式的参数
:param kwargs: 传给远程函数的字典形式的参数
:return: 如果是同步模式,则返回远程函数的返回值,如果是异步模式返回异步对象async_task,通过async_task.get()可以获得执行结果
"""
global DEBUG_LOG_MAX_LEN
call_timeout = trans.call_timeout
sock = trans.sock
if logger.level <= logging.DEBUG:
str_args = repr(args)
if len(str_args) > DEBUG_LOG_MAX_LEN:
str_args = str_args[:DEBUG_LOG_MAX_LEN] + " ... "
str_kwargs = repr(kwargs)
if len(str_kwargs) > DEBUG_LOG_MAX_LEN:
str_kwargs = str_kwargs[:DEBUG_LOG_MAX_LEN] + " ... "
rpc_info = f"CALL RPC({trans.ip}:{trans.port}): {func_name} :\n args={str_args}"
if kwargs:
rpc_info += f"\n kwargs={str_kwargs}"
logger.debug(rpc_info)
if trans.msg_callback:
trans.msg_callback(rpc_info)
async_mode = False
if 'async_mode' in kwargs:
if kwargs['async_mode']:
async_mode = True
del kwargs['async_mode']
if async_mode: # 异步模式
async_task = _AsyncCallTask(trans, func_name, *args, **kwargs)
return async_task
else: # 同步模式
try:
data = pickle.dumps([func_name, args, kwargs])
except Exception as e:
raise UserWarning(f"unsupport args type: {repr(e)}") from e
err, msg, ret_code, ret_data = cs_low_trans.send_cmd(sock, CMD_CALL_FUNC, data, call_timeout)
if err:
raise UserWarning(f"socket error: {msg}")
if ret_code:
raise UserWarning(ret_data.decode())
func_ret = pickle.loads(ret_data)
if logger.level <= logging.DEBUG:
str_ret = repr(func_ret)
if len(str_ret) > DEBUG_LOG_MAX_LEN:
str_ret = str_ret[:DEBUG_LOG_MAX_LEN]
rpc_info = f"RETURN CALL RPC({trans.ip}:{trans.port}) : {func_name} :\n return={str_ret}"
logger.debug(rpc_info)
if trans.msg_callback:
trans.msg_callback(rpc_info)
return func_ret
class Client:
"""
使用方法:
c = Client():
c.connect("tcp://127.0.0.1:4342")
c.my_func()
异步模式:
rs = c.hello('hello', 'world', async_mode=True)
ret = rs.get(10)
注意即使在异步模式下,对同一个连接也只能同时发一个请求,不能同时发多个请求,如果需要同时发多个请求,请建多个连接
"""
def __init__(self, call_timeout=300, msg_callback=None):
self.trans: _Trans = _Trans(sock=None, call_timeout=call_timeout, ip=None, port=None, msg_callback=msg_callback)
self.mutex = threading.Lock()
self.conn_timeout = 300
self.msg_callback = msg_callback
def connect(self, conn_url, password='cstechRpc', conn_timeout=10, data_timeout=300):
"""
s.bind('tcp://127.0.0.1:4342')
:param conn_url:
:param password:
:param timeout: 注意是连接超时,不是调用超时
:return:
"""
self.conn_timeout = conn_timeout
self.data_timeout = data_timeout
protocol, ip, port = parse_connect_url(conn_url)
if protocol != 'tcp':
raise Exception("Unsupported protocol:%s" % protocol)
self.trans.ip = ip
self.trans.port = port
err, msg, mysock = cs_low_trans.connect(ip, port, password, conn_timeout, data_timeout)
if err:
raise Exception(msg)
self.trans.sock = cast(socket.socket, mysock)
err, msg, ret_code, ret_data = cs_low_trans.send_cmd(self.trans.sock, CMD_FUNC_LIST, b'', data_timeout)
if err:
raise Exception("socket error: %s" % msg)
if ret_code:
raise Exception("rpc error: %s" % ret_data)
# 把远程服务中存在的函数名加到本地的类上,这样调用本地类上的函数,相当于调用了远程的函数
self.func_list = pickle.loads(ret_data)
def close(self):
if self.trans.sock is not None:
self.trans.sock.close()
self.trans.sock = None
def __del__(self):
if self.trans.sock is not None:
self.trans.sock.close()
self.sock = None
def __bool__(self):
return True
def __call__(self, method, *args, **kwargs):
return call_remote_func(self.trans, method, *args, **kwargs)
def __getattr__(self, method):
if method not in self.func_list:
raise Exception("function not support %s" % method)
return lambda *args, **kargs: self(method, *args, **kargs)
def __str__(self):
return f"csurpc({self.trans.ip}:{self.trans.port})"
"""
#############################################################################
# 下面为测试代码 #
#############################################################################
"""
class _TestHandle:
def __init__(self):
pass
def func1(self, *args):
time.sleep(0.01)
return '.'.join(args)
def func2(self, *args):
time.sleep(0.01)
return '#'.join(args)
def main():
"""
此函数用做测试
:return: 无
"""
exit_flag = 0
def is_exit():
return exit_flag
def _usage():
print("In Server: %s -s " % sys.argv[0])
print("In Client: %s -c <anystr1> [anystr2] [anystr3] ..." % sys.argv[0])
print("Usage: %s -s" % (sys.argv[0]))
if len(sys.argv) < 2:
_usage()
return
if sys.argv[1] == '-h' or sys.argv[1] == '--help':
_usage()
return
if sys.argv[1] not in {'-s', '-c'}:
_usage()
return
if sys.argv[1] == '-s':
handle = _TestHandle()
srv = Server('ha-service', handle, is_exit, password='cstechRpc', thread_count=10, debug=1)
srv.bind("tcp://0.0.0.0:4342")
srv.run()
elif sys.argv[1] == '-c':
for _i in range(30):
proxy = Client()
proxy.connect("tcp://127.0.0.1:4342")
# 测试同步调用
ret1 = proxy.func1(*sys.argv[2:])
print("Sync call func1(), return: %s " % ret1)
ret2 = proxy.func2(*sys.argv[2:])
print("Sync call func1(), return: %s " % ret2)
# 测试异步调用
c2 = Client()
c2.connect("tcp://127.0.0.1:4342")
print("ASync call func1() ... ")
rs1 = proxy.func1(*sys.argv[2:], async_mode=True)
print("ASync call func2() ... ")
rs2 = c2.func2(*sys.argv[2:], async_mode=True)
ret1 = rs1.get(10)
print("ASync call func1() return : %s" % ret1)
ret2 = rs2.get(10)
print("ASync call func2() return : %s" % ret2)
if __name__ == '__main__':
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。