Modified model as per instructions from Keras blog
https://blog.keras.io/building-autoencoders-in-keras.html Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -46,37 +46,45 @@ class Encoder(keras.Model):
|
|||||||
def __init__(self, zsize: int) -> None:
|
def __init__(self, zsize: int) -> None:
|
||||||
super().__init__(name='encoder')
|
super().__init__(name='encoder')
|
||||||
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
||||||
self.x_padded = keras.layers.ZeroPadding2D(padding=1)
|
# self.x_padded = keras.layers.ZeroPadding2D(padding=1)
|
||||||
self.conv1 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=4, strides=2, name='conv1',
|
self.conv1 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=3, strides=1, name='conv1',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='same', kernel_initializer=weight_init)
|
||||||
self.conv1_a = keras.layers.LeakyReLU(alpha=0.2)
|
self.conv1_a = keras.layers.ReLU()
|
||||||
self.conv1_a_padded = keras.layers.ZeroPadding2D(padding=1)
|
# self.conv1_a_padded = keras.layers.ZeroPadding2D(padding=1)
|
||||||
self.conv2 = keras.layers.Conv2D(filters=zsize * 4, kernel_size=4, strides=2, name='conv2',
|
self.pool1 = keras.layers.MaxPool2D(pool_size=(2, 2), padding='same', name='pool1')
|
||||||
padding='valid', kernel_initializer=weight_init)
|
self.conv2 = keras.layers.Conv2D(filters=zsize, kernel_size=3, strides=1, name='conv2',
|
||||||
|
padding='same', kernel_initializer=weight_init)
|
||||||
self.conv2_bn = keras.layers.BatchNormalization()
|
self.conv2_bn = keras.layers.BatchNormalization()
|
||||||
self.conv2_a = keras.layers.LeakyReLU(alpha=0.2)
|
self.conv2_a = keras.layers.ReLU()
|
||||||
self.conv2_a_padded = keras.layers.ZeroPadding2D(padding=1)
|
self.pool2 = keras.layers.MaxPool2D(pool_size=(2, 2), padding='same', name='pool2')
|
||||||
self.conv3 = keras.layers.Conv2D(filters=zsize * 8, kernel_size=4, strides=2, name='conv3',
|
# self.conv2_a_padded = keras.layers.ZeroPadding2D(padding=1)
|
||||||
padding='valid', kernel_initializer=weight_init)
|
self.conv3 = keras.layers.Conv2D(filters=zsize, kernel_size=3, strides=1, name='conv3',
|
||||||
|
padding='same', kernel_initializer=weight_init)
|
||||||
self.conv3_bn = keras.layers.BatchNormalization()
|
self.conv3_bn = keras.layers.BatchNormalization()
|
||||||
self.conv3_a = keras.layers.LeakyReLU(alpha=0.2)
|
self.conv3_a = keras.layers.ReLU()
|
||||||
self.conv4 = keras.layers.Conv2D(filters=zsize, kernel_size=4, strides=1, name='conv4',
|
self.pool3 = keras.layers.MaxPool2D(pool_size=(2, 2), padding='same', name='pool3')
|
||||||
padding='valid', kernel_initializer=weight_init)
|
# self.conv4 = keras.layers.Conv2D(filters=zsize, kernel_size=4, strides=1, name='conv4',
|
||||||
|
# padding='same', kernel_initializer=weight_init)
|
||||||
|
# self.pool4 = keras.layers.MaxPool2D(pool_size=(2, 2), padding='same', name='pool4')
|
||||||
|
|
||||||
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
result = self.x_padded(inputs)
|
# result = self.x_padded(inputs)
|
||||||
result = self.conv1(result)
|
result = self.conv1(inputs)
|
||||||
result = self.conv1_a(result)
|
result = self.conv1_a(result)
|
||||||
result = self.conv1_a_padded(result)
|
# result = self.conv1_a_padded(result)
|
||||||
|
result = self.pool1(result)
|
||||||
result = self.conv2(result)
|
result = self.conv2(result)
|
||||||
result = self.conv2_bn(result)
|
result = self.conv2_bn(result)
|
||||||
result = self.conv2_a(result)
|
result = self.conv2_a(result)
|
||||||
result = self.conv2_a_padded(result)
|
result = self.pool2(result)
|
||||||
|
# result = self.conv2_a_padded(result)
|
||||||
result = self.conv3(result)
|
result = self.conv3(result)
|
||||||
result = self.conv3_bn(result)
|
result = self.conv3_bn(result)
|
||||||
result = self.conv3_a(result)
|
result = self.conv3_a(result)
|
||||||
result = self.conv4(result)
|
result = self.pool3(result)
|
||||||
|
# result = self.conv4(result)
|
||||||
|
# result = self.pool4(result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -93,40 +101,47 @@ class Decoder(keras.Model):
|
|||||||
def __init__(self, channels: int, zsize: int) -> None:
|
def __init__(self, channels: int, zsize: int) -> None:
|
||||||
super().__init__(name='decoder')
|
super().__init__(name='decoder')
|
||||||
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
||||||
self.deconv1 = keras.layers.Conv2DTranspose(filters=zsize * 8, kernel_size=4, strides=1, name='deconv1',
|
self.deconv1 = keras.layers.Conv2D(filters=zsize, kernel_size=3, strides=1, name='deconv1',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='same', kernel_initializer=weight_init)
|
||||||
self.deconv1_bn = keras.layers.BatchNormalization()
|
self.deconv1_bn = keras.layers.BatchNormalization()
|
||||||
self.deconv1_a = keras.layers.ReLU()
|
self.deconv1_a = keras.layers.ReLU()
|
||||||
self.deconv2 = keras.layers.Conv2DTranspose(filters=zsize * 8, kernel_size=4, strides=2, name='deconv2',
|
self.upsample1 = keras.layers.UpSampling2D(size=(2, 2), name='upsample1')
|
||||||
padding='valid', kernel_initializer=weight_init)
|
self.deconv2 = keras.layers.Conv2D(filters=zsize, kernel_size=3, strides=1, name='deconv2',
|
||||||
self.deconv2_cropped = keras.layers.Cropping2D(cropping=1)
|
padding='same', kernel_initializer=weight_init)
|
||||||
|
# self.deconv2_cropped = keras.layers.Cropping2D(cropping=1)
|
||||||
self.deconv2_bn = keras.layers.BatchNormalization()
|
self.deconv2_bn = keras.layers.BatchNormalization()
|
||||||
self.deconv2_a = keras.layers.ReLU()
|
self.deconv2_a = keras.layers.ReLU()
|
||||||
self.deconv3 = keras.layers.Conv2DTranspose(filters=zsize * 4, kernel_size=4, strides=2, name='deconv3',
|
self.upsample2 = keras.layers.UpSampling2D(size=(2, 2), name='upsample2')
|
||||||
padding='valid', kernel_initializer=weight_init)
|
self.deconv3 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=3, strides=1, name='deconv3',
|
||||||
self.deconv3_cropped = keras.layers.Cropping2D(cropping=1)
|
padding='valid', kernel_initializer=weight_init)
|
||||||
|
# self.deconv3_cropped = keras.layers.Cropping2D(cropping=1)
|
||||||
self.deconv3_bn = keras.layers.BatchNormalization()
|
self.deconv3_bn = keras.layers.BatchNormalization()
|
||||||
self.deconv3_a = keras.layers.ReLU()
|
self.deconv3_a = keras.layers.ReLU()
|
||||||
self.deconv4 = keras.layers.Conv2DTranspose(filters=channels, kernel_size=4, strides=2, name='deconv4',
|
self.upsample3 = keras.layers.UpSampling2D(size=(2, 2), name='upsample3')
|
||||||
padding='valid', kernel_initializer=weight_init)
|
self.deconv4 = keras.layers.Conv2D(filters=channels, kernel_size=4, strides=1, name='deconv4',
|
||||||
self.deconv4_cropped = keras.layers.Cropping2D(cropping=1)
|
padding='same', kernel_initializer=weight_init)
|
||||||
|
# self.deconv4_cropped = keras.layers.Cropping2D(cropping=1)
|
||||||
|
|
||||||
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
result = self.deconv1(inputs)
|
result = self.deconv1(inputs)
|
||||||
result = self.deconv1_bn(result)
|
result = self.deconv1_bn(result)
|
||||||
result = self.deconv1_a(result)
|
result = self.deconv1_a(result)
|
||||||
|
result = self.upsample1(result)
|
||||||
result = self.deconv2(result)
|
result = self.deconv2(result)
|
||||||
result = self.deconv2_cropped(result)
|
# result = self.deconv2_cropped(result)
|
||||||
result = self.deconv2_bn(result)
|
result = self.deconv2_bn(result)
|
||||||
result = self.deconv2_a(result)
|
result = self.deconv2_a(result)
|
||||||
|
result = self.upsample2(result)
|
||||||
result = self.deconv3(result)
|
result = self.deconv3(result)
|
||||||
result = self.deconv3_cropped(result)
|
# result = self.deconv3_cropped(result)
|
||||||
result = self.deconv3_bn(result)
|
result = self.deconv3_bn(result)
|
||||||
result = self.deconv3_a(result)
|
result = self.deconv3_a(result)
|
||||||
|
result = self.upsample3(result)
|
||||||
result = self.deconv4(result)
|
result = self.deconv4(result)
|
||||||
result = self.deconv4_cropped(result)
|
# result = self.deconv4_cropped(result)
|
||||||
result = k.tanh(result) * 0.5 + 0.5
|
# result = k.tanh(result) * 0.5 + 0.5
|
||||||
|
result = k.sigmoid(result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user