0%

<数据分析> code2seq代码复现运行笔记pytorch版本

摘要:code2seq代码复现运行笔记pytorch版本
本文用于学习代码生成论文:code2seq: Generating Sequences from Structured Representations of Code
代码地址:
Tensorflow版本:https://github.com/tech-srl/code2seq
jupyter+pytorch版本:https://github.com/m3yrin/code2seq

摘要

本文用于学习代码生成论文:code2seq: Generating Sequences from Structured Representations of Code
代码地址:
Tensorflow版本:https://github.com/tech-srl/code2seq
jupyter+pytorch版本:https://github.com/m3yrin/code2seq
本文跑的是举jupyter版本,有一些小改动

注意:如果想简单跑一下的话,建议直接fork github中的jupyter版本的项目,本文未提及的代码都在那个github中

项目框架

在这里插入图片描述
项目文件夹下有code、dataset、logs、runs四个子文件夹
其中code下有三个重要子文件夹configs,notebooks(放源码,preparation初步下载并且处理数据,code2seq为项目主代码,上图中.jupyter文件为github中源文件,.py文件是我将其中的代码摘到空python文件中的,因为要放在服务器中运行。),src(工具属性代码,由code2seq代码在最前方引用工具)

处理数据(preparation文件)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#下载数据前将项目框架中需要的三个空文件夹创建一下(代码前有!的都是终端运行或者jupyter运行,自行理解)
!mkdir dataset runs logs
#下载Dataset到刚才创建的dataset文件夹中
!wget https://s3.amazonaws.com/code2seq/datasets/java-small-preprocessed.tar.gz -P dataset/
#将下载的数据解压
!tar -xvzf data/java-small-preprocessed.tar.gz -C dataset/
#切换到刚解压生成的文件夹java-small中
%cd data/java-small/
#for dev(暂时没看出有啥用处)
!head -20000 java-small.train.c2s > java-small.train_dev.c2s
#在java-small文件夹中创建四个不同的文件夹train、train_dev、val、test
!mkdir train train_dev val test
# split命令在shell中不存在可以在该文件夹中使用git bash执行split命令,时间比较长,分割的数据比较小(这一步比较魔幻因为将每一段数据代码路径都放进了一个.txt文件中,猜测此举会将训练时间大大延长,但自己就是试着跑一下,所以就直接用了)
!split -d -a 6 -l 1 --additional-suffix=.txt java-small.test.c2s test/
!split -d -a 6 -l 1 --additional-suffix=.txt java-small.val.c2s val/
!split -d -a 6 -l 1 --additional-suffix=.txt java-small.train.c2s train/
!split -d -a 6 -l 1 --additional-suffix=.txt java-small.train_dev.c2s train_dev/

主代码文件(code2seq)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
# 跑此项目需要把终端的路径cd到code/notebooks中,否则会出现导入src的包错误(自行判断),与路径相关的基本都在configs文件中(本文代码的路径是我自己改过的,与github中的不太一样)

import sys
sys.path.append('../')

import os
import time
import yaml
import random
import numpy as np
import warnings
import logging
import pickle
from datetime import datetime
from tqdm import tqdm_notebook as tqdm

from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

import torch
from torch import einsum
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

from src import utils, messenger

config_file = '../configs/config_code2seq.yml'

config = yaml.load(open(config_file), Loader=yaml.FullLoader)

# Data source
DATA_HOME = config['data']['home']
DICT_FILE = DATA_HOME + config['data']['dict']
TRAIN_DIR = DATA_HOME + config['data']['train']
VALID_DIR = DATA_HOME + config['data']['valid']
TEST_DIR = DATA_HOME + config['data']['test']

# Training parameter
batch_size = config['training']['batch_size']
num_epochs = config['training']['num_epochs']
lr = config['training']['lr']
teacher_forcing_rate = config['training']['teacher_forcing_rate']
nesterov = config['training']['nesterov']
weight_decay = config['training']['weight_decay']
momentum = config['training']['momentum']
decay_ratio = config['training']['decay_ratio']
save_name = config['training']['save_name']
warm_up = config['training']['warm_up']
patience = config['training']['patience']



