Added dense layer as latent space
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -46,7 +46,7 @@ class Encoder(keras.Model):
|
|||||||
def __init__(self, zsize: int) -> None:
|
def __init__(self, zsize: int) -> None:
|
||||||
super().__init__(name='encoder')
|
super().__init__(name='encoder')
|
||||||
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
||||||
self.conv1 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=3, strides=2, name='conv1',
|
self.conv1 = keras.layers.Conv2D(filters=zsize * 4, kernel_size=3, strides=2, name='conv1',
|
||||||
padding='same', kernel_initializer=weight_init)
|
padding='same', kernel_initializer=weight_init)
|
||||||
self.conv1_a = keras.layers.ReLU()
|
self.conv1_a = keras.layers.ReLU()
|
||||||
self.conv2 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=3, strides=2, name='conv2',
|
self.conv2 = keras.layers.Conv2D(filters=zsize * 2, kernel_size=3, strides=2, name='conv2',
|
||||||
@ -55,6 +55,8 @@ class Encoder(keras.Model):
|
|||||||
self.conv3 = keras.layers.Conv2D(filters=zsize, kernel_size=3, strides=2, name='conv3',
|
self.conv3 = keras.layers.Conv2D(filters=zsize, kernel_size=3, strides=2, name='conv3',
|
||||||
padding='same', kernel_initializer=weight_init)
|
padding='same', kernel_initializer=weight_init)
|
||||||
self.conv3_a = keras.layers.ReLU()
|
self.conv3_a = keras.layers.ReLU()
|
||||||
|
self.flatten = keras.layers.Flatten(name='flatten')
|
||||||
|
self.latent = keras.layers.Dense(units=zsize * (2 ** 5), name='latent')
|
||||||
|
|
||||||
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
@ -64,6 +66,8 @@ class Encoder(keras.Model):
|
|||||||
result = self.conv2_a(result)
|
result = self.conv2_a(result)
|
||||||
result = self.conv3(result)
|
result = self.conv3(result)
|
||||||
result = self.conv3_a(result)
|
result = self.conv3_a(result)
|
||||||
|
result = self.flatten(result)
|
||||||
|
result = self.latent(result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -71,31 +75,47 @@ class Encoder(keras.Model):
|
|||||||
class Decoder(keras.Model):
|
class Decoder(keras.Model):
|
||||||
"""
|
"""
|
||||||
Generates input data from latent space values.
|
Generates input data from latent space values.
|
||||||
|
|
||||||
Args:
|
|
||||||
channels: number of channels in the input image
|
|
||||||
zsize: size of the latent space
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels: int, zsize: int) -> None:
|
def __init__(self, channels: int, zsize: int, image_size: int) -> None:
|
||||||
|
"""
|
||||||
|
Initializes the Decoder class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels: number of channels in the input image
|
||||||
|
zsize: size of the latent space
|
||||||
|
image_size: size of height/width of input image
|
||||||
|
"""
|
||||||
super().__init__(name='decoder')
|
super().__init__(name='decoder')
|
||||||
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
weight_init = keras.initializers.RandomNormal(mean=0, stddev=0.02)
|
||||||
self.deconv1 = keras.layers.Conv2DTranspose(filters=zsize * 2, kernel_size=3, strides=2, name='deconv1',
|
# calculate dimension of last conv layer in encoder
|
||||||
|
conv_image_size = image_size / (2 ** 3)
|
||||||
|
dimensions = zsize * conv_image_size
|
||||||
|
self.conv_shape = (-1, conv_image_size, conv_image_size, zsize)
|
||||||
|
self.transform = keras.layers.Dense(units=dimensions, name='input_transform')
|
||||||
|
self.deconv1 = keras.layers.Conv2DTranspose(filters=zsize, kernel_size=3, strides=1, name='deconv1',
|
||||||
padding='same', kernel_initializer=weight_init)
|
padding='same', kernel_initializer=weight_init)
|
||||||
self.deconv1_a = keras.layers.ReLU()
|
self.deconv1_a = keras.layers.ReLU()
|
||||||
self.deconv2 = keras.layers.Conv2DTranspose(filters=zsize * 2, kernel_size=3, strides=2, name='deconv2',
|
self.deconv2 = keras.layers.Conv2DTranspose(filters=zsize * 2, kernel_size=3, strides=2, name='deconv2',
|
||||||
padding='same', kernel_initializer=weight_init)
|
padding='same', kernel_initializer=weight_init)
|
||||||
self.deconv2_a = keras.layers.ReLU()
|
self.deconv2_a = keras.layers.ReLU()
|
||||||
self.deconv3 = keras.layers.Conv2DTranspose(filters=channels, kernel_size=3, strides=2, name='deconv3',
|
self.deconv3 = keras.layers.Conv2DTranspose(filters=zsize * 4, kernel_size=3, strides=2, name='deconv3',
|
||||||
|
padding='same', kernel_initializer=weight_init)
|
||||||
|
self.deconv3_a = keras.layers.ReLU()
|
||||||
|
self.deconv4 = keras.layers.Conv2DTranspose(filters=channels, kernel_size=3, strides=2, name='deconv4',
|
||||||
padding='same', kernel_initializer=weight_init)
|
padding='same', kernel_initializer=weight_init)
|
||||||
|
|
||||||
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
result = self.deconv1(inputs)
|
result = self.transform(inputs)
|
||||||
|
result = tf.reshape(result, self.conv_shape)
|
||||||
|
result = self.deconv1(result)
|
||||||
result = self.deconv1_a(result)
|
result = self.deconv1_a(result)
|
||||||
result = self.deconv2(result)
|
result = self.deconv2(result)
|
||||||
result = self.deconv2_a(result)
|
result = self.deconv2_a(result)
|
||||||
result = self.deconv3(result)
|
result = self.deconv3(result)
|
||||||
|
result = self.deconv3_a(result)
|
||||||
|
result = self.deconv4(result)
|
||||||
result = k.sigmoid(result)
|
result = k.sigmoid(result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@ -39,6 +39,7 @@ tfe = tf.contrib.eager
|
|||||||
def run_simple(dataset: tf.data.Dataset,
|
def run_simple(dataset: tf.data.Dataset,
|
||||||
iteration: int,
|
iteration: int,
|
||||||
weights_prefix: str,
|
weights_prefix: str,
|
||||||
|
image_size: int,
|
||||||
channels: int = 3,
|
channels: int = 3,
|
||||||
zsize: int = 64,
|
zsize: int = 64,
|
||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
@ -52,6 +53,7 @@ def run_simple(dataset: tf.data.Dataset,
|
|||||||
dataset: run dataset
|
dataset: run dataset
|
||||||
iteration: identifier for the used training run
|
iteration: identifier for the used training run
|
||||||
weights_prefix: prefix for trained weights directory
|
weights_prefix: prefix for trained weights directory
|
||||||
|
image_size: height/width of input image
|
||||||
channels: number of channels in input image (default: 3)
|
channels: number of channels in input image (default: 3)
|
||||||
zsize: size of the intermediary z (default: 64)
|
zsize: size of the intermediary z (default: 64)
|
||||||
batch_size: size of each batch (default: 16)
|
batch_size: size of each batch (default: 16)
|
||||||
@ -62,7 +64,7 @@ def run_simple(dataset: tf.data.Dataset,
|
|||||||
checkpointables = {
|
checkpointables = {
|
||||||
# get models
|
# get models
|
||||||
'encoder': model.Encoder(zsize),
|
'encoder': model.Encoder(zsize),
|
||||||
'decoder': model.Decoder(channels, zsize),
|
'decoder': model.Decoder(channels, zsize, image_size),
|
||||||
}
|
}
|
||||||
|
|
||||||
global_step = tf.train.get_or_create_global_step()
|
global_step = tf.train.get_or_create_global_step()
|
||||||
|
|||||||
@ -47,6 +47,7 @@ LOG_FREQUENCY: int = 10
|
|||||||
def train_simple(dataset: tf.data.Dataset,
|
def train_simple(dataset: tf.data.Dataset,
|
||||||
iteration: int,
|
iteration: int,
|
||||||
weights_prefix: str,
|
weights_prefix: str,
|
||||||
|
image_size: int,
|
||||||
channels: int = 3,
|
channels: int = 3,
|
||||||
zsize: int = 64,
|
zsize: int = 64,
|
||||||
lr: float = 0.0001,
|
lr: float = 0.0001,
|
||||||
@ -67,6 +68,7 @@ def train_simple(dataset: tf.data.Dataset,
|
|||||||
dataset: train dataset
|
dataset: train dataset
|
||||||
iteration: identifier for the current training run
|
iteration: identifier for the current training run
|
||||||
weights_prefix: prefix for weights directory
|
weights_prefix: prefix for weights directory
|
||||||
|
image_size: height/width of input image
|
||||||
channels: number of channels in input image (default: 3)
|
channels: number of channels in input image (default: 3)
|
||||||
zsize: size of the intermediary z (default: 64)
|
zsize: size of the intermediary z (default: 64)
|
||||||
lr: initial learning rate (default: 0.0001)
|
lr: initial learning rate (default: 0.0001)
|
||||||
@ -82,7 +84,7 @@ def train_simple(dataset: tf.data.Dataset,
|
|||||||
checkpointables.update({
|
checkpointables.update({
|
||||||
# get models
|
# get models
|
||||||
'encoder': model.Encoder(zsize),
|
'encoder': model.Encoder(zsize),
|
||||||
'decoder': model.Decoder(channels, zsize),
|
'decoder': model.Decoder(channels, zsize, image_size),
|
||||||
# define optimizers
|
# define optimizers
|
||||||
'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var']),
|
'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var']),
|
||||||
# global step counter
|
# global step counter
|
||||||
|
|||||||
@ -123,8 +123,9 @@ def _val(args: argparse.Namespace) -> None:
|
|||||||
category = args.category
|
category = args.category
|
||||||
category_trained = args.category_trained
|
category_trained = args.category_trained
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
|
image_size = 256
|
||||||
coco_data = data.load_coco_val(coco_path, category, num_epochs=1,
|
coco_data = data.load_coco_val(coco_path, category, num_epochs=1,
|
||||||
batch_size=batch_size, resized_shape=(256, 256))
|
batch_size=batch_size, resized_shape=(image_size, image_size))
|
||||||
use_summary_writer = summary_ops_v2.create_file_writer(
|
use_summary_writer = summary_ops_v2.create_file_writer(
|
||||||
f"{args.summary_path}/val/category-{category}/{args.iteration}"
|
f"{args.summary_path}/val/category-{category}/{args.iteration}"
|
||||||
)
|
)
|
||||||
@ -132,11 +133,13 @@ def _val(args: argparse.Namespace) -> None:
|
|||||||
with use_summary_writer.as_default():
|
with use_summary_writer.as_default():
|
||||||
run.run_simple(coco_data, iteration=args.iteration_trained,
|
run.run_simple(coco_data, iteration=args.iteration_trained,
|
||||||
weights_prefix=f"{args.weights_path}/category-{category_trained}",
|
weights_prefix=f"{args.weights_path}/category-{category_trained}",
|
||||||
zsize=16, verbose=args.verbose, channels=3, batch_size=batch_size)
|
zsize=16, verbose=args.verbose, channels=3, batch_size=batch_size,
|
||||||
|
image_size=image_size)
|
||||||
else:
|
else:
|
||||||
run.run_simple(coco_data, iteration=args.iteration_trained,
|
run.run_simple(coco_data, iteration=args.iteration_trained,
|
||||||
weights_prefix=f"{args.weights_path}/category-{category_trained}",
|
weights_prefix=f"{args.weights_path}/category-{category_trained}",
|
||||||
zsize=16, verbose=args.verbose, channels=3, batch_size=batch_size)
|
zsize=16, verbose=args.verbose, channels=3, batch_size=batch_size,
|
||||||
|
image_size=image_size)
|
||||||
|
|
||||||
|
|
||||||
def _auto_encoder_train(args: argparse.Namespace) -> None:
|
def _auto_encoder_train(args: argparse.Namespace) -> None:
|
||||||
@ -149,8 +152,9 @@ def _auto_encoder_train(args: argparse.Namespace) -> None:
|
|||||||
coco_path = args.coco_path
|
coco_path = args.coco_path
|
||||||
category = args.category
|
category = args.category
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
|
image_size = 256
|
||||||
coco_data = data.load_coco_train(coco_path, category, num_epochs=args.num_epochs, batch_size=batch_size,
|
coco_data = data.load_coco_train(coco_path, category, num_epochs=args.num_epochs, batch_size=batch_size,
|
||||||
resized_shape=(256, 256))
|
resized_shape=(image_size, image_size))
|
||||||
train_summary_writer = summary_ops_v2.create_file_writer(
|
train_summary_writer = summary_ops_v2.create_file_writer(
|
||||||
f"{args.summary_path}/train/category-{category}/{args.iteration}"
|
f"{args.summary_path}/train/category-{category}/{args.iteration}"
|
||||||
)
|
)
|
||||||
@ -158,12 +162,12 @@ def _auto_encoder_train(args: argparse.Namespace) -> None:
|
|||||||
with train_summary_writer.as_default():
|
with train_summary_writer.as_default():
|
||||||
train.train_simple(coco_data, iteration=args.iteration,
|
train.train_simple(coco_data, iteration=args.iteration,
|
||||||
weights_prefix=f"{args.weights_path}/category-{category}",
|
weights_prefix=f"{args.weights_path}/category-{category}",
|
||||||
zsize=16, lr=0.0001, verbose=args.verbose,
|
zsize=16, lr=0.0001, verbose=args.verbose, image_size=image_size,
|
||||||
channels=3, train_epoch=args.num_epochs, batch_size=batch_size)
|
channels=3, train_epoch=args.num_epochs, batch_size=batch_size)
|
||||||
else:
|
else:
|
||||||
train.train_simple(coco_data, iteration=args.iteration,
|
train.train_simple(coco_data, iteration=args.iteration,
|
||||||
weights_prefix=f"{args.weights_path}/category-{category}",
|
weights_prefix=f"{args.weights_path}/category-{category}",
|
||||||
zsize=16, lr=0.0001, verbose=args.verbose,
|
zsize=16, lr=0.0001, verbose=args.verbose, image_size=image_size,
|
||||||
channels=3, train_epoch=args.num_epochs, batch_size=batch_size)
|
channels=3, train_epoch=args.num_epochs, batch_size=batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user