代码拉取完成,页面将自动刷新
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from utils.strings import get_font_prop
matplotlib.use('Agg')
import warnings
warnings.filterwarnings('ignore', category=UserWarning)
def draw_interaction_anova(ax, factor, mean, ci_lower, ci_upper,
chosen_best=5,
ylabel=r'$RPD$(%)', value='rpd'):
sns.set(style='white')
plt.rcParams['font.family'] = 'Times New Roman'
# mean['mean_total'] = mean.mean(axis=1)
# top_groups = mean.nlargest(chosen_best, 'mean_total').index
# mean.sort_index()
top_groups = mean.index
# mean = mean.loc[top_groups].drop(columns='mean_total')
ci_lower = ci_lower.loc[top_groups]
ci_upper = ci_upper.loc[top_groups]
ci_interval = (ci_upper - ci_lower) / 2
ci_interval = ci_interval - (ci_interval - ci_interval.min().min())
min_lower = (mean - ci_interval).astype('float').min().min()
max_upper = (mean + ci_interval).astype('float').max().max()
if ax:
plt.ioff()
y_spacing = (max_upper - min_lower) / (chosen_best + 1) / 2
y_spacing = max(y_spacing, 0.2)
y_padding = y_spacing
ax_left, ax_bottom, ax_width, ax_height = ax.get_position().bounds
spacing = ax_width / (chosen_best + 1)
padding = spacing
mean.index = pd.to_numeric(mean.index)
mean = mean.sort_index()
colors = ['steelblue', 'red', 'green', 'purple', 'orange']
for i, sub_group in enumerate(mean.columns):
# Plot the means and LSD intervals with connecting lines
ax.errorbar(x=np.arange(padding, padding + chosen_best * spacing, spacing).tolist(),
y=[x[value] for x in mean[sub_group]],
yerr=ci_interval[sub_group],
fmt='o', linestyle='-', capsize=10, color=colors[i], ecolor=colors[i], elinewidth=2,
capthick=1.5,
label=sub_group)
ax.legend(fontsize=14)
ax.set_xticks(np.arange(padding, padding + chosen_best * spacing, spacing).tolist(), mean.index, fontsize=14)
y_ticks = list(np.arange(int(min_lower - y_padding), max_upper + y_padding, round(y_spacing, 1)))
y_ticks = list(map(lambda x: round(x, 1) if isinstance(x, float) else x, y_ticks))
ax.set_yticks(y_ticks, y_ticks, fontsize=14)
ax.set_xlim(0, (chosen_best + 1) * spacing)
ax.set_ylim((min_lower - y_padding, max_upper + y_padding))
ax.set_xlabel(factor, fontsize=18, fontproperties=get_font_prop(factor))
ax.set_ylabel(ylabel, fontsize=18)
ax.tick_params(axis='both', which='both', length=6, width=1, direction='in', labelsize=16)
for y in y_ticks:
ax.axhline(y=y, color='gray', linestyle='--', alpha=0.4)
def draw_interaction_anova_oldway(ax, name, data: pd.DataFrame, chosen_best=5):
sns.set(style='white')
plt.rcParams['font.family'] = 'Times New Roman'
confidence_bound = 1.96 # 95% HSD level Z-distribution
data['err_bar'] = data['err'].apply(lambda y: y * confidence_bound)
min_lower = data.eval('mean - err_bar').min()
max_upper = data.eval('mean + err_bar').max()
if ax:
plt.ioff()
y_spacing = (max_upper - min_lower) / (chosen_best + 1) / 2
y_spacing = max(y_spacing, 0.2)
y_padding = y_spacing
ax_left, ax_bottom, ax_width, ax_height = ax.get_position().bounds
spacing = ax_width / (chosen_best + 1)
padding = spacing
colors = ['steelblue', 'red', 'green', 'purple', 'orange']
for i, scheduler in enumerate(data['name'].unique()):
x_data = []
y_data = []
y_error_data = []
scheduler_data = data[data['name'] == scheduler]
for j, row in enumerate(scheduler_data.itertuples()):
x_data.append(padding + j * spacing)
y_data.append(row['mean'])
y_error_data.append(row['err_bar'])
# Plot the means and LSD intervals with connecting lines
ax.errorbar(x=x_data,
y=y_data,
yerr=y_error_data,
fmt='o', linestyle='-', capsize=10, color=colors[i], ecolor=colors[i], elinewidth=2,
capthick=1.5,
label=scheduler)
ax.legend(fontsize=14)
levels = data[name].unique()
ax.set_xticks(np.arange(padding, padding + chosen_best * spacing, spacing).tolist(), levels, fontsize=14)
y_ticks = list(np.arange(int(min_lower - y_padding), max_upper + y_padding, round(y_spacing, 1)))
y_ticks = list(map(lambda x: round(x, 1) if isinstance(x, float) else x, y_ticks))
ax.set_yticks(y_ticks, y_ticks, fontsize=14)
ax.set_xlim(0, (chosen_best + 1) * spacing)
ax.set_ylim((min_lower - y_padding, max_upper + y_padding))
ax.set_xlabel(name, fontsize=18, fontproperties=get_font_prop(name))
ax.set_ylabel('$RPD$(%)', fontsize=18)
ax.tick_params(axis='both', which='both', length=6, width=1, direction='in', labelsize=16)
for y in y_ticks:
ax.axhline(y=y, color='gray', linestyle='--', alpha=0.4)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。