Moved to tf.data.Dataset API

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 08:28:41 +01:00
parent d9cd24f769
commit 758cdb5520

View File

@ -86,6 +86,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
return np.asarray([x[1] for x in list_of_pairs], np.float32), np.asarray([x[0] for x in list_of_pairs], np.int)
mnist_train_x, mnist_train_y = list_of_pairs_to_numpy(mnist_train)
# get dataset
dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y))
dataset = dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size, drop_remainder=True).map(normalize).cache()
# get models
encoder = Encoder(zsize)
@ -122,26 +126,15 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
epoch_start_time = time.time()
def shuffle(train_data: np.ndarray) -> None:
"""
Shuffles the given training data inplace.
:param train_data: numpy array of training data
"""
np.take(train_data, np.random.permutation(train_data.shape[0]), axis=0, out=train_data)
shuffle(mnist_train_x)
# update learning rate
if (epoch + 1) % 30 == 0:
learning_rate_var.assign(learning_rate_var.value() / 4)
if verbose:
print("learning rate change!")
nr_batches = len(mnist_train_x) // batch_size
log_frequency = 10
for it in range(nr_batches):
x = k.expand_dims(extract_batch(mnist_train_x, it, batch_size))
batch_iteration = k.variable(0, dtype=tf.int64)
for x, _ in dataset:
# x discriminator
_xd_train_loss = train_xdiscriminator_step(x_discriminator=x_discriminator,
decoder=decoder,
@ -187,7 +180,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
enc_dec_loss_avg(reconstruction_loss)
encoder_loss_avg(encoder_loss)
if it % log_frequency == 0:
if int(global_step_decoder % log_frequency) == 0:
# log the losses every log frequency batches
summary_ops_v2.scalar('encoder_loss', encoder_loss_avg.result(), step=global_step_enc_dec)
summary_ops_v2.scalar('decoder_loss', decoder_loss_avg.result(), step=global_step_decoder)
@ -195,19 +188,21 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
summary_ops_v2.scalar('z_discriminator_loss', zd_loss_avg.result(), step=global_step_zd)
summary_ops_v2.scalar('x_discriminator_loss', xd_loss_avg.result(), step=global_step_xd)
# reset the metrics states
encoder_loss_avg.reset_states()
decoder_loss_avg.reset_states()
enc_dec_loss_avg.reset_states()
zd_loss_avg.reset_states()
xd_loss_avg.reset_states()
if it == 0:
# encoder_loss_avg.init_variables()
# decoder_loss_avg.init_variables()
# enc_dec_loss_avg.init_variables()
# zd_loss_avg.init_variables()
# xd_loss_avg.init_variables()
if int(batch_iteration) == 0:
directory = 'results' + str(inlier_classes[0])
if not os.path.exists(directory):
os.makedirs(directory)
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
save_image(comparison.cpu(),
'results' + str(inlier_classes[0]) + '/reconstruction_' + str(epoch) + '.png', nrow=64)
batch_iteration.assign_add(1)
epoch_end_time = time.time()
per_epoch_time = epoch_end_time - epoch_start_time
@ -385,7 +380,18 @@ def get_z_variable(batch_size: int, zsize: int) -> tf.Variable:
return k.variable(z)
def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tf.Variable:
def normalize(feature: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Normalizes a tensor from a 0-255 range to a 0-1 range and adds one dimension.
:param feature: tensor to be normalized
:param label: label tensor
:return: normalized tensor
"""
return k.expand_dims(tf.divide(feature, 255.0)), label
def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tfe.Variable:
"""
Extracts a batch from data.