I trained the bi-directional type RNN type LSTM for almost 24 hours, and due to fluctuations in error, I decided to reduce training before allowing him to continue training. Since the model is stored in Saver.save (sess, file) in each era, I stopped training with the loss of CTC, minimized to about 115.
Now, after restoring the model, the initial error rate that I get is somewhere around 162, which does not correspond to the stream of errors that I received in the 7th era, as well as what I received in the first era. Therefore, it seems to me that either the βrestoreβ function does not work, or it works, then there should be something else that does not allow it to take effect.
Here is my code:
graph = tf.Graph()
with graph.as_default():
graph_start = time.time()
seq_inputs = tf.placeholder(tf.float32, shape= [None,batch_size,frame_length], name="sequence_inputs")
seq_lens = tf.placeholder(shape=[batch_size],dtype=tf.int32)
seq_inputs = seq_bn(seq_inputs,seq_lens)
initializer = tf.truncated_normal_initializer(mean=0,stddev=0.1)
forward = tf.nn.rnn_cell.LSTMCell(num_units=num_units,
num_proj = hidden_size,
use_peepholes=use_peephole,
initializer=initializer,
state_is_tuple=True)
forward = tf.nn.rnn_cell.MultiRNNCell([forward] * n_layers, state_is_tuple=True)
backward = tf.nn.rnn_cell.LSTMCell(num_units=num_units,
num_proj= hidden_size,
use_peepholes=use_peephole,
initializer=initializer,
state_is_tuple=True)
backward = tf.nn.rnn_cell.MultiRNNCell([backward] * n_layers, state_is_tuple=True)
[fw_out,bw_out], _ = tf.nn.bidirectional_dynamic_rnn(cell_fw=forward, cell_bw=backward, inputs=seq_inputs,time_major=True, dtype=tf.float32, sequence_length=tf.cast(seq_lens,tf.int64))
mew,var_ = tf.nn.moments(fw_out,axes=[0])
fw_out = tf.nn.batch_normalization(fw_out,mew,var_,0.1,1,1e-6)
mew,var_ = tf.nn.moments(bw_out,axes=[0])
bw_out = tf.nn.batch_normalization(bw_out,mew,var_,0.1,1,1e-6)
fw_out = tf.reshape(fw_out,[-1,hidden_size])
bw_out = tf.reshape(bw_out,[-1,hidden_size])
W_fw = tf.Variable(tf.truncated_normal(shape=[hidden_size,n_chars],stddev=np.sqrt(2.0 / (hidden_size))))
W_bw = tf.Variable(tf.truncated_normal(shape=[hidden_size,n_chars],stddev=np.sqrt(2.0 / (hidden_size))))
b_out = tf.constant(0.1,shape=[n_chars])
logits = tf.add(tf.add(tf.matmul(fw_out,W_fw),tf.matmul(bw_out,W_bw)),b_out)
logits = tf.reshape(logits,[-1,batch_size,n_chars])
decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_lens)
indices = tf.placeholder(dtype=tf.int64, shape=[None,2])
values = tf.placeholder(dtype=tf.int32, shape=[None])
shape = tf.placeholder(dtype=tf.int64,shape=[2])
targets = tf.SparseTensor(indices,values,shape)
loss = tf.reduce_mean(tf.nn.ctc_loss(logits, targets, seq_lens))
predicted = tf.to_int32(decoded[0])
error_rate = tf.reduce_sum(tf.edit_distance(predicted,targets,normalize=False))/ \
tf.to_float(tf.size(targets.values))
tvars = tf.trainable_variables()
grad, _ = tf.clip_by_global_norm(tf.gradients(loss,tvars),max_grad_norm)
optimizer = tf.train.MomentumOptimizer(learning_rate=lr,momentum=momentum)
train_step = optimizer.apply_gradients(zip(grad,tvars))
graph_end = time.time()
print("Time elapsed for creating graph: %.3f"%(round(graph_end-graph_start,3)))
start_time = 0
steps = int(np.ceil(len(data_train.files)/batch_size))
loss_tr = []
log_tr = []
loss_vl = []
log_vl = []
err_tr = []
err_vl = []
saver = tf.train.Saver()
with tf.Session(config=config) as sess:
checkpt_path = tf.train.latest_checkpoint(checkpoint_dir)
print(saver.restore(sess,checkpt_path))
print("Model restore from 7th epoch 188th step")
feed = None
epoch = None
step = None
try:
for epoch in range(7,epochs+1):
if epoch==7:
initial_step = 189
else:
initial_step = 0
transcript = []
loss_val = 0
l_pr = 0
start_time = time.time()
for step in range(initial_step,steps):
train_data, transcript, \
targ_indices, targ_values, \
targ_shape, n_frames = data_train.next_batch()
n_frames = np.reshape(n_frames,[-1])
feed = {seq_inputs: train_data, indices:targ_indices, values:targ_values, shape:targ_shape, seq_lens:n_frames}
del train_data,targ_indices,targ_values,targ_shape,n_frames
_,loss_val,deco,l_pr,err_rt_tr = sess.run([train_step,loss,decoded,log_prob,error_rate],
feed_dict=feed)
del feed
loss_tr.append(loss_val)
log_tr.append(l_pr)
err_tr.append(err_rt_tr)
val_data, val_transcript, \
targ_indices, targ_values, \
targ_shape, n_frames = data_val.next_batch()
n_frames = np.reshape(n_frames, [-1])
feed = {seq_inputs: val_data, indices: targ_indices,values: targ_values, shape: targ_shape, seq_lens: n_frames}
del val_data, val_transcript,targ_indices,targ_values,targ_shape,n_frames
vl_loss, l_val_pr, err_rt_vl = sess.run([loss, log_prob, error_rate], feed_dict=feed)
del feed
loss_vl.append(vl_loss)
log_vl.append(l_val_pr)
err_vl.append(err_rt_vl)
print("epoch %d, step: %d, tr_loss: %.2f, vl_loss: %.2f, tr_err: %.2f, vl_err: %.2f"
% (epoch, step, np.mean(loss_tr), np.mean(loss_vl), err_rt_tr, err_rt_vl))
end_time = time.time()
elapsed = round(end_time - start_time, 3)
sample_index = np.random.randint(0, batch_size)
actual_str = [data_train.reverse_map[i] for i in transcript[sample_index]]
pred_sparse = tf.SparseTensor(deco[0].indices, deco[0].values, deco[0].shape)
pred_dense = tf.sparse_tensor_to_dense(pred_sparse)
ans = pred_dense.eval()
pred = []
for i in ans[sample_index,:]:
if i == n_chars-1:
pred.append(data_train.reverse_map[0])
else:
pred.append(data_train.reverse_map[i])
print("time_elapsed for 200 steps: %.3f, " % (elapsed))
if epoch%2 == 0:
print("Sample mini-batch results: \n" \
"predicted string: ", np.array(pred))
print("actual string: ", np.array(actual_str))
print("On training set, the loss: %.2f, log_pr: %.3f, error rate %.3f:"% (loss_val, np.mean(l_pr), err_rt_tr))
print("On validation set, the loss: %.2f, log_pr: %.3f, error rate: %.3f" % (vl_loss, np.mean(l_val_pr), err_rt_vl))
if epoch > 7:
path = saver.save(sess, 'model_%d' % epoch)
print("Session saved at: %s" % path)
np.save(results_fn, np.array([loss_tr, log_tr, loss_vl, log_vl, err_tr, err_vl], dtype=np.object))
except (KeyboardInterrupt, SystemExit, Exception), e:
print("Error/Interruption: %s" % str(e))
exc_type, exc_obj, exc_tb = sys.exc_info()
print("Line no: %d" % exc_tb.tb_lineno)
if epoch > 7:
print("Saving model: %s" % saver.save(sess, 'Last.cpkt'))
print("Current batch: %d" % data_train.b_id)
print("Current epoch: %d" % epoch)
print("Current step: %d"%step)
np.save(results_fn, np.array([loss_tr, log_tr, loss_vl, log_vl, err_tr, err_vl], dtype=np.object))
print("Clossing TF Session...")
sess.close()
print("Terminating Program...")
sys.exit(0)