|
| 1 | +# coding=utf-8 |
| 2 | +#导入依赖包 |
| 3 | +import json |
| 4 | +import os |
| 5 | +import sys |
| 6 | +import time |
| 7 | +import tensorflow as tf |
| 8 | +import seq2seqModel |
| 9 | +from config import getConfig |
| 10 | +import io |
| 11 | +#初始化超参字典,并对相应的参数进行赋值 |
| 12 | +gConfig = {} |
| 13 | +gConfig= getConfig.get_config() |
| 14 | +vocab_inp_size = gConfig['vocab_inp_size'] |
| 15 | +vocab_tar_size = gConfig['vocab_tar_size'] |
| 16 | +embedding_dim=gConfig['embedding_dim'] |
| 17 | +units=gConfig['layer_size'] |
| 18 | +BATCH_SIZE=gConfig['batch_size'] |
| 19 | + |
| 20 | +max_length_inp=gConfig['max_length'] |
| 21 | +max_length_tar=gConfig['max_length'] |
| 22 | +log_dir=gConfig['log_dir'] |
| 23 | +writer = tf.summary.create_file_writer(log_dir) |
| 24 | +#对训练语料进行处理,上下文分别加上start end标示 |
| 25 | +def preprocess_sentence(w): |
| 26 | + w ='start '+ w + ' end' |
| 27 | + return w |
| 28 | +#定义数据读取函数,从训练语料中读取数据并进行word2number的处理,并生成词典 |
| 29 | +def read_data(path): |
| 30 | + path = os.getcwd() + '/' + path |
| 31 | + if not os.path.exists(path): |
| 32 | + path=os.path.dirname(os.getcwd())+'/'+ path |
| 33 | + lines = io.open(path, encoding='UTF-8').read().strip().split('\n') |
| 34 | + word_pairs = [[preprocess_sentence(w) for w in l.split('\t')] for l in lines] |
| 35 | + input_lang,target_lang=zip(*word_pairs) |
| 36 | + input_tokenizer=tokenize(gConfig['vocab_inp_path']) |
| 37 | + target_tokenizer=tokenize(gConfig['vocab_tar_path']) |
| 38 | + input_tensor=input_tokenizer.texts_to_sequences(input_lang) |
| 39 | + target_tensor=target_tokenizer.texts_to_sequences(target_lang) |
| 40 | + input_tensor = tf.keras.preprocessing.sequence.pad_sequences(input_tensor, maxlen=max_length_inp, |
| 41 | + padding='post') |
| 42 | + target_tensor= tf.keras.preprocessing.sequence.pad_sequences(target_tensor, maxlen=max_length_tar, |
| 43 | + padding='post') |
| 44 | + return input_tensor,input_tokenizer,target_tensor,target_tokenizer |
| 45 | +#定义word2number函数,通过对语料的处理提取词典,并进行word2number处理以及padding补全 |
| 46 | +def tokenize(vocab_file): |
| 47 | + #从词典中读取预先生成tokenizer的config,构建词典矩阵 |
| 48 | + with open(vocab_file,'r',encoding='utf-8') as f: |
| 49 | + tokenize_config=json.dumps(json.load(f),ensure_ascii=False) |
| 50 | + lang_tokenizer=tf.keras.preprocessing.text.tokenizer_from_json(tokenize_config) |
| 51 | + #利用词典进行word2number的转换以及padding处理 |
| 52 | + return lang_tokenizer |
| 53 | +input_tensor, input_token, target_tensor, target_token = read_data(gConfig['seq_data']) |
| 54 | +steps_per_epoch = len(input_tensor) // gConfig['batch_size'] |
| 55 | +BUFFER_SIZE = len(input_tensor) |
| 56 | +dataset = tf.data.Dataset.from_tensor_slices((input_tensor,target_tensor)).shuffle(BUFFER_SIZE) |
| 57 | +dataset = dataset.batch(BATCH_SIZE, drop_remainder=True) |
| 58 | +enc_hidden = seq2seqModel.encoder.initialize_hidden_state() |
| 59 | +#定义训练函数 |
| 60 | +def train(): |
| 61 | + # 从训练语料中读取数据并使用预生成词典word2number的转换 |
| 62 | + print("Preparing data in %s" % gConfig['train_data']) |
| 63 | + print('每个epoch的训练步数: {}'.format(steps_per_epoch)) |
| 64 | + #如有已经有预训练的模型则加载预训练模型继续训练 |
| 65 | + checkpoint_dir = gConfig['model_data'] |
| 66 | + ckpt=tf.io.gfile.listdir(checkpoint_dir) |
| 67 | + if ckpt: |
| 68 | + print("reload pretrained model") |
| 69 | + seq2seqModel.checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)) |
| 70 | + |
| 71 | + #使用Dataset加载训练数据,Dataset可以加速数据的并发读取并进行训练效率的优化 |
| 72 | + checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") |
| 73 | + start_time = time.time() |
| 74 | + #current_loss=2 |
| 75 | + #min_loss=gConfig['min_loss'] |
| 76 | + epoch = 0 |
| 77 | + train_epoch = gConfig['train_epoch'] |
| 78 | + #开始进行循环训练,这里设置了一个结束循环的条件就是当loss小于设置的min_loss超参时终止训练 |
| 79 | + while epoch<train_epoch: |
| 80 | + start_time_epoch = time.time() |
| 81 | + total_loss = 0 |
| 82 | + #进行一个epoch的训练,训练的步数为steps_per_epoch |
| 83 | + for batch,(inp, targ) in enumerate(dataset.take(steps_per_epoch)): |
| 84 | + batch_loss = seq2seqModel.training_step(inp, targ,target_token, enc_hidden) |
| 85 | + total_loss += batch_loss |
| 86 | + print('epoch:{}batch:{} batch_loss: {}'.format(epoch,batch,batch_loss)) |
| 87 | + #结束一个epoch的训练后,更新current_loss,计算在本epoch中每步训练平均耗时、loss值 |
| 88 | + step_time_epoch = (time.time() - start_time_epoch) / steps_per_epoch |
| 89 | + step_loss = total_loss / steps_per_epoch |
| 90 | + current_steps = +steps_per_epoch |
| 91 | + epoch_time_total = (time.time() - start_time) |
| 92 | + print('训练总步数: {} 总耗时: {} epoch平均每步耗时: {} 平均每步loss {:.4f}' |
| 93 | + .format(current_steps, epoch_time_total, step_time_epoch, step_loss)) |
| 94 | + #将本epoch训练的模型进行保存,更新模型文件 |
| 95 | + seq2seqModel.checkpoint.save(file_prefix=checkpoint_prefix) |
| 96 | + sys.stdout.flush() |
| 97 | + epoch = epoch + 1 |
| 98 | + with writer.as_default(): |
| 99 | + tf.summary.scalar('loss', step_loss, step=epoch) |
| 100 | +#定义预测函数,用于根据上文预测下文对话 |
| 101 | +def predict(sentence): |
| 102 | + # 从词典中读取预先生成tokenizer的config,构建词典矩阵 |
| 103 | + input_tokenizer = tokenize(gConfig['vocab_inp_path']) |
| 104 | + target_tokenizer = tokenize(gConfig['vocab_tar_path']) |
| 105 | + #加载预训练的模型 |
| 106 | + checkpoint_dir = gConfig['model_data'] |
| 107 | + seq2seqModel.checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)) |
| 108 | + #对输入的语句进行处理,加上start end标示 |
| 109 | + sentence = preprocess_sentence(sentence) |
| 110 | + #进行word2number的转换 |
| 111 | + inputs = input_tokenizer.texts_to_sequences(sentence) |
| 112 | + #进行padding的补全 |
| 113 | + inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],maxlen=max_length_inp,padding='post') |
| 114 | + inputs = tf.convert_to_tensor(inputs) |
| 115 | + result = '' |
| 116 | + #初始化一个中间状态 |
| 117 | + hidden = [tf.zeros((1, units))] |
| 118 | + #对输入上文进行encoder编码,提取特征 |
| 119 | + enc_out, enc_hidden = seq2seqModel.encoder(inputs, hidden) |
| 120 | + dec_hidden = enc_hidden |
| 121 | + #decoder的输入从start的对应Id开始正向输入 |
| 122 | + dec_input = tf.expand_dims([target_tokenizer.word_index['start']], 0) |
| 123 | + #在最大的语句长度范围内容,使用模型中的decoder进行循环解码 |
| 124 | + for t in range(max_length_tar): |
| 125 | + #获得解码结果,并使用argmax确定概率最大的id |
| 126 | + predictions, dec_hidden, attention_weights = seq2seqModel.decoder(dec_input, dec_hidden, enc_out) |
| 127 | + predicted_id = tf.argmax(predictions[0]).numpy() |
| 128 | + #判断当前Id是否为语句结束表示,如果是则停止循环解码,否则进行number2word的转换,并进行语句拼接 |
| 129 | + if target_tokenizer.index_word[predicted_id] == 'end': |
| 130 | + break |
| 131 | + result += str(target_tokenizer.index_word[predicted_id]) + ' ' |
| 132 | + #将预测得到的id作为下一个时刻的decoder的输入 |
| 133 | + dec_input = tf.expand_dims([predicted_id], 0) |
| 134 | + return result |
| 135 | +#main函数的入口,根据超参设置的模式启动不同工作模式 |
| 136 | +if __name__ == '__main__': |
| 137 | + #如果在启动python程序时指定了超参文件,则从超参文件中读取超参,否则从默认的超参文件中读取 |
| 138 | + if len(sys.argv) - 1: |
| 139 | + gConfig = getConfig.get_config(sys.argv[1]) |
| 140 | + else: |
| 141 | + gConfig = getConfig.get_config() |
| 142 | + print('\n>> 执行器模式 : %s\n' %(gConfig['mode'])) |
| 143 | + if gConfig['mode'] == 'train': |
| 144 | + print('现在进行模型的训练') |
| 145 | + train() |
| 146 | + elif gConfig['mode'] == 'serve': |
| 147 | + print('当前为服务模式,请运行web程序,进行人机交互') |
| 148 | + |
0 commit comments