# Model parameter
token_size = config['model']['token_size']
hidden_size = config['model']['hidden_size']
num_layers = config['model']['num_layers']
bidirectional = config['model']['bidirectional']
rnn_dropout = config['model']['rnn_dropout']
embeddings_dropout = config['model']['embeddings_dropout']
num_k = config['model']['num_k']

# etc
slack_url_path = config['etc']['slack_url_path']
info_prefix = config['etc']['info_prefix']


slack_url = None
if os.path.exists(slack_url_path):
slack_url = yaml.load(open(slack_url_path), Loader=yaml.FullLoader)['slack_url']

warnings.filterwarnings('ignore')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(1)
random_state = 42

run_id = datetime.now().strftime('%Y-%m-%d--%H-%M-%S')
log_file = '../../logs/' + run_id + '.log'
exp_dir = '../../runs/' + run_id
os.mkdir(exp_dir)

logging.basicConfig(format='%(asctime)s | %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', filename=log_file, level=logging.DEBUG)
msgr = messenger.Info(info_prefix, slack_url)

msgr.print_msg('run_id : {}'.format(run_id))
msgr.print_msg('log_file : {}'.format(log_file))
msgr.print_msg('exp_dir : {}'.format(exp_dir))
msgr.print_msg('device : {}'.format(device))
msgr.print_msg(str(config))

PAD_TOKEN = '<PAD>'
BOS_TOKEN = '<S>'
EOS_TOKEN = '</S>'
UNK_TOKEN = '<UNK>'
PAD = 0
BOS = 1
EOS = 2
UNK = 3

# load vocab dict
with open(DICT_FILE, 'rb') as file:
subtoken_to_count = pickle.load(file)
node_to_count = pickle.load(file)
target_to_count = pickle.load(file)
max_contexts = pickle.load(file)
num_training_examples = pickle.load(file)
msgr.print_msg('Dictionaries loaded.')

# making vocab dicts for terminal subtoken, nonterminal node and target.

word2id = {
PAD_TOKEN: PAD,
BOS_TOKEN: BOS,
EOS_TOKEN: EOS,
UNK_TOKEN: UNK,
}

vocab_subtoken = utils.Vocab(word2id=word2id)
vocab_nodes = utils.Vocab(word2id=word2id)
vocab_target = utils.Vocab(word2id=word2id)

vocab_subtoken.build_vocab(list(subtoken_to_count.keys()), min_count=0)
vocab_nodes.build_vocab(list(node_to_count.keys()), min_count=0)
vocab_target.build_vocab(list(target_to_count.keys()), min_count=0)

vocab_size_subtoken = len(vocab_subtoken.id2word)
vocab_size_nodes = len(vocab_nodes.id2word)
vocab_size_target = len(vocab_target.id2word)


msgr.print_msg('vocab_size_subtoken:' + str(vocab_size_subtoken))
msgr.print_msg('vocab_size_nodes:' + str(vocab_size_nodes))
msgr.print_msg('vocab_size_target:' + str(vocab_size_target))

num_length_train = num_training_examples
msgr.print_msg('num_examples : ' + str(num_length_train))

class DataLoader(object):

def __init__(self, data_path, batch_size, num_k, vocab_subtoken, vocab_nodes, vocab_target, shuffle=True, batch_time = False):

"""
data_path : path for data
num_examples : total lines of data file
batch_size : batch size
num_k : max ast pathes included to one examples
vocab_subtoken : dict of subtoken and its id
vocab_nodes : dict of node simbol and its id
vocab_target : dict of target simbol and its id
"""

self.data_path = data_path
self.batch_size = batch_size

self.num_examples = self.file_count(data_path)
self.num_k = num_k

self.vocab_subtoken = vocab_subtoken
self.vocab_nodes = vocab_nodes
self.vocab_target = vocab_target

self.index = 0
self.pointer = np.array(range(self.num_examples))
self.shuffle = shuffle

self.batch_time = batch_time

self.reset()


def __iter__(self):
return self

def __next__(self):

if self.batch_time:
t1 = time.time()

if self.index >= self.num_examples:
self.reset()
raise StopIteration()

ids = self.pointer[self.index: self.index + self.batch_size]
seqs_S, seqs_N, seqs_E, seqs_Y = self.read_batch(ids)

# length_k : (batch_size, k)
lengths_k = [len(ex) for ex in seqs_N]

