代码拉取完成,页面将自动刷新
同步操作将从 连享会/Python-causalml 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import numpy as np
import pandas as pd
import pytest
from causalml.match import NearestNeighborMatch, MatchOptimizer
from causalml.propensity import ElasticNetPropensityModel
from .const import RANDOM_SEED, TREATMENT_COL, SCORE_COL, GROUP_COL
@pytest.fixture
def generate_unmatched_data(generate_regression_data):
generated = False
def _generate_data():
if not generated:
y, X, treatment, tau, b, e = generate_regression_data()
features = ['x{}'.format(i) for i in range(X.shape[1])]
df = pd.DataFrame(X, columns=features)
df[TREATMENT_COL] = treatment
df_c = df.loc[treatment == 0]
df_t = df.loc[treatment == 1]
df = pd.concat([df_t, df_c, df_c], axis=0, ignore_index=True)
pm = ElasticNetPropensityModel(random_state=RANDOM_SEED)
ps = pm.fit_predict(df[features], df[TREATMENT_COL])
df[SCORE_COL] = ps
df[GROUP_COL] = np.random.randint(0, 2, size=df.shape[0])
return df, features
yield _generate_data
def test_nearest_neighbor_match_by_group(generate_unmatched_data):
df, features = generate_unmatched_data()
psm = NearestNeighborMatch(replace=False,
ratio=1.,
random_state=RANDOM_SEED)
matched = psm.match_by_group(data=df,
treatment_col=TREATMENT_COL,
score_cols=[SCORE_COL],
groupby_col=GROUP_COL)
assert sum(matched[TREATMENT_COL] == 0) == sum(matched[TREATMENT_COL] != 0)
def test_match_optimizer(generate_unmatched_data):
df, features = generate_unmatched_data()
optimizer = MatchOptimizer(treatment_col=TREATMENT_COL,
ps_col=SCORE_COL,
matching_covariates=[SCORE_COL],
min_users_per_group=100,
smd_cols=[SCORE_COL],
dev_cols_transformations={SCORE_COL: np.mean})
matched = optimizer.search_best_match(df)
assert sum(matched[TREATMENT_COL] == 0) == sum(matched[TREATMENT_COL] != 0)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。