From 41fe940d5bb27cbe99765cb17ee381e85ec22949 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Thu, 11 Apr 2019 15:19:50 +0200 Subject: [PATCH] Modified model as per instructions from Keras blog https://blog.keras.io/building-autoencoders-in-keras.html Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/model.py | 83 ++++++++++++++---------- 1 file changed, 49 insertions(+), 34 deletions(-) diff --git a/src/twomartens/masterthesis/aae/model.py b/src/twomartens/masterthesis/aae/model.py index 7bb0496..976b119 100644 --- a/src/twomartens/masterthesis/aae/model.py +++ b/src/twomartens/masterthesis/aae/model.py @@ -46,37 +46,45 @@ class Encoder(keras.Model): def __init__(self, zsize: int) -> None: super().__init__(name='encoder') weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02) - self.x_padded = keras.layers.ZeroPadding2D(padding=1) - self.conv1 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=4, strides=2, name='conv1', - padding='valid', kernel_initializer=weight_init) - self.conv1_a = keras.layers.LeakyReLU(alpha=0.2) - self.conv1_a_padded = keras.layers.ZeroPadding2D(padding=1) - self.conv2 = keras.layers.Conv2D(filters=zsize * 4, kernel_size=4, strides=2, name='conv2', - padding='valid', kernel_initializer=weight_init) + # self.x_padded = keras.layers.ZeroPadding2D(padding=1) + self.conv1 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=3, strides=1, name='conv1', + padding='same', kernel_initializer=weight_init) + self.conv1_a = keras.layers.ReLU() + # self.conv1_a_padded = keras.layers.ZeroPadding2D(padding=1) + self.pool1 = keras.layers.MaxPool2D(pool_size=(2, 2), padding='same', name='pool1') + self.conv2 = keras.layers.Conv2D(filters=zsize, kernel_size=3, strides=1, name='conv2', + padding='same', kernel_initializer=weight_init) self.conv2_bn = keras.layers.BatchNormalization() - self.conv2_a = keras.layers.LeakyReLU(alpha=0.2) - self.conv2_a_padded = keras.layers.ZeroPadding2D(padding=1) - self.conv3 = keras.layers.Conv2D(filters=zsize * 8, kernel_size=4, strides=2, name='conv3', - padding='valid', kernel_initializer=weight_init) + self.conv2_a = keras.layers.ReLU() + self.pool2 = keras.layers.MaxPool2D(pool_size=(2, 2), padding='same', name='pool2') + # self.conv2_a_padded = keras.layers.ZeroPadding2D(padding=1) + self.conv3 = keras.layers.Conv2D(filters=zsize, kernel_size=3, strides=1, name='conv3', + padding='same', kernel_initializer=weight_init) self.conv3_bn = keras.layers.BatchNormalization() - self.conv3_a = keras.layers.LeakyReLU(alpha=0.2) - self.conv4 = keras.layers.Conv2D(filters=zsize, kernel_size=4, strides=1, name='conv4', - padding='valid', kernel_initializer=weight_init) + self.conv3_a = keras.layers.ReLU() + self.pool3 = keras.layers.MaxPool2D(pool_size=(2, 2), padding='same', name='pool3') + # self.conv4 = keras.layers.Conv2D(filters=zsize, kernel_size=4, strides=1, name='conv4', + # padding='same', kernel_initializer=weight_init) + # self.pool4 = keras.layers.MaxPool2D(pool_size=(2, 2), padding='same', name='pool4') def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: """See base class.""" - result = self.x_padded(inputs) - result = self.conv1(result) + # result = self.x_padded(inputs) + result = self.conv1(inputs) result = self.conv1_a(result) - result = self.conv1_a_padded(result) + # result = self.conv1_a_padded(result) + result = self.pool1(result) result = self.conv2(result) result = self.conv2_bn(result) result = self.conv2_a(result) - result = self.conv2_a_padded(result) + result = self.pool2(result) + # result = self.conv2_a_padded(result) result = self.conv3(result) result = self.conv3_bn(result) result = self.conv3_a(result) - result = self.conv4(result) + result = self.pool3(result) + # result = self.conv4(result) + # result = self.pool4(result) return result @@ -93,40 +101,47 @@ class Decoder(keras.Model): def __init__(self, channels: int, zsize: int) -> None: super().__init__(name='decoder') weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02) - self.deconv1 = keras.layers.Conv2DTranspose(filters=zsize * 8, kernel_size=4, strides=1, name='deconv1', - padding='valid', kernel_initializer=weight_init) + self.deconv1 = keras.layers.Conv2D(filters=zsize, kernel_size=3, strides=1, name='deconv1', + padding='same', kernel_initializer=weight_init) self.deconv1_bn = keras.layers.BatchNormalization() self.deconv1_a = keras.layers.ReLU() - self.deconv2 = keras.layers.Conv2DTranspose(filters=zsize * 8, kernel_size=4, strides=2, name='deconv2', - padding='valid', kernel_initializer=weight_init) - self.deconv2_cropped = keras.layers.Cropping2D(cropping=1) + self.upsample1 = keras.layers.UpSampling2D(size=(2, 2), name='upsample1') + self.deconv2 = keras.layers.Conv2D(filters=zsize, kernel_size=3, strides=1, name='deconv2', + padding='same', kernel_initializer=weight_init) + # self.deconv2_cropped = keras.layers.Cropping2D(cropping=1) self.deconv2_bn = keras.layers.BatchNormalization() self.deconv2_a = keras.layers.ReLU() - self.deconv3 = keras.layers.Conv2DTranspose(filters=zsize * 4, kernel_size=4, strides=2, name='deconv3', - padding='valid', kernel_initializer=weight_init) - self.deconv3_cropped = keras.layers.Cropping2D(cropping=1) + self.upsample2 = keras.layers.UpSampling2D(size=(2, 2), name='upsample2') + self.deconv3 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=3, strides=1, name='deconv3', + padding='valid', kernel_initializer=weight_init) + # self.deconv3_cropped = keras.layers.Cropping2D(cropping=1) self.deconv3_bn = keras.layers.BatchNormalization() self.deconv3_a = keras.layers.ReLU() - self.deconv4 = keras.layers.Conv2DTranspose(filters=channels, kernel_size=4, strides=2, name='deconv4', - padding='valid', kernel_initializer=weight_init) - self.deconv4_cropped = keras.layers.Cropping2D(cropping=1) + self.upsample3 = keras.layers.UpSampling2D(size=(2, 2), name='upsample3') + self.deconv4 = keras.layers.Conv2D(filters=channels, kernel_size=4, strides=1, name='deconv4', + padding='same', kernel_initializer=weight_init) + # self.deconv4_cropped = keras.layers.Cropping2D(cropping=1) def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: """See base class.""" result = self.deconv1(inputs) result = self.deconv1_bn(result) result = self.deconv1_a(result) + result = self.upsample1(result) result = self.deconv2(result) - result = self.deconv2_cropped(result) + # result = self.deconv2_cropped(result) result = self.deconv2_bn(result) result = self.deconv2_a(result) + result = self.upsample2(result) result = self.deconv3(result) - result = self.deconv3_cropped(result) + # result = self.deconv3_cropped(result) result = self.deconv3_bn(result) result = self.deconv3_a(result) + result = self.upsample3(result) result = self.deconv4(result) - result = self.deconv4_cropped(result) - result = k.tanh(result) * 0.5 + 0.5 + # result = self.deconv4_cropped(result) + # result = k.tanh(result) * 0.5 + 0.5 + result = k.sigmoid(result) return result