Made filter sizes in encoder/decoder models dependent on variable
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -47,16 +47,16 @@ class Encoder(keras.Model):
|
||||
super().__init__(name='encoder')
|
||||
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
||||
self.x_padded = keras.layers.ZeroPadding2D(padding=1)
|
||||
self.conv1 = keras.layers.Conv2D(filters=64, kernel_size=4, strides=2, name='conv1',
|
||||
self.conv1 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=4, strides=2, name='conv1',
|
||||
padding='valid', kernel_initializer=weight_init)
|
||||
self.conv1_a = keras.layers.LeakyReLU(alpha=0.2)
|
||||
self.conv1_a_padded = keras.layers.ZeroPadding2D(padding=1)
|
||||
self.conv2 = keras.layers.Conv2D(filters=256, kernel_size=4, strides=2, name='conv2',
|
||||
self.conv2 = keras.layers.Conv2D(filters=zsize * 4, kernel_size=4, strides=2, name='conv2',
|
||||
padding='valid', kernel_initializer=weight_init)
|
||||
self.conv2_bn = keras.layers.BatchNormalization()
|
||||
self.conv2_a = keras.layers.LeakyReLU(alpha=0.2)
|
||||
self.conv2_a_padded = keras.layers.ZeroPadding2D(padding=1)
|
||||
self.conv3 = keras.layers.Conv2D(filters=512, kernel_size=4, strides=2, name='conv3',
|
||||
self.conv3 = keras.layers.Conv2D(filters=zsize * 8, kernel_size=4, strides=2, name='conv3',
|
||||
padding='valid', kernel_initializer=weight_init)
|
||||
self.conv3_bn = keras.layers.BatchNormalization()
|
||||
self.conv3_a = keras.layers.LeakyReLU(alpha=0.2)
|
||||
@ -87,21 +87,22 @@ class Decoder(keras.Model):
|
||||
|
||||
Args:
|
||||
channels: number of channels in the input image
|
||||
zsize: size of the latent space
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int) -> None:
|
||||
def __init__(self, channels: int, zsize: int) -> None:
|
||||
super().__init__(name='decoder')
|
||||
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
||||
self.deconv1 = keras.layers.Conv2DTranspose(filters=256, kernel_size=4, strides=1, name='deconv1',
|
||||
self.deconv1 = keras.layers.Conv2DTranspose(filters=zsize * 8, kernel_size=4, strides=1, name='deconv1',
|
||||
padding='valid', kernel_initializer=weight_init)
|
||||
self.deconv1_bn = keras.layers.BatchNormalization()
|
||||
self.deconv1_a = keras.layers.ReLU()
|
||||
self.deconv2 = keras.layers.Conv2DTranspose(filters=256, kernel_size=4, strides=2, name='deconv2',
|
||||
self.deconv2 = keras.layers.Conv2DTranspose(filters=zsize * 8, kernel_size=4, strides=2, name='deconv2',
|
||||
padding='valid', kernel_initializer=weight_init)
|
||||
self.deconv2_cropped = keras.layers.Cropping2D(cropping=1)
|
||||
self.deconv2_bn = keras.layers.BatchNormalization()
|
||||
self.deconv2_a = keras.layers.ReLU()
|
||||
self.deconv3 = keras.layers.Conv2DTranspose(filters=128, kernel_size=4, strides=2, name='deconv3',
|
||||
self.deconv3 = keras.layers.Conv2DTranspose(filters=zsize * 4, kernel_size=4, strides=2, name='deconv3',
|
||||
padding='valid', kernel_initializer=weight_init)
|
||||
self.deconv3_cropped = keras.layers.Cropping2D(cropping=1)
|
||||
self.deconv3_bn = keras.layers.BatchNormalization()
|
||||
|
||||
@ -82,7 +82,7 @@ def train_simple(dataset: tf.data.Dataset,
|
||||
checkpointables.update({
|
||||
# get models
|
||||
'encoder': model.Encoder(zsize),
|
||||
'decoder': model.Decoder(channels),
|
||||
'decoder': model.Decoder(channels, zsize),
|
||||
# define optimizers
|
||||
'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
||||
beta1=0.5, beta2=0.999),
|
||||
|
||||
@ -122,7 +122,7 @@ def train(dataset: tf.data.Dataset,
|
||||
checkpointables.update({
|
||||
# get models
|
||||
'encoder': model.Encoder(zsize),
|
||||
'decoder': model.Decoder(channels),
|
||||
'decoder': model.Decoder(channels, zsize),
|
||||
'z_discriminator': model.ZDiscriminator(),
|
||||
'x_discriminator': model.XDiscriminator(),
|
||||
# define optimizers
|
||||
|
||||
Reference in New Issue
Block a user