1 Star 0 Fork 0

deeplearningrepos/tensorflow-DeepFM

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
贡献代码
同步代码
取消
提示: 由于 Git 不支持空文件夾,创建文件夹后会生成空的 .keep 文件
Loading...
README
MIT

tensorflow-DeepFM

This project includes a Tensorflow implementation of DeepFM [1].

NEWS

Usage

Input Format

This implementation requires the input data in the following format:

  • Xi: [[ind1_1, ind1_2, ...], [ind2_1, ind2_2, ...], ..., [indi_1, indi_2, ..., indi_j, ...], ...]
    • indi_j is the feature index of feature field j of sample i in the dataset
  • Xv: [[val1_1, val1_2, ...], [val2_1, val2_2, ...], ..., [vali_1, vali_2, ..., vali_j, ...], ...]
    • vali_j is the feature value of feature field j of sample i in the dataset
    • vali_j can be either binary (1/0, for binary/categorical features) or float (e.g., 10.24, for numerical features)
  • y: target of each sample in the dataset (1/0 for classification, numeric number for regression)

Please see example/DataReader.py an example how to prepare the data in required format for DeepFM.

Init and train a model

import tensorflow as tf
from sklearn.metrics import roc_auc_score

# params
dfm_params = {
    "use_fm": True,
    "use_deep": True,
    "embedding_size": 8,
    "dropout_fm": [1.0, 1.0],
    "deep_layers": [32, 32],
    "dropout_deep": [0.5, 0.5, 0.5],
    "deep_layers_activation": tf.nn.relu,
    "epoch": 30,
    "batch_size": 1024,
    "learning_rate": 0.001,
    "optimizer_type": "adam",
    "batch_norm": 1,
    "batch_norm_decay": 0.995,
    "l2_reg": 0.01,
    "verbose": True,
    "eval_metric": roc_auc_score,
    "random_seed": 2017
}

# prepare training and validation data in the required format
Xi_train, Xv_train, y_train = prepare(...)
Xi_valid, Xv_valid, y_valid = prepare(...)

# init a DeepFM model
dfm = DeepFM(**dfm_params)

# fit a DeepFM model
dfm.fit(Xi_train, Xv_train, y_train)

# make prediction
dfm.predict(Xi_valid, Xv_valid)

# evaluate a trained model
dfm.evaluate(Xi_valid, Xv_valid, y_valid)

You can use early_stopping in the training as follow

dfm.fit(Xi_train, Xv_train, y_train, Xi_valid, Xv_valid, y_valid, early_stopping=True)

You can refit the model on the whole training and validation set as follow

dfm.fit(Xi_train, Xv_train, y_train, Xi_valid, Xv_valid, y_valid, early_stopping=True, refit=True)

You can use the FM or DNN part only by setting the parameter use_fm or use_dnn to False.

Regression

This implementation also supports regression task. To use DeepFM for regression, you can set loss_type as mse. Accordingly, you should use eval_metric for regression, e.g., mse or mae.

Example

Folder example includes an example usage of DeepFM/FM/DNN models for Porto Seguro's Safe Driver Prediction competition on Kaggle.

Please download the data from the competition website and put them into the example/data folder.

To train DeepFM model for this dataset, run

$ cd example
$ python main.py

Please see example/DataReader.py how to parse the raw dataset into the required format for DeepFM.

Performance

DeepFM

dfm

FM

fm

DNN

dnn

Some tips

  • You should tune the parameters for each model in order to get reasonable performance.
  • You can also try to ensemble these models or ensemble them with other models (e.g., XGBoost or LightGBM).

Reference

[1] DeepFM: A Factorization-Machine based Neural Network for CTR Prediction, Huifeng Guo, Ruiming Tang, Yunming Yey, Zhenguo Li, Xiuqiang He.

Acknowledgments

This project gets inspirations from the following projects:

License

MIT

MIT License Copyright (c) 2017 Chenglong Chen Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

简介

Tensorflow implementation of DeepFM for CTR prediction. 展开 收起
README
MIT
取消

发行版

暂无发行版

贡献者

全部

语言

近期动态

不能加载更多了
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/deeplearningrepos/tensorflow-DeepFM.git
git@gitee.com:deeplearningrepos/tensorflow-DeepFM.git
deeplearningrepos
tensorflow-DeepFM
tensorflow-DeepFM
master

搜索帮助