Moved to tf.data.Dataset API
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -87,6 +87,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
|
|
||||||
mnist_train_x, mnist_train_y = list_of_pairs_to_numpy(mnist_train)
|
mnist_train_x, mnist_train_y = list_of_pairs_to_numpy(mnist_train)
|
||||||
|
|
||||||
|
# get dataset
|
||||||
|
dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y))
|
||||||
|
dataset = dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size, drop_remainder=True).map(normalize).cache()
|
||||||
|
|
||||||
# get models
|
# get models
|
||||||
encoder = Encoder(zsize)
|
encoder = Encoder(zsize)
|
||||||
decoder = Decoder(channels)
|
decoder = Decoder(channels)
|
||||||
@ -122,26 +126,15 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
|
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
|
|
||||||
def shuffle(train_data: np.ndarray) -> None:
|
|
||||||
"""
|
|
||||||
Shuffles the given training data inplace.
|
|
||||||
|
|
||||||
:param train_data: numpy array of training data
|
|
||||||
"""
|
|
||||||
np.take(train_data, np.random.permutation(train_data.shape[0]), axis=0, out=train_data)
|
|
||||||
|
|
||||||
shuffle(mnist_train_x)
|
|
||||||
|
|
||||||
# update learning rate
|
# update learning rate
|
||||||
if (epoch + 1) % 30 == 0:
|
if (epoch + 1) % 30 == 0:
|
||||||
learning_rate_var.assign(learning_rate_var.value() / 4)
|
learning_rate_var.assign(learning_rate_var.value() / 4)
|
||||||
if verbose:
|
if verbose:
|
||||||
print("learning rate change!")
|
print("learning rate change!")
|
||||||
|
|
||||||
nr_batches = len(mnist_train_x) // batch_size
|
|
||||||
log_frequency = 10
|
log_frequency = 10
|
||||||
for it in range(nr_batches):
|
batch_iteration = k.variable(0, dtype=tf.int64)
|
||||||
x = k.expand_dims(extract_batch(mnist_train_x, it, batch_size))
|
for x, _ in dataset:
|
||||||
# x discriminator
|
# x discriminator
|
||||||
_xd_train_loss = train_xdiscriminator_step(x_discriminator=x_discriminator,
|
_xd_train_loss = train_xdiscriminator_step(x_discriminator=x_discriminator,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
@ -187,7 +180,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
enc_dec_loss_avg(reconstruction_loss)
|
enc_dec_loss_avg(reconstruction_loss)
|
||||||
encoder_loss_avg(encoder_loss)
|
encoder_loss_avg(encoder_loss)
|
||||||
|
|
||||||
if it % log_frequency == 0:
|
if int(global_step_decoder % log_frequency) == 0:
|
||||||
# log the losses every log frequency batches
|
# log the losses every log frequency batches
|
||||||
summary_ops_v2.scalar('encoder_loss', encoder_loss_avg.result(), step=global_step_enc_dec)
|
summary_ops_v2.scalar('encoder_loss', encoder_loss_avg.result(), step=global_step_enc_dec)
|
||||||
summary_ops_v2.scalar('decoder_loss', decoder_loss_avg.result(), step=global_step_decoder)
|
summary_ops_v2.scalar('decoder_loss', decoder_loss_avg.result(), step=global_step_decoder)
|
||||||
@ -195,13 +188,13 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
summary_ops_v2.scalar('z_discriminator_loss', zd_loss_avg.result(), step=global_step_zd)
|
summary_ops_v2.scalar('z_discriminator_loss', zd_loss_avg.result(), step=global_step_zd)
|
||||||
summary_ops_v2.scalar('x_discriminator_loss', xd_loss_avg.result(), step=global_step_xd)
|
summary_ops_v2.scalar('x_discriminator_loss', xd_loss_avg.result(), step=global_step_xd)
|
||||||
# reset the metrics states
|
# reset the metrics states
|
||||||
encoder_loss_avg.reset_states()
|
# encoder_loss_avg.init_variables()
|
||||||
decoder_loss_avg.reset_states()
|
# decoder_loss_avg.init_variables()
|
||||||
enc_dec_loss_avg.reset_states()
|
# enc_dec_loss_avg.init_variables()
|
||||||
zd_loss_avg.reset_states()
|
# zd_loss_avg.init_variables()
|
||||||
xd_loss_avg.reset_states()
|
# xd_loss_avg.init_variables()
|
||||||
|
|
||||||
if it == 0:
|
if int(batch_iteration) == 0:
|
||||||
directory = 'results' + str(inlier_classes[0])
|
directory = 'results' + str(inlier_classes[0])
|
||||||
if not os.path.exists(directory):
|
if not os.path.exists(directory):
|
||||||
os.makedirs(directory)
|
os.makedirs(directory)
|
||||||
@ -209,6 +202,8 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
save_image(comparison.cpu(),
|
save_image(comparison.cpu(),
|
||||||
'results' + str(inlier_classes[0]) + '/reconstruction_' + str(epoch) + '.png', nrow=64)
|
'results' + str(inlier_classes[0]) + '/reconstruction_' + str(epoch) + '.png', nrow=64)
|
||||||
|
|
||||||
|
batch_iteration.assign_add(1)
|
||||||
|
|
||||||
epoch_end_time = time.time()
|
epoch_end_time = time.time()
|
||||||
per_epoch_time = epoch_end_time - epoch_start_time
|
per_epoch_time = epoch_end_time - epoch_start_time
|
||||||
|
|
||||||
@ -385,7 +380,18 @@ def get_z_variable(batch_size: int, zsize: int) -> tf.Variable:
|
|||||||
return k.variable(z)
|
return k.variable(z)
|
||||||
|
|
||||||
|
|
||||||
def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tf.Variable:
|
def normalize(feature: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||||
|
"""
|
||||||
|
Normalizes a tensor from a 0-255 range to a 0-1 range and adds one dimension.
|
||||||
|
|
||||||
|
:param feature: tensor to be normalized
|
||||||
|
:param label: label tensor
|
||||||
|
:return: normalized tensor
|
||||||
|
"""
|
||||||
|
return k.expand_dims(tf.divide(feature, 255.0)), label
|
||||||
|
|
||||||
|
|
||||||
|
def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tfe.Variable:
|
||||||
"""
|
"""
|
||||||
Extracts a batch from data.
|
Extracts a batch from data.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user