Ai
1 Star 0 Fork 1

Owen/Python-causalml

forked from 连享会/Python-causalml 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
test_match.py 2.17 KB
一键复制 编辑 原始数据 按行查看 历史
import numpy as np
import pandas as pd
import pytest
from causalml.match import NearestNeighborMatch, MatchOptimizer
from causalml.propensity import ElasticNetPropensityModel
from .const import RANDOM_SEED, TREATMENT_COL, SCORE_COL, GROUP_COL
@pytest.fixture
def generate_unmatched_data(generate_regression_data):
generated = False
def _generate_data():
if not generated:
y, X, treatment, tau, b, e = generate_regression_data()
features = ['x{}'.format(i) for i in range(X.shape[1])]
df = pd.DataFrame(X, columns=features)
df[TREATMENT_COL] = treatment
df_c = df.loc[treatment == 0]
df_t = df.loc[treatment == 1]
df = pd.concat([df_t, df_c, df_c], axis=0, ignore_index=True)
pm = ElasticNetPropensityModel(random_state=RANDOM_SEED)
ps = pm.fit_predict(df[features], df[TREATMENT_COL])
df[SCORE_COL] = ps
df[GROUP_COL] = np.random.randint(0, 2, size=df.shape[0])
return df, features
yield _generate_data
def test_nearest_neighbor_match_by_group(generate_unmatched_data):
df, features = generate_unmatched_data()
psm = NearestNeighborMatch(replace=False,
ratio=1.,
random_state=RANDOM_SEED)
matched = psm.match_by_group(data=df,
treatment_col=TREATMENT_COL,
score_cols=[SCORE_COL],
groupby_col=GROUP_COL)
assert sum(matched[TREATMENT_COL] == 0) == sum(matched[TREATMENT_COL] != 0)
def test_match_optimizer(generate_unmatched_data):
df, features = generate_unmatched_data()
optimizer = MatchOptimizer(treatment_col=TREATMENT_COL,
ps_col=SCORE_COL,
matching_covariates=[SCORE_COL],
min_users_per_group=100,
smd_cols=[SCORE_COL],
dev_cols_transformations={SCORE_COL: np.mean})
matched = optimizer.search_best_match(df)
assert sum(matched[TREATMENT_COL] == 0) == sum(matched[TREATMENT_COL] != 0)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/owen560/Python-causalml.git
git@gitee.com:owen560/Python-causalml.git
owen560
Python-causalml
Python-causalml
v0.12.0

搜索帮助