diff --git a/NL2SQL.rar b/NL2SQL.rar new file mode 100644 index 0000000000000000000000000000000000000000..cbad3df9ccda4c2bb6b6a67543eb1b758dc9c92a Binary files /dev/null and b/NL2SQL.rar differ diff --git a/code/.gitignore b/code/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..3a5c32b519411d391f465cbbed763eae840f49e2 --- /dev/null +++ b/code/.gitignore @@ -0,0 +1,15 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# log files +*.log + +# data folders +data/* +database/* +embedding/* + +# pycharm folders +.idea/ diff --git a/code/README.md b/code/README.md new file mode 100644 index 0000000000000000000000000000000000000000..06494fb95ce4fe07157720fc032a4cae4c24fd5e --- /dev/null +++ b/code/README.md @@ -0,0 +1,91 @@ +# CSpider: A Large-Scale Chinese Dataset for Complex and Cross-Domain Semantic Parsing and Text-to-SQL Task + +CSpider is a large Chinese dataset for complex and cross-domain semantic parsing and text-to-SQL task (natural language interfaces for relational databases). It is released with our EMNLP 2019 paper: [A Pilot Study for Chinese SQL Semantic Parsing](https://arxiv.org/abs/1909.13293). This repo contains all code for evaluation, preprocessing, and all baselines used in our paper. Please refer to [the task site](https://taolusi.github.io/CSpider-explorer/) for more general introduction and the leaderboard. + +### Changelog +- `10/2019` We start a Chinese text-to-SQL task with the full dataset translated from [Spider](https://yale-lily.github.io/spider). The submission tutorial and our dataset can be found at our [task site](https://taolusi.github.io/CSpider-explorer/). Please follow it to get your results on the unreleased test data. Thank [Tao Yu](https://taoyds.github.io/) for sharing the test set with us. +- `9/2019` The dataset used in our EMNLP 2019 paper is redivided based on the training and deveploment sets from Spider. The dataset can be downloaded from [here](https://drive.google.com/drive/folders/1SVAdUQqZ2UjjcSCSxhVXRPcXxIMu1r_C?usp=sharing). This dataset is just released to reproduce the results in our paper. To join the CSpider leaderboard and better compare with the original English results, please refer to our [task site](https://taolusi.github.io/CSpider-explorer/) for full dataset. + +### Citation +When you use the CSpider dataset, we would appreciate it if you cite the following: +``` +@inproceedings{min2019pilot, + title={A Pilot Study for Chinese SQL Semantic Parsing}, + author={Min, Qingkai and Shi, Yuefeng and Zhang, Yue}, + booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)}, + pages={3643--3649}, + year={2019} +} +``` +Our dataset is based on [Spider](https://github.com/taoyds/spider/), please cite it too. + +### Baseline models + +#### Environment Setup + +1. The code uses Python 2.7 and [Pytorch 0.2.0](https://pytorch.org/get-started/previous-versions/) GPU, and will update python and Pytorch soon. +2. Install Pytorch via conda: `conda install pytorch=0.2.0 -c pytorch` +3. Install Python dependency: `pip install -r requirements.txt` + +#### Prepare Data, Embeddings, and Pretrained Models +1. Download the data, embedding and database: + - To use the full dataset(recommended), download train/dev data from [Google Drive](https://drive.google.com/drive/folders/1TxCUq1ydPuBdDdHF3MkHT-8zixluQuLa?usp=sharing) or [BaiduNetDisk](https://pan.baidu.com/s/1aDYht6eSIyBgceCjeuSmpw)(code: cgh1) and evaluate on the unreleased test data based on the submission tutorial on our [task site](https://taolusi.github.io/CSpider-explorer/). Specifically, + - Put the downloaded `train.json` and `dev.json` under `chisp/data/char/` directory. To use word-based methods, please do the word segmentation first and put the json files under `chisp/data/word/` directory. + - Put the downloaded `char_emb.txt` under `chisp/embedding/` directory. This is generated from the Tencent multilingual embeddings for the cross-lingual word embeddings schema. To use monolingual embedding schema, step 2 is necessary. + - Put the downloaded `database` directory under `chisp/` directory. + - Put the downloaded `train_gold.sql` and `dev_glod.sql` under `chisp/data/` directory. + - To use the dataset redivided based on the original train and dev data in our paper, download the train/dev/test data from [here](https://drive.google.com/drive/folders/1SVAdUQqZ2UjjcSCSxhVXRPcXxIMu1r_C?usp=sharing). This dataset is released just to reproduce the results in our paper and results based on this dataset cannot join the leaderboard. Specifically, + - Put the downloaded `data`, `database` and `embedding` directory under `chisp/` directory. And you can run all the experiments(step 2 is necessary) shown in our paper. + - `models` directory contains all the pretrained models. +2. (optional) Download the pretrained [Glove](https://nlp.stanford.edu/data/wordvecs/glove.42B.300d.zip), and put it as `chisp/embedding/glove.%dB.%dd.txt` +3. Generate training files for each module: `python preprocess_data.py -s char|word` + +#### Folder/File Description +- ``data/`` contains: + - ``char/`` for character-based raw train/dev/test data, corresponding processed dataset and saved models can be found at ``char/generated_datasets``. + - ``word/`` for word-based raw train/dev/test data, corresponding processed dataset and saved models can be found at ``word/generated_datasets``. +- ``train.py`` is the main file for training. Use ``train_all.sh`` to train all the modules (see below). +- ``test.py`` is the main file for testing. It uses ``supermodel.py`` to call the trained modules and generate SQL queries. In practice, use ``test_gen.sh`` to generate SQL queries. +- ``evaluation.py`` is for evaluation. It uses ``process_sql.py``. In practice, use ``evaluation.sh`` to evaluate the generated SQL queries. + + +#### Training +Run ``train_all.sh`` to train all the modules. +It looks like: +``` +python train.py \ + --data_root path/to/char/or/word/based/generated_data \ + --save_dir path/to/save/trained/module \ + --train_component \ + --emb_path path/to/embeddings + --col_emb_path path/to/corresponding/embeddings/for/column +``` + +#### Testing +Run ``test_gen.sh`` to generate SQL queries. +``test_gen.sh`` looks like: +``` +python test.py \ + --test_data_path path/to/char/or/word/based/raw/dev/or/test/data \ + --models path/to/trained/module \ + --output_path path/to/print/generated/SQL \ + --emb_path path/to/embeddings + --col_emb_path path/to/corresponding/embeddings/for/column +``` + +#### Evaluation +Run ``evaluation.sh`` to evaluate generated SQL queries. +``evaluation.sh`` looks like: +``` +python evaluation.py \ + --gold path/to/gold/dev/or/test/queries \ + --pred path/to/predicted/dev/or/test/queries \ + --etype evaluation/metric \ + --db path/to/database \ + --table path/to/tables \ +``` +``evalution.py`` is from the general evaluation process in [the Spider github page](https://github.com/taoyds/spider). + +#### Acknowledgement + +The implementation is based on [SyntaxSQLNet](https://github.com/taoyds/syntaxSQL). Please cite it too if you use this code. diff --git a/code/config.py b/code/config.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c0deaf001f9ea59073b63dd4c2b32b5ac61562 --- /dev/null +++ b/code/config.py @@ -0,0 +1,26 @@ +class _Config: + def __init__(self): + self.emb_size = 200 + self.col_emb_size = 300 + self.hidden_size = 400 + self.batch_size = 10 + self.epoch = 600 + self.dropout = 0.5 + self.num_layers = 2 + self.learning_rate = 1e-4 + self.toy = False + self.train_emb = False + self.history_type = 'full' # full, part or no + self.nogpu = False + self.table_type = 'std' # choices=['std','no'], help='standard, hierarchical, or no table info' + + def _char_init(self): + self.data_root = "./data/char/generated_datasets" + self.sep_emb = "./embedding/char/separate_emb.txt" + self.comb_emb = "./embedding/char/combine_emb.txt" + def _word_init(self): + self.data_root = "./data/word/generated_datasets" + self.sep_emb = "./embedding/word/separate_emb.txt" + self.comb_emb = "./embedding/word/combine_emb.txt" + +global_config = _Config() diff --git a/code/config/chisp-config.yml b/code/config/chisp-config.yml new file mode 100644 index 0000000000000000000000000000000000000000..acf8b00a2ae51ae6973c22087452b580b11f4061 --- /dev/null +++ b/code/config/chisp-config.yml @@ -0,0 +1,43 @@ +# Allows at most 5 submissions per user per period, where period is 24 hours by default. +max_submissions_per_period: 2 + +# UUID of the worksheet where prediction and evaluation bundles are created for submissions. +log_worksheet_uuid: '0xd2d37ae7db5b40d09aa52850ed34ee1e' + +# Configure the tag that participants use to submit to the competition. +# In this example, any bundle with the tag `some-competition-submit` would be +# considered as an official submission. +submission_tag: cspider-test-submit + +# Configure how to mimic the submitted prediction bundles. When evaluating a submission, +# `new` bundle will replace `old` bundle. +# For a machine learning competition, `old` bundle might be the dev set and `new` bundle +# might be the hidden test set. +predict: + mimic: + - {new: '0x064a87a9db764d7a814726c77b86ff73', old: '0xd2e9047706aa44a38df48bf16b4385f6'} + +# Configure how to evaluate the new prediction bundles. +# In this example, evaluate.py is script that takes in the paths of the test labels and +# predicted labels and outputs the evaluation results. +evaluate: + # Essentially + # cl run evaluate.py:0x089063eb85b64b239b342405b5ebab57 \ + # test.json:0x5538cba32e524fad8b005cd19abb9f95 \ + # predictions.json:{predict}/predictions.json --- \ + # python evaluate.py test.json predictions.json + # where {predict} gets filled in with the uuid of the mimicked bundle above. + dependencies: + - {child_path: evaluation.py, parent_uuid: '0xed9c9d64b6e74056a98a5a592d9286c9'} + - {child_path: dev_gold.txt, parent_uuid: '0x7c4006535b2d40288931afd71cc4e8e5'} + - {child_path: predicted_sql.txt, parent_uuid: '0x7c4006535b2d40288931afd71cc4e8e5'} + - {child_path: tables.json, parent_path: data, parent_uuid: '0xd2e9047706aa44a38df48bf16b4385f6'} + - {child_path: database, parent_uuid: '0x794231a420384b6aa5086407ac21286b'} + command: python evaluation.py --gold dev_gold.txt --pred predicted_sql.txt --etype match --db database --table data/tables.json + +# Define how to extract the scores from the evaluation bundle. +# In this example, result.json is a JSON file outputted from the evaluation step +# with F1 and exact match metrics (e.g. {"f1": 91, "exact_match": 92}). +score_specs: +- {key: '/result.json:dev_f1', name: dev_f1} +- {key: '/result.json:test_f1', name: test_f1} diff --git a/code/config/leaderboard.json b/code/config/leaderboard.json new file mode 100644 index 0000000000000000000000000000000000000000..c620cbf022a1e2c903f91035f9efcb6a77725690 --- /dev/null +++ b/code/config/leaderboard.json @@ -0,0 +1,72 @@ +{ + "config": { + "allow_multiple_models": false, + "allow_orphans": true, + "count_failed_submissions": true, + "evaluate": { + "command": "python evaluation.py --gold dev_gold.txt --pred predicted_sql.txt --etype match --db database --table data/tables.json", + "dependencies": [ + { + "child_path": "evaluation.py", + "parent_path": "", + "parent_uuid": "0xed9c9d64b6e74056a98a5a592d9286c9" + }, + { + "child_path": "dev_gold.txt", + "parent_path": "", + "parent_uuid": "0x7c4006535b2d40288931afd71cc4e8e5" + }, + { + "child_path": "predicted_sql.txt", + "parent_path": "", + "parent_uuid": "0x7c4006535b2d40288931afd71cc4e8e5" + }, + { + "child_path": "tables.json", + "parent_path": "data", + "parent_uuid": "0xd2e9047706aa44a38df48bf16b4385f6" + }, + { + "child_path": "database", + "parent_path": "", + "parent_uuid": "0x794231a420384b6aa5086407ac21286b" + } + ], + "metadata": {}, + "tag": "competition-evaluate" + }, + "host": "https://worksheets.codalab.org", + "log_worksheet_uuid": "0xd2d37ae7db5b40d09aa52850ed34ee1e", + "make_predictions_public": false, + "max_leaderboard_size": 10000, + "max_submissions_per_period": 2, + "max_submissions_total": 10000, + "metadata": {}, + "predict": { + "depth": 10, + "metadata": {}, + "mimic": [ + { + "new": "0x064a87a9db764d7a814726c77b86ff73", + "old": "0xd2e9047706aa44a38df48bf16b4385f6" + } + ], + "tag": "competition-predict" + }, + "quota_period_seconds": 86400, + "refresh_period_seconds": 60, + "score_specs": [ + { + "key": "/result.json:dev_f1", + "name": "dev_f1" + }, + { + "key": "/result.json:test_f1", + "name": "test_f1" + } + ], + "submission_tag": "cspider-test-submit" + }, + "leaderboard": [], + "updated": 1615514134.3646703 +} \ No newline at end of file diff --git a/code/evaluation.py b/code/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..9a3292cb39d60ca5955f03fe7ba844a3e4495968 --- /dev/null +++ b/code/evaluation.py @@ -0,0 +1,866 @@ +################################ +# val: number(float)/string(str)/sql(dict) +# col_unit: (agg_id, col_id, isDistinct(bool)) +# val_unit: (unit_op, col_unit1, col_unit2) +# table_unit: (table_type, col_unit/sql) +# cond_unit: (not_op, op_id, val_unit, val1, val2) +# condition: [cond_unit1, 'and'/'or', cond_unit2, ...] +# sql { +# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) +# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} +# 'where': condition +# 'groupBy': [col_unit1, col_unit2, ...] +# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) +# 'having': condition +# 'limit': None/limit value +# 'intersect': None/sql +# 'except': None/sql +# 'union': None/sql +# } +################################ + +import os, sys +import json +import sqlite3 +import traceback +import argparse + +from utils.process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql + +# Flag to disable value evaluation +DISABLE_VALUE = True +# Flag to disable distinct in select evaluation +DISABLE_DISTINCT = True + + +CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') +JOIN_KEYWORDS = ('join', 'on', 'as') + +WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') +UNIT_OPS = ('none', '-', '+', "*", '/') +AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') +TABLE_TYPE = { + 'sql': "sql", + 'table_unit': "table_unit", +} + +COND_OPS = ('and', 'or') +SQL_OPS = ('intersect', 'union', 'except') +ORDER_OPS = ('desc', 'asc') + + +HARDNESS = { + "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'), + "component2": ('except', 'union', 'intersect') +} + + +def condition_has_or(conds): + return 'or' in conds[1::2] + + +def condition_has_like(conds): + return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]] + + +def condition_has_sql(conds): + for cond_unit in conds[::2]: + val1, val2 = cond_unit[3], cond_unit[4] + if val1 is not None and type(val1) is dict: + return True + if val2 is not None and type(val2) is dict: + return True + return False + + +def val_has_op(val_unit): + return val_unit[0] != UNIT_OPS.index('none') + + +def has_agg(unit): + return unit[0] != AGG_OPS.index('none') + + +def accuracy(count, total): + if count == total: + return 1 + return 0 + + +def recall(count, total): + if count == total: + return 1 + return 0 + + +def F1(acc, rec): + if (acc + rec) == 0: + return 0 + return (2. * acc * rec) / (acc + rec) + + +def get_scores(count, pred_total, label_total): + if pred_total != label_total: + return 0,0,0 + elif count == pred_total: + return 1,1,1 + return 0,0,0 + + +def eval_sel(pred, label): + pred_sel = pred['select'][1] + label_sel = label['select'][1] + label_wo_agg = [unit[1] for unit in label_sel] + pred_total = len(pred_sel) + label_total = len(label_sel) + cnt = 0 + cnt_wo_agg = 0 + + for unit in pred_sel: + if unit in label_sel: + cnt += 1 + label_sel.remove(unit) + if unit[1] in label_wo_agg: + cnt_wo_agg += 1 + label_wo_agg.remove(unit[1]) + + return label_total, pred_total, cnt, cnt_wo_agg + + +def eval_where(pred, label): + pred_conds = [unit for unit in pred['where'][::2]] + label_conds = [unit for unit in label['where'][::2]] + label_wo_agg = [unit[2] for unit in label_conds] + pred_total = len(pred_conds) + label_total = len(label_conds) + cnt = 0 + cnt_wo_agg = 0 + + for unit in pred_conds: + if unit in label_conds: + cnt += 1 + label_conds.remove(unit) + if unit[2] in label_wo_agg: + cnt_wo_agg += 1 + label_wo_agg.remove(unit[2]) + + return label_total, pred_total, cnt, cnt_wo_agg + + +def eval_group(pred, label): + pred_cols = [unit[1] for unit in pred['groupBy']] + label_cols = [unit[1] for unit in label['groupBy']] + pred_total = len(pred_cols) + label_total = len(label_cols) + cnt = 0 + pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols] + label_cols = [label.split(".")[1] if "." in label else label for label in label_cols] + for col in pred_cols: + if col in label_cols: + cnt += 1 + label_cols.remove(col) + return label_total, pred_total, cnt + + +def eval_having(pred, label): + pred_total = label_total = cnt = 0 + if len(pred['groupBy']) > 0: + pred_total = 1 + if len(label['groupBy']) > 0: + label_total = 1 + + pred_cols = [unit[1] for unit in pred['groupBy']] + label_cols = [unit[1] for unit in label['groupBy']] + if pred_total == label_total == 1 \ + and pred_cols == label_cols \ + and pred['having'] == label['having']: + cnt = 1 + + return label_total, pred_total, cnt + + +def eval_order(pred, label): + pred_total = label_total = cnt = 0 + if len(pred['orderBy']) > 0: + pred_total = 1 + if len(label['orderBy']) > 0: + label_total = 1 + if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \ + ((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)): + cnt = 1 + return label_total, pred_total, cnt + + +def eval_and_or(pred, label): + pred_ao = pred['where'][1::2] + label_ao = label['where'][1::2] + pred_ao = set(pred_ao) + label_ao = set(label_ao) + + if pred_ao == label_ao: + return 1,1,1 + return len(pred_ao),len(label_ao),0 + + +def get_nestedSQL(sql): + nested = [] + for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]: + if type(cond_unit[3]) is dict: + nested.append(cond_unit[3]) + if type(cond_unit[4]) is dict: + nested.append(cond_unit[4]) + if sql['intersect'] is not None: + nested.append(sql['intersect']) + if sql['except'] is not None: + nested.append(sql['except']) + if sql['union'] is not None: + nested.append(sql['union']) + return nested + + +def eval_nested(pred, label): + label_total = 0 + pred_total = 0 + cnt = 0 + if pred is not None: + pred_total += 1 + if label is not None: + label_total += 1 + if pred is not None and label is not None: + cnt += Evaluator().eval_exact_match(pred, label) + return label_total, pred_total, cnt + + +def eval_IUEN(pred, label): + lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect']) + lt2, pt2, cnt2 = eval_nested(pred['except'], label['except']) + lt3, pt3, cnt3 = eval_nested(pred['union'], label['union']) + label_total = lt1 + lt2 + lt3 + pred_total = pt1 + pt2 + pt3 + cnt = cnt1 + cnt2 + cnt3 + return label_total, pred_total, cnt + + +def get_keywords(sql): + res = set() + if len(sql['where']) > 0: + res.add('where') + if len(sql['groupBy']) > 0: + res.add('group') + if len(sql['having']) > 0: + res.add('having') + if len(sql['orderBy']) > 0: + res.add(sql['orderBy'][0]) + res.add('order') + if sql['limit'] is not None: + res.add('limit') + if sql['except'] is not None: + res.add('except') + if sql['union'] is not None: + res.add('union') + if sql['intersect'] is not None: + res.add('intersect') + + # or keyword + ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] + if len([token for token in ao if token == 'or']) > 0: + res.add('or') + + cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] + # not keyword + if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0: + res.add('not') + + # in keyword + if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0: + res.add('in') + + # like keyword + if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0: + res.add('like') + + return res + + +def eval_keywords(pred, label): + pred_keywords = get_keywords(pred) + label_keywords = get_keywords(label) + pred_total = len(pred_keywords) + label_total = len(label_keywords) + cnt = 0 + + for k in pred_keywords: + if k in label_keywords: + cnt += 1 + return label_total, pred_total, cnt + + +def count_agg(units): + return len([unit for unit in units if has_agg(unit)]) + + +def count_component1(sql): + count = 0 + if len(sql['where']) > 0: + count += 1 + if len(sql['groupBy']) > 0: + count += 1 + if len(sql['orderBy']) > 0: + count += 1 + if sql['limit'] is not None: + count += 1 + if len(sql['from']['table_units']) > 0: # JOIN + count += len(sql['from']['table_units']) - 1 + + ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] + count += len([token for token in ao if token == 'or']) + cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] + count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) + + return count + + +def count_component2(sql): + nested = get_nestedSQL(sql) + return len(nested) + + +def count_others(sql): + count = 0 + # number of aggregation + agg_count = count_agg(sql['select'][1]) + agg_count += count_agg(sql['where'][::2]) + agg_count += count_agg(sql['groupBy']) + if len(sql['orderBy']) > 0: + agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] + + [unit[2] for unit in sql['orderBy'][1] if unit[2]]) + agg_count += count_agg(sql['having']) + if agg_count > 1: + count += 1 + + # number of select columns + if len(sql['select'][1]) > 1: + count += 1 + + # number of where conditions + if len(sql['where']) > 1: + count += 1 + + # number of group by clauses + if len(sql['groupBy']) > 1: + count += 1 + + return count + + +class Evaluator: + """A simple evaluator""" + def __init__(self): + self.partial_scores = None + + def eval_hardness(self, sql): + count_comp1_ = count_component1(sql) + count_comp2_ = count_component2(sql) + count_others_ = count_others(sql) + + if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0: + return "easy" + elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \ + (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0): + return "medium" + elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \ + (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \ + (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1): + return "hard" + else: + return "extra" + + def eval_exact_match(self, pred, label): + partial_scores = self.eval_partial_match(pred, label) + self.partial_scores = partial_scores + + for _, score in partial_scores.items(): + if score['f1'] != 1: + return 0 + if len(label['from']['table_units']) > 0: + label_tables = sorted(label['from']['table_units']) + pred_tables = sorted(pred['from']['table_units']) + return label_tables == pred_tables + return 1 + + def eval_partial_match(self, pred, label): + res = {} + + label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) + res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) + res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt = eval_group(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt = eval_having(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt = eval_order(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt = eval_and_or(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt = eval_IUEN(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt = eval_keywords(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + return res + + +def isValidSQL(sql, db): + conn = sqlite3.connect(db) + cursor = conn.cursor() + try: + cursor.execute(sql) + except: + return False + return True + + +def print_scores(scores, etype): + levels = ['easy', 'medium', 'hard', 'extra', 'all'] + partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', + 'group', 'order', 'and/or', 'IUEN', 'keywords'] + + print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels)) + counts = [scores[level]['count'] for level in levels] + print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts)) + + if etype in ["all", "exec"]: + print('===================== EXECUTION ACCURACY =====================') + this_scores = [scores[level]['exec'] for level in levels] + print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores)) + + if etype in ["all", "match"]: + print('\n====================== EXACT MATCHING ACCURACY =====================') + exact_scores = [scores[level]['exact'] for level in levels] + print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores)) + print('\n---------------------PARTIAL MATCHING ACCURACY----------------------') + for type_ in partial_types: + this_scores = [scores[level]['partial'][type_]['acc'] for level in levels] + print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) + + print('---------------------- PARTIAL MATCHING RECALL ----------------------') + for type_ in partial_types: + this_scores = [scores[level]['partial'][type_]['rec'] for level in levels] + print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) + + print('---------------------- PARTIAL MATCHING F1 --------------------------') + for type_ in partial_types: + this_scores = [scores[level]['partial'][type_]['f1'] for level in levels] + print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) + + +def evaluate(gold, predict, db_dir, etype, kmaps): + with open(gold) as f: + glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] + + with open(predict) as f: + plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] + # plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")] + # glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")] + evaluator = Evaluator() + + levels = ['easy', 'medium', 'hard', 'extra', 'all'] + partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', + 'group', 'order', 'and/or', 'IUEN', 'keywords'] + entries = [] + scores = {} + + for level in levels: + scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} + scores[level]['exec'] = 0 + for type_ in partial_types: + scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0} + + eval_err_num = 0 + for p, g in zip(plist, glist): + p_str = p[0] + g_str, db = g + db_name = db + db = os.path.join(db_dir, db, db + ".sqlite") + schema = Schema(get_schema(db)) + g_sql = get_sql(schema, g_str) + hardness = evaluator.eval_hardness(g_sql) + scores[hardness]['count'] += 1 + scores['all']['count'] += 1 + + try: + p_sql = get_sql(schema, p_str) + except: + # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql + p_sql = { + "except": None, + "from": { + "conds": [], + "table_units": [] + }, + "groupBy": [], + "having": [], + "intersect": None, + "limit": None, + "orderBy": [], + "select": [ + False, + [] + ], + "union": None, + "where": [] + } + eval_err_num += 1 + print("eval_err_num:{}".format(eval_err_num)) + + # rebuild sql for value evaluation + kmap = kmaps[db_name] + g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema) + g_sql = rebuild_sql_val(g_sql) + g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) + p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema) + p_sql = rebuild_sql_val(p_sql) + p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) + + if etype in ["all", "exec"]: + exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql) + if exec_score: + scores[hardness]['exec'] += 1 + + if etype in ["all", "match"]: + exact_score = evaluator.eval_exact_match(p_sql, g_sql) + partial_scores = evaluator.partial_scores + if exact_score == 0: + print("{} pred: {}".format(hardness,p_str)) + print("{} gold: {}".format(hardness,g_str)) + print("") + scores[hardness]['exact'] += exact_score + scores['all']['exact'] += exact_score + for type_ in partial_types: + if partial_scores[type_]['pred_total'] > 0: + scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc'] + scores[hardness]['partial'][type_]['acc_count'] += 1 + if partial_scores[type_]['label_total'] > 0: + scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec'] + scores[hardness]['partial'][type_]['rec_count'] += 1 + scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1'] + if partial_scores[type_]['pred_total'] > 0: + scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc'] + scores['all']['partial'][type_]['acc_count'] += 1 + if partial_scores[type_]['label_total'] > 0: + scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec'] + scores['all']['partial'][type_]['rec_count'] += 1 + scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1'] + + entries.append({ + 'predictSQL': p_str, + 'goldSQL': g_str, + 'hardness': hardness, + 'exact': exact_score, + 'partial': partial_scores + }) + + for level in levels: + if scores[level]['count'] == 0: + continue + if etype in ["all", "exec"]: + scores[level]['exec'] /= scores[level]['count'] + + if etype in ["all", "match"]: + scores[level]['exact'] /= scores[level]['count'] + for type_ in partial_types: + if scores[level]['partial'][type_]['acc_count'] == 0: + scores[level]['partial'][type_]['acc'] = 0 + else: + scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \ + scores[level]['partial'][type_]['acc_count'] * 1.0 + if scores[level]['partial'][type_]['rec_count'] == 0: + scores[level]['partial'][type_]['rec'] = 0 + else: + scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \ + scores[level]['partial'][type_]['rec_count'] * 1.0 + if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0: + scores[level]['partial'][type_]['f1'] = 1 + else: + scores[level]['partial'][type_]['f1'] = \ + 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( + scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) + + print_scores(scores, etype) + + +def eval_exec_match(db, p_str, g_str, pred, gold): + """ + return 1 if the values between prediction and gold are matching + in the corresponding index. Currently not support multiple col_unit(pairs). + """ + conn = sqlite3.connect(db) + cursor = conn.cursor() + try: + cursor.execute(p_str) + p_res = cursor.fetchall() + except: + return False + + cursor.execute(g_str) + q_res = cursor.fetchall() + + def res_map(res, val_units): + rmap = {} + for idx, val_unit in enumerate(val_units): + key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2])) + rmap[key] = [r[idx] for r in res] + return rmap + + p_val_units = [unit[1] for unit in pred['select'][1]] + q_val_units = [unit[1] for unit in gold['select'][1]] + return res_map(p_res, p_val_units) == res_map(q_res, q_val_units) + + +# Rebuild SQL functions for value evaluation +def rebuild_cond_unit_val(cond_unit): + if cond_unit is None or not DISABLE_VALUE: + return cond_unit + + not_op, op_id, val_unit, val1, val2 = cond_unit + if type(val1) is not dict: + val1 = None + else: + val1 = rebuild_sql_val(val1) + if type(val2) is not dict: + val2 = None + else: + val2 = rebuild_sql_val(val2) + return not_op, op_id, val_unit, val1, val2 + + +def rebuild_condition_val(condition): + if condition is None or not DISABLE_VALUE: + return condition + + res = [] + for idx, it in enumerate(condition): + if idx % 2 == 0: + res.append(rebuild_cond_unit_val(it)) + else: + res.append(it) + return res + + +def rebuild_sql_val(sql): + if sql is None or not DISABLE_VALUE: + return sql + + sql['from']['conds'] = rebuild_condition_val(sql['from']['conds']) + sql['having'] = rebuild_condition_val(sql['having']) + sql['where'] = rebuild_condition_val(sql['where']) + sql['intersect'] = rebuild_sql_val(sql['intersect']) + sql['except'] = rebuild_sql_val(sql['except']) + sql['union'] = rebuild_sql_val(sql['union']) + + return sql + + +# Rebuild SQL functions for foreign key evaluation +def build_valid_col_units(table_units, schema): + col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']] + prefixs = [col_id[:-2] for col_id in col_ids] + valid_col_units= [] + for value in schema.idMap.values(): + if '.' in value and value[:value.index('.')] in prefixs: + valid_col_units.append(value) + return valid_col_units + + +def rebuild_col_unit_col(valid_col_units, col_unit, kmap): + if col_unit is None: + return col_unit + + agg_id, col_id, distinct = col_unit + if col_id in kmap and col_id in valid_col_units: + col_id = kmap[col_id] + if DISABLE_DISTINCT: + distinct = None + return agg_id, col_id, distinct + + +def rebuild_val_unit_col(valid_col_units, val_unit, kmap): + if val_unit is None: + return val_unit + + unit_op, col_unit1, col_unit2 = val_unit + col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap) + col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap) + return unit_op, col_unit1, col_unit2 + + +def rebuild_table_unit_col(valid_col_units, table_unit, kmap): + if table_unit is None: + return table_unit + + table_type, col_unit_or_sql = table_unit + if isinstance(col_unit_or_sql, tuple): + col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap) + return table_type, col_unit_or_sql + + +def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap): + if cond_unit is None: + return cond_unit + + not_op, op_id, val_unit, val1, val2 = cond_unit + val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap) + return not_op, op_id, val_unit, val1, val2 + + +def rebuild_condition_col(valid_col_units, condition, kmap): + for idx in range(len(condition)): + if idx % 2 == 0: + condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap) + return condition + + +def rebuild_select_col(valid_col_units, sel, kmap): + if sel is None: + return sel + distinct, _list = sel + new_list = [] + for it in _list: + agg_id, val_unit = it + new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap))) + if DISABLE_DISTINCT: + distinct = None + return distinct, new_list + + +def rebuild_from_col(valid_col_units, from_, kmap): + if from_ is None: + return from_ + + from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']] + from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap) + return from_ + + +def rebuild_group_by_col(valid_col_units, group_by, kmap): + if group_by is None: + return group_by + + return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by] + + +def rebuild_order_by_col(valid_col_units, order_by, kmap): + if order_by is None or len(order_by) == 0: + return order_by + + direction, val_units = order_by + new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units] + return direction, new_val_units + + +def rebuild_sql_col(valid_col_units, sql, kmap): + if sql is None: + return sql + + sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap) + sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap) + sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap) + sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap) + sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap) + sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap) + sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap) + sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap) + sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap) + + return sql + + +def build_foreign_key_map(entry): + cols_orig = entry["column_names_original"] + tables_orig = entry["table_names_original"] + + # rebuild cols corresponding to idmap in Schema + cols = [] + for col_orig in cols_orig: + if col_orig[0] >= 0: + t = tables_orig[col_orig[0]] + c = col_orig[1] + cols.append("__" + t.lower() + "." + c.lower() + "__") + else: + cols.append("__all__") + + def keyset_in_list(k1, k2, k_list): + for k_set in k_list: + if k1 in k_set or k2 in k_set: + return k_set + new_k_set = set() + k_list.append(new_k_set) + return new_k_set + + foreign_key_list = [] + foreign_keys = entry["foreign_keys"] + for fkey in foreign_keys: + key1, key2 = fkey + key_set = keyset_in_list(key1, key2, foreign_key_list) + key_set.add(key1) + key_set.add(key2) + + foreign_key_map = {} + for key_set in foreign_key_list: + sorted_list = sorted(list(key_set)) + midx = sorted_list[0] + for idx in sorted_list: + foreign_key_map[cols[idx]] = cols[midx] + + return foreign_key_map + + +def build_foreign_key_map_from_json(table): + with open(table) as f: + data = json.load(f) + tables = {} + for entry in data: + tables[entry['db_id']] = build_foreign_key_map(entry) + return tables + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--gold', dest='gold', type=str) + parser.add_argument('--pred', dest='pred', type=str) + parser.add_argument('--db', dest='db', type=str) + parser.add_argument('--table', dest='table', type=str) + parser.add_argument('--etype', dest='etype', type=str) + args = parser.parse_args() + + gold = args.gold + pred = args.pred + db_dir = args.db + table = args.table + etype = args.etype + + assert etype in ["all", "exec", "match"], "Unknown evaluation method" + + kmaps = build_foreign_key_map_from_json(table) + + evaluate(gold, pred, db_dir, etype, kmaps) \ No newline at end of file diff --git a/code/evaluation.sh b/code/evaluation.sh new file mode 100644 index 0000000000000000000000000000000000000000..b62bcbbe31583d1a95ee1c56a342ce75ad81583d --- /dev/null +++ b/code/evaluation.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +schema="char" # char or word +#schema="word" + +embedding="multi" # multi for multi-lingual or mono for monolingual +#embedding="mono" + +SAVE_PATH="data/${schema}/generated_datasets/saved_models_multi_2022-12-05-06:04:27" + +# evaluation +python evaluation.py \ + --gold "data/dev_gold.sql" \ + --pred "${SAVE_PATH}/dev_result.txt" \ + --etype "match" \ + --db "database" \ + --table "data/tables.json" \ + > "result_${schema}_${embedding}.log" \ + 2>&1 & diff --git a/code/models/__init__.py b/code/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/code/models/agg_predictor.py b/code/models/agg_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..f52510b3a3f9872b00e1f98df7208278cf511ec3 --- /dev/null +++ b/code/models/agg_predictor.py @@ -0,0 +1,163 @@ +import json +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from models.net_utils import run_lstm, col_name_encode + + +class AggPredictor(nn.Module): + def __init__(self, N_word, N_col, N_h, N_depth, dropout, gpu, use_hs): + super(AggPredictor, self).__init__() + self.N_h = N_h + self.gpu = gpu + self.use_hs = use_hs + + self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + if N_col: + N_word = N_col + + self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + self.col_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + self.q_num_att = nn.Linear(N_h, N_h) + self.hs_num_att = nn.Linear(N_h, N_h) + self.agg_num_out_q = nn.Linear(N_h, N_h) + self.agg_num_out_hs = nn.Linear(N_h, N_h) + self.agg_num_out_c = nn.Linear(N_h, N_h) + self.agg_num_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 4)) #for 0-3 agg num + + self.q_att = nn.Linear(N_h, N_h) + self.hs_att = nn.Linear(N_h, N_h) + self.agg_out_q = nn.Linear(N_h, N_h) + self.agg_out_hs = nn.Linear(N_h, N_h) + self.agg_out_c = nn.Linear(N_h, N_h) + self.agg_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 5)) #for 1-5 aggregators + + self.softmax = nn.Softmax(dim=1) #dim=1 + self.CE = nn.CrossEntropyLoss() + self.log_softmax = nn.LogSoftmax() + self.mlsml = nn.MultiLabelSoftMarginLoss() + self.bce_logit = nn.BCEWithLogitsLoss() + self.sigm = nn.Sigmoid() + if gpu: + self.cuda() + + + def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col): + max_q_len = max(q_len) + max_hs_len = max(hs_len) + max_col_len = max(col_len) + B = len(q_len) + + q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) + hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) + col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) + + col_emb = [] + for b in range(B): + col_emb.append(col_enc[b, gt_col[b]]) + col_emb = torch.stack(col_emb) + + # Predict agg number + att_val_qc_num = torch.bmm(col_emb.unsqueeze(1), self.q_num_att(q_enc).transpose(1, 2)).view(B, -1) + for idx, num in enumerate(q_len): + if num < max_q_len: + att_val_qc_num[idx, num:] = -100 + att_prob_qc_num = self.softmax(att_val_qc_num) + q_weighted_num = (q_enc * att_prob_qc_num.unsqueeze(2)).sum(1) + + # Same as the above, compute SQL history embedding weighted by column attentions + att_val_hc_num = torch.bmm(col_emb.unsqueeze(1), self.hs_num_att(hs_enc).transpose(1, 2)).view(B, -1) + for idx, num in enumerate(hs_len): + if num < max_hs_len: + att_val_hc_num[idx, num:] = -100 + att_prob_hc_num = self.softmax(att_val_hc_num) + hs_weighted_num = (hs_enc * att_prob_hc_num.unsqueeze(2)).sum(1) + # agg_num_score: (B, 4) + agg_num_score = self.agg_num_out(self.agg_num_out_q(q_weighted_num) + int(self.use_hs)* self.agg_num_out_hs(hs_weighted_num) + self.agg_num_out_c(col_emb)) + + # Predict aggregators + att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B, -1) + for idx, num in enumerate(q_len): + if num < max_q_len: + att_val_qc[idx, num:] = -100 + att_prob_qc = self.softmax(att_val_qc) + q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1) + + # Same as the above, compute SQL history embedding weighted by column attentions + att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B, -1) + for idx, num in enumerate(hs_len): + if num < max_hs_len: + att_val_hc[idx, num:] = -100 + att_prob_hc = self.softmax(att_val_hc) + hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1) + # agg_score: (B, 5) + agg_score = self.agg_out(self.agg_out_q(q_weighted) + int(self.use_hs)* self.agg_out_hs(hs_weighted) + self.agg_out_c(col_emb)) + + score = (agg_num_score, agg_score) + + return score + + + def loss(self, score, truth): + loss = 0 + B = len(truth) + agg_num_score, agg_score = score + #loss for the column number + truth_num = [len(t) for t in truth] # double check truth format and for test cases + data = torch.from_numpy(np.array(truth_num)) + truth_num_var = Variable(data.cuda()) + loss += self.CE(agg_num_score, truth_num_var) + #loss for the key words + T = len(agg_score[0]) + truth_prob = np.zeros((B, T), dtype=np.float32) + for b in range(B): + truth_prob[b][truth[b]] = 1 + data = torch.from_numpy(truth_prob) + truth_var = Variable(data.cuda()) + #loss += self.mlsml(agg_score, truth_var) + #loss += self.bce_logit(agg_score, truth_var) # double check no sigmoid + pred_prob = self.sigm(agg_score) + bce_loss = -torch.mean( 3*(truth_var * \ + torch.log(pred_prob+1e-10)) + \ + (1-truth_var) * torch.log(1-pred_prob+1e-10) ) + loss += bce_loss + + return loss + + + def check_acc(self, score, truth): + num_err, err, tot_err = 0, 0, 0 + B = len(truth) + pred = [] + agg_num_score, agg_score = [x.data.cpu().numpy() for x in score] + for b in range(B): + cur_pred = {} + agg_num = np.argmax(agg_num_score[b]) #double check + cur_pred['agg_num'] = agg_num + cur_pred['agg'] = np.argsort(-agg_score[b])[:agg_num] + pred.append(cur_pred) + + for b, (p, t) in enumerate(zip(pred, truth)): + agg_num, agg = p['agg_num'], p['agg'] + flag = True + if agg_num != len(t): # double check truth format and for test cases + num_err += 1 + flag = False + if flag and set(agg) != set(t): + err += 1 + flag = False + if not flag: + tot_err += 1 + + return np.array((num_err, err, tot_err)) diff --git a/code/models/andor_predictor.py b/code/models/andor_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..49c6240732d80d78ee79eb160814825b436c545e --- /dev/null +++ b/code/models/andor_predictor.py @@ -0,0 +1,94 @@ +import json +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from models.net_utils import run_lstm, col_name_encode + + +class AndOrPredictor(nn.Module): + def __init__(self, N_word, N_col, N_h, N_depth, dropout, gpu, use_hs): + super(AndOrPredictor, self).__init__() + self.N_h = N_h + self.gpu = gpu + self.use_hs = use_hs + + self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + if N_col: + N_word = N_col + + self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + self.q_att = nn.Linear(N_h, N_h) + self.hs_att = nn.Linear(N_h, N_h) + self.ao_out_q = nn.Linear(N_h, N_h) + self.ao_out_hs = nn.Linear(N_h, N_h) + self.ao_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 2)) #for and/or + + self.softmax = nn.Softmax(dim=1) #dim=1 + self.CE = nn.CrossEntropyLoss() + self.log_softmax = nn.LogSoftmax() + self.mlsml = nn.MultiLabelSoftMarginLoss() + self.bce_logit = nn.BCEWithLogitsLoss() + self.sigm = nn.Sigmoid() + if gpu: + self.cuda() + + def forward(self, q_emb_var, q_len, hs_emb_var, hs_len): + max_q_len = max(q_len) + max_hs_len = max(hs_len) + B = len(q_len) + + q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) + hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) + + att_np_q = np.ones((B, max_q_len)) + att_val_q = torch.from_numpy(att_np_q).float() + att_val_q = Variable(att_val_q.cuda()) + for idx, num in enumerate(q_len): + if num < max_q_len: + att_val_q[idx, num:] = -100 + att_prob_q = self.softmax(att_val_q) + q_weighted = (q_enc * att_prob_q.unsqueeze(2)).sum(1) + + # Same as the above, compute SQL history embedding weighted by column attentions + att_np_h = np.ones((B, max_hs_len)) + att_val_h = torch.from_numpy(att_np_h).float() + att_val_h = Variable(att_val_h.cuda()) + for idx, num in enumerate(hs_len): + if num < max_hs_len: + att_val_h[idx, num:] = -100 + att_prob_h = self.softmax(att_val_h) + hs_weighted = (hs_enc * att_prob_h.unsqueeze(2)).sum(1) + # ao_score: (B, 2) + ao_score = self.ao_out(self.ao_out_q(q_weighted) + int(self.use_hs)* self.ao_out_hs(hs_weighted)) + + return ao_score + + + def loss(self, score, truth): + loss = 0 + data = torch.from_numpy(np.array(truth)) + truth_var = Variable(data.cuda()) + loss = self.CE(score, truth_var) + + return loss + + + def check_acc(self, score, truth): + err = 0 + B = len(score) + pred = [] + for b in range(B): + pred.append(np.argmax(score[b].data.cpu().numpy())) + for b, (p, t) in enumerate(zip(pred, truth)): + if p != t: + err += 1 + + return err diff --git a/code/models/col_predictor.py b/code/models/col_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..d3ada978e69bf24d9bccad2ee2aadc3edb82a70e --- /dev/null +++ b/code/models/col_predictor.py @@ -0,0 +1,200 @@ +import json +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from models.net_utils import run_lstm, col_name_encode + + +class ColPredictor(nn.Module): + def __init__(self, N_word, N_col, N_h, N_depth, dropout, gpu, use_hs): + super(ColPredictor, self).__init__() + self.N_h = N_h + self.gpu = gpu + self.use_hs = use_hs + + self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + if N_col: + N_word = N_col + + self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + self.col_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + self.q_num_att = nn.Linear(N_h, N_h) + self.hs_num_att = nn.Linear(N_h, N_h) + self.col_num_out_q = nn.Linear(N_h, N_h) + self.col_num_out_hs = nn.Linear(N_h, N_h) + self.col_num_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 6)) # num of cols: 1-3 + + self.q_att = nn.Linear(N_h, N_h) + self.hs_att = nn.Linear(N_h, N_h) + self.col_out_q = nn.Linear(N_h, N_h) + self.col_out_c = nn.Linear(N_h, N_h) + self.col_out_hs = nn.Linear(N_h, N_h) + self.col_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 1)) + + self.softmax = nn.Softmax(dim=1) #dim=1 + self.CE = nn.CrossEntropyLoss() + self.log_softmax = nn.LogSoftmax() + self.mlsml = nn.MultiLabelSoftMarginLoss() + self.bce_logit = nn.BCEWithLogitsLoss() + self.sigm = nn.Sigmoid() + if gpu: + self.cuda() + + def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len): + + max_q_len = max(q_len) + max_hs_len = max(hs_len) + max_col_len = max(col_len) + B = len(q_len) + + q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) + hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) + col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) + + # Predict column number: 1-3 + # att_val_qc_num: (B, max_col_len, max_q_len) + att_val_qc_num = torch.bmm(col_enc, self.q_num_att(q_enc).transpose(1, 2)) + for idx, num in enumerate(col_len): + if num < max_col_len: + att_val_qc_num[idx, num:, :] = -100 + for idx, num in enumerate(q_len): + if num < max_q_len: + att_val_qc_num[idx, :, num:] = -100 + att_prob_qc_num = self.softmax(att_val_qc_num.view((-1, max_q_len))).view(B, -1, max_q_len) + # q_weighted_num: (B, hid_dim) + q_weighted_num = (q_enc.unsqueeze(1) * att_prob_qc_num.unsqueeze(3)).sum(2).sum(1) + + # Same as the above, compute SQL history embedding weighted by column attentions + # att_val_hc_num: (B, max_col_len, max_hs_len) + att_val_hc_num = torch.bmm(col_enc, self.hs_num_att(hs_enc).transpose(1, 2)) + for idx, num in enumerate(hs_len): + if num < max_hs_len: + att_val_hc_num[idx, :, num:] = -100 + for idx, num in enumerate(col_len): + if num < max_col_len: + att_val_hc_num[idx, num:, :] = -100 + att_prob_hc_num = self.softmax(att_val_hc_num.view((-1, max_hs_len))).view(B, -1, max_hs_len) + hs_weighted_num = (hs_enc.unsqueeze(1) * att_prob_hc_num.unsqueeze(3)).sum(2).sum(1) + # self.col_num_out: (B, 3) + col_num_score = self.col_num_out(self.col_num_out_q(q_weighted_num) + int(self.use_hs)* self.col_num_out_hs(hs_weighted_num)) + + # Predict columns. + att_val_qc = torch.bmm(col_enc, self.q_att(q_enc).transpose(1, 2)) + for idx, num in enumerate(q_len): + if num < max_q_len: + att_val_qc[idx, :, num:] = -100 + att_prob_qc = self.softmax(att_val_qc.view((-1, max_q_len))).view(B, -1, max_q_len) + # q_weighted: (B, max_col_len, hid_dim) + q_weighted = (q_enc.unsqueeze(1) * att_prob_qc.unsqueeze(3)).sum(2) + + # Same as the above, compute SQL history embedding weighted by column attentions + att_val_hc = torch.bmm(col_enc, self.hs_att(hs_enc).transpose(1, 2)) + for idx, num in enumerate(hs_len): + if num < max_hs_len: + att_val_hc[idx, :, num:] = -100 + att_prob_hc = self.softmax(att_val_hc.view((-1, max_hs_len))).view(B, -1, max_hs_len) + hs_weighted = (hs_enc.unsqueeze(1) * att_prob_hc.unsqueeze(3)).sum(2) + # Compute prediction scores + # self.col_out.squeeze(): (B, max_col_len) + col_score = self.col_out(self.col_out_q(q_weighted) + int(self.use_hs)* self.col_out_hs(hs_weighted) + self.col_out_c(col_enc)).view(B,-1) + + for idx, num in enumerate(col_len): + if num < max_col_len: + col_score[idx, num:] = -100 + + score = (col_num_score, col_score) + + return score + + def loss(self, score, truth): + #here suppose truth looks like [[[1, 4], 3], [], ...] + loss = 0 + B = len(truth) + col_num_score, col_score = score + #loss for the column number + truth_num = [len(t) - 1 for t in truth] # double check truth format and for test cases + data = torch.from_numpy(np.array(truth_num)) + truth_num_var = Variable(data.cuda()) + loss += self.CE(col_num_score, truth_num_var) + #loss for the key words + T = len(col_score[0]) + # print("T {}".format(T)) + truth_prob = np.zeros((B, T), dtype=np.float32) + for b in range(B): + gold_l = [] + for t in truth[b]: + if isinstance(t, list): + gold_l.extend(t) + else: + gold_l.append(t) + truth_prob[b][gold_l] = 1 + data = torch.from_numpy(truth_prob) + # print("data {}".format(data)) + # print("data {}".format(data.cuda())) + truth_var = Variable(data.cuda()) + #loss += self.mlsml(col_score, truth_var) + #loss += self.bce_logit(col_score, truth_var) # double check no sigmoid + pred_prob = self.sigm(col_score) + bce_loss = -torch.mean( 3*(truth_var * \ + torch.log(pred_prob+1e-10)) + \ + (1-truth_var) * torch.log(1-pred_prob+1e-10) ) + loss += bce_loss + + return loss + + + def check_acc(self, score, truth): + num_err, err, tot_err = 0, 0, 0 + B = len(truth) + pred = [] + col_num_score, col_score = [x.data.cpu().numpy() for x in score] + for b in range(B): + cur_pred = {} + col_num = np.argmax(col_num_score[b]) + 1 #double check + cur_pred['col_num'] = col_num + cur_pred['col'] = np.argsort(-col_score[b])[:col_num] + pred.append(cur_pred) + + for b, (p, t) in enumerate(zip(pred, truth)): + col_num, col = p['col_num'], p['col'] + flag = True + if col_num != len(t): # double check truth format and for test cases + num_err += 1 + flag = False + #to eval col predicts, if the gold sql has JOIN and foreign key col, then both fks are acceptable + fk_list = [] + regular = [] + for l in t: + if isinstance(l, list): + fk_list.append(l) + else: + regular.append(l) + + if flag: #double check + for c in col: + for fk in fk_list: + if c in fk: + fk_list.remove(fk) + for r in regular: + if c == r: + regular.remove(r) + + if len(fk_list) != 0 or len(regular) != 0: + err += 1 + flag = False + + if not flag: + tot_err += 1 + + return np.array((num_err, err, tot_err)) diff --git a/code/models/desasc_limit_predictor.py b/code/models/desasc_limit_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..1bec67c805c61aa2865887cc9389bb3b24004208 --- /dev/null +++ b/code/models/desasc_limit_predictor.py @@ -0,0 +1,105 @@ +import json +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from models.net_utils import run_lstm, col_name_encode + + +class DesAscLimitPredictor(nn.Module): + def __init__(self, N_word, N_col, N_h, N_depth, dropout, gpu, use_hs): + super(DesAscLimitPredictor, self).__init__() + self.N_h = N_h + self.gpu = gpu + self.use_hs = use_hs + + self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + if N_col: + N_word = N_col + + self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + self.col_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + + self.q_att = nn.Linear(N_h, N_h) + self.hs_att = nn.Linear(N_h, N_h) + self.dat_out_q = nn.Linear(N_h, N_h) + self.dat_out_hs = nn.Linear(N_h, N_h) + self.dat_out_c = nn.Linear(N_h, N_h) + self.dat_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 4)) #for 4 desc/asc limit/none combinations + + self.softmax = nn.Softmax(dim=1) #dim=1 + self.CE = nn.CrossEntropyLoss() + self.log_softmax = nn.LogSoftmax() + self.mlsml = nn.MultiLabelSoftMarginLoss() + self.bce_logit = nn.BCEWithLogitsLoss() + self.sigm = nn.Sigmoid() + if gpu: + self.cuda() + + def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col): + max_q_len = max(q_len) + max_hs_len = max(hs_len) + max_col_len = max(col_len) + B = len(q_len) + + q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) + hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) + col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) + + # get target/predicted column's embedding + # col_emb: (B, hid_dim) + col_emb = [] + for b in range(B): + col_emb.append(col_enc[b, gt_col[b]]) + col_emb = torch.stack(col_emb) # [B, dim] + # self.q_att(q_enc).transpose(1, 2): [B, dim, max_q_len] + att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B, -1) + for idx, num in enumerate(q_len): + if num < max_q_len: + att_val_qc[idx, num:] = -100 + att_prob_qc = self.softmax(att_val_qc) + q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1) + + # Same as the above, compute SQL history embedding weighted by column attentions + att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B, -1) + for idx, num in enumerate(hs_len): + if num < max_hs_len: + att_val_hc[idx, num:] = -100 + att_prob_hc = self.softmax(att_val_hc) + hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1) + # dat_score: (B, 4) + dat_score = self.dat_out(self.dat_out_q(q_weighted) + int(self.use_hs)* self.dat_out_hs(hs_weighted) + self.dat_out_c(col_emb)) + + return dat_score + + + def loss(self, score, truth): + loss = 0 + data = torch.from_numpy(np.array(truth)) + truth_var = Variable(data.cuda()) + loss = self.CE(score, truth_var) + + return loss + + + def check_acc(self, score, truth): + err = 0 + B = len(score) + pred = [] + for b in range(B): + pred.append(np.argmax(score[b].data.cpu().numpy())) + for b, (p, t) in enumerate(zip(pred, truth)): + if p != t: + err += 1 + + return err diff --git a/code/models/having_predictor.py b/code/models/having_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..b6355c6b680cb61ab323945534b5187b501dae0c --- /dev/null +++ b/code/models/having_predictor.py @@ -0,0 +1,103 @@ +import json +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from models.net_utils import run_lstm, col_name_encode + + +class HavingPredictor(nn.Module): + def __init__(self, N_word, N_col, N_h, N_depth, dropout, gpu, use_hs): + super(HavingPredictor, self).__init__() + self.N_h = N_h + self.gpu = gpu + self.use_hs = use_hs + + self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + if N_col: + N_word = N_col + + self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + self.col_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + self.q_att = nn.Linear(N_h, N_h) + self.hs_att = nn.Linear(N_h, N_h) + self.hv_out_q = nn.Linear(N_h, N_h) + self.hv_out_hs = nn.Linear(N_h, N_h) + self.hv_out_c = nn.Linear(N_h, N_h) + self.hv_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 2)) #for having/none + + self.softmax = nn.Softmax(dim=1) #dim=1 + self.CE = nn.CrossEntropyLoss() + self.log_softmax = nn.LogSoftmax() + self.mlsml = nn.MultiLabelSoftMarginLoss() + self.bce_logit = nn.BCEWithLogitsLoss() + self.sigm = nn.Sigmoid() + if gpu: + self.cuda() + + def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col): + max_q_len = max(q_len) + max_hs_len = max(hs_len) + max_col_len = max(col_len) + B = len(q_len) + + q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) + hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) + col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) + + # get target/predicted column's embedding + # col_emb: (B, hid_dim) + col_emb = [] + for b in range(B): + col_emb.append(col_enc[b, gt_col[b]]) + col_emb = torch.stack(col_emb) + att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B,-1) + for idx, num in enumerate(q_len): + if num < max_q_len: + att_val_qc[idx, num:] = -100 + att_prob_qc = self.softmax(att_val_qc) + q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1) + + # Same as the above, compute SQL history embedding weighted by column attentions + att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B,-1) + for idx, num in enumerate(hs_len): + if num < max_hs_len: + att_val_hc[idx, num:] = -100 + att_prob_hc = self.softmax(att_val_hc) + hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1) + # hv_score: (B, 2) + hv_score = self.hv_out(self.hv_out_q(q_weighted) + int(self.use_hs)* self.hv_out_hs(hs_weighted) + self.hv_out_c(col_emb)) + + return hv_score + + + def loss(self, score, truth): + loss = 0 + data = torch.from_numpy(np.array(truth)) + truth_var = Variable(data.cuda()) + loss = self.CE(score, truth_var) + + return loss + + + def check_acc(self, score, truth): + err = 0 + B = len(score) + pred = [] + for b in range(B): + pred.append(np.argmax(score[b].data.cpu().numpy())) + for b, (p, t) in enumerate(zip(pred, truth)): + if p != t: + err += 1 + + return err diff --git a/code/models/keyword_predictor.py b/code/models/keyword_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..daf6547e627bdc189e0ea103c04e50d92396fd9d --- /dev/null +++ b/code/models/keyword_predictor.py @@ -0,0 +1,160 @@ +import json +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from models.net_utils import run_lstm, col_name_encode + + +class KeyWordPredictor(nn.Module): + '''Predict if the next token is (SQL key words): + WHERE, GROUP BY, ORDER BY. excluding SELECT (it is a must)''' + def __init__(self, N_word, N_col, N_h, N_depth, dropout, gpu, use_hs): + super(KeyWordPredictor, self).__init__() + self.N_h = N_h + self.gpu = gpu + self.use_hs = use_hs + + self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + if N_col: + N_word = N_col + + self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + self.kw_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + self.q_num_att = nn.Linear(N_h, N_h) + self.hs_num_att = nn.Linear(N_h, N_h) + self.kw_num_out_q = nn.Linear(N_h, N_h) + self.kw_num_out_hs = nn.Linear(N_h, N_h) + self.kw_num_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 4)) # num of key words: 0-3 + + self.q_att = nn.Linear(N_h, N_h) + self.hs_att = nn.Linear(N_h, N_h) + self.kw_out_q = nn.Linear(N_h, N_h) + self.kw_out_hs = nn.Linear(N_h, N_h) + self.kw_out_kw = nn.Linear(N_h, N_h) + self.kw_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 1)) + + self.softmax = nn.Softmax(dim=1) #dim=1 + self.CE = nn.CrossEntropyLoss() + self.log_softmax = nn.LogSoftmax() + self.mlsml = nn.MultiLabelSoftMarginLoss() + self.bce_logit = nn.BCEWithLogitsLoss() + self.sigm = nn.Sigmoid() + if gpu: + self.cuda() + + def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, kw_emb_var, kw_len): + max_q_len = max(q_len) + max_hs_len = max(hs_len) + B = len(q_len) + + q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) + hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) + kw_enc, _ = run_lstm(self.kw_lstm, kw_emb_var, kw_len) + + # Predict key words number: 0-3 + att_val_qkw_num = torch.bmm(kw_enc, self.q_num_att(q_enc).transpose(1, 2)) + for idx, num in enumerate(q_len): + if num < max_q_len: + att_val_qkw_num[idx, :, num:] = -100 + att_prob_qkw_num = self.softmax(att_val_qkw_num.view((-1, max_q_len))).view(B, -1, max_q_len) + # q_weighted: (B, hid_dim) + q_weighted_num = (q_enc.unsqueeze(1) * att_prob_qkw_num.unsqueeze(3)).sum(2).sum(1) + + # Same as the above, compute SQL history embedding weighted by key words attentions + att_val_hskw_num = torch.bmm(kw_enc, self.hs_num_att(hs_enc).transpose(1, 2)) + for idx, num in enumerate(hs_len): + if num < max_hs_len: + att_val_hskw_num[idx, :, num:] = -100 + att_prob_hskw_num = self.softmax(att_val_hskw_num.view((-1, max_hs_len))).view(B, -1, max_hs_len) + hs_weighted_num = (hs_enc.unsqueeze(1) * att_prob_hskw_num.unsqueeze(3)).sum(2).sum(1) + # Compute prediction scores + # self.kw_num_out: (B, 4) + kw_num_score = self.kw_num_out(self.kw_num_out_q(q_weighted_num) + int(self.use_hs)* self.kw_num_out_hs(hs_weighted_num)) + + # Predict key words: WHERE, GROUP BY, ORDER BY. + att_val_qkw = torch.bmm(kw_enc, self.q_att(q_enc).transpose(1, 2)) + for idx, num in enumerate(q_len): + if num < max_q_len: + att_val_qkw[idx, :, num:] = -100 + att_prob_qkw = self.softmax(att_val_qkw.view((-1, max_q_len))).view(B, -1, max_q_len) + # q_weighted: (B, 3, hid_dim) + q_weighted = (q_enc.unsqueeze(1) * att_prob_qkw.unsqueeze(3)).sum(2) + + # Same as the above, compute SQL history embedding weighted by key words attentions + att_val_hskw = torch.bmm(kw_enc, self.hs_att(hs_enc).transpose(1, 2)) + for idx, num in enumerate(hs_len): + if num < max_hs_len: + att_val_hskw[idx, :, num:] = -100 + att_prob_hskw = self.softmax(att_val_hskw.view((-1, max_hs_len))).view(B, -1, max_hs_len) + hs_weighted = (hs_enc.unsqueeze(1) * att_prob_hskw.unsqueeze(3)).sum(2) + # Compute prediction scores + # self.kw_out.squeeze(): (B, 3) + kw_score = self.kw_out(self.kw_out_q(q_weighted) + int(self.use_hs)* self.kw_out_hs(hs_weighted) + self.kw_out_kw(kw_enc)).view(B,-1) + + score = (kw_num_score, kw_score) + + return score + + def loss(self, score, truth): + loss = 0 + B = len(truth) + kw_num_score, kw_score = score + #loss for the key word number + truth_num = [len(t) for t in truth] # double check to exclude select + data = torch.from_numpy(np.array(truth_num)) + truth_num_var = Variable(data.cuda()) + loss += self.CE(kw_num_score, truth_num_var) + #loss for the key words + T = len(kw_score[0]) + truth_prob = np.zeros((B, T), dtype=np.float32) + for b in range(B): + truth_prob[b][truth[b]] = 1 + data = torch.from_numpy(truth_prob) + truth_var = Variable(data.cuda()) + #loss += self.mlsml(kw_score, truth_var) + #loss += self.bce_logit(kw_score, truth_var) # double check no sigmoid for kw + pred_prob = self.sigm(kw_score) + bce_loss = -torch.mean( 3*(truth_var * \ + torch.log(pred_prob+1e-10)) + \ + (1-truth_var) * torch.log(1-pred_prob+1e-10) ) + loss += bce_loss + + return loss + + + def check_acc(self, score, truth): + num_err, err, tot_err = 0, 0, 0 + B = len(truth) + pred = [] + kw_num_score, kw_score = [x.data.cpu().numpy() for x in score] + for b in range(B): + cur_pred = {} + kw_num = np.argmax(kw_num_score[b]) + cur_pred['kw_num'] = kw_num + cur_pred['kw'] = np.argsort(-kw_score[b])[:kw_num] + pred.append(cur_pred) + + for b, (p, t) in enumerate(zip(pred, truth)): + kw_num, kw = p['kw_num'], p['kw'] + flag = True + if kw_num != len(t): # double check to excluding select + num_err += 1 + flag = False + if flag and set(kw) != set(t): + err += 1 + flag = False + if not flag: + tot_err += 1 + + return np.array((num_err, err, tot_err)) diff --git a/code/models/multisql_predictor.py b/code/models/multisql_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..3ff4e2ac19e620942a7b41e62036a75071070a7f --- /dev/null +++ b/code/models/multisql_predictor.py @@ -0,0 +1,113 @@ +import json +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from models.net_utils import run_lstm, col_name_encode + + +class MultiSqlPredictor(nn.Module): + '''Predict if the next token is (multi SQL key words): + NONE, EXCEPT, INTERSECT, or UNION.''' + def __init__(self, N_word, N_col, N_h, N_depth, dropout, gpu, use_hs): + super(MultiSqlPredictor, self).__init__() + self.N_h = N_h + self.gpu = gpu + self.use_hs = use_hs + + self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + if N_col: + N_word = N_col + + self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + self.mkw_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + self.q_att = nn.Linear(N_h, N_h) + self.hs_att = nn.Linear(N_h, N_h) + self.multi_out_q = nn.Linear(N_h, N_h) + self.multi_out_hs = nn.Linear(N_h, N_h) + self.multi_out_c = nn.Linear(N_h, N_h) + self.multi_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 1)) + + self.softmax = nn.Softmax(dim=1) #dim=1 + self.CE = nn.CrossEntropyLoss() + self.log_softmax = nn.LogSoftmax() + self.mlsml = nn.MultiLabelSoftMarginLoss() + self.bce_logit = nn.BCEWithLogitsLoss() + self.sigm = nn.Sigmoid() + + if gpu: + self.cuda() + + def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, mkw_emb_var, mkw_len): + # print("q_emb_shape:{} hs_emb_shape:{}".format(q_emb_var.size(), hs_emb_var.size())) + max_q_len = max(q_len) + max_hs_len = max(hs_len) + B = len(q_len) + + # q_enc: (B, max_q_len, hid_dim) + # hs_enc: (B, max_hs_len, hid_dim) + # mkw: (B, 4, hid_dim) + q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) + hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) + mkw_enc, _ = run_lstm(self.mkw_lstm, mkw_emb_var, mkw_len) + + # Compute attention values between multi SQL key words and question tokens. + # qmkw_att(q_enc).transpose(1, 2): (B, hid_dim, max_q_len) + # att_val_qmkw: (B, 4, max_q_len) + # print("mkw_enc {} q_enc {}".format(mkw_enc.size(), self.q_att(q_enc).transpose(1, 2).size())) + att_val_qmkw = torch.bmm(mkw_enc, self.q_att(q_enc).transpose(1, 2)) + # assign appended positions values -100 + for idx, num in enumerate(q_len): + if num < max_q_len: + att_val_qmkw[idx, :, num:] = -100 + # att_prob_qmkw: (B, 4, max_q_len) + att_prob_qmkw = self.softmax(att_val_qmkw.view((-1, max_q_len))).view(B, -1, max_q_len) + # q_enc.unsqueeze(1): (B, 1, max_q_len, hid_dim) + # att_prob_qmkw.unsqueeze(3): (B, 4, max_q_len, 1) + # q_weighted: (B, 4, hid_dim) + q_weighted = (q_enc.unsqueeze(1) * att_prob_qmkw.unsqueeze(3)).sum(2) + + # Same as the above, compute SQL history embedding weighted by key words attentions + att_val_hsmkw = torch.bmm(mkw_enc, self.hs_att(hs_enc).transpose(1, 2)) + for idx, num in enumerate(hs_len): + if num < max_hs_len: + att_val_hsmkw[idx, :, num:] = -100 + att_prob_hsmkw = self.softmax(att_val_hsmkw.view((-1, max_hs_len))).view(B, -1, max_hs_len) + hs_weighted = (hs_enc.unsqueeze(1) * att_prob_hsmkw.unsqueeze(3)).sum(2) + + # Compute prediction scores + # self.multi_out.squeeze(): (B, 4, 1) -> (B, 4) + mulit_score = self.multi_out(self.multi_out_q(q_weighted) + int(self.use_hs)* self.multi_out_hs(hs_weighted) + self.multi_out_c(mkw_enc)).view(B,-1) + + return mulit_score + + + def loss(self, score, truth): + data = torch.from_numpy(np.array(truth)) + truth_var = Variable(data.cuda()) + loss = self.CE(score, truth_var) + + return loss + + + def check_acc(self, score, truth): + err = 0 + B = len(score) + pred = [] + for b in range(B): + pred.append(np.argmax(score[b].data.cpu().numpy())) + for b, (p, t) in enumerate(zip(pred, truth)): + if p != t: + err += 1 + + return err diff --git a/code/models/net_utils.py b/code/models/net_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fc9815c7687bf230b71d3b96ec917ea9eb2a9d6e --- /dev/null +++ b/code/models/net_utils.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.autograd import Variable + +def run_lstm(lstm, inp, inp_len, hidden=None): + # Run the LSTM using packed sequence. + # This requires to first sort the input according to its length. + sort_perm = np.array(sorted(range(len(inp_len)), + key=lambda k:inp_len[k], reverse=True)) + sort_inp_len = inp_len[sort_perm] + sort_perm_inv = np.argsort(sort_perm) + if inp.is_cuda: + sort_perm = torch.LongTensor(sort_perm).cuda() + sort_perm_inv = torch.LongTensor(sort_perm_inv).cuda() + + lstm_inp = nn.utils.rnn.pack_padded_sequence(inp[sort_perm], + sort_inp_len, batch_first=True) + if hidden is None: + lstm_hidden = None + else: + lstm_hidden = (hidden[0][:, sort_perm], hidden[1][:, sort_perm]) + + sort_ret_s, sort_ret_h = lstm(lstm_inp, lstm_hidden) + ret_s = nn.utils.rnn.pad_packed_sequence( + sort_ret_s, batch_first=True)[0][sort_perm_inv] + ret_h = (sort_ret_h[0][:, sort_perm_inv], sort_ret_h[1][:, sort_perm_inv]) + return ret_s, ret_h + + +def col_name_encode(name_inp_var, name_len, col_len, enc_lstm): + #Encode the columns. + #The embedding of a column name is the last state of its LSTM output. + name_hidden, _ = run_lstm(enc_lstm, name_inp_var, name_len) + name_out = name_hidden[tuple(range(len(name_len))), name_len-1] + ret = torch.FloatTensor( + len(col_len), max(col_len), name_out.size()[1]).zero_() + if name_out.is_cuda: + ret = ret.cuda() + + st = 0 + for idx, cur_len in enumerate(col_len): + ret[idx, :cur_len] = name_out.data[st:st+cur_len] + st += cur_len + ret_var = Variable(ret) + + return ret_var, col_len + diff --git a/code/models/op_predictor.py b/code/models/op_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb6ae759a54425ccd9b4a5c38f0f0fd8d7e9996 --- /dev/null +++ b/code/models/op_predictor.py @@ -0,0 +1,176 @@ +import json +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from models.net_utils import run_lstm, col_name_encode + + +class OpPredictor(nn.Module): + def __init__(self, N_word, N_col, N_h, N_depth, dropout, gpu, use_hs): + super(OpPredictor, self).__init__() + self.N_h = N_h + self.gpu = gpu + self.use_hs = use_hs + + self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + if N_col: + N_word = N_col + + self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + self.col_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + self.q_num_att = nn.Linear(N_h, N_h) + self.hs_num_att = nn.Linear(N_h, N_h) + self.op_num_out_q = nn.Linear(N_h, N_h) + self.op_num_out_hs = nn.Linear(N_h, N_h) + self.op_num_out_c = nn.Linear(N_h, N_h) + self.op_num_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 2)) #for 1-2 op num, could be changed + + self.q_att = nn.Linear(N_h, N_h) + self.hs_att = nn.Linear(N_h, N_h) + self.op_out_q = nn.Linear(N_h, N_h) + self.op_out_hs = nn.Linear(N_h, N_h) + self.op_out_c = nn.Linear(N_h, N_h) + self.op_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 11)) #for 11 operators + + self.softmax = nn.Softmax(dim=1) #dim=1 + self.CE = nn.CrossEntropyLoss() + self.log_softmax = nn.LogSoftmax() + self.mlsml = nn.MultiLabelSoftMarginLoss() + self.bce_logit = nn.BCEWithLogitsLoss() + self.sigm = nn.Sigmoid() + if gpu: + self.cuda() + + def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col): + max_q_len = max(q_len) + max_hs_len = max(hs_len) + max_col_len = max(col_len) + B = len(q_len) + + q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) + hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) + col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) + + # get target/predicted column's embedding + # col_emb: (B, hid_dim) + col_emb = [] + for b in range(B): + col_emb.append(col_enc[b, gt_col[b]]) + col_emb = torch.stack(col_emb) + + # Predict op number + att_val_qc_num = torch.bmm(col_emb.unsqueeze(1), self.q_num_att(q_enc).transpose(1, 2)).view(B,-1) + for idx, num in enumerate(q_len): + if num < max_q_len: + att_val_qc_num[idx, num:] = -100 + att_prob_qc_num = self.softmax(att_val_qc_num) + q_weighted_num = (q_enc * att_prob_qc_num.unsqueeze(2)).sum(1) + + # Same as the above, compute SQL history embedding weighted by column attentions + att_val_hc_num = torch.bmm(col_emb.unsqueeze(1), self.hs_num_att(hs_enc).transpose(1, 2)).view(B,-1) + for idx, num in enumerate(hs_len): + if num < max_hs_len: + att_val_hc_num[idx, num:] = -100 + att_prob_hc_num = self.softmax(att_val_hc_num) + hs_weighted_num = (hs_enc * att_prob_hc_num.unsqueeze(2)).sum(1) + # op_num_score: (B, 2) + op_num_score = self.op_num_out(self.op_num_out_q(q_weighted_num) + int(self.use_hs)* self.op_num_out_hs(hs_weighted_num) + self.op_num_out_c(col_emb)) + + # Compute attention values between selected column and question tokens. + # q_enc.transpose(1, 2): (B, hid_dim, max_q_len) + # col_emb.unsqueeze(1): (B, 1, hid_dim) + # att_val_qc: (B, max_q_len) + # print("col_emb {} q_enc {}".format(col_emb.unsqueeze(1).size(),self.q_att(q_enc).transpose(1, 2).size())) + att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B,-1) + # assign appended positions values -100 + for idx, num in enumerate(q_len): + if num < max_q_len: + att_val_qc[idx, num:] = -100 + # att_prob_qc: (B, max_q_len) + att_prob_qc = self.softmax(att_val_qc) + # q_enc: (B, max_q_len, hid_dim) + # att_prob_qc.unsqueeze(2): (B, max_q_len, 1) + # q_weighted: (B, hid_dim) + q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1) + + # Same as the above, compute SQL history embedding weighted by column attentions + att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B,-1) + for idx, num in enumerate(hs_len): + if num < max_hs_len: + att_val_hc[idx, num:] = -100 + att_prob_hc = self.softmax(att_val_hc) + hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1) + + # Compute prediction scores + # op_score: (B, 10) + op_score = self.op_out(self.op_out_q(q_weighted) + int(self.use_hs)* self.op_out_hs(hs_weighted) + self.op_out_c(col_emb)) + + score = (op_num_score, op_score) + + return score + + + def loss(self, score, truth): + loss = 0 + B = len(truth) + op_num_score, op_score = score + truth = [t if len(t) <= 2 else t[:2] for t in truth] + # loss for the op number + truth_num = [len(t)-1 for t in truth] #num_score 0 maps to 1 in truth + data = torch.from_numpy(np.array(truth_num)) + truth_num_var = Variable(data.cuda()) + loss += self.CE(op_num_score, truth_num_var) + # loss for op + T = len(op_score[0]) + truth_prob = np.zeros((B, T), dtype=np.float32) + for b in range(B): + truth_prob[b][truth[b]] = 1 + data = torch.from_numpy(np.array(truth_prob)) + truth_var = Variable(data.cuda()) + #loss += self.mlsml(op_score, truth_var) + #loss += self.bce_logit(op_score, truth_var) + pred_prob = self.sigm(op_score) + bce_loss = -torch.mean( 3*(truth_var * \ + torch.log(pred_prob+1e-10)) + \ + (1-truth_var) * torch.log(1-pred_prob+1e-10) ) + loss += bce_loss + + return loss + + + def check_acc(self, score, truth): + num_err, err, tot_err = 0, 0, 0 + B = len(truth) + pred = [] + op_num_score, op_score = [x.data.cpu().numpy() for x in score] + for b in range(B): + cur_pred = {} + op_num = np.argmax(op_num_score[b]) + 1 #num_score 0 maps to 1 in truth, must have at least one op + cur_pred['op_num'] = op_num + cur_pred['op'] = np.argsort(-op_score[b])[:op_num] + pred.append(cur_pred) + + for b, (p, t) in enumerate(zip(pred, truth)): + op_num, op = p['op_num'], p['op'] + flag = True + if op_num != len(t): + num_err += 1 + flag = False + if flag and set(op) != set(t): + err += 1 + flag = False + if not flag: + tot_err += 1 + + return np.array((num_err, err, tot_err)) diff --git a/code/models/root_teminal_predictor.py b/code/models/root_teminal_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..730a3e074936738c1034f42e50e5155a4c48e6aa --- /dev/null +++ b/code/models/root_teminal_predictor.py @@ -0,0 +1,103 @@ +import json +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from models.net_utils import run_lstm, col_name_encode + + +class RootTeminalPredictor(nn.Module): + def __init__(self, N_word, N_col, N_h, N_depth, dropout, gpu, use_hs): + super(RootTeminalPredictor, self).__init__() + self.N_h = N_h + self.gpu = gpu + self.use_hs = use_hs + + self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + if N_col: + N_word = N_col + + self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + self.col_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), + num_layers=N_depth, batch_first=True, + dropout=dropout, bidirectional=True) + + self.q_att = nn.Linear(N_h, N_h) + self.hs_att = nn.Linear(N_h, N_h) + self.rt_out_q = nn.Linear(N_h, N_h) + self.rt_out_hs = nn.Linear(N_h, N_h) + self.rt_out_c = nn.Linear(N_h, N_h) + self.rt_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 2)) #for 2 operators + + self.softmax = nn.Softmax(dim=1) #dim=1 + self.CE = nn.CrossEntropyLoss() + self.log_softmax = nn.LogSoftmax() + self.mlsml = nn.MultiLabelSoftMarginLoss() + self.bce_logit = nn.BCEWithLogitsLoss() + self.sigm = nn.Sigmoid() + if gpu: + self.cuda() + + def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col): + max_q_len = max(q_len) + max_hs_len = max(hs_len) + max_col_len = max(col_len) + B = len(q_len) + + q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) + hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) + col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) + + # get target/predicted column's embedding + # col_emb: (B, hid_dim) + col_emb = [] + for b in range(B): + col_emb.append(col_enc[b, gt_col[b]]) + col_emb = torch.stack(col_emb) + att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B,-1) + for idx, num in enumerate(q_len): + if num < max_q_len: + att_val_qc[idx, num:] = -100 + att_prob_qc = self.softmax(att_val_qc) + q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1) + + # Same as the above, compute SQL history embedding weighted by column attentions + att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B,-1) + for idx, num in enumerate(hs_len): + if num < max_hs_len: + att_val_hc[idx, num:] = -100 + att_prob_hc = self.softmax(att_val_hc) + hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1) + # rt_score: (B, 2) + rt_score = self.rt_out(self.rt_out_q(q_weighted) + int(self.use_hs)* self.rt_out_hs(hs_weighted) + self.rt_out_c(col_emb)) + + return rt_score + + + def loss(self, score, truth): + loss = 0 + data = torch.from_numpy(np.array(truth)) + truth_var = Variable(data.cuda()) + loss = self.CE(score, truth_var) + + return loss + + + def check_acc(self, score, truth): + err = 0 + B = len(score) + pred = [] + for b in range(B): + pred.append(np.argmax(score[b].data.cpu().numpy())) + for b, (p, t) in enumerate(zip(pred, truth)): + if p != t: + err += 1 + + return err diff --git a/code/models/supermodel.py b/code/models/supermodel.py new file mode 100644 index 0000000000000000000000000000000000000000..0142a0ff4a8b623b7d8deb2c4ba48b10f4cd1c31 --- /dev/null +++ b/code/models/supermodel.py @@ -0,0 +1,893 @@ +import json +import torch +import datetime +import time +import argparse +import numpy as np +import torch.nn as nn +import traceback +from collections import defaultdict + +from utils.word_embedding import WordEmbedding +from models.agg_predictor import AggPredictor +from models.col_predictor import ColPredictor +from models.desasc_limit_predictor import DesAscLimitPredictor +from models.having_predictor import HavingPredictor +from models.keyword_predictor import KeyWordPredictor +from models.multisql_predictor import MultiSqlPredictor +from models.root_teminal_predictor import RootTeminalPredictor +from models.andor_predictor import AndOrPredictor +from models.op_predictor import OpPredictor +from preprocess_data import index_to_column_name + + +SQL_OPS = ('none', 'intersect', 'union', 'except') +KW_OPS = ('where', 'groupBy', 'orderBy') +AGG_OPS = ('max', 'min', 'count', 'sum', 'avg') +ROOT_TERM_OPS = ("root", "terminal") +COND_OPS = ("and", "or") +DEC_ASC_OPS = (("asc", True), ("asc", False), ("desc", True), ("desc", False)) +NEW_WHERE_OPS = ( + '=', + '>', + '<', + '>=', + '<=', + '!=', + 'like', + 'not in', + 'in', + 'between') +KW_WITH_COL = ("select", "where", "groupBy", "orderBy", "having") + + +class Stack: + def __init__(self): + self.items = [] + + def isEmpty(self): + return self.items == [] + + def push(self, item): + self.items.append(item) + + def pop(self): + return self.items.pop() + + def peek(self): + return self.items[len(self.items) - 1] + + def size(self): + return len(self.items) + + def insert(self, i, x): + return self.items.insert(i, x) + + +def to_batch_tables(tables, B, table_type): + # col_lens = [] + col_seq = [] + ts = [ + tables["table_names"], + tables["column_names"], + tables["column_types"]] + tname_toks = [x.split(" ") for x in ts[0]] + col_type = ts[2] + cols = [x.split(" ") for xid, x in ts[1]] + tab_seq = [xid for xid, x in ts[1]] + cols_add = [] + for tid, col, ct in zip(tab_seq, cols, col_type): + col_one = [ct] + if tid == -1: + tabn = ["all"] + else: + if table_type == "no": + tabn = [] + else: + tabn = tname_toks[tid] + for t in tabn: + if t not in col: + col_one.append(t) + col_one.extend(col) + cols_add.append(col_one) + + col_seq = [cols_add] * B + + return col_seq + + +class SuperModel(nn.Module): + def __init__( + self, + word_emb, + col_emb, + N_word, + N_col, + N_h, + N_depth, + dropout, + gpu=True, + trainable_emb=False, + table_type="std", + use_hs=True): + super(SuperModel, self).__init__() + self.gpu = gpu + self.N_h = N_h + self.N_depth = N_depth + self.dropout = dropout + self.trainable_emb = trainable_emb + self.table_type = table_type + self.use_hs = use_hs + self.SQL_TOK = [ + '', + '', + 'WHERE', + 'AND', + 'EQL', + 'GT', + 'LT', + ''] + + # word embedding layer + self.embed_layer = WordEmbedding(word_emb, N_word, gpu, + self.SQL_TOK, trainable=trainable_emb) + self.q_embed_layer = self.embed_layer + + if not col_emb: + N_col = None + else: + self.embed_layer = WordEmbedding( + col_emb, N_col, gpu, self.SQL_TOK, trainable=trainable_emb) + + # initial all modules + self.multi_sql = MultiSqlPredictor( + N_word=N_word, + N_col=N_col, + N_h=N_h, + N_depth=N_depth, + dropout=dropout, + gpu=gpu, + use_hs=use_hs) + self.multi_sql.eval() + + self.key_word = KeyWordPredictor( + N_word=N_word, + N_col=N_col, + N_h=N_h, + N_depth=N_depth, + dropout=dropout, + gpu=gpu, + use_hs=use_hs) + self.key_word.eval() + + self.col = ColPredictor( + N_word=N_word, + N_col=N_col, + N_h=N_h, + N_depth=N_depth, + dropout=dropout, + gpu=gpu, + use_hs=use_hs) + self.col.eval() + + self.op = OpPredictor( + N_word=N_word, + N_col=N_col, + N_h=N_h, + N_depth=N_depth, + dropout=dropout, + gpu=gpu, + use_hs=use_hs) + self.op.eval() + + self.agg = AggPredictor( + N_word=N_word, + N_col=N_col, + N_h=N_h, + N_depth=N_depth, + dropout=dropout, + gpu=gpu, + use_hs=use_hs) + self.agg.eval() + + self.root_teminal = RootTeminalPredictor( + N_word=N_word, + N_col=N_col, + N_h=N_h, + N_depth=N_depth, + dropout=dropout, + gpu=gpu, + use_hs=use_hs) + self.root_teminal.eval() + + self.des_asc = DesAscLimitPredictor( + N_word=N_word, + N_col=N_col, + N_h=N_h, + N_depth=N_depth, + dropout=dropout, + gpu=gpu, + use_hs=use_hs) + self.des_asc.eval() + + self.having = HavingPredictor( + N_word=N_word, + N_col=N_col, + N_h=N_h, + N_depth=N_depth, + dropout=dropout, + gpu=gpu, + use_hs=use_hs) + self.having.eval() + + self.andor = AndOrPredictor( + N_word=N_word, + N_col=N_col, + N_h=N_h, + N_depth=N_depth, + dropout=dropout, + gpu=gpu, + use_hs=use_hs) + self.andor.eval() + + self.softmax = nn.Softmax(dim=1) # dim=1 + self.CE = nn.CrossEntropyLoss() + self.log_softmax = nn.LogSoftmax() + self.mlsml = nn.MultiLabelSoftMarginLoss() + self.bce_logit = nn.BCEWithLogitsLoss() + self.sigm = nn.Sigmoid() + if gpu: + self.cuda() + self.path_not_found = 0 + + def forward(self, q_seq, history, tables): + # if self.part: + # return self.part_forward(q_seq,history,tables) + # else: + return self.full_forward(q_seq, history, tables) + + def full_forward(self, q_seq, history, tables): + B = len(q_seq) + # print("q_seq:{}".format(q_seq)) + # print("Batch size:{}".format(B)) + q_emb_var, q_len = self.q_embed_layer.gen_x_q_batch(q_seq) + col_seq = to_batch_tables(tables, B, self.table_type) + + col_emb_var, col_name_len, col_len = self.embed_layer.gen_col_batch( + col_seq) + + mkw_emb_var = self.embed_layer.gen_word_list_embedding( + ["none", "except", "intersect", "union"], (B)) + mkw_len = np.full(q_len.shape, 4, dtype=np.int64) + kw_emb_var = self.embed_layer.gen_word_list_embedding( + ["where", "group by", "order by"], (B)) + kw_len = np.full(q_len.shape, 3, dtype=np.int64) + + stack = Stack() + stack.push(("root", None)) + history = [["root"]] * B + andor_cond = "" + has_limit = False + # sql = {} + current_sql = {} + sql_stack = [] + idx_stack = [] + kw_stack = [] + kw = "" + nested_label = "" + has_having = False + + timeout = time.time() + 2 # set timer to prevent infinite recursion in SQL generation + failed = False + while not stack.isEmpty(): + if time.time() > timeout: + failed = True + break + vet = stack.pop() + # print(vet) + hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch(history) + if len(idx_stack) > 0 and stack.size() < idx_stack[-1]: + # print("pop!!!!!!!!!!!!!!!!!!!!!!") + idx_stack.pop() + current_sql = sql_stack.pop() + kw = kw_stack.pop() + # current_sql = current_sql["sql"] + # history.append(vet) + # print("hs_emb:{} hs_len:{}".format(hs_emb_var.size(),hs_len.size())) + if isinstance(vet, tuple) and vet[0] == "root": + if history[0][-1] != "root": + history[0].append("root") + hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch( + history) + if vet[1] != "original": + idx_stack.append(stack.size()) + sql_stack.append(current_sql) + kw_stack.append(kw) + else: + idx_stack.append(stack.size()) + sql_stack.append(sql_stack[-1]) + kw_stack.append(kw) + if "sql" in current_sql: + current_sql["nested_sql"] = {} + current_sql["nested_label"] = nested_label + current_sql = current_sql["nested_sql"] + elif isinstance(vet[1], dict): + vet[1]["sql"] = {} + current_sql = vet[1]["sql"] + elif vet[1] != "original": + current_sql["sql"] = {} + current_sql = current_sql["sql"] + # print("q_emb_var:{} hs_emb_var:{} mkw_emb_var:{}".format(q_emb_var.size(),hs_emb_var.size(),mkw_emb_var.size())) + if vet[1] == "nested" or vet[1] == "original": + stack.push("none") + history[0].append("none") + else: + score = self.multi_sql.forward( + q_emb_var, q_len, hs_emb_var, hs_len, mkw_emb_var, mkw_len) + label = np.argmax(score[0].data.cpu().numpy()) + label = SQL_OPS[label] + history[0].append(label) + stack.push(label) + if label != "none": + nested_label = label + + elif vet in ('intersect', 'except', 'union'): + stack.push(("root", "nested")) + stack.push(("root", "original")) + # history[0].append("root") + elif vet == "none": + score = self.key_word.forward( + q_emb_var, q_len, hs_emb_var, hs_len, kw_emb_var, kw_len) + kw_num_score, kw_score = [x.data.cpu().numpy() for x in score] + # print("kw_num_score:{}".format(kw_num_score)) + # print("kw_score:{}".format(kw_score)) + num_kw = np.argmax(kw_num_score[0]) + kw_score = list(np.argsort(-kw_score[0])[:num_kw]) + kw_score.sort(reverse=True) + # print("num_kw:{}".format(num_kw)) + for kw in kw_score: + stack.push(KW_OPS[kw]) + stack.push("select") + elif vet in ("select", "orderBy", "where", "groupBy", "having"): + kw = vet + current_sql[kw] = [] + history[0].append(vet) + stack.push(("col", vet)) + # score = self.andor.forward(q_emb_var,q_len,hs_emb_var,hs_len) + # label = score[0].data.cpu().numpy() + # andor_cond = COND_OPS[label] + # history.append("") + # elif vet == "groupBy": + # score = self.having.forward(q_emb_var,q_len,hs_emb_var,hs_len,col_emb_var,col_len,) + elif isinstance(vet, tuple) and vet[0] == "col": + # print("q_emb_var:{} hs_emb_var:{} col_emb_var:{}".format(q_emb_var.size(), hs_emb_var.size(),col_emb_var.size())) + score = self.col.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + col_emb_var, + col_len, + col_name_len) + col_num_score, col_score = [ + x.data.cpu().numpy() for x in score] + col_num = np.argmax(col_num_score[0]) + 1 # double check + cols = np.argsort(-col_score[0])[:col_num] + # print(col_num) + # print("col_num_score:{}".format(col_num_score)) + # print("col_score:{}".format(col_score)) + for col in cols: + if vet[1] == "where": + stack.push(("op", "where", col)) + elif vet[1] != "groupBy": + stack.push(("agg", vet[1], col)) + elif vet[1] == "groupBy": + history[0].append(index_to_column_name(col, tables)) + current_sql[kw].append( + index_to_column_name(col, tables)) + # predict and or or when there is multi col in where condition + if col_num > 1 and vet[1] == "where": + score = self.andor.forward( + q_emb_var, q_len, hs_emb_var, hs_len) + label = np.argmax(score[0].data.cpu().numpy()) + andor_cond = COND_OPS[label] + current_sql[kw].append(andor_cond) + if vet[1] == "groupBy" and col_num > 0: + score = self.having.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + col_emb_var, + col_len, + col_name_len, + np.full( + B, + cols[0], + dtype=np.int64)) + label = np.argmax(score[0].data.cpu().numpy()) + if label == 1: + has_having = (label == 1) + # stack.insert(-col_num,"having") + stack.push("having") + # history.append(index_to_column_name(cols[-1], tables[0])) + elif isinstance(vet, tuple) and vet[0] == "agg": + history[0].append(index_to_column_name(vet[2], tables)) + if vet[1] not in ("having", "orderBy"): # DEBUG-ed 20180817 + try: + current_sql[kw].append( + index_to_column_name(vet[2], tables)) + except Exception as e: + # print(e) + traceback.print_exc() + print( + "history:{},current_sql:{} stack:{}".format( + history[0], current_sql, stack.items)) + print("idx_stack:{}".format(idx_stack)) + print("sql_stack:{}".format(sql_stack)) + exit(1) + hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch( + history) + + score = self.agg.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + col_emb_var, + col_len, + col_name_len, + np.full( + B, + vet[2], + dtype=np.int64)) + agg_num_score, agg_score = [ + x.data.cpu().numpy() for x in score] + agg_num = np.argmax(agg_num_score[0]) # double check + agg_idxs = np.argsort(-agg_score[0])[:agg_num] + # print("agg:{}".format([AGG_OPS[agg] for agg in agg_idxs])) + if len(agg_idxs) > 0: + history[0].append(AGG_OPS[agg_idxs[0]]) + if vet[1] not in ("having", "orderBy"): + current_sql[kw].append(AGG_OPS[agg_idxs[0]]) + elif vet[1] == "orderBy": + # DEBUG-ed 20180817 + stack.push(("des_asc", vet[2], AGG_OPS[agg_idxs[0]])) + else: + stack.push( + ("op", "having", vet[2], AGG_OPS[agg_idxs[0]])) + for agg in agg_idxs[1:]: + history[0].append(index_to_column_name(vet[2], tables)) + history[0].append(AGG_OPS[agg]) + if vet[1] not in ("having", "orderBy"): + current_sql[kw].append( + index_to_column_name(vet[2], tables)) + current_sql[kw].append(AGG_OPS[agg]) + elif vet[1] == "orderBy": + stack.push(("des_asc", vet[2], AGG_OPS[agg])) + else: + stack.push(("op", "having", vet[2], agg_idxs)) + if len(agg_idxs) == 0: + if vet[1] not in ("having", "orderBy"): + current_sql[kw].append("none_agg") + elif vet[1] == "orderBy": + stack.push(("des_asc", vet[2], "none_agg")) + else: + stack.push(("op", "having", vet[2], "none_agg")) + # current_sql[kw].append([AGG_OPS[agg] for agg in agg_idxs]) + # if vet[1] == "having": + # stack.push(("op","having",vet[2],agg_idxs)) + # if vet[1] == "orderBy": + # stack.push(("des_asc",vet[2],agg_idxs)) + # if vet[1] == "groupBy" and has_having: + # stack.push("having") + elif isinstance(vet, tuple) and vet[0] == "op": + if vet[1] == "where": + # current_sql[kw].append(index_to_column_name(vet[2], tables)) + history[0].append(index_to_column_name(vet[2], tables)) + hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch( + history) + + score = self.op.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + col_emb_var, + col_len, + col_name_len, + np.full( + B, + vet[2], + dtype=np.int64)) + + op_num_score, op_score = [x.data.cpu().numpy() for x in score] + # num_score 0 maps to 1 in truth, must have at least one op + op_num = np.argmax(op_num_score[0]) + 1 + ops = np.argsort(-op_score[0])[:op_num] + # current_sql[kw].append([NEW_WHERE_OPS[op] for op in ops]) + if op_num > 0: + history[0].append(NEW_WHERE_OPS[ops[0]]) + if vet[1] == "having": + stack.push(("root_teminal", vet[2], vet[3], ops[0])) + else: + stack.push(("root_teminal", vet[2], ops[0])) + # current_sql[kw].append(NEW_WHERE_OPS[ops[0]]) + for op in ops[1:]: + history[0].append(index_to_column_name(vet[2], tables)) + history[0].append(NEW_WHERE_OPS[op]) + # current_sql[kw].append(index_to_column_name(vet[2], tables)) + # current_sql[kw].append(NEW_WHERE_OPS[op]) + if vet[1] == "having": + stack.push(("root_teminal", vet[2], vet[3], op)) + else: + stack.push(("root_teminal", vet[2], op)) + # stack.push(("root_teminal",vet[2])) + elif isinstance(vet, tuple) and vet[0] == "root_teminal": + score = self.root_teminal.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + col_emb_var, + col_len, + col_name_len, + np.full( + B, + vet[1], + dtype=np.int64)) + + label = np.argmax(score[0].data.cpu().numpy()) + label = ROOT_TERM_OPS[label] + if len(vet) == 4: + current_sql[kw].append( + index_to_column_name(vet[1], tables)) + current_sql[kw].append(vet[2]) + current_sql[kw].append(NEW_WHERE_OPS[vet[3]]) + else: + # print("kw:{}".format(kw)) + try: + current_sql[kw].append( + index_to_column_name(vet[1], tables)) + except Exception as e: + # print(e) + traceback.print_exc() + print( + "history:{},current_sql:{} stack:{}".format( + history[0], current_sql, stack.items)) + print("idx_stack:{}".format(idx_stack)) + print("sql_stack:{}".format(sql_stack)) + exit(1) + current_sql[kw].append(NEW_WHERE_OPS[vet[2]]) + if label == "root": + history[0].append("root") + current_sql[kw].append({}) + # current_sql = current_sql[kw][-1] + stack.push(("root", current_sql[kw][-1])) + else: + current_sql[kw].append("terminal") + elif isinstance(vet, tuple) and vet[0] == "des_asc": + current_sql[kw].append(index_to_column_name(vet[1], tables)) + current_sql[kw].append(vet[2]) + score = self.des_asc.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + col_emb_var, + col_len, + col_name_len, + np.full( + B, + vet[1], + dtype=np.int64)) + label = np.argmax(score[0].data.cpu().numpy()) + dec_asc, has_limit = DEC_ASC_OPS[label] + history[0].append(dec_asc) + current_sql[kw].append(dec_asc) + current_sql[kw].append(has_limit) + # print("{}".format(current_sql)) + + if failed: + return None + print("history:{}".format(history[0])) + if len(sql_stack) > 0: + current_sql = sql_stack[0] + # print("{}".format(current_sql)) + return current_sql + + def gen_col(self, col, table, table_alias_dict): + colname = table["column_names_original"][col[2]][1] + table_idx = table["column_names_original"][col[2]][0] + if table_idx not in table_alias_dict: + return colname + return "T{}.{}".format(table_alias_dict[table_idx], colname) + + def gen_group_by(self, sql, kw, table, table_alias_dict): + ret = [] + for i in range(0, len(sql)): + # if len(sql[i+1]) == 0: + # if sql[i+1] == "none_agg": + ret.append(self.gen_col(sql[i], table, table_alias_dict)) + # else: + # ret.append("{}({})".format(sql[i+1], self.gen_col(sql[i], table, table_alias_dict))) + # for agg in sql[i+1]: + # ret.append("{}({})".format(agg,gen_col(sql[i],table,table_alias_dict))) + return "{} {}".format(kw, ",".join(ret)) + + def gen_select(self, sql, kw, table, table_alias_dict): + ret = [] + for i in range(0, len(sql), 2): + # if len(sql[i+1]) == 0: + if sql[i + + 1] == "none_agg" or not isinstance(sql[i + + 1], str): # DEBUG-ed 20180817 + ret.append(self.gen_col(sql[i], table, table_alias_dict)) + else: + ret.append("{}({})".format( + sql[i + 1], self.gen_col(sql[i], table, table_alias_dict))) + # for agg in sql[i+1]: + # ret.append("{}({})".format(agg,gen_col(sql[i],table,table_alias_dict))) + return "{} {}".format(kw, ",".join(ret)) + + def gen_where(self, sql, table, table_alias_dict): + if len(sql) == 0: + return "" + start_idx = 0 + andor = "and" + if isinstance(sql[0], str): + start_idx += 1 + andor = sql[0] + ret = [] + for i in range(start_idx, len(sql), 3): + # rewrite to stop a bug + if i + 2 < len(sql): + col = self.gen_col(sql[i], table, table_alias_dict) + op = sql[i + 1] + val = sql[i + 2] + where_item = "" + else: + break + if val == "terminal": + where_item = "{} {} '{}'".format(col, op, val) + else: + val = self.gen_sql(val, table) + where_item = "{} {} ({})".format(col, op, val) + if op == "between": + # TODO temprarily fixed + where_item += " and 'terminal'" + ret.append(where_item) + return "where {}".format(" {} ".format(andor).join(ret)) + + def gen_orderby(self, sql, table, table_alias_dict): + ret = [] + limit = "" + if sql[-1]: + limit = "limit 1" + for i in range(0, len(sql), 4): + if sql[i + + 1] == "none_agg" or not isinstance(sql[i + + 1], str): # DEBUG-ed 20180817 + ret.append("{} {}".format(self.gen_col( + sql[i], table, table_alias_dict), sql[i + 2])) + else: + ret.append("{}({}) {}".format( + sql[i + 1], self.gen_col(sql[i], table, table_alias_dict), sql[i + 2])) + return "order by {} {}".format(",".join(ret), limit) + + def gen_having(self, sql, table, table_alias_dict): + ret = [] + for i in range(0, len(sql), 4): + if sql[i + 1] == "none_agg": + col = self.gen_col(sql[i], table, table_alias_dict) + else: + col = "{}({})".format( + sql[i + 1], self.gen_col(sql[i], table, table_alias_dict)) + op = sql[i + 2] + val = sql[i + 3] + if val == "terminal": + ret.append("{} {} '{}'".format(col, op, val)) + else: + val = self.gen_sql(val, table) + ret.append("{} {} ({})".format(col, op, val)) + return "having {}".format(",".join(ret)) + + def find_shortest_path(self, start, end, graph): + stack = [[start, []]] + visited = set() + while len(stack) > 0: + ele, history = stack.pop() + if ele == end: + return history + for node in graph[ele]: + if node[0] not in visited: + stack.append((node[0], history + [(node[0], node[1])])) + visited.add(node[0]) + print("table {} table {}".format(start, end)) + # print("could not find path!!!!!{}".format(self.path_not_found)) + self.path_not_found += 1 + # return [] + + def gen_from(self, candidate_tables, table): + def find(d, col): + if d[col] == -1: + return col + return find(d, d[col]) + + def union(d, c1, c2): + r1 = find(d, c1) + r2 = find(d, c2) + if r1 == r2: + return + d[r1] = r2 + + ret = "" + if len(candidate_tables) <= 1: + if len(candidate_tables) == 1: + ret = "from {}".format( + table["table_names_original"][list(candidate_tables)[0]]) + else: + ret = "from {}".format(table["table_names_original"][0]) + # TODO: temporarily settings + return {}, ret + # print("candidate:{}".format(candidate_tables)) + table_alias_dict = {} + uf_dict = {} + for t in candidate_tables: + uf_dict[t] = -1 + idx = 1 + graph = defaultdict(list) + for acol, bcol in table["foreign_keys"]: + t1 = table["column_names"][acol][0] + t2 = table["column_names"][bcol][0] + graph[t1].append((t2, (acol, bcol))) + graph[t2].append((t1, (bcol, acol))) + # if t1 in candidate_tables and t2 in candidate_tables: + # r1 = find(uf_dict,t1) + # r2 = find(uf_dict,t2) + # if r1 == r2: + # continue + # union(uf_dict,t1,t2) + # if len(ret) == 0: + # ret = "from {} as T{} join {} as T{} on T{}.{}=T{}.{}".format(table["table_names"][t1],idx,table["table_names"][t2], + # idx+1,idx,table["column_names_original"][acol][1],idx+1, + # table["column_names_original"][bcol][1]) + # table_alias_dict[t1] = idx + # table_alias_dict[t2] = idx+1 + # idx += 2 + # else: + # if t1 in table_alias_dict: + # old_t = t1 + # new_t = t2 + # acol,bcol = bcol,acol + # elif t2 in table_alias_dict: + # old_t = t2 + # new_t = t1 + # else: + # ret = "{} join {} as T{} join {} as T{} on T{}.{}=T{}.{}".format(ret,table["table_names"][t1], idx, + # table["table_names"][t2], + # idx + 1, idx, + # table["column_names_original"][acol][1], + # idx + 1, + # table["column_names_original"][bcol][1]) + # table_alias_dict[t1] = idx + # table_alias_dict[t2] = idx + 1 + # idx += 2 + # continue + # ret = "{} join {} as T{} on T{}.{}=T{}.{}".format(ret,new_t,idx,idx,table["column_names_original"][acol][1], + # table_alias_dict[old_t],table["column_names_original"][bcol][1]) + # table_alias_dict[new_t] = idx + # idx += 1 + # visited = set() + candidate_tables = list(candidate_tables) + start = candidate_tables[0] + table_alias_dict[start] = idx + idx += 1 + ret = "from {} as T1".format(table["table_names_original"][start]) + try: + for end in candidate_tables[1:]: + if end in table_alias_dict: + continue + path = self.find_shortest_path(start, end, graph) + prev_table = start + if not path: + table_alias_dict[end] = idx + idx += 1 + ret = "{} join {} as T{}".format( + ret, table["table_names_original"][end], table_alias_dict[end], ) + continue + for node, (acol, bcol) in path: + if node in table_alias_dict: + prev_table = node + continue + table_alias_dict[node] = idx + idx += 1 + ret = "{} join {} as T{} on T{}.{} = T{}.{}".format( + ret, + table["table_names_original"][node], + table_alias_dict[node], + table_alias_dict[prev_table], + table["column_names_original"][acol][1], + table_alias_dict[node], + table["column_names_original"][bcol][1]) + prev_table = node + except BaseException: + traceback.print_exc() + print("db:{}".format(table["db_id"])) + # print(table["db_id"]) + return table_alias_dict, ret + # if len(candidate_tables) != len(table_alias_dict): + # print("error in generate from clause!!!!!") + return table_alias_dict, ret + + def gen_sql(self, sql, table): + select_clause = "" + from_clause = "" + groupby_clause = "" + orderby_clause = "" + having_clause = "" + where_clause = "" + nested_clause = "" + cols = {} + candidate_tables = set() + nested_sql = {} + nested_label = "" + parent_sql = sql + # if "sql" in sql: + # sql = sql["sql"] + if "nested_label" in sql: + nested_label = sql["nested_label"] + nested_sql = sql["nested_sql"] + sql = sql["sql"] + elif "sql" in sql: + sql = sql["sql"] + for key in sql: + if key not in KW_WITH_COL: + continue + for item in sql[key]: + if isinstance(item, tuple) and len(item) == 3: + if table["column_names"][item[2]][0] != -1: + candidate_tables.add(table["column_names"][item[2]][0]) + table_alias_dict, from_clause = self.gen_from(candidate_tables, table) + ret = [] + if "select" in sql: + select_clause = self.gen_select( + sql["select"], "select", table, table_alias_dict) + if len(select_clause) > 0: + ret.append(select_clause) + else: + print("select not found:{}".format(parent_sql)) + else: + print("select not found:{}".format(parent_sql)) + if len(from_clause) > 0: + ret.append(from_clause) + if "where" in sql: + where_clause = self.gen_where( + sql["where"], table, table_alias_dict) + if len(where_clause) > 0: + ret.append(where_clause) + if "groupBy" in sql: # DEBUG-ed order + groupby_clause = self.gen_group_by( + sql["groupBy"], "group by", table, table_alias_dict) + if len(groupby_clause) > 0: + ret.append(groupby_clause) + if "orderBy" in sql: + orderby_clause = self.gen_orderby( + sql["orderBy"], table, table_alias_dict) + if len(orderby_clause) > 0: + ret.append(orderby_clause) + if "having" in sql: + having_clause = self.gen_having( + sql["having"], table, table_alias_dict) + if len(having_clause) > 0: + ret.append(having_clause) + if len(nested_label) > 0: + nested_clause = "{} {}".format( + nested_label, self.gen_sql( + nested_sql, table)) + if len(nested_clause) > 0: + ret.append(nested_clause) + return " ".join(ret) + + def check_acc(self, pred_sql, gt_sql): + pass diff --git a/code/preprocess_data.py b/code/preprocess_data.py new file mode 100644 index 0000000000000000000000000000000000000000..8dbd9daf38ba632ed5102ad2d1a028dde028f7b7 --- /dev/null +++ b/code/preprocess_data.py @@ -0,0 +1,789 @@ +# _*_ coding: utf_8 _* + +import argparse +import codecs +import json +import os +import sys +from collections import defaultdict + +defaultencoding = 'utf-8' +if sys.getdefaultencoding() != defaultencoding: + reload(sys) + sys.setdefaultencoding(defaultencoding) + + +OLD_WHERE_OPS = ( + 'not', + 'between', + '=', + '>', + '<', + '>=', + '<=', + '!=', + 'in', + 'like', + 'is', + 'exists') +NEW_WHERE_OPS = ( + '=', + '>', + '<', + '>=', + '<=', + '!=', + 'like', + 'not in', + 'in', + 'between', + 'is') +NEW_WHERE_DICT = { + '=': 0, + '>': 1, + '<': 2, + '>=': 3, + '<=': 4, + '!=': 5, + 'like': 6, + 'not in': 7, + 'in': 8, + 'between': 9, + 'is': 10 +} +# SQL_OPS = ('none','intersect', 'union', 'except') +SQL_OPS = { + 'none': 0, + 'intersect': 1, + 'union': 2, + 'except': 3 +} +KW_DICT = { + 'where': 0, + 'groupBy': 1, + 'orderBy': 2 +} +ORDER_OPS = { + 'desc': 0, + 'asc': 1} +AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') + +COND_OPS = { + 'and': 0, + 'or': 1 +} + + +def convert_to_op_index(is_not, op): + op = OLD_WHERE_OPS[op] + if is_not and op == "in": + return 7 + try: + return NEW_WHERE_DICT[op] + except BaseException: + print("Unsupport op: {}".format(op)) + return -1 + + +def index_to_column_name(index, table): + column_name = table["column_names"][index][1] + table_index = table["column_names"][index][0] + table_name = table["table_names"][table_index] + return table_name, column_name, index + + +def get_label_cols(with_join, fk_dict, labels): + # list(set([l[1][i][0][2] for i in range(min(len(l[1]), 3))])) + cols = set() + ret = [] + for i in range(len(labels)): + cols.add(labels[i][0][2]) + if len(cols) > 3: + break + for col in cols: + # ret.append([col]) + if with_join and len(fk_dict[col]) > 0: + ret.append([col] + fk_dict[col]) + else: + ret.append(col) + return ret + + +class MultiSqlPredictor: + def __init__(self, question, sql, history): + self.sql = sql + self.question = question + self.history = history + self.keywords = ('intersect', 'except', 'union') + + def generate_output(self): + for key in self.sql: + if key in self.keywords and self.sql[key]: + return self.history + ['root'], key, self.sql[key] + return self.history + ['root'], 'none', self.sql + + +class KeyWordPredictor: + def __init__(self, question, sql, history): + self.sql = sql + self.question = question + self.history = history + self.keywords = ( + 'select', + 'where', + 'groupBy', + 'orderBy', + 'limit', + 'having') + + def generate_output(self): + sql_keywords = [] + for key in self.sql: + if key in self.keywords and self.sql[key]: + sql_keywords.append(key) + return self.history, [len(sql_keywords), sql_keywords], self.sql + + +class ColPredictor: + def __init__(self, question, sql, table, history, kw=None): + self.sql = sql + self.question = question + self.history = history + self.table = table + self.keywords = ('select', 'where', 'groupBy', 'orderBy', 'having') + self.kw = kw + + def generate_output(self): + ret = [] + candidate_keys = self.sql.keys() + if self.kw: + candidate_keys = [self.kw] + for key in candidate_keys: + if key in self.keywords and self.sql[key]: + cols = [] + sqls = [] + if key == 'groupBy': + sql_cols = self.sql[key] + for col in sql_cols: + cols.append( + (index_to_column_name( + col[1], self.table), col[2])) + sqls.append(col) + elif key == 'orderBy': + sql_cols = self.sql[key][1] + for col in sql_cols: + cols.append( + (index_to_column_name( + col[1][1], self.table), col[1][2])) + sqls.append(col) + elif key == 'select': + sql_cols = self.sql[key][1] + for col in sql_cols: + cols.append( + (index_to_column_name( + col[1][1][1], + self.table), + col[1][1][2])) + sqls.append(col) + elif key == 'where' or key == 'having': + sql_cols = self.sql[key] + for col in sql_cols: + if not isinstance(col, list): + continue + try: + cols.append( + (index_to_column_name( + col[2][1][1], + self.table), + col[2][1][2])) + except BaseException: + print( + "Key:{} Col:{} Question:{}".format( + key, col, self.question)) + sqls.append(col) + ret.append(( + self.history + [key], (len(cols), cols), sqls + )) + return ret + # ret.append(history+[key],) + + +class OpPredictor: + def __init__(self, question, sql, history): + self.sql = sql + self.question = question + self.history = history + # self.keywords = ('select', 'where', 'groupBy', 'orderBy', 'having') + + def generate_output(self): + return self.history, convert_to_op_index( + self.sql[0], self.sql[1]), (self.sql[3], self.sql[4]) + + +class AggPredictor: + def __init__(self, question, sql, history, kw=None): + self.sql = sql + self.question = question + self.history = history + self.kw = kw + + def generate_output(self): + label = -1 + if self.kw: + key = self.kw + else: + key = self.history[-2] + if key == 'select': + label = self.sql[0] + elif key == 'orderBy': + label = self.sql[1][0] + elif key == 'having': + label = self.sql[2][1][0] + return self.history, label + + +# class RootTemPredictor: +# def __init__(self, question, sql): +# self.sql = sql +# self.question = question +# self.keywords = ('intersect', 'except', 'union') +# +# def generate_output(self): +# for key in self.sql: +# if key in self.keywords: +# return ['ROOT'], key, self.sql[key] +# return ['ROOT'], 'none', self.sql + + +class DesAscPredictor: + def __init__(self, question, sql, table, history): + self.sql = sql + self.question = question + self.history = history + self.table = table + + def generate_output(self): + for key in self.sql: + if key == "orderBy" and self.sql[key]: + # self.history.append(key) + try: + col = self.sql[key][1][0][1][1] + except BaseException: + print("question:{} sql:{}".format(self.question, self.sql)) + # self.history.append(index_to_column_name(col, self.table)) + # self.history.append(self.sql[key][1][0][1][0]) + if self.sql[key][0] == "asc" and self.sql["limit"]: + label = 0 + elif self.sql[key][0] == "asc" and not self.sql["limit"]: + label = 1 + elif self.sql[key][0] == "desc" and self.sql["limit"]: + label = 2 + else: + label = 3 + return self.history + \ + [index_to_column_name(col, self.table), self.sql[key][1][0][1][0]], label + + +class AndOrPredictor: + def __init__(self, question, sql, table, history): + self.sql = sql + self.question = question + self.history = history + self.table = table + + def generate_output(self): + if 'where' in self.sql and self.sql['where'] and len( + self.sql['where']) > 1: + return self.history, COND_OPS[self.sql['where'][1]] + return self.history, -1 + + +def parser_item_with_long_history( + question_tokens, + sql, + table, + history, + dataset): + table_schema = [ + table["table_names"], + table["column_names"], + table["column_types"] + ] + stack = [("root", sql)] + with_join = False + fk_dict = defaultdict(list) + for fk in table["foreign_keys"]: + fk_dict[fk[0]].append(fk[1]) + fk_dict[fk[1]].append(fk[0]) + while len(stack) > 0: + node = stack.pop() + if node[0] == "root": + history, label, sql = MultiSqlPredictor( + question_tokens, node[1], history).generate_output() + dataset['multi_sql_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": history[:], + "label": SQL_OPS[label] + }) + history.append(label) + if label == "none": + stack.append((label, sql)) + else: + node[1][label] = None + stack.append((label, node[1], sql)) + # if label != "none": + # stack.append(("none",node[1])) + elif node[0] in ('intersect', 'except', 'union'): + stack.append(("root", node[1])) + stack.append(("root", node[2])) + elif node[0] == "none": + with_join = len(node[1]["from"]["table_units"]) > 1 + history, label, sql = KeyWordPredictor( + question_tokens, node[1], history).generate_output() + label_idxs = [] + for item in label[1]: + if item in KW_DICT: + label_idxs.append(KW_DICT[item]) + label_idxs.sort() + dataset['keyword_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": history[:], + "label": label_idxs + }) + if "having" in label[1]: + stack.append(("having", node[1])) + if "orderBy" in label[1]: + stack.append(("orderBy", node[1])) + if "groupBy" in label[1]: + if "having" in label[1]: + dataset['having_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": history[:], + "gt_col": node[1]["groupBy"][0][1], + "label": 1 + }) + else: + dataset['having_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": history[:], + "gt_col": node[1]["groupBy"][0][1], + "label": 0 + }) + stack.append(("groupBy", node[1])) + if "where" in label[1]: + stack.append(("where", node[1])) + if "select" in label[1]: + stack.append(("select", node[1])) + elif node[0] in ("select", "having", "orderBy"): + # if node[0] != "orderBy": + history.append(node[0]) + if node[0] == "orderBy": + orderby_ret = DesAscPredictor( + question_tokens, node[1], table, history).generate_output() + if orderby_ret: + dataset['des_asc_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": orderby_ret[0], + "gt_col": node[1]["orderBy"][1][0][1][1], + "label": orderby_ret[1] + }) + # history.append(orderby_ret[1]) + col_ret = ColPredictor( + question_tokens, + node[1], + table, + history, + node[0]).generate_output() + agg_col_dict = dict() + op_col_dict = dict() + for h, l, s in col_ret: + if l[0] == 0: + print("Warning: predicted 0 columns!") + continue + dataset['col_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": history[:], + "label": get_label_cols(with_join, fk_dict, l[1]) + }) + for col, sql_item in zip(l[1], s): + key = "{}{}{}".format(col[0][0], col[0][1], col[0][2]) + if key not in agg_col_dict: + agg_col_dict[key] = [(sql_item, col[0])] + else: + agg_col_dict[key].append((sql_item, col[0])) + if key not in op_col_dict: + op_col_dict[key] = [(sql_item, col[0])] + else: + op_col_dict[key].append((sql_item, col[0])) + for key in agg_col_dict: + stack.append( + ("col", node[0], agg_col_dict[key], op_col_dict[key])) + elif node[0] == "col": + history.append(node[2][0][1]) + if node[1] == "where": + stack.append(("op", node[2], "where")) + else: + labels = [] + for sql_item, col in node[2]: + _, label = AggPredictor( + question_tokens, sql_item, history, node[1]).generate_output() + if label - 1 >= 0: + labels.append(label - 1) + + # print(node[2][0][1][2]) + dataset['agg_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": history[:], + "gt_col": node[2][0][1][2], + "label": labels[:min(len(labels), 3)] + }) + if node[1] == "having": + stack.append(("op", node[2], "having")) + # if len(labels) == 0: + # history.append("none") + # else: + if len(labels) > 0: + history.append(AGG_OPS[labels[0] + 1]) + elif node[0] == "op": + # history.append(node[1][0][1]) + labels = [] + # if len(labels) > 2: + # print(question_tokens) + dataset['op_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": history[:], + "gt_col": node[1][0][1][2], + "label": labels + }) + + for sql_item, col in node[1]: + _, label, s = OpPredictor( + question_tokens, sql_item, history).generate_output() + if label != -1: + labels.append(label) + history.append(NEW_WHERE_OPS[label]) + if isinstance(s[0], dict): + stack.append(("root", s[0])) + # history.append("root") + dataset['root_tem_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": history[:], + "gt_col": node[1][0][1][2], + "label": 0 + }) + else: + dataset['root_tem_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": history[:], + "gt_col": node[1][0][1][2], + "label": 1 + }) + # history.append("terminal") + if len(labels) > 2: + print(question_tokens) + dataset['op_dataset'][-1]["label"] = labels + elif node[0] == "where": + history.append(node[0]) + hist, label = AndOrPredictor( + question_tokens, node[1], table, history).generate_output() + if label != -1: + dataset['andor_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": history[:], + "label": label + }) + col_ret = ColPredictor( + question_tokens, + node[1], + table, + history, + "where").generate_output() + op_col_dict = dict() + for h, l, s in col_ret: + if l[0] == 0: + print("Warning: predicted 0 columns!") + continue + dataset['col_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": history[:], + "label": get_label_cols(with_join, fk_dict, l[1]) + }) + for col, sql_item in zip(l[1], s): + key = "{}{}{}".format(col[0][0], col[0][1], col[0][2]) + if key not in op_col_dict: + op_col_dict[key] = [(sql_item, col[0])] + else: + op_col_dict[key].append((sql_item, col[0])) + for key in op_col_dict: + stack.append(("col", "where", op_col_dict[key])) + elif node[0] == "groupBy": + history.append(node[0]) + col_ret = ColPredictor( + question_tokens, + node[1], + table, + history, + node[0]).generate_output() + agg_col_dict = dict() + for h, l, s in col_ret: + if l[0] == 0: + print("Warning: predicted 0 columns!") + continue + dataset['col_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": history[:], + "label": get_label_cols(with_join, fk_dict, l[1]) + }) + for col, sql_item in zip(l[1], s): + key = "{}{}{}".format(col[0][0], col[0][1], col[0][2]) + if key not in agg_col_dict: + agg_col_dict[key] = [(sql_item, col[0])] + else: + agg_col_dict[key].append((sql_item, col[0])) + for key in agg_col_dict: + stack.append(("col", node[0], agg_col_dict[key])) + + +def parser_item(question_tokens, sql, table, history, dataset): + # try: + # question_tokens = item['question_toks'] + # except: + # print(item) + # sql = item['sql'] + table_schema = [ + table["table_names"], + table["column_names"], + table["column_types"] + ] + history, label, sql = MultiSqlPredictor( + question_tokens, sql, history).generate_output() + dataset['multi_sql_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": history[:], + "label": SQL_OPS[label] + }) + history.append(label) + history, label, sql = KeyWordPredictor( + question_tokens, sql, history).generate_output() + label_idxs = [] + for item in label[1]: + if item in KW_DICT: + label_idxs.append(KW_DICT[item]) + label_idxs.sort() + dataset['keyword_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": history[:], + "label": label_idxs + }) + hist, label = AndOrPredictor( + question_tokens, sql, table, history).generate_output() + if label != -1: + dataset['andor_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": hist[:] + ["where"], + "label": label + }) + orderby_ret = DesAscPredictor( + question_tokens, + sql, + table, + history).generate_output() + if orderby_ret: + dataset['des_asc_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": orderby_ret[0][:], + "label": orderby_ret[1] + }) + col_ret = ColPredictor( + question_tokens, + sql, + table, + history).generate_output() + agg_candidates = [] + op_candidates = [] + for h, l, s in col_ret: + if l[0] == 0: + print("Warning: predicted 0 columns!") + continue + dataset['col_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": h[:], + "label": list(set([l[1][i][0][2] for i in range(min(len(l[1]), 3))])) + }) + for col, sql_item in zip(l[1], s): + if h[-1] in ('where', 'having'): + op_candidates.append((h + [col[0]], sql_item)) + if h[-1] in ('select', 'orderBy', 'having'): + agg_candidates.append((h + [col[0]], sql_item)) + if h[-1] == "groupBy": + label = 0 + if sql["having"]: + label = 1 + dataset['having_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": h[:] + [col[0]], + "label": label + }) + + op_col_dict = dict() + for h, sql_item in op_candidates: + _, label, s = OpPredictor( + question_tokens, sql_item, h).generate_output() + if label == -1: + continue + key = "{}{}".format(h[-2], h[-1][2]) + label = NEW_WHERE_OPS[label] + if key in op_col_dict: + op_col_dict[key][1].append(label) + else: + op_col_dict[key] = [h[:], [label]] + # dataset['op_dataset'].append({ + # "question_tokens": question_tokens, + # "ts": table_schema, + # "history": h[:], + # "label": label + # }) + if isinstance(s[0], dict): + dataset['root_tem_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": h[:] + [label], + "label": 0 + }) + parser_item(question_tokens, s[0], table, h[:] + [label], dataset) + else: + dataset['root_tem_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": h[:] + [label], + "label": 1 + }) + for key in op_col_dict: + # if len(op_col_dict[key][1]) > 1: + # print("same col has mult op ") + dataset['op_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": op_col_dict[key][0], + "label": op_col_dict[key][1] + }) + agg_col_dict = dict() + for h, sql_item in agg_candidates: + _, label = AggPredictor(question_tokens, sql_item, h).generate_output() + if label != 5: + key = "{}{}".format(h[-2], h[-1][2]) + if key in agg_col_dict: + agg_col_dict[key][1].append(label) + else: + agg_col_dict[key] = [h[:], [label]] + for key in agg_col_dict: + # if 5 in agg_col_dict[key][1]: + # print("none in agg label!!!") + dataset['agg_dataset'].append({ + "question_tokens": question_tokens, + "ts": table_schema, + "history": agg_col_dict[key][0], + "label": agg_col_dict[key][1] + }) + + +def get_table_dict(table_data_path): + data = json.load(open(table_data_path)) + table = dict() + for item in data: + table[item["db_id"]] = item + return table + + +def parse_data(data, table_path, gen_data_path, history_option, part): + dataset = { + "multi_sql_dataset": [], + "keyword_dataset": [], + "col_dataset": [], + "op_dataset": [], + "agg_dataset": [], + "root_tem_dataset": [], + "des_asc_dataset": [], + "having_dataset": [], + "andor_dataset": [] + } + table_dict = get_table_dict(table_path) + for item in data: + if history_option == "full": + # parser_item(item["question_toks"], item["sql"], table_dict[item["db_id"]], [], dataset) + parser_item_with_long_history( + item["question_toks"], item["sql"], table_dict[item["db_id"]], [], dataset) + else: + parser_item(item["question_toks"], item["sql"], + table_dict[item["db_id"]], [], dataset) + print("\nfinished preprocess %s part" % (part)) + for key in dataset: + print("dataset:{} size:{}".format(key, len(dataset[key]))) + with open(os.path.join(gen_data_path, + "{}_{}_{}.json".format( + history_option, + part, + key)),"w", encoding='utf-8') as json_file: + json.dump( + dataset[key], + json_file, + indent=2, + ensure_ascii=False) + #print('done') + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '-s', + '--schema', + type=str, + default='char', + choices=[ + 'char', + 'word'], + help='char for char-based schema and word for word-based schema.') + + parser.add_argument( + '--history_option', + type=str, + default='full', + choices=[ + 'full', + 'part', + 'no'], + help='full, part, or no history') + + args = parser.parse_args() + + table_path = "./data/tables.json" + + schema = args.schema + history_option = args.history_option + train_dev_test = ['train', 'dev'] + for part in train_dev_test: + data_path = os.path.join('data', args.schema, part + '.json') + data = json.load(codecs.open(data_path, 'r', encoding='utf-8')) + gen_data_path = os.path.join('data', args.schema, 'generated_datasets') + print(gen_data_path) + if not os.path.exists(gen_data_path): + os.mkdir(gen_data_path) + parse_data(data, table_path, gen_data_path, history_option, part) diff --git a/code/requirements.txt b/code/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6fa2de44417ce85838334618e6430c5e1abc6a5d --- /dev/null +++ b/code/requirements.txt @@ -0,0 +1 @@ +nltk \ No newline at end of file diff --git a/code/test.py b/code/test.py new file mode 100644 index 0000000000000000000000000000000000000000..55893cdc52d57cf52312e834b1005bae86269e8c --- /dev/null +++ b/code/test.py @@ -0,0 +1,102 @@ +import torch +import argparse +from utils.util import * +from models.supermodel import SuperModel + +from config import global_config as cfg + +if __name__ == '__main__': + N_word = cfg.emb_size + N_col = cfg.col_emb_size + N_h = cfg.hidden_size + N_depth = cfg.num_layers + dropout = cfg.dropout + BATCH_SIZE = cfg.batch_size + learning_rate = cfg.learning_rate + train_emb = cfg.train_emb + + history_type = cfg.history_type + table_type = cfg.table_type + + parser = argparse.ArgumentParser() + parser.add_argument('--models', type=str, help='path to saved model') + parser.add_argument('--test_data_path', type=str) + parser.add_argument('--output_path', type=str) + parser.add_argument( + '--emb_path', + type=str, + default='', + help='embedding path, multi-lingual or monolingual') + parser.add_argument( + '--col_emb_path', + type=str, + default='', + help='column embedding path') + parser.add_argument( + '--toy', + action='store_true', + help='If set, use small data; used for fast debugging.') + args = parser.parse_args() + + use_hs = True + if history_type == "no": + use_hs = False + + # default to use GPU, but have to check if GPU exists + GPU = True + if not cfg.nogpu: + if torch.cuda.device_count() == 0: + GPU = False + + toy = args.toy + if toy: + USE_SMALL = True + else: + USE_SMALL = False + + data = json.load(open(args.test_data_path)) + + emb_path = args.emb_path + word_emb = load_emb(emb_path, load_used=train_emb, use_small=USE_SMALL) + col_emb_path = args.col_emb_path + col_emb = None + if col_emb_path != 'None': + col_emb = load_emb( + col_emb_path, + load_used=train_emb, + use_small=USE_SMALL) + print("Finished load word embedding") + + model = SuperModel( + word_emb, + col_emb, + N_word=N_word, + N_col = N_col, + N_h=N_h, + N_depth=N_depth, + dropout=dropout, + gpu=GPU, + trainable_emb=train_emb, + table_type=table_type, + use_hs=use_hs) + + print("Loading from modules...") + model.multi_sql.load_state_dict(torch.load( + "{}/multi_sql_models.dump".format(args.models))) + model.key_word.load_state_dict(torch.load( + "{}/keyword_models.dump".format(args.models))) + model.col.load_state_dict(torch.load( + "{}/col_models.dump".format(args.models))) + model.op.load_state_dict(torch.load( + "{}/op_models.dump".format(args.models))) + model.agg.load_state_dict(torch.load( + "{}/agg_models.dump".format(args.models))) + model.root_teminal.load_state_dict(torch.load( + "{}/root_tem_models.dump".format(args.models))) + model.des_asc.load_state_dict(torch.load( + "{}/des_asc_models.dump".format(args.models))) + model.having.load_state_dict(torch.load( + "{}/having_models.dump".format(args.models))) + + test_acc(model, BATCH_SIZE, data, args.output_path) + # test_exec_acc() diff --git a/code/test_gen.sh b/code/test_gen.sh new file mode 100644 index 0000000000000000000000000000000000000000..7a489d8f1093a04d559cde62f9073fd419e2f6ac --- /dev/null +++ b/code/test_gen.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +#export CUDA_VISIBLE_DEVICES=1 +schema="char" # char or word +#schema="word" + +toy="" + +embedding="multi" # multi for multi-lingual or mono for monolingual +#embedding="mono" + +emb_path="embedding/${schema}_emb.txt" +col_emb_path="embedding/glove.42B.300d.txt" +if [[ ${embedding} == "multi" ]]; then col_emb_path="None"; fi + +TEST_DATA="data/${schema}/dev.json" + +SAVE_PATH="data/${schema}/generated_datasets/saved_models_multi_2022-12-05-06:04:27" +python -u test.py \ + --test_data_path ${TEST_DATA} \ + --models ${SAVE_PATH} \ + --output_path ${SAVE_PATH}/test_result.txt \ + --emb_path ${emb_path} \ + --col_emb_path ${col_emb_path} \ + ${toy} \ + > "${SAVE_PATH}/test_result.out.txt" \ + 2>&1 & diff --git a/code/train.py b/code/train.py new file mode 100644 index 0000000000000000000000000000000000000000..538a71a6e2bba8e010d21c18a02e39b4a72696af --- /dev/null +++ b/code/train.py @@ -0,0 +1,273 @@ +import argparse +import datetime +import torch +from config import global_config as cfg +from utils.util import * +from utils.word_embedding import WordEmbedding +from models.agg_predictor import AggPredictor +from models.col_predictor import ColPredictor +from models.desasc_limit_predictor import DesAscLimitPredictor +from models.having_predictor import HavingPredictor +from models.keyword_predictor import KeyWordPredictor +from models.multisql_predictor import MultiSqlPredictor +from models.op_predictor import OpPredictor +from models.root_teminal_predictor import RootTeminalPredictor +from models.andor_predictor import AndOrPredictor +from tqdm import * + +TRAIN_COMPONENTS = ( + 'multi_sql', + 'keyword', + 'col', + 'op', + 'agg', + 'root_tem', + 'des_asc', + 'having', + 'andor') +SQL_TOK = ['', '', 'WHERE', 'AND', 'EQL', 'GT', 'LT', ''] + + +def lr_decay(optimizer, epoch, decay_rate, init_lr): + lr = init_lr * ((1 - decay_rate)**epoch) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + return optimizer + + +if __name__ == '__main__': + np.random.seed(100) + torch.manual_seed(100) + torch.cuda.manual_seed(100) + torch.backends.cudnn.enabled = False + torch.backends.cudnn.deterministic = True + + device = 'cuda:2' + + N_word = cfg.emb_size + N_col = cfg.col_emb_size + N_h = cfg.hidden_size + N_depth = cfg.num_layers + dropout = cfg.dropout + BATCH_SIZE = cfg.batch_size + learning_rate = cfg.learning_rate + train_emb = cfg.train_emb + epoch = cfg.epoch + + history_type = cfg.history_type + table_type = cfg.table_type + + use_hs = True + if history_type == "no": + use_hs = False + + # default to use GPU, but have to check if GPU exists + GPU = True + if not cfg.nogpu: + if torch.cuda.device_count() == 0: + GPU = False + + parser = argparse.ArgumentParser() + parser.add_argument( + '--data_root', + type=str, + default='', + help='root path for generated_data') + parser.add_argument( + '--save_dir', + type=str, + default='', + help='set model save directory.') + parser.add_argument( + '--train_component', + type=str, + default='', + help='set train components,available:[multi_sql,keyword,col,op,agg,root_tem,des_asc,having,andor].') + parser.add_argument( + '--emb_path', + type=str, + default='', + help='embedding path, multi-lingual or monolingual') + parser.add_argument( + '--col_emb_path', + type=str, + default='', + help='column embedding path') + parser.add_argument( + '--toy', + action='store_true', + help='If set, use small data; used for fast debugging.') + args = parser.parse_args() + + toy = args.toy + if toy: + USE_SMALL = True + else: + USE_SMALL = False + + data_root = args.data_root + save_dir = args.save_dir + + train_component = args.train_component + if train_component not in TRAIN_COMPONENTS: + print("Invalid train component") + exit(1) + train_data = load_dataset( + train_component, + "train", + history_type, + data_root) + dev_data = load_dataset( + train_component, + "dev", + history_type, + data_root) + + emb_path = args.emb_path + word_emb = load_emb(emb_path, load_used=train_emb, use_small=USE_SMALL) + col_emb_path = args.col_emb_path + col_emb = None + if col_emb_path != 'None': + col_emb = load_emb( + col_emb_path, + load_used=train_emb, + use_small=USE_SMALL) + print("Finished load word embedding") + + embed_layer = WordEmbedding( + word_emb, + N_word, + gpu=GPU, + SQL_TOK=SQL_TOK, + trainable=train_emb) + q_embed_layer = embed_layer + + if not col_emb: + N_col = None + else: + embed_layer = WordEmbedding( + col_emb, N_col, gpu=GPU, SQL_TOK=SQL_TOK, trainable=train_emb) + + model = None + if train_component == "multi_sql": + model = MultiSqlPredictor( + N_word=N_word, + N_col=N_col, + N_h=N_h, + N_depth=N_depth, + gpu=GPU, + dropout=dropout, + use_hs=use_hs) + elif train_component == "keyword": + model = KeyWordPredictor( + N_word=N_word, + N_col=N_col, + N_h=N_h, + N_depth=N_depth, + dropout=dropout, + gpu=GPU, + use_hs=use_hs) + elif train_component == "col": + model = ColPredictor( + N_word=N_word, + N_col=N_col, + N_h=N_h, + N_depth=N_depth, + dropout=dropout, + gpu=GPU, + use_hs=use_hs) + elif train_component == "op": + model = OpPredictor( + N_word=N_word, + N_col=N_col, + N_h=N_h, + N_depth=N_depth, + dropout=dropout, + gpu=GPU, + use_hs=use_hs) + elif train_component == "agg": + model = AggPredictor( + N_word=N_word, + N_col=N_col, + N_h=N_h, + N_depth=N_depth, + dropout=dropout, + gpu=GPU, + use_hs=use_hs) + elif train_component == "root_tem": + model = RootTeminalPredictor( + N_word=N_word, + N_col=N_col, + N_h=N_h, + N_depth=N_depth, + dropout=dropout, + gpu=GPU, + use_hs=use_hs) + elif train_component == "des_asc": + model = DesAscLimitPredictor( + N_word=N_word, + N_col=N_col, + N_h=N_h, + N_depth=N_depth, + dropout=dropout, + gpu=GPU, + use_hs=use_hs) + elif train_component == "having": + model = HavingPredictor( + N_word=N_word, + N_col=N_col, + N_h=N_h, + N_depth=N_depth, + dropout=dropout, + gpu=GPU, + use_hs=use_hs) + elif train_component == "andor": + model = AndOrPredictor( + N_word=N_word, + N_col=N_col, + N_h=N_h, + N_depth=N_depth, + dropout=dropout, + gpu=GPU, + use_hs=use_hs) + parameters = filter(lambda p: p.requires_grad, model.parameters()) + optimizer = torch.optim.Adam( + model.parameters(), + lr=learning_rate, + weight_decay=0) + print("Finished build model") + + print_flag = False + + model.to(device) + + print("Start training") + best_acc = 0.0 + for i in tqdm(range(epoch)): + print('Epoch %d @ %s' % (i + 1, datetime.datetime.now())) + print( + ' Loss = %s' % + epoch_train( + model, + optimizer, + BATCH_SIZE, + train_component, + train_data, + table_type, + q_embed_layer, + embed_layer)) + acc = epoch_acc( + model, + BATCH_SIZE, + train_component, + dev_data, + table_type, + q_embed_layer, + embed_layer) + if acc > best_acc: + best_acc = acc + print("Save model...") + torch.save( + model.state_dict(), + save_dir + "/{}_models.dump".format( + train_component)) diff --git a/code/train_all.sh b/code/train_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..e1bb04cde3f4b9af0bf5a7149a6f57f79e12d55a --- /dev/null +++ b/code/train_all.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +DATE=`date '+%Y-%m-%d-%H:%M:%S'` + +schema="char" # char or word +#schema="word" + +toy="" + +embedding="multi" # multi for multi-lingual or mono for monolingual +#embedding="mono" + +data_root="data/${schema}/generated_datasets" +emb_path="embedding/${schema}_emb.txt" +col_emb_path="embedding/glove.42B.300d.txt" +if [[ ${embedding} == "multi" ]]; then col_emb_path="None"; fi +save_dir="${data_root}/saved_models_${embedding}_${DATE}" +log_dir="${save_dir}/train_log" +mkdir -p ${save_dir} +mkdir -p ${log_dir} + +export CUDA_VISIBLE_DEVICES=0 +for module in col +do + nohup python -u train.py \ + --data_root ${data_root} \ + --save_dir ${save_dir} \ + --train_component ${module} \ + --emb_path ${emb_path} \ + --col_emb_path ${col_emb_path} \ + ${toy} \ + > "${log_dir}/train_${module}_${DATE}.txt" \ + 2>&1 & +done + +export CUDA_VISIBLE_DEVICES=1 +for module in keyword op des_asc multi_sql agg having root_tem andor +do + nohup python -u train.py \ + --data_root ${data_root} \ + --save_dir ${save_dir} \ + --train_component ${module} \ + --emb_path ${emb_path} \ + --col_emb_path ${col_emb_path} \ + ${toy} \ + > "${log_dir}/train_${module}_${DATE}.txt" \ + 2>&1 & +done + diff --git a/code/utils/__init__.py b/code/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/code/utils/process_sql.py b/code/utils/process_sql.py new file mode 100644 index 0000000000000000000000000000000000000000..0ce1f6c350f7950fcb0ac3b9016a0b9de85ce19a --- /dev/null +++ b/code/utils/process_sql.py @@ -0,0 +1,564 @@ +################################ +# Assumptions: +# 1. sql is correct +# 2. only table name has alias +# 3. only one intersect/union/except +# +# val: number(float)/string(str)/sql(dict) +# col_unit: (agg_id, col_id, isDistinct(bool)) +# val_unit: (unit_op, col_unit1, col_unit2) +# table_unit: (table_type, col_unit/sql) +# cond_unit: (not_op, op_id, val_unit, val1, val2) +# condition: [cond_unit1, 'and'/'or', cond_unit2, ...] +# sql { +# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) +# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} +# 'where': condition +# 'groupBy': [col_unit1, col_unit2, ...] +# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) +# 'having': condition +# 'limit': None/limit value +# 'intersect': None/sql +# 'except': None/sql +# 'union': None/sql +# } +################################ + +import json +import sqlite3 +from nltk import word_tokenize +import pdb + +CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') +JOIN_KEYWORDS = ('join', 'on', 'as') + +WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') +UNIT_OPS = ('none', '-', '+', "*", '/') +AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') +TABLE_TYPE = { + 'sql': "sql", + 'table_unit': "table_unit", +} + +COND_OPS = ('and', 'or') +SQL_OPS = ('intersect', 'union', 'except') +ORDER_OPS = ('desc', 'asc') + + + +class Schema: + """ + Simple schema which maps table&column to a unique identifier + """ + def __init__(self, schema): + self._schema = schema + self._idMap = self._map(self._schema) + + @property + def schema(self): + return self._schema + + @property + def idMap(self): + return self._idMap + + def _map(self, schema): + idMap = {'*': "__all__"} + id = 1 + for key, vals in schema.items(): + for val in vals: + idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__" + id += 1 + + for key in schema: + idMap[key.lower()] = "__" + key.lower() + "__" + id += 1 + + return idMap + + +def get_schema(db): + """ + Get database's schema, which is a dict with table name as key + and list of column names as value + :param db: database path + :return: schema dict + """ + + schema = {} + conn = sqlite3.connect(db) + cursor = conn.cursor() + + # fetch table names + cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") + tables = [str(table[0].lower()) for table in cursor.fetchall()] + + # fetch table info + for table in tables: + cursor.execute("PRAGMA table_info({})".format(table)) + schema[table] = [str(col[1].lower()) for col in cursor.fetchall()] + + return schema + + +def get_schema_from_json(fpath): + with open(fpath) as f: + data = json.load(f) + + schema = {} + for entry in data: + table = str(entry['table'].lower()) + cols = [str(col['column_name'].lower()) for col in entry['col_data']] + schema[table] = cols + + return schema + + +def tokenize(string): + string = str(string) + string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem?? + quote_idxs = [idx for idx, char in enumerate(string) if char == '"'] + assert len(quote_idxs) % 2 == 0, "Unexpected quote" + + # keep string value as token + vals = {} + for i in range(len(quote_idxs)-1, -1, -2): + qidx1 = quote_idxs[i-1] + qidx2 = quote_idxs[i] + val = string[qidx1: qidx2+1] + key = "__val_{}_{}__".format(qidx1, qidx2) + string = string[:qidx1] + key + string[qidx2+1:] + vals[key] = val + + toks = [word.lower() for word in word_tokenize(string)] + # replace with string value token + for i in range(len(toks)): + if toks[i] in vals: + toks[i] = vals[toks[i]] + + # find if there exists !=, >=, <= + eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="] + eq_idxs.reverse() + prefix = ('!', '>', '<') + for eq_idx in eq_idxs: + pre_tok = toks[eq_idx-1] + if pre_tok in prefix: + toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ] + + return toks + + +def scan_alias(toks): + """Scan the index of 'as' and build the map for all alias""" + as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as'] + alias = {} + for idx in as_idxs: + alias[toks[idx+1]] = toks[idx-1] + return alias + + +def get_tables_with_alias(schema, toks): + tables = scan_alias(toks) + for key in schema: + assert key not in tables, "Alias {} has the same name in table".format(key) + tables[key] = key + return tables + + +def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): + """ + :returns next idx, column id + """ + tok = toks[start_idx] + if tok == "*": + return start_idx + 1, schema.idMap[tok] + + if '.' in tok: # if token is a composite + alias, col = tok.split('.') + key = tables_with_alias[alias] + "." + col + return start_idx+1, schema.idMap[key] + + assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty" + + for alias in default_tables: + table = tables_with_alias[alias] + if tok in schema.schema[table]: + key = table + "." + tok + return start_idx+1, schema.idMap[key] + + assert False, "Error col: {}".format(tok) + + +def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): + """ + :returns next idx, (agg_op id, col_id) + """ + idx = start_idx + len_ = len(toks) + isBlock = False + isDistinct = False + if toks[idx] == '(': + isBlock = True + idx += 1 + + if toks[idx] in AGG_OPS: + agg_id = AGG_OPS.index(toks[idx]) + idx += 1 + assert idx < len_ and toks[idx] == '(' + idx += 1 + if toks[idx] == "distinct": + idx += 1 + isDistinct = True + idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) + assert idx < len_ and toks[idx] == ')' + idx += 1 + return idx, (agg_id, col_id, isDistinct) + + if toks[idx] == "distinct": + idx += 1 + isDistinct = True + agg_id = AGG_OPS.index("none") + idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) + + if isBlock: + assert toks[idx] == ')' + idx += 1 # skip ')' + + return idx, (agg_id, col_id, isDistinct) + + +def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): + idx = start_idx + len_ = len(toks) + isBlock = False + if toks[idx] == '(': + isBlock = True + idx += 1 + + col_unit1 = None + col_unit2 = None + unit_op = UNIT_OPS.index('none') + + idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) + if idx < len_ and toks[idx] in UNIT_OPS: + unit_op = UNIT_OPS.index(toks[idx]) + idx += 1 + idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) + + if isBlock: + assert toks[idx] == ')' + idx += 1 # skip ')' + + return idx, (unit_op, col_unit1, col_unit2) + + +def parse_table_unit(toks, start_idx, tables_with_alias, schema): + """ + :returns next idx, table id, table name + """ + idx = start_idx + len_ = len(toks) + #pdb.set_trace() + key = tables_with_alias[toks[idx]] + + if idx + 1 < len_ and toks[idx+1] == "as": + idx += 3 + else: + idx += 1 + + return idx, schema.idMap[key], key + + +def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None): + idx = start_idx + len_ = len(toks) + + isBlock = False + if toks[idx] == '(': + isBlock = True + idx += 1 + + if toks[idx] == 'select': + idx, val = parse_sql(toks, idx, tables_with_alias, schema) + elif "\"" in toks[idx]: # token is a string value + val = toks[idx] + idx += 1 + else: + try: + val = float(toks[idx]) + idx += 1 + except: + end_idx = idx + while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\ + and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS: + end_idx += 1 + + idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables) + idx = end_idx + + if isBlock: + assert toks[idx] == ')' + idx += 1 + + return idx, val + + +def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None): + idx = start_idx + len_ = len(toks) + conds = [] + + while idx < len_: + idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) + not_op = False + if toks[idx] == 'not': + not_op = True + idx += 1 + + assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) + op_id = WHERE_OPS.index(toks[idx]) + idx += 1 + val1 = val2 = None + if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values + idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) + assert toks[idx] == 'and' + idx += 1 + idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables) + else: # normal case: single value + idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) + val2 = None + + conds.append((not_op, op_id, val_unit, val1, val2)) + + if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS): + break + + if idx < len_ and toks[idx] in COND_OPS: + conds.append(toks[idx]) + idx += 1 # skip and/or + + return idx, conds + + +def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None): + idx = start_idx + len_ = len(toks) + + assert toks[idx] == 'select', "'select' not found" + idx += 1 + isDistinct = False + if idx < len_ and toks[idx] == 'distinct': + idx += 1 + isDistinct = True + val_units = [] + + while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS: + agg_id = AGG_OPS.index("none") + if toks[idx] in AGG_OPS: + agg_id = AGG_OPS.index(toks[idx]) + idx += 1 + idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) + val_units.append((agg_id, val_unit)) + if idx < len_ and toks[idx] == ',': + idx += 1 # skip ',' + + return idx, (isDistinct, val_units) + + +def parse_from(toks, start_idx, tables_with_alias, schema): + """ + Assume in the from clause, all table units are combined with join + """ + assert 'from' in toks[start_idx:], "'from' not found" + + len_ = len(toks) + idx = toks.index('from', start_idx) + 1 + default_tables = [] + table_units = [] + conds = [] + + while idx < len_: + isBlock = False + if toks[idx] == '(': + isBlock = True + idx += 1 + + if toks[idx] == 'select': + idx, sql = parse_sql(toks, idx, tables_with_alias, schema) + table_units.append((TABLE_TYPE['sql'], sql)) + else: + if idx < len_ and toks[idx] == 'join': + idx += 1 # skip join + idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema) + table_units.append((TABLE_TYPE['table_unit'],table_unit)) + default_tables.append(table_name) + if idx < len_ and toks[idx] == "on": + idx += 1 # skip on + idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) + if len(conds) > 0: + conds.append('and') + conds.extend(this_conds) + + if isBlock: + assert toks[idx] == ')' + idx += 1 + if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): + break + + return idx, table_units, conds, default_tables + + +def parse_where(toks, start_idx, tables_with_alias, schema, default_tables): + idx = start_idx + len_ = len(toks) + + if idx >= len_ or toks[idx] != 'where': + return idx, [] + + idx += 1 + idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) + return idx, conds + + +def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables): + idx = start_idx + len_ = len(toks) + col_units = [] + + if idx >= len_ or toks[idx] != 'group': + return idx, col_units + + idx += 1 + assert toks[idx] == 'by' + idx += 1 + + while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): + idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) + col_units.append(col_unit) + if idx < len_ and toks[idx] == ',': + idx += 1 # skip ',' + else: + break + + return idx, col_units + + +def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): + idx = start_idx + len_ = len(toks) + val_units = [] + order_type = 'asc' # default type is 'asc' + + if idx >= len_ or toks[idx] != 'order': + return idx, val_units + + idx += 1 + assert toks[idx] == 'by' + idx += 1 + + while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): + idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) + val_units.append(val_unit) + if idx < len_ and toks[idx] in ORDER_OPS: + order_type = toks[idx] + idx += 1 + if idx < len_ and toks[idx] == ',': + idx += 1 # skip ',' + else: + break + + return idx, (order_type, val_units) + + +def parse_having(toks, start_idx, tables_with_alias, schema, default_tables): + idx = start_idx + len_ = len(toks) + + if idx >= len_ or toks[idx] != 'having': + return idx, [] + + idx += 1 + idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) + return idx, conds + + +def parse_limit(toks, start_idx): + idx = start_idx + len_ = len(toks) + + if idx < len_ and toks[idx] == 'limit': + idx += 2 + return idx, int(toks[idx-1]) + + return idx, None + + +def parse_sql(toks, start_idx, tables_with_alias, schema): + isBlock = False # indicate whether this is a block of sql/sub-sql + len_ = len(toks) + idx = start_idx + + sql = {} + if toks[idx] == '(': + isBlock = True + idx += 1 + + # parse from clause in order to get default tables + from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema) + sql['from'] = {'table_units': table_units, 'conds': conds} + # select clause + _, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables) + idx = from_end_idx + sql['select'] = select_col_units + # where clause + idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables) + sql['where'] = where_conds + # group by clause + idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables) + sql['groupBy'] = group_col_units + # having clause + idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables) + sql['having'] = having_conds + # order by clause + idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables) + sql['orderBy'] = order_col_units + # limit clause + idx, limit_val = parse_limit(toks, idx) + sql['limit'] = limit_val + + idx = skip_semicolon(toks, idx) + if isBlock: + assert toks[idx] == ')' + idx += 1 # skip ')' + idx = skip_semicolon(toks, idx) + + # intersect/union/except clause + for op in SQL_OPS: # initialize IUE + sql[op] = None + if idx < len_ and toks[idx] in SQL_OPS: + sql_op = toks[idx] + idx += 1 + idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema) + sql[sql_op] = IUE_sql + return idx, sql + + +def load_data(fpath): + with open(fpath) as f: + data = json.load(f) + return data + + +def get_sql(schema, query): + toks = tokenize(query) + tables_with_alias = get_tables_with_alias(schema.schema, toks) + _, sql = parse_sql(toks, 0, tables_with_alias, schema) + + return sql + + +def skip_semicolon(toks, start_idx): + idx = start_idx + while idx < len(toks) and toks[idx] == ";": + idx += 1 + return idx diff --git a/code/utils/util.py b/code/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..8ad4932d155b0b90d05fab239b46c96d49700b4d --- /dev/null +++ b/code/utils/util.py @@ -0,0 +1,523 @@ +import re +import io +import json +import numpy as np +import os +import signal + + +def get_table_dict(table_data_path): + data = json.load(open(table_data_path)) + table = dict() + for item in data: + table[item["db_id"]] = item + return table + + +def load_dataset(component, train_dev, history, root): + return json.load( + open("{}/{}_{}_{}_dataset.json".format(root, history, train_dev, component))) + + +def to_batch_seq(data, idxes, st, ed): + q_seq = [] + history = [] + label = [] + for i in range(st, ed): + q_seq.append(data[idxes[i]]['question_tokens']) + history.append(data[idxes[i]]["history"]) + label.append(data[idxes[i]]["label"]) + return q_seq, history, label + +# CHANGED + + +def to_batch_tables(data, idxes, st, ed, table_type): + # col_lens = [] + col_seq = [] + for i in range(st, ed): + ts = data[idxes[i]]["ts"] + tname_toks = [x.split(" ") for x in ts[0]] + col_type = ts[2] + cols = [x.split(" ") for xid, x in ts[1]] + tab_seq = [xid for xid, x in ts[1]] + cols_add = [] + for tid, col, ct in zip(tab_seq, cols, col_type): + col_one = [ct] + if tid == -1: + tabn = ["all"] + else: + if table_type == "no": + tabn = [] + else: + tabn = tname_toks[tid] + for t in tabn: + if t not in col: + col_one.append(t) + col_one.extend(col) + cols_add.append(col_one) + col_seq.append(cols_add) + + return col_seq + +# used for training in train.py + + +def epoch_train( + model, + optimizer, + batch_size, + component, + data, + table_type, + q_embed_layer, + embed_layer): + model.train() + perm = np.random.permutation(len(data)) + cum_loss = 0.0 + st = 0 + + while st < len(data): + ed = st + batch_size if st + batch_size < len(perm) else len(perm) + q_seq, history, label = to_batch_seq(data, perm, st, ed) + q_emb_var, q_len = q_embed_layer.gen_x_q_batch(q_seq) + hs_emb_var, hs_len = embed_layer.gen_x_history_batch(history) + score = 0.0 + loss = 0.0 + if component == "multi_sql": + mkw_emb_var = embed_layer.gen_word_list_embedding( + ["none", "except", "intersect", "union"], (ed - st)) + mkw_len = np.full(q_len.shape, 4, dtype=np.int64) + # print("mkw_emb:{}".format(mkw_emb_var.size())) + score = model.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + mkw_emb_var=mkw_emb_var, + mkw_len=mkw_len) + elif component == "keyword": + # where group by order by + # [[0,1,2]] + kw_emb_var = embed_layer.gen_word_list_embedding( + ["where", "group by", "order by"], (ed - st)) + mkw_len = np.full(q_len.shape, 3, dtype=np.int64) + score = model.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + kw_emb_var=kw_emb_var, + kw_len=mkw_len) + elif component == "col": + # col word embedding + # [[0,1,3]] + col_seq = to_batch_tables(data, perm, st, ed, table_type) + col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch( + col_seq) + score = model.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + col_emb_var, + col_len, + col_name_len) + + elif component == "op": + # B*index + gt_col = np.zeros(q_len.shape, dtype=np.int64) + index = 0 + for i in range(st, ed): + # print(i) + gt_col[index] = data[perm[i]]["gt_col"] + index += 1 + + col_seq = to_batch_tables(data, perm, st, ed, table_type) + col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch( + col_seq) + score = model.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + col_emb_var, + col_len, + col_name_len, + gt_col=gt_col) + + elif component == "agg": + # [[0,1,3]] + col_seq = to_batch_tables(data, perm, st, ed, table_type) + col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch( + col_seq) + gt_col = np.zeros(q_len.shape, dtype=np.int64) + # print(ed) + index = 0 + for i in range(st, ed): + # print(i) + gt_col[index] = data[perm[i]]["gt_col"] + index += 1 + score = model.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + col_emb_var, + col_len, + col_name_len, + gt_col=gt_col) + + elif component == "root_tem": + # B*0/1 + col_seq = to_batch_tables(data, perm, st, ed, table_type) + col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch( + col_seq) + gt_col = np.zeros(q_len.shape, dtype=np.int64) + # print(ed) + index = 0 + for i in range(st, ed): + # print(data[perm[i]]["history"]) + gt_col[index] = data[perm[i]]["gt_col"] + index += 1 + score = model.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + col_emb_var, + col_len, + col_name_len, + gt_col=gt_col) + + elif component == "des_asc": + # B*0/1 + col_seq = to_batch_tables(data, perm, st, ed, table_type) + col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch( + col_seq) + gt_col = np.zeros(q_len.shape, dtype=np.int64) + # print(ed) + index = 0 + for i in range(st, ed): + # print(i) + gt_col[index] = data[perm[i]]["gt_col"] + index += 1 + score = model.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + col_emb_var, + col_len, + col_name_len, + gt_col=gt_col) + + elif component == 'having': + col_seq = to_batch_tables(data, perm, st, ed, table_type) + col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch( + col_seq) + gt_col = np.zeros(q_len.shape, dtype=np.int64) + # print(ed) + index = 0 + for i in range(st, ed): + # print(i) + gt_col[index] = data[perm[i]]["gt_col"] + index += 1 + score = model.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + col_emb_var, + col_len, + col_name_len, + gt_col=gt_col) + + elif component == "andor": + score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len) + # score = model.forward(q_seq, col_seq, col_num, pred_entry, + # gt_where=gt_where_seq, gt_cond=gt_cond_seq, gt_sel=gt_sel_seq) + # print("label {}".format(label)) + loss = model.loss(score, label) + # print("loss {}".format(loss.data.cpu().numpy())) + cum_loss += loss.cpu().data.numpy() * (ed - st) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + st = ed + + return cum_loss / len(data) + +# used for development evaluation in train.py + + +def epoch_acc( + model, + batch_size, + component, + data, + table_type, + q_embed_layer, + embed_layer, + error_print=False, + train_flag=False): + model.eval() + perm = list(range(len(data))) + st = 0 + total_number_error = 0.0 + total_p_error = 0.0 + total_error = 0.0 + print("dev data size {}".format(len(data))) + while st < len(data): + ed = st + batch_size if st + batch_size < len(perm) else len(perm) + + q_seq, history, label = to_batch_seq(data, perm, st, ed) + q_emb_var, q_len = q_embed_layer.gen_x_q_batch(q_seq) + hs_emb_var, hs_len = embed_layer.gen_x_history_batch(history) + score = 0.0 + + if component == "multi_sql": + # none, except, intersect,union + # truth B*index(0,1,2,3) + # print("hs_len:{}".format(hs_len)) + # print("q_emb_shape:{} hs_emb_shape:{}".format(q_emb_var.size(), hs_emb_var.size())) + mkw_emb_var = embed_layer.gen_word_list_embedding( + ["none", "except", "intersect", "union"], (ed - st)) + mkw_len = np.full(q_len.shape, 4, dtype=np.int64) + # print("mkw_emb:{}".format(mkw_emb_var.size())) + score = model.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + mkw_emb_var=mkw_emb_var, + mkw_len=mkw_len) + elif component == "keyword": + # where group by order by + # [[0,1,2]] + kw_emb_var = embed_layer.gen_word_list_embedding( + ["where", "group by", "order by"], (ed - st)) + mkw_len = np.full(q_len.shape, 3, dtype=np.int64) + score = model.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + kw_emb_var=kw_emb_var, + kw_len=mkw_len) + elif component == "col": + # col word embedding + # [[0,1,3]] + col_seq = to_batch_tables(data, perm, st, ed, table_type) + col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch( + col_seq) + score = model.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + col_emb_var, + col_len, + col_name_len) + elif component == "op": + # B*index + col_seq = to_batch_tables(data, perm, st, ed, table_type) + col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch( + col_seq) + gt_col = np.zeros(q_len.shape, dtype=np.int64) + # print(ed) + index = 0 + for i in range(st, ed): + # print(i) + gt_col[index] = data[perm[i]]["gt_col"] + index += 1 + score = model.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + col_emb_var, + col_len, + col_name_len, + gt_col=gt_col) + + elif component == "agg": + # [[0,1,3]] + col_seq = to_batch_tables(data, perm, st, ed, table_type) + col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch( + col_seq) + gt_col = np.zeros(q_len.shape, dtype=np.int64) + # print(ed) + index = 0 + for i in range(st, ed): + # print(i) + gt_col[index] = data[perm[i]]["gt_col"] + index += 1 + + score = model.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + col_emb_var, + col_len, + col_name_len, + gt_col=gt_col) + + elif component == "root_tem": + # B*0/1 + col_seq = to_batch_tables(data, perm, st, ed, table_type) + col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch( + col_seq) + gt_col = np.zeros(q_len.shape, dtype=np.int64) + # print(ed) + index = 0 + for i in range(st, ed): + # print(data[perm[i]]["history"]) + gt_col[index] = data[perm[i]]["gt_col"] + index += 1 + score = model.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + col_emb_var, + col_len, + col_name_len, + gt_col=gt_col) + + elif component == "des_asc": + # B*0/1 + col_seq = to_batch_tables(data, perm, st, ed, table_type) + col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch( + col_seq) + gt_col = np.zeros(q_len.shape, dtype=np.int64) + # print(ed) + index = 0 + for i in range(st, ed): + # print(i) + gt_col[index] = data[perm[i]]["gt_col"] + index += 1 + score = model.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + col_emb_var, + col_len, + col_name_len, + gt_col=gt_col) + + elif component == 'having': + col_seq = to_batch_tables(data, perm, st, ed, table_type) + col_emb_var, col_name_len, col_len = embed_layer.gen_col_batch( + col_seq) + gt_col = np.zeros(q_len.shape, dtype=np.int64) + # print(ed) + index = 0 + for i in range(st, ed): + # print(i) + gt_col[index] = data[perm[i]]["gt_col"] + index += 1 + score = model.forward( + q_emb_var, + q_len, + hs_emb_var, + hs_len, + col_emb_var, + col_len, + col_name_len, + gt_col=gt_col) + + elif component == "andor": + score = model.forward(q_emb_var, q_len, hs_emb_var, hs_len) + # print("label {}".format(label)) + if component in ("agg", "col", "keyword", "op"): + num_err, p_err, err = model.check_acc(score, label) + total_number_error += num_err + total_p_error += p_err + total_error += err + else: + err = model.check_acc(score, label) + total_error += err + st = ed + + if component in ("agg", "col", "keyword", "op"): + print( + "Dev {} acc number predict acc:{} partial acc: {} total acc: {}".format( + component, + 1 - + total_number_error * + 1.0 / + len(data), + 1 - + total_p_error * + 1.0 / + len(data), + 1 - + total_error * + 1.0 / + len(data))) + return 1 - total_error * 1.0 / len(data) + else: + print( + "Dev {} acc total acc: {}".format( + component, + 1 - + total_error * + 1.0 / + len(data))) + return 1 - total_error * 1.0 / len(data) + + +def timeout_handler(num, stack): + print("Received SIGALRM") + raise Exception("Timeout") + +## used in test.py + + +def test_acc(model, batch_size, data, output_path): + table_dict = get_table_dict("./data/tables.json") + f = open(output_path, "w") + for item in data[:]: + db_id = item["db_id"] + if db_id not in table_dict: + print("Error %s not in table_dict" % db_id) + # signal.signal(signal.SIGALRM, timeout_handler) + # signal.alarm(2) # set timer to prevent infinite recursion in SQL + # generation + sql = model.forward([item["question_toks"]] * + batch_size, [], table_dict[db_id]) + if sql is not None: + print(sql) + sql = model.gen_sql(sql, table_dict[db_id]) + else: + sql = "select a from b" + print(sql) + print("") + f.write("{}\n".format(sql)) + f.close() + + +def load_emb(file_name, load_used=False, use_small=False): + if not load_used: + print('Loading word embedding from %s' % file_name) + ret = {} + with open(file_name) as inf: + for idx, line in enumerate(inf): + if (use_small and idx >= 5000): + break + info = line.strip().split(' ') + info[0] = info[0].encode('utf-8').decode('utf-8') + if info[0].lower() not in ret: + ret[info[0]] = np.array(list(map(lambda x: float(x), info[1:]))) + return ret + else: + print('Load used word embedding') + with open('../alt/glove/word2idx.json') as inf: + w2i = json.load(inf) + with open('../alt/glove/usedwordemb.npy') as inf: + word_emb_val = np.load(inf) + return w2i, word_emb_val diff --git a/code/utils/word_embedding.py b/code/utils/word_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..aa5fe5a4a3c3d2e5c394cf774bc6e5fef9922712 --- /dev/null +++ b/code/utils/word_embedding.py @@ -0,0 +1,261 @@ +import json +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np + +AGG_OPS = ('none', 'maximum', 'minimum', 'count', 'sum', 'average') + + +class WordEmbedding(nn.Module): + def __init__(self, word_emb, N_word, gpu, SQL_TOK, + trainable=False): + super(WordEmbedding, self).__init__() + self.trainable = trainable + self.N_word = N_word + self.gpu = gpu + self.SQL_TOK = SQL_TOK + + if trainable: + print("Using trainable embedding") + self.w2i, word_emb_val = word_emb + # tranable when using pretrained model, init embedding weights + # using prev embedding + self.embedding = nn.Embedding(len(self.w2i), N_word) + self.embedding.weight = nn.Parameter( + torch.from_numpy(word_emb_val.astype(np.float32))) + else: + # else use word2vec or glove + self.word_emb = word_emb + print("Using fixed embedding") + + def gen_x_q_batch(self, q): + B = len(q) + val_embs = [] + val_len = np.zeros(B, dtype=np.int64) + for i, one_q in enumerate(q): + q_val = [] + for ws in one_q: + q_val.append( + self.word_emb.get( + ws, + np.zeros( + self.N_word, + dtype=np.float32))) + + val_embs.append([np.zeros(self.N_word, + dtype=np.float32)] + q_val + [np.zeros(self.N_word, + dtype=np.float32)]) # and + val_len[i] = 1 + len(q_val) + 1 + max_len = max(val_len) + + val_emb_array = np.zeros((B, max_len, self.N_word), dtype=np.float32) + for i in range(B): + for t in range(len(val_embs[i])): + val_emb_array[i, t, :] = val_embs[i][t] + val_inp = torch.from_numpy(val_emb_array) + if self.gpu: + val_inp = val_inp.cuda() + val_inp_var = Variable(val_inp) + + return val_inp_var, val_len + + def gen_x_history_batch(self, history): + B = len(history) + val_embs = [] + val_len = np.zeros(B, dtype=np.int64) + for i, one_history in enumerate(history): + history_val = [] + for item in one_history: + # col + if isinstance(item, list) or isinstance(item, tuple): + emb_list = [] + ws = item[0].split() + item[1].split() + ws_len = len(ws) + for w in ws: + emb_list.append( + self.word_emb.get( + w, np.zeros( + self.N_word, dtype=np.float32))) + if ws_len == 0: + raise Exception("word list should not be empty!") + elif ws_len == 1: + history_val.append(emb_list[0]) + else: + history_val.append(sum(emb_list) / float(ws_len)) + # ROOT + elif isinstance(item, str): + if item == "ROOT": + item = "root" + elif item == "asc": + item = "ascending" + elif item == "desc": + item == "descending" + if item in ( + "none", + "select", + "from", + "where", + "having", + "limit", + "intersect", + "except", + "union", + 'not', + 'between', + '=', + '>', + '<', + 'in', + 'like', + 'is', + 'exists', + 'root', + 'ascending', + 'descending'): + history_val.append( + self.word_emb.get( + item, np.zeros( + self.N_word, dtype=np.float32))) + elif item == "orderBy": + history_val.append( + (self.word_emb.get( + "order", + np.zeros( + self.N_word, + dtype=np.float32)) + + self.word_emb.get( + "by", + np.zeros( + self.N_word, + dtype=np.float32))) / + 2) + elif item == "groupBy": + history_val.append( + (self.word_emb.get( + "group", + np.zeros( + self.N_word, + dtype=np.float32)) + + self.word_emb.get( + "by", + np.zeros( + self.N_word, + dtype=np.float32))) / + 2) + elif item in ('>=', '<=', '!='): + history_val.append( + (self.word_emb.get( + item[0], + np.zeros( + self.N_word, + dtype=np.float32)) + + self.word_emb.get( + item[1], + np.zeros( + self.N_word, + dtype=np.float32))) / + 2) + elif isinstance(item, int): + history_val.append( + self.word_emb.get( + AGG_OPS[item], np.zeros( + self.N_word, dtype=np.float32))) + else: + print( + "Warning: unsupported data type in history! {}".format(item)) + + val_embs.append(history_val) + val_len[i] = len(history_val) + max_len = max(val_len) + + val_emb_array = np.zeros((B, max_len, self.N_word), dtype=np.float32) + for i in range(B): + for t in range(len(val_embs[i])): + val_emb_array[i, t, :] = val_embs[i][t] + val_inp = torch.from_numpy(val_emb_array) + if self.gpu: + val_inp = val_inp.cuda() + val_inp_var = Variable(val_inp) + + return val_inp_var, val_len + + def gen_word_list_embedding(self, words, B): + val_emb_array = np.zeros( + (B, len(words), self.N_word), dtype=np.float32) + for i, word in enumerate(words): + if len(word.split()) == 1: + emb = self.word_emb.get( + word, np.zeros( + self.N_word, dtype=np.float32)) + else: + word = word.split() + emb = ( + self.word_emb.get( + word[0], + np.zeros( + self.N_word, + dtype=np.float32)) + self.word_emb.get( + word[1], + np.zeros( + self.N_word, + dtype=np.float32))) / 2 + for b in range(B): + val_emb_array[b, i, :] = emb + val_inp = torch.from_numpy(val_emb_array) + if self.gpu: + val_inp = val_inp.cuda() + val_inp_var = Variable(val_inp) + return val_inp_var + + def gen_col_batch(self, cols): + ret = [] + col_len = np.zeros(len(cols), dtype=np.int64) + + names = [] + for b, one_cols in enumerate(cols): + names = names + one_cols + col_len[b] = len(one_cols) + # TODO: what is the diff bw name_len and col_len? + name_inp_var, name_len = self.str_list_to_batch(names) + return name_inp_var, name_len, col_len + + def str_list_to_batch(self, str_list): + """get a list var of wemb of words in each column name in current bactch""" + B = len(str_list) + + val_embs = [] + val_len = np.zeros(B, dtype=np.int64) + for i, one_str in enumerate(str_list): + if self.trainable: + val = [self.w2i.get(x, 0) for x in one_str] + else: + val = [self.word_emb.get(x, np.zeros( + self.N_word, dtype=np.float32)) for x in one_str] + val_embs.append(val) + val_len[i] = len(val) + max_len = max(val_len) + + if self.trainable: + val_tok_array = np.zeros((B, max_len), dtype=np.int64) + for i in range(B): + for t in range(len(val_embs[i])): + val_tok_array[i, t] = val_embs[i][t] + val_tok = torch.from_numpy(val_tok_array) + if self.gpu: + val_tok = val_tok.cuda() + val_tok_var = Variable(val_tok) + val_inp_var = self.embedding(val_tok_var) + else: + val_emb_array = np.zeros( + (B, max_len, self.N_word), dtype=np.float32) + for i in range(B): + for t in range(len(val_embs[i])): + val_emb_array[i, t, :] = val_embs[i][t] + val_inp = torch.from_numpy(val_emb_array) + if self.gpu: + val_inp = val_inp.cuda() + val_inp_var = Variable(val_inp) + + return val_inp_var, val_len diff --git "a/\346\212\200\346\234\257\346\226\207\346\241\243.md" "b/\346\212\200\346\234\257\346\226\207\346\241\243.md" new file mode 100644 index 0000000000000000000000000000000000000000..4cd005d67a314fb466383170f1b48e638c9cea10 --- /dev/null +++ "b/\346\212\200\346\234\257\346\226\207\346\241\243.md" @@ -0,0 +1,11 @@ +# 基于openGauss的自然语言查询器 + +### 数据预处理 +将训练语料中的question经过英文翻译和中文翻译生成新的语料,作为其中数据增强的方式。 + + +### 模型 +模型以官方提供的baseline为基础,更新为python3版本,并使用最新的torch函数替换。训练过程中,将输入问题用LSTM编码,结合关键词以及表头名称,通过注意力机制与问题结合得到表征信息。模型分为多个模块,针对不同训练任务分别训练。 + +### 后续优化 +可以加入预训练模型,使用Multilingual BERT进行编码训练,但由于时间问题和硬件不支持,探索之后由于显卡不足未能在比赛截止前得到结果,后续会继续探索使用预训练模型生成SQL查询语句。 \ No newline at end of file diff --git "a/\350\257\264\346\230\216\346\226\207\346\241\243.md" "b/\350\257\264\346\230\216\346\226\207\346\241\243.md" new file mode 100644 index 0000000000000000000000000000000000000000..4a52adcd9bc6eab9d1275b12f5fe58bed9904c59 --- /dev/null +++ "b/\350\257\264\346\230\216\346\226\207\346\241\243.md" @@ -0,0 +1,47 @@ +# 基于openGauss的自然语言查询器 + +### 模型 + +#### 数据, Embeddings 和 预训练模型 + +1. 下载数据,embedding和数据库: + - 将 ``train.json`` 和 ``dev.json`` 放入 ``./data/char/`` + - 将 ``char_emb.txt``放入``./embedding/`` + - 将 ``database`` 放入 ``./`` +2. 使用 ``python preprocess_data.py -s char|word`` 生成每个模块的训练文件 + +#### 训练 +运行``train_all.sh``训练所有模块 +``` +python train.py \ + --data_root path/to/char/or/word/based/generated_data \ + --save_dir path/to/save/trained/module \ + --train_component \ + --emb_path path/to/embeddings + --col_emb_path path/to/corresponding/embeddings/for/column +``` + +#### 测试 +运行``test_gen.sh`` 生成SQL查询 +``` +python train.py \ + --data_root path/to/char/or/word/based/generated_data \ + --save_dir path/to/save/trained/module \ + --train_component \ + --emb_path path/to/embeddings + --col_emb_path path/to/corresponding/embeddings/for/column +``` + +#### 评估 +运行``evaluation.sh``评估生成的SQL查询 +``` +python evaluation.py \ + --gold path/to/gold/dev/or/test/queries \ + --pred path/to/predicted/dev/or/test/queries \ + --etype evaluation/metric \ + --db path/to/database \ + --table path/to/tables \ +``` + + +