Updated docstring for model module
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -14,56 +14,65 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""aae.model.py: contains model definitions"""
|
"""
|
||||||
|
Provides the models of my AAE implementation.
|
||||||
|
|
||||||
|
Classes:
|
||||||
|
``Encoder``: encodes an image input to a latent space
|
||||||
|
|
||||||
|
``Decoder``: decodes data from a latent space to resemble input data
|
||||||
|
|
||||||
|
``XDiscriminator``: differentiates between real input data and decoded input data
|
||||||
|
|
||||||
|
``ZDiscriminator``: differentiates between z values drawn from a normal distribution (real) and the encoded input
|
||||||
|
(fake)
|
||||||
|
|
||||||
|
"""
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
# shortcuts for tensorflow - quasi imports
|
# shortcuts for tensorflow - quasi imports
|
||||||
keras = tf.keras
|
keras = tf.keras
|
||||||
k = tf.keras.backend
|
k = tf.keras.backend
|
||||||
Model = keras.Model
|
|
||||||
sigmoid = keras.activations.sigmoid
|
|
||||||
RandomNormal = keras.initializers.RandomNormal
|
|
||||||
BatchNormalization = keras.layers.BatchNormalization
|
|
||||||
Conv2D = keras.layers.Conv2D
|
|
||||||
Conv2DTranspose = keras.layers.Conv2DTranspose
|
|
||||||
Dense = keras.layers.Dense
|
|
||||||
Cropping2D = keras.layers.Cropping2D
|
|
||||||
ZeroPadding2D = keras.layers.ZeroPadding2D
|
|
||||||
ReLU = keras.layers.ReLU
|
|
||||||
LeakyReLU = keras.layers.LeakyReLU
|
|
||||||
|
|
||||||
|
|
||||||
class Encoder(Model):
|
class Encoder(keras.Model):
|
||||||
"""
|
"""
|
||||||
Encoder model.
|
Encodes input to a latent space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
zsize: size of the latent space
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, zsize: int) -> None:
|
def __init__(self, zsize: int) -> None:
|
||||||
super().__init__(name='encoder')
|
super().__init__(name='encoder')
|
||||||
weight_init = RandomNormal(mean=0, stddev=0.02)
|
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
||||||
self.x_padded = ZeroPadding2D(padding=1)
|
self.x_padded = keras.layers.ZeroPadding2D(padding=1)
|
||||||
self.conv1 = Conv2D(filters=64, kernel_size=4, strides=2, name='conv1',
|
self.conv1 = keras.layers.Conv2D(filters=64, kernel_size=4, strides=2, name='conv1',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='valid', kernel_initializer=weight_init)
|
||||||
self.conv1_a = LeakyReLU(alpha=0.2)
|
self.conv1_a = keras.layers.LeakyReLU(alpha=0.2)
|
||||||
self.conv1_a_padded = ZeroPadding2D(padding=1)
|
self.conv1_a_padded = keras.layers.ZeroPadding2D(padding=1)
|
||||||
self.conv2 = Conv2D(filters=256, kernel_size=4, strides=2, name='conv2',
|
self.conv2 = keras.layers.Conv2D(filters=256, kernel_size=4, strides=2, name='conv2',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='valid', kernel_initializer=weight_init)
|
||||||
self.conv2_bn = BatchNormalization()
|
self.conv2_bn = keras.layers.BatchNormalization()
|
||||||
self.conv2_a = LeakyReLU(alpha=0.2)
|
self.conv2_a = keras.layers.LeakyReLU(alpha=0.2)
|
||||||
self.conv2_a_padded = ZeroPadding2D(padding=1)
|
self.conv2_a_padded = keras.layers.ZeroPadding2D(padding=1)
|
||||||
self.conv3 = Conv2D(filters=512, kernel_size=4, strides=2, name='conv3',
|
self.conv3 = keras.layers.Conv2D(filters=512, kernel_size=4, strides=2, name='conv3',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='valid', kernel_initializer=weight_init)
|
||||||
self.conv3_bn = BatchNormalization()
|
self.conv3_bn = keras.layers.BatchNormalization()
|
||||||
self.conv3_a = LeakyReLU(alpha=0.2)
|
self.conv3_a = keras.layers.LeakyReLU(alpha=0.2)
|
||||||
self.conv4 = Conv2D(filters=zsize, kernel_size=4, strides=1, name='conv4',
|
self.conv4 = keras.layers.Conv2D(filters=zsize, kernel_size=4, strides=1, name='conv4',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='valid', kernel_initializer=weight_init)
|
||||||
|
|
||||||
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
Performs the forward pass.
|
Overwrites the ``call`` method and is called by ``__call__``.
|
||||||
:param inputs: input values
|
|
||||||
:param kwargs: additional keyword arguments - none are used
|
Args:
|
||||||
:return: result values
|
inputs: input values
|
||||||
|
``**kwargs``: additional keyword arguments - none are used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result values
|
||||||
"""
|
"""
|
||||||
result = self.x_padded(inputs)
|
result = self.x_padded(inputs)
|
||||||
result = self.conv1(result)
|
result = self.conv1(result)
|
||||||
@ -81,38 +90,45 @@ class Encoder(Model):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class Decoder(Model):
|
class Decoder(keras.Model):
|
||||||
"""
|
"""
|
||||||
Decoder model.
|
Generates input data from latent space values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels: number of channels in the input image
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels: int) -> None:
|
def __init__(self, channels: int) -> None:
|
||||||
super().__init__(name='decoder')
|
super().__init__(name='decoder')
|
||||||
weight_init = RandomNormal(mean=0, stddev=0.02)
|
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
||||||
self.deconv1 = Conv2DTranspose(filters=256, kernel_size=4, strides=1, name='deconv1',
|
self.deconv1 = keras.layers.Conv2DTranspose(filters=256, kernel_size=4, strides=1, name='deconv1',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='valid', kernel_initializer=weight_init)
|
||||||
self.deconv1_bn = BatchNormalization()
|
self.deconv1_bn = keras.layers.BatchNormalization()
|
||||||
self.deconv1_a = ReLU()
|
self.deconv1_a = keras.layers.ReLU()
|
||||||
self.deconv2 = Conv2DTranspose(filters=256, kernel_size=4, strides=2, name='deconv2',
|
self.deconv2 = keras.layers.Conv2DTranspose(filters=256, kernel_size=4, strides=2, name='deconv2',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='valid', kernel_initializer=weight_init)
|
||||||
self.deconv2_cropped = Cropping2D(cropping=1)
|
self.deconv2_cropped = keras.layers.Cropping2D(cropping=1)
|
||||||
self.deconv2_bn = BatchNormalization()
|
self.deconv2_bn = keras.layers.BatchNormalization()
|
||||||
self.deconv2_a = ReLU()
|
self.deconv2_a = keras.layers.ReLU()
|
||||||
self.deconv3 = Conv2DTranspose(filters=128, kernel_size=4, strides=2, name='deconv3',
|
self.deconv3 = keras.layers.Conv2DTranspose(filters=128, kernel_size=4, strides=2, name='deconv3',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='valid', kernel_initializer=weight_init)
|
||||||
self.deconv3_cropped = Cropping2D(cropping=1)
|
self.deconv3_cropped = keras.layers.Cropping2D(cropping=1)
|
||||||
self.deconv3_bn = BatchNormalization()
|
self.deconv3_bn = keras.layers.BatchNormalization()
|
||||||
self.deconv3_a = ReLU()
|
self.deconv3_a = keras.layers.ReLU()
|
||||||
self.deconv4 = Conv2DTranspose(filters=channels, kernel_size=4, strides=2, name='deconv4',
|
self.deconv4 = keras.layers.Conv2DTranspose(filters=channels, kernel_size=4, strides=2, name='deconv4',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='valid', kernel_initializer=weight_init)
|
||||||
self.deconv4_cropped = Cropping2D(cropping=1)
|
self.deconv4_cropped = keras.layers.Cropping2D(cropping=1)
|
||||||
|
|
||||||
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
Performs the forward pass.
|
Overwrites the ``call`` method and is called by ``__call__``.
|
||||||
:param inputs: input values
|
|
||||||
:param kwargs: additional keyword arguments - none are used
|
Args:
|
||||||
:return: result values
|
inputs: input values
|
||||||
|
``**kwargs``: additional keyword arguments - none are used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result values
|
||||||
"""
|
"""
|
||||||
result = self.deconv1(inputs)
|
result = self.deconv1(inputs)
|
||||||
result = self.deconv1_bn(result)
|
result = self.deconv1_bn(result)
|
||||||
@ -132,27 +148,34 @@ class Decoder(Model):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class ZDiscriminator(Model):
|
class ZDiscriminator(keras.Model):
|
||||||
"""
|
"""
|
||||||
ZDiscriminator model
|
Discriminates between encoded inputs and latent space distribution.
|
||||||
|
|
||||||
|
The latent space value is drawn from a normal distribution with ``0`` mean
|
||||||
|
and a variance of ``1``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__(name='zdiscriminator')
|
super().__init__(name='zdiscriminator')
|
||||||
weight_init = RandomNormal(mean=0, stddev=0.02)
|
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
||||||
self.zd1 = Dense(units=128, name='zd1', kernel_initializer=weight_init)
|
self.zd1 = keras.layers.Dense(units=128, name='zd1', kernel_initializer=weight_init)
|
||||||
self.zd1_a = LeakyReLU(alpha=0.2)
|
self.zd1_a = keras.layers.LeakyReLU(alpha=0.2)
|
||||||
self.zd2 = Dense(units=128, name='zd2', kernel_initializer=weight_init)
|
self.zd2 = keras.layers.Dense(units=128, name='zd2', kernel_initializer=weight_init)
|
||||||
self.zd2_a = LeakyReLU(alpha=0.2)
|
self.zd2_a = keras.layers.LeakyReLU(alpha=0.2)
|
||||||
self.zd3 = Dense(units=1, name='zd3', activation='sigmoid',
|
self.zd3 = keras.layers.Dense(units=1, name='zd3', activation='sigmoid',
|
||||||
kernel_initializer=weight_init)
|
kernel_initializer=weight_init)
|
||||||
|
|
||||||
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
Performs the forward pass.
|
Overwrites the ``call`` method and is called by ``__call__``.
|
||||||
:param inputs: input values
|
|
||||||
:param kwargs: additional keyword arguments - none are used
|
Args:
|
||||||
:return: result values
|
inputs: input values
|
||||||
|
``**kwargs``: additional keyword arguments - none are used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result values
|
||||||
"""
|
"""
|
||||||
result = self.zd1(inputs)
|
result = self.zd1(inputs)
|
||||||
result = self.zd1_a(result)
|
result = self.zd1_a(result)
|
||||||
@ -163,38 +186,42 @@ class ZDiscriminator(Model):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class XDiscriminator(Model):
|
class XDiscriminator(keras.Model):
|
||||||
"""
|
"""
|
||||||
XDiscriminator model
|
Discriminates between generated inputs and the actual inputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__(name='xdiscriminator')
|
super().__init__(name='xdiscriminator')
|
||||||
weight_init = RandomNormal(mean=0, stddev=0.02)
|
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
||||||
self.x_padded = ZeroPadding2D(padding=1)
|
self.x_padded = keras.layers.ZeroPadding2D(padding=1)
|
||||||
self.xd1 = Conv2D(filters=64, kernel_size=4, strides=2, name='xd1',
|
self.xd1 = keras.layers.Conv2D(filters=64, kernel_size=4, strides=2, name='xd1',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='valid', kernel_initializer=weight_init)
|
||||||
self.xd1_a = LeakyReLU(alpha=0.2)
|
self.xd1_a = keras.layers.LeakyReLU(alpha=0.2)
|
||||||
self.xd1_a_padded = ZeroPadding2D(padding=1)
|
self.xd1_a_padded = keras.layers.ZeroPadding2D(padding=1)
|
||||||
self.xd2 = Conv2D(filters=256, kernel_size=4, strides=2, name='xd2',
|
self.xd2 = keras.layers.Conv2D(filters=256, kernel_size=4, strides=2, name='xd2',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='valid', kernel_initializer=weight_init)
|
||||||
self.xd2_bn = BatchNormalization()
|
self.xd2_bn = keras.layers.BatchNormalization()
|
||||||
self.xd2_a = LeakyReLU(alpha=0.2)
|
self.xd2_a = keras.layers.LeakyReLU(alpha=0.2)
|
||||||
self.xd2_a_padded = ZeroPadding2D(padding=1)
|
self.xd2_a_padded = keras.layers.ZeroPadding2D(padding=1)
|
||||||
self.xd3 = Conv2D(filters=512, kernel_size=4, strides=2, name='xd3',
|
self.xd3 = keras.layers.Conv2D(filters=512, kernel_size=4, strides=2, name='xd3',
|
||||||
padding='valid', kernel_initializer=weight_init)
|
padding='valid', kernel_initializer=weight_init)
|
||||||
self.xd3_bn = BatchNormalization()
|
self.xd3_bn = keras.layers.BatchNormalization()
|
||||||
self.xd3_a = LeakyReLU(alpha=0.2)
|
self.xd3_a = keras.layers.LeakyReLU(alpha=0.2)
|
||||||
self.xd4 = Conv2D(filters=1, kernel_size=4, strides=1, name='xd4',
|
self.xd4 = keras.layers.Conv2D(filters=1, kernel_size=4, strides=1, name='xd4',
|
||||||
padding='valid', kernel_initializer=weight_init,
|
padding='valid', kernel_initializer=weight_init,
|
||||||
activation='sigmoid')
|
activation='sigmoid')
|
||||||
|
|
||||||
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
Performs the forward pass.
|
Overwrites the ``call`` method and is called by ``__call__``.
|
||||||
:param inputs: input values
|
|
||||||
:param kwargs: additional keyword arguments - none are used
|
Args:
|
||||||
:return: result values
|
inputs: input values
|
||||||
|
``**kwargs``: additional keyword arguments - none are used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result values
|
||||||
"""
|
"""
|
||||||
result = self.x_padded(inputs)
|
result = self.x_padded(inputs)
|
||||||
result = self.xd1(result)
|
result = self.xd1(result)
|
||||||
|
|||||||
Reference in New Issue
Block a user