# flattening (batch_size, k, l) to (batch_size * k, l)
# this is useful to make torch.tensor
seqs_S = [symbol for k in seqs_S for symbol in k]
seqs_N = [symbol for k in seqs_N for symbol in k]
seqs_E = [symbol for k in seqs_E for symbol in k]

# Padding
lengths_S = [len(s) for s in seqs_S]
lengths_N = [len(s) for s in seqs_N]
lengths_E = [len(s) for s in seqs_E]
lengths_Y = [len(s) for s in seqs_Y]

max_length_S = max(lengths_S)
max_length_N = max(lengths_N)
max_length_E = max(lengths_E)
max_length_Y = max(lengths_Y)

padded_S = [utils.pad_seq(s, max_length_S) for s in seqs_S]
padded_N = [utils.pad_seq(s, max_length_N) for s in seqs_N]
padded_E = [utils.pad_seq(s, max_length_E) for s in seqs_E]
padded_Y = [utils.pad_seq(s, max_length_Y) for s in seqs_Y]

# index for split (batch_size * k, l) into (batch_size, k, l)
index_N = range(len(lengths_N))

# sort for rnn
seq_pairs = sorted(zip(lengths_N, index_N, padded_N, padded_S, padded_E), key=lambda p: p[0], reverse=True)
lengths_N, index_N, padded_N, padded_S, padded_E = zip(*seq_pairs)

batch_S = torch.tensor(padded_S, dtype=torch.long, device=device)
batch_E = torch.tensor(padded_E, dtype=torch.long, device=device)

# transpose for rnn
batch_N = torch.tensor(padded_N, dtype=torch.long, device=device).transpose(0, 1)
batch_Y = torch.tensor(padded_Y, dtype=torch.long, device=device).transpose(0, 1)

# update index
self.index += self.batch_size

if self.batch_time:
t2 = time.time()
elapsed_time = t2-t1
print(f"batching time:0.0535")

return batch_S, batch_N, batch_E, batch_Y, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N


def reset(self):
if self.shuffle:
self.pointer = shuffle(self.pointer)
self.index = 0

def file_count(self, path):
lst = [name for name in os.listdir(path) if os.path.isfile(os.path.join(path, name))]
return len(lst)

def read_batch(self, ids):

seqs_S = []
seqs_E = []
seqs_N = []
seqs_Y = []

for i in ids:
path = self.data_path + '/{:0>6d}.txt'.format(i)
with open(path, 'r') as f:
seq_S = []
seq_N = []
seq_E = []

target, *syntax_path = f.readline().split(' ')
target = target.split('|')
target = utils.sentence_to_ids(self.vocab_target, target)

# remove '' and '\n' in sequence, java-small dataset contains many '' in a line.
syntax_path = [s for s in syntax_path if s != '' and s != '\n']

# if the amount of ast path exceed the k,
# uniformly sample ast pathes, as described in the paper.
if len(syntax_path) > self.num_k:
sampled_path_index = random.sample(range(len(syntax_path)) , self.num_k)
else :
sampled_path_index = range(len(syntax_path))

for j in sampled_path_index:
terminal1, ast_path, terminal2 = syntax_path[j].split(',')

terminal1 = utils.sentence_to_ids(self.vocab_subtoken, terminal1.split('|'))
ast_path = utils.sentence_to_ids(self.vocab_nodes, ast_path.split('|'))
terminal2 = utils.sentence_to_ids(self.vocab_subtoken, terminal2.split('|'))

seq_S.append(terminal1)
seq_E.append(terminal2)
seq_N.append(ast_path)

seqs_S.append(seq_S)
seqs_E.append(seq_E)
seqs_N.append(seq_N)
seqs_Y.append(target)

return seqs_S, seqs_N, seqs_E, seqs_Y

class Encoder(nn.Module):
def __init__(self, input_size_subtoken, input_size_node, token_size, hidden_size, bidirectional = True, num_layers = 2, rnn_dropout = 0.5, embeddings_dropout = 0.25):

"""
input_size_subtoken : # of unique subtoken
input_size_node : # of unique node symbol
token_size : embedded token size
hidden_size : size of initial state of decoder
rnn_dropout = 0.5 : rnn drop out ratio
embeddings_dropout = 0.25 : dropout ratio for context vector
"""

super(Encoder, self).__init__()
self.hidden_size = hidden_size
self.token_size = token_size

