Simplified use of debug
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -42,8 +42,7 @@ def run_simple(dataset: tf.data.Dataset,
|
||||
channels: int = 3,
|
||||
zsize: int = 64,
|
||||
batch_size: int = 16,
|
||||
verbose: bool = False,
|
||||
debug: bool = False) -> None:
|
||||
verbose: bool = False) -> None:
|
||||
"""
|
||||
Runs the trained auto-encoder for given data set.
|
||||
|
||||
@ -57,7 +56,6 @@ def run_simple(dataset: tf.data.Dataset,
|
||||
zsize: size of the intermediary z (default: 64)
|
||||
batch_size: size of each batch (default: 16)
|
||||
verbose: if True training progress is printed to console (default: False)
|
||||
debug: if True summaries are collected (default: False)
|
||||
"""
|
||||
|
||||
# checkpointed tensors and variables
|
||||
@ -79,7 +77,6 @@ def run_simple(dataset: tf.data.Dataset,
|
||||
outputs = _run_one_epoch_simple(dataset,
|
||||
batch_size=batch_size,
|
||||
global_step=global_step,
|
||||
debug=debug,
|
||||
**checkpointables)
|
||||
|
||||
if verbose:
|
||||
@ -91,7 +88,6 @@ def run_simple(dataset: tf.data.Dataset,
|
||||
|
||||
def _run_one_epoch_simple(dataset: tf.data.Dataset,
|
||||
batch_size: int,
|
||||
debug: bool,
|
||||
encoder: model.Encoder,
|
||||
decoder: model.Decoder,
|
||||
global_step: tf.Variable) -> Dict[str, float]:
|
||||
@ -104,11 +100,10 @@ def _run_one_epoch_simple(dataset: tf.data.Dataset,
|
||||
reconstruction_loss, x_decoded, z = _run_enc_dec_step_simple(encoder=encoder,
|
||||
decoder=decoder,
|
||||
inputs=x,
|
||||
global_step=global_step,
|
||||
debug=debug)
|
||||
global_step=global_step)
|
||||
enc_dec_loss_avg(reconstruction_loss)
|
||||
|
||||
if int(global_step % train.LOG_FREQUENCY) == 0 and debug:
|
||||
if int(global_step % train.LOG_FREQUENCY) == 0:
|
||||
comparison = K.concatenate([x[:int(batch_size / 2)], x_decoded[:int(batch_size / 2)],
|
||||
z[:int(batch_size / 2)]], axis=0)
|
||||
grid = util.prepare_image(comparison.cpu(), nrow=int(batch_size / 2))
|
||||
@ -131,8 +126,7 @@ def _run_one_epoch_simple(dataset: tf.data.Dataset,
|
||||
|
||||
def _run_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder,
|
||||
inputs: tf.Tensor,
|
||||
global_step: tf.Variable,
|
||||
debug: bool) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
||||
global_step: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
||||
"""
|
||||
Runs the encoder and decoder jointly for one step (one batch).
|
||||
|
||||
@ -141,7 +135,6 @@ def _run_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder,
|
||||
decoder: instance of decoder model
|
||||
inputs: inputs from data set
|
||||
global_step: the global step variable
|
||||
debug: if True summaries are collected
|
||||
|
||||
Returns:
|
||||
tuple of reconstruction loss, reconstructed input, latent space value
|
||||
@ -151,7 +144,7 @@ def _run_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder,
|
||||
|
||||
reconstruction_loss = tf.losses.log_loss(inputs, x_decoded)
|
||||
|
||||
if int(global_step % train.LOG_FREQUENCY) == 0 and debug:
|
||||
if int(global_step % train.LOG_FREQUENCY) == 0:
|
||||
summary_ops_v2.scalar(name='reconstruction_loss', tensor=reconstruction_loss,
|
||||
step=global_step)
|
||||
|
||||
|
||||
@ -52,8 +52,7 @@ def train_simple(dataset: tf.data.Dataset,
|
||||
lr: float = 0.0001,
|
||||
train_epoch: int = 1,
|
||||
batch_size: int = 16,
|
||||
verbose: bool = False,
|
||||
debug: bool = False) -> None:
|
||||
verbose: bool = False) -> None:
|
||||
"""
|
||||
Trains auto-encoder for given data set.
|
||||
|
||||
@ -74,7 +73,6 @@ def train_simple(dataset: tf.data.Dataset,
|
||||
train_epoch: number of epochs to train (default: 1)
|
||||
batch_size: size of each batch (default: 16)
|
||||
verbose: if True training progress is printed to console (default: False)
|
||||
debug: if True summaries are collected (default: False)
|
||||
"""
|
||||
|
||||
# checkpointed tensors and variables
|
||||
@ -117,7 +115,6 @@ def train_simple(dataset: tf.data.Dataset,
|
||||
_epoch = epoch + previous_epochs
|
||||
outputs = _train_one_epoch_simple(_epoch, dataset,
|
||||
verbose=verbose,
|
||||
debug=debug,
|
||||
batch_size=batch_size,
|
||||
**checkpointables)
|
||||
|
||||
@ -141,7 +138,6 @@ def train_simple(dataset: tf.data.Dataset,
|
||||
def _train_one_epoch_simple(epoch: int,
|
||||
dataset: tf.data.Dataset,
|
||||
verbose: bool,
|
||||
debug: bool,
|
||||
batch_size: int,
|
||||
learning_rate_var: tf.Variable,
|
||||
decoder: model.Decoder,
|
||||
@ -170,11 +166,10 @@ def _train_one_epoch_simple(epoch: int,
|
||||
optimizer=enc_dec_optimizer,
|
||||
inputs=x,
|
||||
global_step_enc_dec=global_step_enc_dec,
|
||||
global_step=global_step,
|
||||
debug=debug)
|
||||
global_step=global_step)
|
||||
enc_dec_loss_avg(reconstruction_loss)
|
||||
|
||||
if int(global_step % LOG_FREQUENCY) == 0 and debug:
|
||||
if int(global_step % LOG_FREQUENCY) == 0:
|
||||
comparison = K.concatenate([x[:int(batch_size / 2)], x_decoded[:int(batch_size / 2)],
|
||||
z[:int(batch_size/2)]], axis=0)
|
||||
grid = util.prepare_image(comparison.cpu(), nrow=int(batch_size/2))
|
||||
@ -211,7 +206,6 @@ def _train_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder,
|
||||
inputs: inputs from data set
|
||||
global_step: the global step variable
|
||||
global_step_enc_dec: global step variable for enc_dec
|
||||
debug: if True summaries are collected
|
||||
|
||||
Returns:
|
||||
tuple of reconstruction loss, reconstructed input, z value
|
||||
@ -224,7 +218,7 @@ def _train_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder,
|
||||
|
||||
enc_dec_grads = tape.gradient(reconstruction_loss,
|
||||
encoder.trainable_variables + decoder.trainable_variables)
|
||||
if int(global_step % LOG_FREQUENCY) == 0 and debug:
|
||||
if int(global_step % LOG_FREQUENCY) == 0:
|
||||
summary_ops_v2.scalar(name='reconstruction_loss', tensor=reconstruction_loss,
|
||||
step=global_step)
|
||||
for grad, variable in zip(enc_dec_grads, encoder.trainable_variables + decoder.trainable_variables):
|
||||
|
||||
@ -128,8 +128,13 @@ def _val(args: argparse.Namespace) -> None:
|
||||
use_summary_writer = summary_ops_v2.create_file_writer(
|
||||
f"{args.summary_path}/val/category-{category}/{args.iteration}"
|
||||
)
|
||||
with use_summary_writer.as_default():
|
||||
run.run_simple(coco_data, iteration=args.iteration_trained, debug=args.debug,
|
||||
if args.debug:
|
||||
with use_summary_writer.as_default():
|
||||
run.run_simple(coco_data, iteration=args.iteration_trained,
|
||||
weights_prefix=f"{args.weights_path}/category-{category_trained}",
|
||||
zsize=64, verbose=args.verbose, channels=3, batch_size=batch_size)
|
||||
else:
|
||||
run.run_simple(coco_data, iteration=args.iteration_trained,
|
||||
weights_prefix=f"{args.weights_path}/category-{category_trained}",
|
||||
zsize=64, verbose=args.verbose, channels=3, batch_size=batch_size)
|
||||
|
||||
@ -149,10 +154,16 @@ def _auto_encoder_train(args: argparse.Namespace) -> None:
|
||||
train_summary_writer = summary_ops_v2.create_file_writer(
|
||||
f"{args.summary_path}/train/category-{category}/{args.iteration}"
|
||||
)
|
||||
with train_summary_writer.as_default():
|
||||
if args.debug:
|
||||
with train_summary_writer.as_default():
|
||||
train.train_simple(coco_data, iteration=args.iteration,
|
||||
weights_prefix=f"{args.weights_path}/category-{category}",
|
||||
zsize=64, lr=0.0001, verbose=args.verbose,
|
||||
channels=3, train_epoch=args.num_epochs, batch_size=batch_size)
|
||||
else:
|
||||
train.train_simple(coco_data, iteration=args.iteration,
|
||||
weights_prefix=f"{args.weights_path}/category-{category}",
|
||||
zsize=64, lr=0.0001, verbose=args.verbose, debug=args.debug,
|
||||
zsize=64, lr=0.0001, verbose=args.verbose,
|
||||
channels=3, train_epoch=args.num_epochs, batch_size=batch_size)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user