Made filter sizes in encoder/decoder models dependent on variable

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-04-11 13:31:06 +02:00
parent 28ac1a2027
commit 0d0ea882dd
3 changed files with 16 additions and 15 deletions

View File

@ -47,16 +47,16 @@ class Encoder(keras.Model):
super().__init__(name='encoder') super().__init__(name='encoder')
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02) weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
self.x_padded = keras.layers.ZeroPadding2D(padding=1) self.x_padded = keras.layers.ZeroPadding2D(padding=1)
self.conv1 = keras.layers.Conv2D(filters=64, kernel_size=4, strides=2, name='conv1', self.conv1 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=4, strides=2, name='conv1',
padding='valid', kernel_initializer=weight_init) padding='valid', kernel_initializer=weight_init)
self.conv1_a = keras.layers.LeakyReLU(alpha=0.2) self.conv1_a = keras.layers.LeakyReLU(alpha=0.2)
self.conv1_a_padded = keras.layers.ZeroPadding2D(padding=1) self.conv1_a_padded = keras.layers.ZeroPadding2D(padding=1)
self.conv2 = keras.layers.Conv2D(filters=256, kernel_size=4, strides=2, name='conv2', self.conv2 = keras.layers.Conv2D(filters=zsize * 4, kernel_size=4, strides=2, name='conv2',
padding='valid', kernel_initializer=weight_init) padding='valid', kernel_initializer=weight_init)
self.conv2_bn = keras.layers.BatchNormalization() self.conv2_bn = keras.layers.BatchNormalization()
self.conv2_a = keras.layers.LeakyReLU(alpha=0.2) self.conv2_a = keras.layers.LeakyReLU(alpha=0.2)
self.conv2_a_padded = keras.layers.ZeroPadding2D(padding=1) self.conv2_a_padded = keras.layers.ZeroPadding2D(padding=1)
self.conv3 = keras.layers.Conv2D(filters=512, kernel_size=4, strides=2, name='conv3', self.conv3 = keras.layers.Conv2D(filters=zsize * 8, kernel_size=4, strides=2, name='conv3',
padding='valid', kernel_initializer=weight_init) padding='valid', kernel_initializer=weight_init)
self.conv3_bn = keras.layers.BatchNormalization() self.conv3_bn = keras.layers.BatchNormalization()
self.conv3_a = keras.layers.LeakyReLU(alpha=0.2) self.conv3_a = keras.layers.LeakyReLU(alpha=0.2)
@ -87,21 +87,22 @@ class Decoder(keras.Model):
Args: Args:
channels: number of channels in the input image channels: number of channels in the input image
zsize: size of the latent space
""" """
def __init__(self, channels: int) -> None: def __init__(self, channels: int, zsize: int) -> None:
super().__init__(name='decoder') super().__init__(name='decoder')
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02) weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
self.deconv1 = keras.layers.Conv2DTranspose(filters=256, kernel_size=4, strides=1, name='deconv1', self.deconv1 = keras.layers.Conv2DTranspose(filters=zsize * 8, kernel_size=4, strides=1, name='deconv1',
padding='valid', kernel_initializer=weight_init) padding='valid', kernel_initializer=weight_init)
self.deconv1_bn = keras.layers.BatchNormalization() self.deconv1_bn = keras.layers.BatchNormalization()
self.deconv1_a = keras.layers.ReLU() self.deconv1_a = keras.layers.ReLU()
self.deconv2 = keras.layers.Conv2DTranspose(filters=256, kernel_size=4, strides=2, name='deconv2', self.deconv2 = keras.layers.Conv2DTranspose(filters=zsize * 8, kernel_size=4, strides=2, name='deconv2',
padding='valid', kernel_initializer=weight_init) padding='valid', kernel_initializer=weight_init)
self.deconv2_cropped = keras.layers.Cropping2D(cropping=1) self.deconv2_cropped = keras.layers.Cropping2D(cropping=1)
self.deconv2_bn = keras.layers.BatchNormalization() self.deconv2_bn = keras.layers.BatchNormalization()
self.deconv2_a = keras.layers.ReLU() self.deconv2_a = keras.layers.ReLU()
self.deconv3 = keras.layers.Conv2DTranspose(filters=128, kernel_size=4, strides=2, name='deconv3', self.deconv3 = keras.layers.Conv2DTranspose(filters=zsize * 4, kernel_size=4, strides=2, name='deconv3',
padding='valid', kernel_initializer=weight_init) padding='valid', kernel_initializer=weight_init)
self.deconv3_cropped = keras.layers.Cropping2D(cropping=1) self.deconv3_cropped = keras.layers.Cropping2D(cropping=1)
self.deconv3_bn = keras.layers.BatchNormalization() self.deconv3_bn = keras.layers.BatchNormalization()

View File

@ -82,7 +82,7 @@ def train_simple(dataset: tf.data.Dataset,
checkpointables.update({ checkpointables.update({
# get models # get models
'encoder': model.Encoder(zsize), 'encoder': model.Encoder(zsize),
'decoder': model.Decoder(channels), 'decoder': model.Decoder(channels, zsize),
# define optimizers # define optimizers
'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], 'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
beta1=0.5, beta2=0.999), beta1=0.5, beta2=0.999),

View File

@ -121,14 +121,14 @@ def train(dataset: tf.data.Dataset,
} }
checkpointables.update({ checkpointables.update({
# get models # get models
'encoder': model.Encoder(zsize), 'encoder': model.Encoder(zsize),
'decoder': model.Decoder(channels), 'decoder': model.Decoder(channels, zsize),
'z_discriminator': model.ZDiscriminator(), 'z_discriminator': model.ZDiscriminator(),
'x_discriminator': model.XDiscriminator(), 'x_discriminator': model.XDiscriminator(),
# define optimizers # define optimizers
'decoder_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], 'decoder_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
beta1=0.5, beta2=0.999), beta1=0.5, beta2=0.999),
'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], 'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
beta1=0.5, beta2=0.999), beta1=0.5, beta2=0.999),
'z_discriminator_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], 'z_discriminator_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
beta1=0.5, beta2=0.999), beta1=0.5, beta2=0.999),