self.embedding_subtoken = nn.Embedding(input_size_subtoken, token_size, padding_idx=PAD)
self.embedding_node = nn.Embedding(input_size_node, token_size, padding_idx=PAD)

self.lstm = nn.LSTM(token_size, token_size, num_layers = num_layers, bidirectional=bidirectional, dropout=rnn_dropout)
self.out = nn.Linear(token_size * 4, hidden_size)

self.dropout = nn.Dropout(embeddings_dropout)
self.num_directions = 2 if bidirectional else 1
self.num_layers = num_layers

def forward(self, batch_S, batch_N, batch_E, lengths_k, index_N, hidden=None):

"""
batch_S : (B * k, l) start terminals' subtoken of each ast path
batch_N : (l, B*k) nonterminals' nodes of each ast path
batch_E : (B * k, l) end terminals' subtoken of each ast path

lengths_k : length of k in each example
index_N : index for unsorting,
"""

bk_size = batch_N.shape[1]
output_bag = []
hidden_batch = []

# (B * k, l, d)
encode_S = self.embedding_subtoken(batch_S)
encode_E = self.embedding_subtoken(batch_E)

# encode_S (B * k, d) token_representation of each ast path
encode_S = encode_S.sum(1)
encode_E = encode_E.sum(1)


"""
LSTM Outputs: output, (h_n, c_n)
output (seq_len, batch, num_directions * hidden_size)
h_n (num_layers * num_directions, batch, hidden_size) : tensor containing the hidden state for t = seq_len.
c_n (num_layers * num_directions, batch, hidden_size)
"""

# emb_N :(l, B*k, d)
emb_N = self.embedding_node(batch_N)
packed = pack_padded_sequence(emb_N, lengths_N)
output, (hidden, cell) = self.lstm(packed, hidden)
#output, _ = pad_packed_sequence(output)

# hidden (num_layers * num_directions, batch, hidden_size)
# only last layer, (num_directions, batch, hidden_size)
hidden = hidden[-self.num_directions:, :, :]

# -> (Bk, num_directions, hidden_size)
hidden = hidden.transpose(0, 1)

# -> (Bk, 1, hidden_size * num_directions)
hidden = hidden.contiguous().view(bk_size, 1, -1)

# encode_N (Bk, hidden_size * num_directions)
encode_N = hidden.squeeze(1)

# encode_SNE : (B*k, hidden_size * num_directions + 2)
encode_SNE = torch.cat([encode_N, encode_S, encode_E], dim=1)

# encode_SNE : (B*k, d)
encode_SNE = self.out(encode_SNE)

# unsort as example
#index = torch.tensor(index_N, dtype=torch.long, device=device)
#encode_SNE = torch.index_select(encode_SNE, dim=0, index=index)
index = np.argsort(index_N)
encode_SNE = encode_SNE[[index]]

# as is in https://github.com/tech-srl/code2seq/blob/ec0ae309efba815a6ee8af88301479888b20daa9/model.py#L511
encode_SNE = self.dropout(encode_SNE)

# output_bag : [ B, (k, d) ]
output_bag = torch.split(encode_SNE, lengths_k, dim=0)

# hidden_0 : (1, B, d)
# for decoder initial state
hidden_0 = [ob.mean(0).unsqueeze(dim=0) for ob in output_bag]
hidden_0 = torch.cat(hidden_0, dim=0).unsqueeze(dim=0)

return output_bag, hidden_0

class Decoder(nn.Module):
def __init__(self, hidden_size, output_size, rnn_dropout):
"""
hidden_size : decoder unit size,
output_size : decoder output size,
rnn_dropout : dropout ratio for rnn
"""

