Moved to tf.data.Dataset API
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -86,6 +86,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
return np.asarray([x[1] for x in list_of_pairs], np.float32), np.asarray([x[0] for x in list_of_pairs], np.int)
|
||||
|
||||
mnist_train_x, mnist_train_y = list_of_pairs_to_numpy(mnist_train)
|
||||
|
||||
# get dataset
|
||||
dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y))
|
||||
dataset = dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size, drop_remainder=True).map(normalize).cache()
|
||||
|
||||
# get models
|
||||
encoder = Encoder(zsize)
|
||||
@ -122,26 +126,15 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
|
||||
epoch_start_time = time.time()
|
||||
|
||||
def shuffle(train_data: np.ndarray) -> None:
|
||||
"""
|
||||
Shuffles the given training data inplace.
|
||||
|
||||
:param train_data: numpy array of training data
|
||||
"""
|
||||
np.take(train_data, np.random.permutation(train_data.shape[0]), axis=0, out=train_data)
|
||||
|
||||
shuffle(mnist_train_x)
|
||||
|
||||
# update learning rate
|
||||
if (epoch + 1) % 30 == 0:
|
||||
learning_rate_var.assign(learning_rate_var.value() / 4)
|
||||
if verbose:
|
||||
print("learning rate change!")
|
||||
|
||||
nr_batches = len(mnist_train_x) // batch_size
|
||||
log_frequency = 10
|
||||
for it in range(nr_batches):
|
||||
x = k.expand_dims(extract_batch(mnist_train_x, it, batch_size))
|
||||
batch_iteration = k.variable(0, dtype=tf.int64)
|
||||
for x, _ in dataset:
|
||||
# x discriminator
|
||||
_xd_train_loss = train_xdiscriminator_step(x_discriminator=x_discriminator,
|
||||
decoder=decoder,
|
||||
@ -187,7 +180,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
enc_dec_loss_avg(reconstruction_loss)
|
||||
encoder_loss_avg(encoder_loss)
|
||||
|
||||
if it % log_frequency == 0:
|
||||
if int(global_step_decoder % log_frequency) == 0:
|
||||
# log the losses every log frequency batches
|
||||
summary_ops_v2.scalar('encoder_loss', encoder_loss_avg.result(), step=global_step_enc_dec)
|
||||
summary_ops_v2.scalar('decoder_loss', decoder_loss_avg.result(), step=global_step_decoder)
|
||||
@ -195,19 +188,21 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
summary_ops_v2.scalar('z_discriminator_loss', zd_loss_avg.result(), step=global_step_zd)
|
||||
summary_ops_v2.scalar('x_discriminator_loss', xd_loss_avg.result(), step=global_step_xd)
|
||||
# reset the metrics states
|
||||
encoder_loss_avg.reset_states()
|
||||
decoder_loss_avg.reset_states()
|
||||
enc_dec_loss_avg.reset_states()
|
||||
zd_loss_avg.reset_states()
|
||||
xd_loss_avg.reset_states()
|
||||
|
||||
if it == 0:
|
||||
# encoder_loss_avg.init_variables()
|
||||
# decoder_loss_avg.init_variables()
|
||||
# enc_dec_loss_avg.init_variables()
|
||||
# zd_loss_avg.init_variables()
|
||||
# xd_loss_avg.init_variables()
|
||||
|
||||
if int(batch_iteration) == 0:
|
||||
directory = 'results' + str(inlier_classes[0])
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
|
||||
save_image(comparison.cpu(),
|
||||
'results' + str(inlier_classes[0]) + '/reconstruction_' + str(epoch) + '.png', nrow=64)
|
||||
|
||||
batch_iteration.assign_add(1)
|
||||
|
||||
epoch_end_time = time.time()
|
||||
per_epoch_time = epoch_end_time - epoch_start_time
|
||||
@ -385,7 +380,18 @@ def get_z_variable(batch_size: int, zsize: int) -> tf.Variable:
|
||||
return k.variable(z)
|
||||
|
||||
|
||||
def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tf.Variable:
|
||||
def normalize(feature: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||
"""
|
||||
Normalizes a tensor from a 0-255 range to a 0-1 range and adds one dimension.
|
||||
|
||||
:param feature: tensor to be normalized
|
||||
:param label: label tensor
|
||||
:return: normalized tensor
|
||||
"""
|
||||
return k.expand_dims(tf.divide(feature, 255.0)), label
|
||||
|
||||
|
||||
def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tfe.Variable:
|
||||
"""
|
||||
Extracts a batch from data.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user