Added dense layer as latent space
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -46,7 +46,7 @@ class Encoder(keras.Model):
|
||||
def __init__(self, zsize: int) -> None:
|
||||
super().__init__(name='encoder')
|
||||
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
||||
self.conv1 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=3, strides=2, name='conv1',
|
||||
self.conv1 = keras.layers.Conv2D(filters=zsize * 4, kernel_size=3, strides=2, name='conv1',
|
||||
padding='same', kernel_initializer=weight_init)
|
||||
self.conv1_a = keras.layers.ReLU()
|
||||
self.conv2 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=3, strides=2, name='conv2',
|
||||
@ -55,6 +55,8 @@ class Encoder(keras.Model):
|
||||
self.conv3 = keras.layers.Conv2D(filters=zsize, kernel_size=3, strides=2, name='conv3',
|
||||
padding='same', kernel_initializer=weight_init)
|
||||
self.conv3_a = keras.layers.ReLU()
|
||||
self.flatten = keras.layers.Flatten(name='flatten')
|
||||
self.latent = keras.layers.Dense(units=zsize * (2 ** 5), name='latent')
|
||||
|
||||
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
||||
"""See base class."""
|
||||
@ -64,6 +66,8 @@ class Encoder(keras.Model):
|
||||
result = self.conv2_a(result)
|
||||
result = self.conv3(result)
|
||||
result = self.conv3_a(result)
|
||||
result = self.flatten(result)
|
||||
result = self.latent(result)
|
||||
|
||||
return result
|
||||
|
||||
@ -71,31 +75,47 @@ class Encoder(keras.Model):
|
||||
class Decoder(keras.Model):
|
||||
"""
|
||||
Generates input data from latent space values.
|
||||
|
||||
Args:
|
||||
channels: number of channels in the input image
|
||||
zsize: size of the latent space
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int, zsize: int) -> None:
|
||||
def __init__(self, channels: int, zsize: int, image_size: int) -> None:
|
||||
"""
|
||||
Initializes the Decoder class.
|
||||
|
||||
Args:
|
||||
channels: number of channels in the input image
|
||||
zsize: size of the latent space
|
||||
image_size: size of height/width of input image
|
||||
"""
|
||||
super().__init__(name='decoder')
|
||||
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
||||
self.deconv1 = keras.layers.Conv2DTranspose(filters=zsize * 2, kernel_size=3, strides=2, name='deconv1',
|
||||
# calculate dimension of last conv layer in encoder
|
||||
conv_image_size = image_size / (2 ** 3)
|
||||
dimensions = zsize * conv_image_size
|
||||
self.conv_shape = (-1, conv_image_size, conv_image_size, zsize)
|
||||
self.transform = keras.layers.Dense(units=dimensions, name='input_transform')
|
||||
self.deconv1 = keras.layers.Conv2DTranspose(filters=zsize, kernel_size=3, strides=1, name='deconv1',
|
||||
padding='same', kernel_initializer=weight_init)
|
||||
self.deconv1_a = keras.layers.ReLU()
|
||||
self.deconv2 = keras.layers.Conv2DTranspose(filters=zsize * 2, kernel_size=3, strides=2, name='deconv2',
|
||||
padding='same', kernel_initializer=weight_init)
|
||||
self.deconv2_a = keras.layers.ReLU()
|
||||
self.deconv3 = keras.layers.Conv2DTranspose(filters=channels, kernel_size=3, strides=2, name='deconv3',
|
||||
self.deconv3 = keras.layers.Conv2DTranspose(filters=zsize * 4, kernel_size=3, strides=2, name='deconv3',
|
||||
padding='same', kernel_initializer=weight_init)
|
||||
self.deconv3_a = keras.layers.ReLU()
|
||||
self.deconv4 = keras.layers.Conv2DTranspose(filters=channels, kernel_size=3, strides=2, name='deconv4',
|
||||
padding='same', kernel_initializer=weight_init)
|
||||
|
||||
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
||||
"""See base class."""
|
||||
result = self.deconv1(inputs)
|
||||
result = self.transform(inputs)
|
||||
result = tf.reshape(result, self.conv_shape)
|
||||
result = self.deconv1(result)
|
||||
result = self.deconv1_a(result)
|
||||
result = self.deconv2(result)
|
||||
result = self.deconv2_a(result)
|
||||
result = self.deconv3(result)
|
||||
result = self.deconv3_a(result)
|
||||
result = self.deconv4(result)
|
||||
result = k.sigmoid(result)
|
||||
|
||||
return result
|
||||
|
||||
@ -39,6 +39,7 @@ tfe = tf.contrib.eager
|
||||
def run_simple(dataset: tf.data.Dataset,
|
||||
iteration: int,
|
||||
weights_prefix: str,
|
||||
image_size: int,
|
||||
channels: int = 3,
|
||||
zsize: int = 64,
|
||||
batch_size: int = 16,
|
||||
@ -52,6 +53,7 @@ def run_simple(dataset: tf.data.Dataset,
|
||||
dataset: run dataset
|
||||
iteration: identifier for the used training run
|
||||
weights_prefix: prefix for trained weights directory
|
||||
image_size: height/width of input image
|
||||
channels: number of channels in input image (default: 3)
|
||||
zsize: size of the intermediary z (default: 64)
|
||||
batch_size: size of each batch (default: 16)
|
||||
@ -62,7 +64,7 @@ def run_simple(dataset: tf.data.Dataset,
|
||||
checkpointables = {
|
||||
# get models
|
||||
'encoder': model.Encoder(zsize),
|
||||
'decoder': model.Decoder(channels, zsize),
|
||||
'decoder': model.Decoder(channels, zsize, image_size),
|
||||
}
|
||||
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
|
||||
@ -47,6 +47,7 @@ LOG_FREQUENCY: int = 10
|
||||
def train_simple(dataset: tf.data.Dataset,
|
||||
iteration: int,
|
||||
weights_prefix: str,
|
||||
image_size: int,
|
||||
channels: int = 3,
|
||||
zsize: int = 64,
|
||||
lr: float = 0.0001,
|
||||
@ -67,6 +68,7 @@ def train_simple(dataset: tf.data.Dataset,
|
||||
dataset: train dataset
|
||||
iteration: identifier for the current training run
|
||||
weights_prefix: prefix for weights directory
|
||||
image_size: height/width of input image
|
||||
channels: number of channels in input image (default: 3)
|
||||
zsize: size of the intermediary z (default: 64)
|
||||
lr: initial learning rate (default: 0.0001)
|
||||
@ -82,7 +84,7 @@ def train_simple(dataset: tf.data.Dataset,
|
||||
checkpointables.update({
|
||||
# get models
|
||||
'encoder': model.Encoder(zsize),
|
||||
'decoder': model.Decoder(channels, zsize),
|
||||
'decoder': model.Decoder(channels, zsize, image_size),
|
||||
# define optimizers
|
||||
'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var']),
|
||||
# global step counter
|
||||
|
||||
@ -123,8 +123,9 @@ def _val(args: argparse.Namespace) -> None:
|
||||
category = args.category
|
||||
category_trained = args.category_trained
|
||||
batch_size = 16
|
||||
image_size = 256
|
||||
coco_data = data.load_coco_val(coco_path, category, num_epochs=1,
|
||||
batch_size=batch_size, resized_shape=(256, 256))
|
||||
batch_size=batch_size, resized_shape=(image_size, image_size))
|
||||
use_summary_writer = summary_ops_v2.create_file_writer(
|
||||
f"{args.summary_path}/val/category-{category}/{args.iteration}"
|
||||
)
|
||||
@ -132,11 +133,13 @@ def _val(args: argparse.Namespace) -> None:
|
||||
with use_summary_writer.as_default():
|
||||
run.run_simple(coco_data, iteration=args.iteration_trained,
|
||||
weights_prefix=f"{args.weights_path}/category-{category_trained}",
|
||||
zsize=16, verbose=args.verbose, channels=3, batch_size=batch_size)
|
||||
zsize=16, verbose=args.verbose, channels=3, batch_size=batch_size,
|
||||
image_size=image_size)
|
||||
else:
|
||||
run.run_simple(coco_data, iteration=args.iteration_trained,
|
||||
weights_prefix=f"{args.weights_path}/category-{category_trained}",
|
||||
zsize=16, verbose=args.verbose, channels=3, batch_size=batch_size)
|
||||
zsize=16, verbose=args.verbose, channels=3, batch_size=batch_size,
|
||||
image_size=image_size)
|
||||
|
||||
|
||||
def _auto_encoder_train(args: argparse.Namespace) -> None:
|
||||
@ -149,8 +152,9 @@ def _auto_encoder_train(args: argparse.Namespace) -> None:
|
||||
coco_path = args.coco_path
|
||||
category = args.category
|
||||
batch_size = 16
|
||||
image_size = 256
|
||||
coco_data = data.load_coco_train(coco_path, category, num_epochs=args.num_epochs, batch_size=batch_size,
|
||||
resized_shape=(256, 256))
|
||||
resized_shape=(image_size, image_size))
|
||||
train_summary_writer = summary_ops_v2.create_file_writer(
|
||||
f"{args.summary_path}/train/category-{category}/{args.iteration}"
|
||||
)
|
||||
@ -158,12 +162,12 @@ def _auto_encoder_train(args: argparse.Namespace) -> None:
|
||||
with train_summary_writer.as_default():
|
||||
train.train_simple(coco_data, iteration=args.iteration,
|
||||
weights_prefix=f"{args.weights_path}/category-{category}",
|
||||
zsize=16, lr=0.0001, verbose=args.verbose,
|
||||
zsize=16, lr=0.0001, verbose=args.verbose, image_size=image_size,
|
||||
channels=3, train_epoch=args.num_epochs, batch_size=batch_size)
|
||||
else:
|
||||
train.train_simple(coco_data, iteration=args.iteration,
|
||||
weights_prefix=f"{args.weights_path}/category-{category}",
|
||||
zsize=16, lr=0.0001, verbose=args.verbose,
|
||||
zsize=16, lr=0.0001, verbose=args.verbose, image_size=image_size,
|
||||
channels=3, train_epoch=args.num_epochs, batch_size=batch_size)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user