Made filter sizes in encoder/decoder models dependent on variable
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -47,16 +47,16 @@ class Encoder(keras.Model):
|
|||||||
super().__init__(name='encoder')
|
super().__init__(name='encoder')
|
||||||
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
||||||
self.x_padded = keras.layers.ZeroPadding2D(padding=1)
|
self.x_padded = keras.layers.ZeroPadding2D(padding=1)
|
||||||
self.conv1 = keras.layers.Conv2D(filters=64, kernel_size=4, strides=2, name='conv1',
|
self.conv1 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=4, strides=2, name='conv1',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='valid', kernel_initializer=weight_init)
|
||||||
self.conv1_a = keras.layers.LeakyReLU(alpha=0.2)
|
self.conv1_a = keras.layers.LeakyReLU(alpha=0.2)
|
||||||
self.conv1_a_padded = keras.layers.ZeroPadding2D(padding=1)
|
self.conv1_a_padded = keras.layers.ZeroPadding2D(padding=1)
|
||||||
self.conv2 = keras.layers.Conv2D(filters=256, kernel_size=4, strides=2, name='conv2',
|
self.conv2 = keras.layers.Conv2D(filters=zsize * 4, kernel_size=4, strides=2, name='conv2',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='valid', kernel_initializer=weight_init)
|
||||||
self.conv2_bn = keras.layers.BatchNormalization()
|
self.conv2_bn = keras.layers.BatchNormalization()
|
||||||
self.conv2_a = keras.layers.LeakyReLU(alpha=0.2)
|
self.conv2_a = keras.layers.LeakyReLU(alpha=0.2)
|
||||||
self.conv2_a_padded = keras.layers.ZeroPadding2D(padding=1)
|
self.conv2_a_padded = keras.layers.ZeroPadding2D(padding=1)
|
||||||
self.conv3 = keras.layers.Conv2D(filters=512, kernel_size=4, strides=2, name='conv3',
|
self.conv3 = keras.layers.Conv2D(filters=zsize * 8, kernel_size=4, strides=2, name='conv3',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='valid', kernel_initializer=weight_init)
|
||||||
self.conv3_bn = keras.layers.BatchNormalization()
|
self.conv3_bn = keras.layers.BatchNormalization()
|
||||||
self.conv3_a = keras.layers.LeakyReLU(alpha=0.2)
|
self.conv3_a = keras.layers.LeakyReLU(alpha=0.2)
|
||||||
@ -87,21 +87,22 @@ class Decoder(keras.Model):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
channels: number of channels in the input image
|
channels: number of channels in the input image
|
||||||
|
zsize: size of the latent space
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels: int) -> None:
|
def __init__(self, channels: int, zsize: int) -> None:
|
||||||
super().__init__(name='decoder')
|
super().__init__(name='decoder')
|
||||||
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
||||||
self.deconv1 = keras.layers.Conv2DTranspose(filters=256, kernel_size=4, strides=1, name='deconv1',
|
self.deconv1 = keras.layers.Conv2DTranspose(filters=zsize * 8, kernel_size=4, strides=1, name='deconv1',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='valid', kernel_initializer=weight_init)
|
||||||
self.deconv1_bn = keras.layers.BatchNormalization()
|
self.deconv1_bn = keras.layers.BatchNormalization()
|
||||||
self.deconv1_a = keras.layers.ReLU()
|
self.deconv1_a = keras.layers.ReLU()
|
||||||
self.deconv2 = keras.layers.Conv2DTranspose(filters=256, kernel_size=4, strides=2, name='deconv2',
|
self.deconv2 = keras.layers.Conv2DTranspose(filters=zsize * 8, kernel_size=4, strides=2, name='deconv2',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='valid', kernel_initializer=weight_init)
|
||||||
self.deconv2_cropped = keras.layers.Cropping2D(cropping=1)
|
self.deconv2_cropped = keras.layers.Cropping2D(cropping=1)
|
||||||
self.deconv2_bn = keras.layers.BatchNormalization()
|
self.deconv2_bn = keras.layers.BatchNormalization()
|
||||||
self.deconv2_a = keras.layers.ReLU()
|
self.deconv2_a = keras.layers.ReLU()
|
||||||
self.deconv3 = keras.layers.Conv2DTranspose(filters=128, kernel_size=4, strides=2, name='deconv3',
|
self.deconv3 = keras.layers.Conv2DTranspose(filters=zsize * 4, kernel_size=4, strides=2, name='deconv3',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='valid', kernel_initializer=weight_init)
|
||||||
self.deconv3_cropped = keras.layers.Cropping2D(cropping=1)
|
self.deconv3_cropped = keras.layers.Cropping2D(cropping=1)
|
||||||
self.deconv3_bn = keras.layers.BatchNormalization()
|
self.deconv3_bn = keras.layers.BatchNormalization()
|
||||||
|
|||||||
@ -82,7 +82,7 @@ def train_simple(dataset: tf.data.Dataset,
|
|||||||
checkpointables.update({
|
checkpointables.update({
|
||||||
# get models
|
# get models
|
||||||
'encoder': model.Encoder(zsize),
|
'encoder': model.Encoder(zsize),
|
||||||
'decoder': model.Decoder(channels),
|
'decoder': model.Decoder(channels, zsize),
|
||||||
# define optimizers
|
# define optimizers
|
||||||
'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
||||||
beta1=0.5, beta2=0.999),
|
beta1=0.5, beta2=0.999),
|
||||||
|
|||||||
@ -121,14 +121,14 @@ def train(dataset: tf.data.Dataset,
|
|||||||
}
|
}
|
||||||
checkpointables.update({
|
checkpointables.update({
|
||||||
# get models
|
# get models
|
||||||
'encoder': model.Encoder(zsize),
|
'encoder': model.Encoder(zsize),
|
||||||
'decoder': model.Decoder(channels),
|
'decoder': model.Decoder(channels, zsize),
|
||||||
'z_discriminator': model.ZDiscriminator(),
|
'z_discriminator': model.ZDiscriminator(),
|
||||||
'x_discriminator': model.XDiscriminator(),
|
'x_discriminator': model.XDiscriminator(),
|
||||||
# define optimizers
|
# define optimizers
|
||||||
'decoder_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
'decoder_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
||||||
beta1=0.5, beta2=0.999),
|
beta1=0.5, beta2=0.999),
|
||||||
'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
||||||
beta1=0.5, beta2=0.999),
|
beta1=0.5, beta2=0.999),
|
||||||
'z_discriminator_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
'z_discriminator_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
||||||
beta1=0.5, beta2=0.999),
|
beta1=0.5, beta2=0.999),
|
||||||
|
|||||||
Reference in New Issue
Block a user