You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
time += 1
finished = (time > sequence_length) # check if each batch is completed
if all(finished):
next_inputs = zero-tensor with same shape
else:
next_inputs = input_tas[time]
return finished, next_inputs, state
DecoderOutput
Wraps Decoder's attributes such as output_size, output_dtype
finished, inputs, state = decoder.initialize()
time = 0
outputs_ta = TensorArray(size=decoder.output_size)
while not all(finished):
outputs, state, inputs, finished = decoder.step(time, inputs, state)
if maximum_iterations is not None and time + 1 >= maximum_iterations:
finished = True
# if finished!
# => zero out all remaining outputs
if impute_finished:
outputs = zero_outputs # zero-tensor with same shape
outputs_ta[time] = outputs
time += 1
final_outputs = outputs_ta.stack()
final_state = state
# time_major => batch_major
# [max_dec_len x batch_size x hidden_size] => [batch_size x max_dec_len x hidden_size]
if not output_time_major:
final_outputs = tf.transpose(final_outputs, [1,0,2])
return final_outputs, final_state