Removed obsolete loss history code
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -111,12 +111,6 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
global_step_xd = k.variable(0, dtype=tf.int64)
|
global_step_xd = k.variable(0, dtype=tf.int64)
|
||||||
global_step_zd = k.variable(0, dtype=tf.int64)
|
global_step_zd = k.variable(0, dtype=tf.int64)
|
||||||
|
|
||||||
encoder_loss_history = []
|
|
||||||
decoder_loss_history = []
|
|
||||||
enc_dec_loss_history = []
|
|
||||||
zd_loss_history = []
|
|
||||||
xd_loss_history = []
|
|
||||||
|
|
||||||
for epoch in range(train_epoch):
|
for epoch in range(train_epoch):
|
||||||
# define loss variables
|
# define loss variables
|
||||||
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
|
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
|
||||||
@ -213,12 +207,6 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
save_image(comparison.cpu(),
|
save_image(comparison.cpu(),
|
||||||
'results' + str(inlier_classes[0]) + '/reconstruction_' + str(epoch) + '.png', nrow=64)
|
'results' + str(inlier_classes[0]) + '/reconstruction_' + str(epoch) + '.png', nrow=64)
|
||||||
|
|
||||||
encoder_loss_history.append(encoder_loss_avg.result())
|
|
||||||
decoder_loss_history.append(decoder_loss_avg.result())
|
|
||||||
enc_dec_loss_history.append(enc_dec_loss_avg.result())
|
|
||||||
xd_loss_history.append(xd_loss_avg.result())
|
|
||||||
zd_loss_history.append(zd_loss_avg.result())
|
|
||||||
|
|
||||||
epoch_end_time = time.time()
|
epoch_end_time = time.time()
|
||||||
per_epoch_time = epoch_end_time - epoch_start_time
|
per_epoch_time = epoch_end_time - epoch_start_time
|
||||||
|
|
||||||
@ -243,16 +231,6 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
decoder.save_weights("./weights/decoder")
|
decoder.save_weights("./weights/decoder")
|
||||||
z_discriminator.save_weights("./weights/z_discriminator")
|
z_discriminator.save_weights("./weights/z_discriminator")
|
||||||
x_discriminator.save_weights("./weights/x_discriminator")
|
x_discriminator.save_weights("./weights/x_discriminator")
|
||||||
with open("./results0/losses/encoder_loss.txt", "wb") as file:
|
|
||||||
pickle.dump(encoder_loss_history, file)
|
|
||||||
with open("./results0/losses/decoder_loss.txt", "wb") as file:
|
|
||||||
pickle.dump(decoder_loss_history, file)
|
|
||||||
with open("./results0/losses/enc_dec_loss.txt", "wb") as file:
|
|
||||||
pickle.dump(enc_dec_loss_history, file)
|
|
||||||
with open("./results0/losses/xd_loss.txt", "wb") as file:
|
|
||||||
pickle.dump(xd_loss_history, file)
|
|
||||||
with open("./results0/losses/zd_loss.txt", "wb") as file:
|
|
||||||
pickle.dump(zd_loss_history, file)
|
|
||||||
|
|
||||||
|
|
||||||
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
||||||
|
|||||||
Reference in New Issue
Block a user