Ai
2 Star 2 Fork 0

Kenny小狼/python-tools

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
testdata_tools.py 8.29 KB
一键复制 编辑 原始数据 按行查看 历史
KennyLee 提交于 2023-09-15 11:44 +08:00 . refactor: 优化代码
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import json
import random
import threading
import time
from datetime import datetime
import pymysql
class ClassUtils:
@staticmethod
def get_all_items(obj):
items_ = obj.__class__.__dict__.items() | obj.__dict__.items()
# 过滤__的class中的__属性,根据key排序
return sorted({(k, v) for k, v in items_ if not k.startswith('__')}, key=lambda s: s[0])
class SqlBuilder:
def __init__(self, table_name, test_datas=None):
if test_datas is None:
test_datas = []
self.table_name = table_name
self.test_data_db_dict = []
for test_data in test_datas:
self.test_data_db_dict.append(ClassUtils.get_all_items(test_data))
def build_insert_sql(self):
_sql = 'INSERT INTO ' + self.table_name + ' ('
for k, v in self.test_data_db_dict[0]:
_sql += k + ','
_sql = _sql[:-1] + ') VALUES ('
for i in range(len(self.test_data_db_dict[0])):
_sql += '%s,'
_sql = _sql[:-1] + ')' # 去掉最后一个逗号
return _sql
def get_vals(self):
vals = []
for test_data in self.test_data_db_dict:
vals.append(([v for k, v in test_data]))
return vals
class GenerateDataConfig:
def __init__(self, size=10, slice_size=1):
self.config_file_path = 'assets/db_config.json'
# 数据数量
self.size = size
# 分片数
self.slice_size = slice_size
class DbConfig:
def __init__(self, db_config, database):
self.host = db_config['mysql']['host']
self.port = db_config['mysql']['port']
self.user = db_config['mysql']['user']
self.password = db_config['mysql']['password']
self.database = database
class FileUtils:
@staticmethod
def load_json_file(file_path):
with open(file_path, 'r') as f:
return json.load(f)
class DbUtils:
@staticmethod
def get_db_config(database):
db_config_json = FileUtils.load_json_file('assets/db_config.json')
return DbConfig(db_config_json, database)
class TimeUtils:
@staticmethod
def current_timemillis():
current_time = datetime.now()
# 默认只是到秒,乘以1000转毫秒
return int(current_time.timestamp() * 1000)
@staticmethod
def current_time_str():
# 获取当前时间
current_time = datetime.now()
# 使用strftime函数将时间格式化为指定格式的字符串,-3用于去掉微秒部分的最后三位
return current_time.strftime("%Y%m%d%H%M%S%f")[:-3]
class SnowflakeIDGenerator:
def __init__(self):
self.epoch = int(time.mktime(time.strptime('2020-01-01 00:00:00', '%Y-%m-%d %H:%M:%S'))) # 设置一个起始时间
self.sequence = 0
self.machine_id = self.get_machine_id()
self.lock = threading.Lock()
@staticmethod
def get_machine_id():
# 在这里,你可以根据需要分配一个唯一的机器ID
# 如果在单一机器上使用,可以使用固定值,如果在多台机器上使用,需要确保唯一性
return 1
def generate_id(self):
with self.lock:
timestamp = int(time.time() * 1000) - self.epoch
if timestamp < 0:
raise Exception("Time moved backwards, cannot generate ID")
if timestamp == self.sequence:
self.sequence = (self.sequence + 1) % 4096
if self.sequence == 0:
# 如果同一毫秒内生成的ID数量达到上限,等待下一毫秒
while timestamp == self.sequence:
timestamp = int(time.time() * 1000) - self.epoch
self.sequence = timestamp
unique_id = (timestamp << 12) | self.machine_id | self.sequence
return unique_id
class TestUtils:
id_generator = SnowflakeIDGenerator()
@staticmethod
def gen_id():
timestamp = int(time.time() * 1000) # 获取当前时间戳(毫秒)
timestamp += TestUtils.six_digit_random() + TestUtils.six_digit_random()
random_part = TestUtils.six_digit_random()
numeric_id = f"{timestamp}{random_part}"
return numeric_id
@staticmethod
def gen_snowflake_id():
return TestUtils.id_generator.generate_id()
@staticmethod
def six_digit_random():
# 生成一个随机的1到999999之间的整数
return random.randint(1, 999999)
@staticmethod
def six_digit_random_str():
# 6位长度字符串,不足前面补0
return "{:06d}".format(TestUtils.six_digit_random())
@staticmethod
def gen_order_no():
return TimeUtils.current_time_str() + TestUtils.six_digit_random_str()
class DbThreadRunner:
def __init__(self, total_data_number=10, max_threads=6, db_name=None, tb_name=None):
self.db_name = db_name
self.tb_name = tb_name
self.commit_size = 2000
# 生成数据的数据量大小
self.total_data_number = total_data_number
# 控制线程数的变量
self.max_threads = max_threads
# 计算平均数并确保没有余数
self.average_number = self.total_data_number // self.max_threads
self.remainder = self.total_data_number % self.max_threads
self.thread_count = 0
self.thread_count_lock = threading.Lock()
# 创建一个事件,用于通知主线程所有线程已完成
self.all_threads_done = threading.Event()
# 创建线程函数
def create_thread(self, thread_id, task_func, generate_data_func=None):
with self.thread_count_lock:
self.thread_count += 1
# 如果有余数,则将平均数加1分配给前几个线程
if thread_id < self.remainder:
avg = self.average_number + 1
else:
avg = self.average_number
_thread = threading.Thread(target=task_func, args=(thread_id, avg, generate_data_func))
_thread.start()
_thread.join() # 等待线程完成
with self.thread_count_lock:
self.thread_count -= 1
if self.thread_count == 0:
self.all_threads_done.set()
# 根据generate_data_func 执行数据插入操作
def insert_test_datas(self, count=10, get_sqlbuilder_func=None):
db_config = DbUtils.get_db_config(self.db_name)
conn = pymysql.connect(host=db_config.host,
user=db_config.user, password=db_config.password,
database=db_config.database)
remaining_amount = count
while remaining_amount > 0:
# 每次生成的数量
deduction = min(remaining_amount, self.commit_size)
# 剩余量
remaining_amount -= deduction
sql_builder = get_sqlbuilder_func(self.tb_name, deduction)
# sql语句模板
sql = sql_builder.build_insert_sql()
# 模板变了
val = sql_builder.get_vals()
my_cursor = conn.cursor()
try:
my_cursor.executemany(sql, val)
conn.commit()
# print(my_cursor.rowcount, "records inserted.")
except conn.Error as error:
print("Failed to insert record into MySQL table {}".format(error))
conn.close()
exit(1)
conn.close()
# 包装线程任务
def insert_test_datas_task(self, thread_id, count, get_sqlbuilder_func=None):
# 记录开始时间
_start_time = time.time()
print(f"线程 {thread_id},数据量 {count} 开始执行...")
self.insert_test_datas(count, get_sqlbuilder_func)
# 记录结束时间
_end_time = time.time()
# 计算执行时间
_execution_time = _end_time - _start_time
print(f"线程 {thread_id},数据量 {count} 执行完毕.执行时间{_execution_time}秒")
def run_task(self, task_func=None):
# 创建和启动线程
threads = []
for i in range(self.max_threads):
thread = threading.Thread(target=self.create_thread, args=(i, self.insert_test_datas_task, task_func))
thread.start()
threads.append(thread)
# 等待所有线程完成
self.all_threads_done.wait()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/kennylee/python-tools.git
git@gitee.com:kennylee/python-tools.git
kennylee
python-tools
python-tools
master

搜索帮助