1 Star 5 Fork 3

多读书多看报少玩游戏少睡觉/Easy SmartHome

Create your Gitee Account
Explore and code with more than 13.5 million developers,Free private repositories !:)
Sign up
文件
Clone or Download
DM_predict-checkpoint.ipynb 8.85 KB
Copy Edit Raw Blame History
import tensorflow as tf
import json
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
params = {
    'batch_size': 64,
    'lr' : 0.001,
    'epochs': 500,
    'drops' : [0.1]
         }
with open('./DM_char.json', mode='r', encoding='utf-8') as f:
    dicts = json.load(f)
action2id = dicts['action2id']
intent2id = dicts['intent2id']
slots2id = dicts['entities2id']
id2action = dicts['id2action']
previous_action_len = len(action2id)
slots_len = len(slots2id)
user_intent_len = len(intent2id)
tf.keras.backend.clear_session()
previous_action_inputs = tf.keras.layers.Input(shape=(previous_action_len,), name = 'previous_action_inputs')
slots_inputs = tf.keras.layers.Input(shape = (slots_len,), name = 'slots_inputs')
user_intent_inputs = tf.keras.layers.Input(shape = (user_intent_len,), name = 'user_intent_inputs')

previous_action_embed = tf.keras.layers.Embedding(256,32)(previous_action_inputs)
slots_embed = tf.keras.layers.Embedding(256,32)(slots_inputs)
user_intent_embed = tf.keras.layers.Embedding(256,32)(user_intent_inputs)

utter_inputs = tf.keras.layers.concatenate([previous_action_embed,slots_embed,user_intent_embed],axis=1)
bilstm = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(64,return_sequences=True))(utter_inputs)
x_in = tf.keras.layers.LayerNormalization()(bilstm)
x_conv = tf.keras.layers.GlobalAveragePooling1D()(x_in)
pre_action = tf.keras.layers.Dense(previous_action_len, activation='sigmoid',name = 'pre_action')(x_conv)
model = tf.keras.Model([previous_action_inputs,slots_inputs,user_intent_inputs],pre_action)
model.summary()
Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
previous_action_inputs (InputLa [(None, 36)]         0                                            
__________________________________________________________________________________________________
slots_inputs (InputLayer)       [(None, 13)]         0                                            
__________________________________________________________________________________________________
user_intent_inputs (InputLayer) [(None, 40)]         0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 36, 32)       8192        previous_action_inputs[0][0]     
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 13, 32)       8192        slots_inputs[0][0]               
__________________________________________________________________________________________________
embedding_2 (Embedding)         (None, 40, 32)       8192        user_intent_inputs[0][0]         
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 89, 32)       0           embedding[0][0]                  
                                                                 embedding_1[0][0]                
                                                                 embedding_2[0][0]                
__________________________________________________________________________________________________
bidirectional (Bidirectional)   (None, 89, 128)      37632       concatenate[0][0]                
__________________________________________________________________________________________________
layer_normalization (LayerNorma (None, 89, 128)      256         bidirectional[0][0]              
__________________________________________________________________________________________________
global_average_pooling1d (Globa (None, 128)          0           layer_normalization[0][0]        
__________________________________________________________________________________________________
pre_action (Dense)              (None, 36)           4644        global_average_pooling1d[0][0]   
==================================================================================================
Total params: 67,108
Trainable params: 67,108
Non-trainable params: 0
__________________________________________________________________________________________________
model.load_weights('../DM_model_weight/DM_weight_625.h5')
def trans2labelid(vocab,x):
        max_len = len(vocab)
        labels = [vocab[label] for label in x]
        label_onehot = np.eye(max_len)[labels]
        values = sum(label_onehot)
        values = np.expand_dims(values,axis=0)
        return values
def predict(x):
    x = list(x)
    previous_action_inputs = x[0]
    slots_inputs = x[1]
    user_intent_inputs = x[2] 
    previous_action_inputs = trans2labelid(action2id,previous_action_inputs)
    print(previous_action_inputs)
    slots_inputs = trans2labelid(slots2id,slots_inputs)
    print(slots_inputs)
    user_intent_inputs = trans2labelid(intent2id,user_intent_inputs) 
    print(user_intent_inputs)
    pre_data= model.predict([previous_action_inputs,slots_inputs,user_intent_inputs])
    pre_action = id2action[str(np.argmax(pre_data))]
    print(' text: {} \n action:{} \n '.format(x,pre_action))
inputs = [['PAD'],['operation','device','address'],['Control-Lamp_Lightness']]
predict(inputs)
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
[[0. 0. 1. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.]]
[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
 text: [['PAD'], ['operation', 'device', 'address'], ['Control-Lamp_Lightness']] 
 action:action_controllamplightness 
 

 
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/xulucky/easy-smarthome.git
git@gitee.com:xulucky/easy-smarthome.git
xulucky
easy-smarthome
Easy SmartHome
master

Search