super(Decoder, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size

self.embedding = nn.Embedding(output_size, hidden_size, padding_idx=PAD)
self.gru = nn.GRU(hidden_size, hidden_size, dropout=rnn_dropout)
self.out = nn.Linear(hidden_size * 2, output_size)

def forward(self, seqs, hidden, attn):
emb = self.embedding(seqs)
_, hidden = self.gru(emb, hidden)

output = torch.cat((hidden, attn), 2)
output = self.out(output)

return output, hidden

class EncoderDecoder_with_Attention(nn.Module):

"""Conbine Encoder and Decoder"""

def __init__(self, input_size_subtoken, input_size_node, token_size, output_size, hidden_size, bidirectional = True, num_layers = 2, rnn_dropout = 0.5, embeddings_dropout = 0.25):

super(EncoderDecoder_with_Attention, self).__init__()
self.encoder = Encoder(input_size_subtoken, input_size_node, token_size, hidden_size, bidirectional = bidirectional, num_layers = num_layers, rnn_dropout = rnn_dropout, embeddings_dropout = embeddings_dropout)
self.decoder = Decoder(hidden_size, output_size, rnn_dropout)

self.W_a = torch.rand((hidden_size, hidden_size), dtype=torch.float,device=device , requires_grad=True)

nn.init.xavier_uniform_(self.W_a)


def forward(self, batch_S, batch_N, batch_E, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S, max_length_N,max_length_E,max_length_Y, lengths_k, index_N, terget_max_length, batch_Y=None, use_teacher_forcing=False):

# Encoder
encoder_output_bag, encoder_hidden = \
self.encoder(batch_S, batch_N, batch_E, lengths_k, index_N)

_batch_size = len(encoder_output_bag)
decoder_hidden = encoder_hidden

# make initial input for decoder
decoder_input = torch.tensor([BOS] * _batch_size, dtype=torch.long, device=device)
decoder_input = decoder_input.unsqueeze(0) # (1, batch_size)

# output holder
decoder_outputs = torch.zeros(terget_max_length, _batch_size, self.decoder.output_size, device=device)

#print('=' * 20)
for t in range(terget_max_length):

# ct
ct = self.attention(encoder_output_bag, decoder_hidden, lengths_k)

decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, ct)

#print(decoder_output.max(-1)[1])

decoder_outputs[t] = decoder_output

# Teacher Forcing
if use_teacher_forcing and batch_Y is not None:
decoder_input = batch_Y[t].unsqueeze(0)
else:
decoder_input = decoder_output.max(-1)[1]

return decoder_outputs

def attention(self, encoder_output_bag, hidden, lengths_k):

"""
encoder_output_bag : (batch, k, hidden_size) bag of embedded ast path
hidden : (1 , batch, hidden_size):
lengths_k : (batch, 1) length of k in each example
"""

# e_out : (batch * k, hidden_size)
e_out = torch.cat(encoder_output_bag, dim=0)

# e_out : (batch * k(i), hidden_size(j))
# self.W_a : [hidden_size(j), hidden_size(k)]
# ha -> : [batch * k(i), hidden_size(k)]
ha = einsum('ij,jk->ik', e_out, self.W_a)

# ha -> : [batch, (k, hidden_size)]
ha = torch.split(ha, lengths_k, dim=0)

# dh = [batch, (1, hidden_size)]
hd = hidden.transpose(0,1)
hd = torch.unbind(hd, dim = 0)

# _ha : (k(i), hidden_size(j))
# _hd : (1(k), hidden_size(j))
# at : [batch, ( k(i) ) ]
at = [F.softmax(torch.einsum('ij,kj->i', _ha, _hd), dim=0) for _ha, _hd in zip(ha, hd)]

# a : ( k(i) )
# e : ( k(i), hidden_size(j))
# ct : [batch, (hidden_size(j)) ] -> [batch, (1, hidden_size) ]
ct = [torch.einsum('i,ij->j', a, e).unsqueeze(0) for a, e in zip(at, encoder_output_bag)]

# ct [batch, hidden_size(k)]
# -> (1, batch, hidden_size)
ct = torch.cat(ct, dim=0).unsqueeze(0)


return ct

mce = nn.CrossEntropyLoss(size_average=False, ignore_index=PAD)
def masked_cross_entropy(logits, target):
return mce(logits.view(-1, logits.size(-1)), target.view(-1))

batch_time = False
train_dataloader = DataLoader(TRAIN_DIR, batch_size, num_k, vocab_subtoken, vocab_nodes, vocab_target, batch_time=batch_time, shuffle=True)
valid_dataloader = DataLoader(VALID_DIR, batch_size, num_k, vocab_subtoken, vocab_nodes, vocab_target, shuffle=False)

model_args = {
'input_size_subtoken' : vocab_size_subtoken,
'input_size_node' : vocab_size_nodes,
'output_size' : vocab_size_target,
'hidden_size' : hidden_size,
'token_size' : token_size,
'bidirectional' : bidirectional,
'num_layers' : num_layers,
'rnn_dropout' : rnn_dropout,
'embeddings_dropout' : embeddings_dropout
}

