diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 55dcca5..11d047e 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -86,6 +86,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i return np.asarray([x[1] for x in list_of_pairs], np.float32), np.asarray([x[0] for x in list_of_pairs], np.int) mnist_train_x, mnist_train_y = list_of_pairs_to_numpy(mnist_train) + + # get dataset + dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y)) + dataset = dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size, drop_remainder=True).map(normalize).cache() # get models encoder = Encoder(zsize) @@ -122,26 +126,15 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i epoch_start_time = time.time() - def shuffle(train_data: np.ndarray) -> None: - """ - Shuffles the given training data inplace. - - :param train_data: numpy array of training data - """ - np.take(train_data, np.random.permutation(train_data.shape[0]), axis=0, out=train_data) - - shuffle(mnist_train_x) - # update learning rate if (epoch + 1) % 30 == 0: learning_rate_var.assign(learning_rate_var.value() / 4) if verbose: print("learning rate change!") - nr_batches = len(mnist_train_x) // batch_size log_frequency = 10 - for it in range(nr_batches): - x = k.expand_dims(extract_batch(mnist_train_x, it, batch_size)) + batch_iteration = k.variable(0, dtype=tf.int64) + for x, _ in dataset: # x discriminator _xd_train_loss = train_xdiscriminator_step(x_discriminator=x_discriminator, decoder=decoder, @@ -187,7 +180,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i enc_dec_loss_avg(reconstruction_loss) encoder_loss_avg(encoder_loss) - if it % log_frequency == 0: + if int(global_step_decoder % log_frequency) == 0: # log the losses every log frequency batches summary_ops_v2.scalar('encoder_loss', encoder_loss_avg.result(), step=global_step_enc_dec) summary_ops_v2.scalar('decoder_loss', decoder_loss_avg.result(), step=global_step_decoder) @@ -195,19 +188,21 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i summary_ops_v2.scalar('z_discriminator_loss', zd_loss_avg.result(), step=global_step_zd) summary_ops_v2.scalar('x_discriminator_loss', xd_loss_avg.result(), step=global_step_xd) # reset the metrics states - encoder_loss_avg.reset_states() - decoder_loss_avg.reset_states() - enc_dec_loss_avg.reset_states() - zd_loss_avg.reset_states() - xd_loss_avg.reset_states() - - if it == 0: + # encoder_loss_avg.init_variables() + # decoder_loss_avg.init_variables() + # enc_dec_loss_avg.init_variables() + # zd_loss_avg.init_variables() + # xd_loss_avg.init_variables() + + if int(batch_iteration) == 0: directory = 'results' + str(inlier_classes[0]) if not os.path.exists(directory): os.makedirs(directory) comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0) save_image(comparison.cpu(), 'results' + str(inlier_classes[0]) + '/reconstruction_' + str(epoch) + '.png', nrow=64) + + batch_iteration.assign_add(1) epoch_end_time = time.time() per_epoch_time = epoch_end_time - epoch_start_time @@ -385,7 +380,18 @@ def get_z_variable(batch_size: int, zsize: int) -> tf.Variable: return k.variable(z) -def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tf.Variable: +def normalize(feature: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Normalizes a tensor from a 0-255 range to a 0-1 range and adds one dimension. + + :param feature: tensor to be normalized + :param label: label tensor + :return: normalized tensor + """ + return k.expand_dims(tf.divide(feature, 255.0)), label + + +def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tfe.Variable: """ Extracts a batch from data.