Removed unnecessary layers and fixed the names of the remaining ones

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-04-12 14:26:43 +02:00
parent 32aecaea36
commit 1764e10da4

View File

@ -49,29 +49,21 @@ class Encoder(keras.Model):
self.conv1 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=7, strides=1, name='conv1', self.conv1 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=7, strides=1, name='conv1',
padding='same', kernel_initializer=weight_init) padding='same', kernel_initializer=weight_init)
self.conv1_a = keras.layers.ReLU() self.conv1_a = keras.layers.ReLU()
self.dropout = keras.layers.Dropout(rate=0.25) self.conv2 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=7, strides=1, name='conv2',
self.pool1 = keras.layers.MaxPool2D(pool_size=(2, 2), padding='same', name='pool1') padding='same', kernel_initializer=weight_init)
self.conv3 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=7, strides=1, name='conv3', self.conv2_a = keras.layers.ReLU()
self.conv3 = keras.layers.Conv2D(filters=zsize, kernel_size=7, strides=1, name='conv3',
padding='same', kernel_initializer=weight_init) padding='same', kernel_initializer=weight_init)
self.conv3_a = keras.layers.ReLU() self.conv3_a = keras.layers.ReLU()
self.pool3 = keras.layers.MaxPool2D(pool_size=(2, 2), padding='same', name='pool3')
self.conv4 = keras.layers.Conv2D(filters=zsize, kernel_size=7, strides=1, name='conv4',
padding='same', kernel_initializer=weight_init)
self.conv4_a = keras.layers.ReLU()
self.pool4 = keras.layers.MaxPool2D(pool_size=(2, 2), padding='same', name='pool4')
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
"""See base class.""" """See base class."""
result = self.conv1(inputs) result = self.conv1(inputs)
result = self.conv1_a(result) result = self.conv1_a(result)
# result = self.dropout(result) result = self.conv2(result)
# result = self.pool1(result) result = self.conv2_a(result)
result = self.conv3(result) result = self.conv3(result)
result = self.conv3_a(result) result = self.conv3_a(result)
# result = self.pool3(result)
result = self.conv4(result)
result = self.conv4_a(result)
# result = self.pool4(result)
return result return result
@ -91,30 +83,24 @@ class Decoder(keras.Model):
self.deconv1 = keras.layers.Conv2D(filters=zsize, kernel_size=7, strides=1, name='deconv1', self.deconv1 = keras.layers.Conv2D(filters=zsize, kernel_size=7, strides=1, name='deconv1',
padding='same', kernel_initializer=weight_init) padding='same', kernel_initializer=weight_init)
self.deconv1_a = keras.layers.ReLU() self.deconv1_a = keras.layers.ReLU()
self.upsample1 = keras.layers.UpSampling2D(size=(2, 2), name='upsample1')
self.deconv2 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=7, strides=1, name='deconv2', self.deconv2 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=7, strides=1, name='deconv2',
padding='same', kernel_initializer=weight_init) padding='same', kernel_initializer=weight_init)
self.deconv2_a = keras.layers.ReLU() self.deconv2_a = keras.layers.ReLU()
self.upsample2 = keras.layers.UpSampling2D(size=(2, 2), name='upsample2')
self.deconv3 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=7, strides=1, name='deconv3', self.deconv3 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=7, strides=1, name='deconv3',
padding='same', kernel_initializer=weight_init) padding='same', kernel_initializer=weight_init)
self.deconv3_a = keras.layers.ReLU() self.deconv3_a = keras.layers.ReLU()
self.upsample3 = keras.layers.UpSampling2D(size=(2, 2), name='upsample3') self.deconv4 = keras.layers.Conv2D(filters=channels, kernel_size=7, strides=1, name='deconv4',
self.deconv5 = keras.layers.Conv2D(filters=channels, kernel_size=7, strides=1, name='deconv5',
padding='same', kernel_initializer=weight_init) padding='same', kernel_initializer=weight_init)
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
"""See base class.""" """See base class."""
result = self.deconv1(inputs) result = self.deconv1(inputs)
result = self.deconv1_a(result) result = self.deconv1_a(result)
# result = self.upsample1(result)
result = self.deconv2(result) result = self.deconv2(result)
result = self.deconv2_a(result) result = self.deconv2_a(result)
# result = self.upsample2(result)
result = self.deconv3(result) result = self.deconv3(result)
result = self.deconv3_a(result) result = self.deconv3_a(result)
# result = self.upsample3(result) result = self.deconv4(result)
result = self.deconv5(result)
result = k.sigmoid(result) result = k.sigmoid(result)
return result return result