model = EncoderDecoder_with_Attention(**model_args).to(device)

#optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov = nesterov)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda epoch: decay_ratio ** epoch)

fname = exp_dir + save_name
early_stopping = utils.EarlyStopping(fname, patience, warm_up, verbose=True)

def compute_loss(batch_S, batch_N, batch_E, batch_Y, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N, model, optimizer=None, is_train=True):
model.train(is_train)

use_teacher_forcing = is_train and (random.random() < teacher_forcing_rate)

target_max_length = batch_Y.size(0)
pred_Y = model(batch_S, batch_N, batch_E, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N, target_max_length, batch_Y, use_teacher_forcing)

loss = masked_cross_entropy(pred_Y.contiguous(), batch_Y.contiguous())

if is_train:
optimizer.zero_grad()
loss.backward()
optimizer.step()

batch_Y = batch_Y.transpose(0, 1).contiguous().data.cpu().tolist()
pred = pred_Y.max(dim=-1)[1].data.cpu().numpy().T.tolist()


return loss.item(), batch_Y, pred

#
# Training Loop
#
progress_bar = False # progress bar is visible in progress_bar = False


for epoch in range(1, num_epochs+1):
print('运行开始:')
print(epoch)
train_loss = 0.
train_refs = []
train_hyps = []
valid_loss = 0.
valid_refs = []
valid_hyps = []

# train

