diff --git a/src/twomartens/masterthesis/aae/model.py b/src/twomartens/masterthesis/aae/model.py index 59d26c2..7bb0496 100644 --- a/src/twomartens/masterthesis/aae/model.py +++ b/src/twomartens/masterthesis/aae/model.py @@ -47,16 +47,16 @@ class Encoder(keras.Model): super().__init__(name='encoder') weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02) self.x_padded = keras.layers.ZeroPadding2D(padding=1) - self.conv1 = keras.layers.Conv2D(filters=64, kernel_size=4, strides=2, name='conv1', + self.conv1 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=4, strides=2, name='conv1', padding='valid', kernel_initializer=weight_init) self.conv1_a = keras.layers.LeakyReLU(alpha=0.2) self.conv1_a_padded = keras.layers.ZeroPadding2D(padding=1) - self.conv2 = keras.layers.Conv2D(filters=256, kernel_size=4, strides=2, name='conv2', + self.conv2 = keras.layers.Conv2D(filters=zsize * 4, kernel_size=4, strides=2, name='conv2', padding='valid', kernel_initializer=weight_init) self.conv2_bn = keras.layers.BatchNormalization() self.conv2_a = keras.layers.LeakyReLU(alpha=0.2) self.conv2_a_padded = keras.layers.ZeroPadding2D(padding=1) - self.conv3 = keras.layers.Conv2D(filters=512, kernel_size=4, strides=2, name='conv3', + self.conv3 = keras.layers.Conv2D(filters=zsize * 8, kernel_size=4, strides=2, name='conv3', padding='valid', kernel_initializer=weight_init) self.conv3_bn = keras.layers.BatchNormalization() self.conv3_a = keras.layers.LeakyReLU(alpha=0.2) @@ -87,21 +87,22 @@ class Decoder(keras.Model): Args: channels: number of channels in the input image + zsize: size of the latent space """ - - def __init__(self, channels: int) -> None: + + def __init__(self, channels: int, zsize: int) -> None: super().__init__(name='decoder') weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02) - self.deconv1 = keras.layers.Conv2DTranspose(filters=256, kernel_size=4, strides=1, name='deconv1', + self.deconv1 = keras.layers.Conv2DTranspose(filters=zsize * 8, kernel_size=4, strides=1, name='deconv1', padding='valid', kernel_initializer=weight_init) self.deconv1_bn = keras.layers.BatchNormalization() self.deconv1_a = keras.layers.ReLU() - self.deconv2 = keras.layers.Conv2DTranspose(filters=256, kernel_size=4, strides=2, name='deconv2', + self.deconv2 = keras.layers.Conv2DTranspose(filters=zsize * 8, kernel_size=4, strides=2, name='deconv2', padding='valid', kernel_initializer=weight_init) self.deconv2_cropped = keras.layers.Cropping2D(cropping=1) self.deconv2_bn = keras.layers.BatchNormalization() self.deconv2_a = keras.layers.ReLU() - self.deconv3 = keras.layers.Conv2DTranspose(filters=128, kernel_size=4, strides=2, name='deconv3', + self.deconv3 = keras.layers.Conv2DTranspose(filters=zsize * 4, kernel_size=4, strides=2, name='deconv3', padding='valid', kernel_initializer=weight_init) self.deconv3_cropped = keras.layers.Cropping2D(cropping=1) self.deconv3_bn = keras.layers.BatchNormalization() diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index eeeec6a..7d3de5f 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -82,7 +82,7 @@ def train_simple(dataset: tf.data.Dataset, checkpointables.update({ # get models 'encoder': model.Encoder(zsize), - 'decoder': model.Decoder(channels), + 'decoder': model.Decoder(channels, zsize), # define optimizers 'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999), diff --git a/src/twomartens/masterthesis/aae/train_aae.py b/src/twomartens/masterthesis/aae/train_aae.py index b746bb1..4927f2a 100644 --- a/src/twomartens/masterthesis/aae/train_aae.py +++ b/src/twomartens/masterthesis/aae/train_aae.py @@ -121,14 +121,14 @@ def train(dataset: tf.data.Dataset, } checkpointables.update({ # get models - 'encoder': model.Encoder(zsize), - 'decoder': model.Decoder(channels), - 'z_discriminator': model.ZDiscriminator(), - 'x_discriminator': model.XDiscriminator(), + 'encoder': model.Encoder(zsize), + 'decoder': model.Decoder(channels, zsize), + 'z_discriminator': model.ZDiscriminator(), + 'x_discriminator': model.XDiscriminator(), # define optimizers - 'decoder_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], + 'decoder_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999), - 'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], + 'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999), 'z_discriminator_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999),