Ai
1 Star 0 Fork 1

Owen/Python-causalml

forked from 连享会/Python-causalml 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
test_cevae.py 1.55 KB
一键复制 编辑 原始数据 按行查看 历史
ppstacy 提交于 2021-02-02 04:42 +08:00 . add a wrapper for CEVAE (#276)
import pandas as pd
import torch
from causalml.inference.nn import CEVAE
from causalml.dataset import simulate_hidden_confounder
from causalml.metrics import get_cumgain
def test_CEVAE():
y, X, treatment, tau, b, e = simulate_hidden_confounder(n=10000, p=5, sigma=1.0, adj=0.)
outcome_dist = "normal"
latent_dim = 20
hidden_dim = 200
num_epochs = 50
batch_size = 100
learning_rate = 1e-3
learning_rate_decay = 0.1
cevae = CEVAE(outcome_dist=outcome_dist,
latent_dim=latent_dim,
hidden_dim=hidden_dim,
num_epochs=num_epochs,
batch_size=batch_size,
learning_rate=learning_rate,
learning_rate_decay=learning_rate_decay)
cevae.fit(X=torch.tensor(X, dtype=torch.float),
treatment=torch.tensor(treatment, dtype=torch.float),
y=torch.tensor(y, dtype=torch.float))
# check the accuracy of the ite accuracy
ite = cevae.predict(X).flatten()
auuc_metrics = pd.DataFrame({'ite': ite,
'W': treatment,
'y': y,
'treatment_effect_col': tau})
cumgain = get_cumgain(auuc_metrics,
outcome_col='y',
treatment_col='W',
treatment_effect_col='tau')
# Check if the cumulative gain when using the model's prediction is
# higher than it would be under random targeting
assert cumgain['ite'].sum() > cumgain['Random'].sum()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/owen560/Python-causalml.git
git@gitee.com:owen560/Python-causalml.git
owen560
Python-causalml
Python-causalml
v0.12.0

搜索帮助