for batch in tqdm(train_dataloader, total=train_dataloader.num_examples // train_dataloader.batch_size + 1, desc='TRAIN'):
print('第一次训练开始。。。')
batch_S, batch_N, batch_E, batch_Y, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S, max_length_N,max_length_E,max_length_Y, lengths_k, index_N = batch

loss, gold, pred = compute_loss(
batch_S, batch_N, batch_E, batch_Y,
lengths_S, lengths_N, lengths_E, lengths_Y,
max_length_S,max_length_N,max_length_E,max_length_Y,
lengths_k, index_N, model, optimizer,
is_train=True
)

train_loss += loss
train_refs += gold
train_hyps += pred

# valid
for batch in tqdm(valid_dataloader, total=valid_dataloader.num_examples // valid_dataloader.batch_size + 1, desc='VALID'):

batch_S, batch_N, batch_E, batch_Y, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N = batch

loss, gold, pred = compute_loss(
batch_S, batch_N, batch_E, batch_Y,
lengths_S, lengths_N, lengths_E, lengths_Y,
max_length_S,max_length_N,max_length_E,max_length_Y,
lengths_k, index_N, model, optimizer,
is_train=False
)

valid_loss += loss
valid_refs += gold
valid_hyps += pred


train_loss = np.sum(train_loss) / train_dataloader.num_examples
valid_loss = np.sum(valid_loss) / valid_dataloader.num_examples

# F1 etc
train_precision, train_recall, train_f1 = utils.calculate_results_set(train_refs, train_hyps)
valid_precision, valid_recall, valid_f1 = utils.calculate_results_set(valid_refs, valid_hyps)


early_stopping(valid_f1, model, epoch)
if early_stopping.early_stop:
msgr.print_msg("Early stopping")
break

msgr.print_msg('Epoch {}: train_loss: {:5.2f} train_f1: {:2.4f} valid_loss: {:5.2f} valid_f1: {:2.4f}'.format(
epoch, train_loss, train_f1, valid_loss, valid_f1))

print('-'*80)

scheduler.step()


# evaluation
print('训练结束,开始评估')
model = EncoderDecoder_with_Attention(**model_args).to(device)

fname = exp_dir + save_name
ckpt = torch.load(fname)
model.load_state_dict(ckpt)

model.eval()

test_dataloader = DataLoader(TEST_DIR, batch_size, num_k, vocab_subtoken, vocab_nodes, vocab_target, batch_time=batch_time, shuffle=True)

refs_list = []
hyp_list = []

for batch in tqdm(test_dataloader,
total=test_dataloader.num_examples // test_dataloader.batch_size + 1,
desc='TEST'):

batch_S, batch_N, batch_E, batch_Y, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N = batch
target_max_length = batch_Y.size(0)
use_teacher_forcing = False

pred_Y = model(batch_S, batch_N, batch_E, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N, target_max_length, batch_Y, use_teacher_forcing)

refs = batch_Y.transpose(0, 1).contiguous().data.cpu().tolist()[0]
pred = pred_Y.max(dim=-1)[1].data.cpu().numpy().T.tolist()[0]

refs_list.append(refs)
hyp_list.append(pred)

msgr.print_msg('Tested model : ' + fname)

test_precision, test_recall, test_f1 = utils.calculate_results(refs_list, hyp_list)
msgr.print_msg('Test : precision {:1.5f}, recall {:1.5f}, f1 {:1.5f}'.format(test_precision, test_recall, test_f1))

test_precision, test_recall, test_f1 = utils.calculate_results_set(refs_list, hyp_list)
msgr.print_msg('Test(set) : precision {:1.5f}, recall {:1.5f}, f1 {:1.5f}'.format(test_precision, test_recall, test_f1))

batch_time = False
test_dataloader = DataLoader(TEST_DIR, 1, num_k, vocab_subtoken, vocab_nodes, vocab_target, batch_time=batch_time, shuffle=True)

model.eval()

batch_S, batch_N, batch_E, batch_Y, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N = next(test_dataloader)

sentence_Y = ' '.join(utils.ids_to_sentence(vocab_target, batch_Y.data.cpu().numpy()[:-1, 0]))
msgr.print_msg('tgt: {}'.format(sentence_Y))

target_max_length = batch_Y.size(0)
use_teacher_forcing = False
output = model(batch_S, batch_N, batch_E, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N, target_max_length, batch_Y, use_teacher_forcing)

output = output.max(dim=-1)[1].view(-1).data.cpu().tolist()
output_sentence = ' '.join(utils.ids_to_sentence(vocab_target, utils.trim_eos(output)))
msgr.print_msg('out: {}'.format(output_sentence))

配置类文件(configs_code2seq.yml文件)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#与地址相关的配置在code2seq主代码开始部分
data:
home: ../../dataset
dict: /java-small.dict.c2s
train: /train
valid: /val
test: /test

training:
batch_size: 256
num_epochs: 50
lr: 0.001
teacher_forcing_rate: 0.4
nesterov: True
weight_decay: 0.01
momentum: 0.95
decay_ratio: 0.95
save_name: /model.pth
warm_up: 1
patience: 2

model:
token_size: 128
hidden_size: 64
num_layers: 1
bidirectional: True
rnn_dropout: 0.5
embeddings_dropout: 0.3
num_k : 200

etc:
info_prefix: code2seq
#下面这个路径我好像也没发现有什么用
slack_url_path: ../slack/slack_url.yml

comment: code2seq

工具类代码(src文件)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#messenger.py文件
import logging
import slackweb

class Info(object):
def __init__(self, info_prefix='', slack_url = None):

self.info_prefix = info_prefix
self.slack = None
if slack_url is not None:
self.slack = slackweb.Slack(url = slack_url)
self.slack.notify(text = "="*80)

def print_msg(self, msg):
text = self.info_prefix + ' ' + msg

print(text)
logging.info(text)
if self.slack is not None:
self.slack.notify(text = text)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#utils.py文件
import torch
from nltk import bleu_score

PAD = 0
BOS = 1
EOS = 2
UNK = 3

class Vocab(object):
def __init__(self, word2id={}):

self.word2id = dict(word2id)
self.id2word = {v: k for k, v in self.word2id.items()}

def build_vocab(self, sentences, min_count=1):
word_counter = {}
for word in sentences:
word_counter[word] = word_counter.get(word, 0) + 1

for word, count in sorted(word_counter.items(), key=lambda x: -x[1]):
if count < min_count:
break
_id = len(self.word2id)
self.word2id.setdefault(word, _id)
self.id2word[_id] = word

def sentence_to_ids(vocab, sentence):
ids = [vocab.word2id.get(word, UNK) for word in sentence]
ids += [EOS]
return ids

def ids_to_sentence(vocab, ids):
return [vocab.id2word[_id] for _id in ids]

def trim_eos(ids):
if EOS in ids:
return ids[:ids.index(EOS)]
else:
return ids

def calculate_results_set(refs, preds):
#calc precision, recall and F1
#same as https://github.com/tech-srl/code2seq/blob/ec0ae309efba815a6ee8af88301479888b20daa9/model.py#L239

filterd_refs = [ref[:ref.index(EOS)] for ref in refs]
filterd_preds = [pred[:pred.index(EOS)] if EOS in pred else pred for pred in preds]

filterd_refs = [list(set(ref)) for ref in filterd_refs]
filterd_preds = [list(set(pred)) for pred in filterd_preds]

true_positive, false_positive, false_negative = 0, 0, 0

for filterd_pred, filterd_ref in zip(filterd_preds, filterd_refs):

for fp in filterd_pred:
if fp in filterd_ref:
true_positive += 1
else:
false_positive += 1

for fr in filterd_ref:
if not fr in filterd_pred:
false_negative += 1

# https://github.com/tech-srl/code2seq/blob/ec0ae309efba815a6ee8af88301479888b20daa9/model.py#L282
if true_positive + false_positive > 0:
precision = true_positive / (true_positive + false_positive)
else:
precision = 0

if true_positive + false_negative > 0:
recall = true_positive / (true_positive + false_negative)
else:
recall = 0

if precision + recall > 0:
f1 = 2 * precision * recall / (precision + recall)
else:
f1 = 0

return precision, recall, f1

def calculate_results(refs, preds):
#calc precision, recall and F1
#same as https://github.com/tech-srl/code2seq/blob/ec0ae309efba815a6ee8af88301479888b20daa9/model.py#L239

filterd_refs = [ref[:ref.index(EOS)] for ref in refs]
filterd_preds = [pred[:pred.index(EOS)] if EOS in pred else pred for pred in preds]

true_positive, false_positive, false_negative = 0, 0, 0

for filterd_pred, filterd_ref in zip(filterd_preds, filterd_refs):

if filterd_pred == filterd_ref:
true_positive += len(filterd_pred)
continue

for fp in filterd_pred:
if fp in filterd_ref:
true_positive += 1
else:
false_positive += 1

for fr in filterd_ref:
if not fr in filterd_pred:
false_negative += 1

# https://github.com/tech-srl/code2seq/blob/ec0ae309efba815a6ee8af88301479888b20daa9/model.py#L282
if true_positive + false_positive > 0:
precision = true_positive / (true_positive + false_positive)
else:
precision = 0

if true_positive + false_negative > 0:
recall = true_positive / (true_positive + false_negative)
else:
recall = 0

if precision + recall > 0:
f1 = 2 * precision * recall / (precision + recall)
else:
f1 = 0

return precision, recall, f1

class EarlyStopping(object):
def __init__(self, filename = None, patience=3, warm_up=0, verbose=False):

self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.warm_up = warm_up
self.filename = filename

def __call__(self, score, model, epoch):

if self.best_score is None:
self.best_score = score
self.save_checkpoint(score, model)

elif (score <= self.best_score) and (epoch > self.warm_up) :
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
if (epoch <= self.warm_up):
print('Warming up until epoch', self.warm_up)

else:
if self.verbose:
print(f'Score improved. ({self.best_score:.6f} --> {score:.6f}).')

self.best_score = score
self.save_checkpoint(score, model)
self.counter = 0

def save_checkpoint(self, score, model):

if self.filename is not None:
torch.save(model.state_dict(), self.filename)

if self.verbose:
print('Model saved...')

def pad_seq(seq, max_length):
# pad tail of sequence to extend sequence length up to max_length
res = seq + [PAD for i in range(max_length - len(seq))]
return res

def calc_bleu(refs, hyps):
_refs = [[ref[:ref.index(EOS)]] for ref in refs]
_hyps = [hyp[:hyp.index(EOS)] if EOS in hyp else hyp for hyp in hyps]
return 100 * bleu_score.corpus_bleu(_refs, _hyps)

运行项目

1
2
#终端中运行,记得把当前位置切换到code/notebooks中
!python code2seq.py

结果

因为我对原数据集进行了裁剪,用了java-small中的一部分数据来train所以最后的结果不怎么好,但大致流程是这样跑出来的
在这里插入图片描述

参考

1.code2seq: Generating Sequences from Structured Representations of Code笔记
2.big code: code2seq论文复现 Generating Sequences from Structured Representations of Code
3.big code: code2seq Generating Sequences from Structured Representations of Code