Switched to keras binary_crossentropy loss
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -31,6 +31,7 @@ from .util import save_image
|
|||||||
k = tf.keras.backend
|
k = tf.keras.backend
|
||||||
AdamOptimizer = tf.train.AdamOptimizer
|
AdamOptimizer = tf.train.AdamOptimizer
|
||||||
tfe = tf.contrib.eager
|
tfe = tf.contrib.eager
|
||||||
|
binary_crossentropy = tf.keras.losses.binary_crossentropy
|
||||||
|
|
||||||
|
|
||||||
def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int,
|
def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int,
|
||||||
@ -138,13 +139,13 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
# x discriminator
|
# x discriminator
|
||||||
with tf.GradientTape() as tape:
|
with tf.GradientTape() as tape:
|
||||||
xd_result = tf.squeeze(x_discriminator(x))
|
xd_result = tf.squeeze(x_discriminator(x))
|
||||||
xd_real_loss = k.mean(k.binary_crossentropy(y_real, xd_result), axis=0)
|
xd_real_loss = binary_crossentropy(y_real, xd_result)
|
||||||
z = k.reshape(k.random_normal((batch_size, zsize)), (-1, 1, 1, zsize))
|
z = k.reshape(k.random_normal((batch_size, zsize)), (-1, 1, 1, zsize))
|
||||||
z = k.variable(z)
|
z = k.variable(z)
|
||||||
|
|
||||||
x_fake = decoder(z)
|
x_fake = decoder(z)
|
||||||
xd_result = tf.squeeze(x_discriminator(x_fake))
|
xd_result = tf.squeeze(x_discriminator(x_fake))
|
||||||
xd_fake_loss = k.mean(k.binary_crossentropy(y_fake, xd_result), axis=0)
|
xd_fake_loss = binary_crossentropy(y_fake, xd_result)
|
||||||
|
|
||||||
_xd_train_loss = xd_real_loss + xd_fake_loss
|
_xd_train_loss = xd_real_loss + xd_fake_loss
|
||||||
|
|
||||||
@ -161,7 +162,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
|
|
||||||
x_fake = decoder(z)
|
x_fake = decoder(z)
|
||||||
xd_result = tf.squeeze(x_discriminator(x_fake))
|
xd_result = tf.squeeze(x_discriminator(x_fake))
|
||||||
_decoder_train_loss = k.mean(k.binary_crossentropy(y_real, xd_result), axis=0)
|
_decoder_train_loss = binary_crossentropy(y_real, xd_result)
|
||||||
|
|
||||||
decoder_grads = tape.gradient(_decoder_train_loss, decoder.trainable_variables)
|
decoder_grads = tape.gradient(_decoder_train_loss, decoder.trainable_variables)
|
||||||
decoder_optimizer.apply_gradients(zip(decoder_grads, decoder.trainable_variables),
|
decoder_optimizer.apply_gradients(zip(decoder_grads, decoder.trainable_variables),
|
||||||
@ -175,11 +176,11 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
z = k.variable(z)
|
z = k.variable(z)
|
||||||
|
|
||||||
zd_result = tf.squeeze(z_discriminator(z))
|
zd_result = tf.squeeze(z_discriminator(z))
|
||||||
zd_real_loss = k.mean(k.binary_crossentropy(y_real_z, zd_result), axis=0)
|
zd_real_loss = binary_crossentropy(y_real_z, zd_result)
|
||||||
|
|
||||||
z = tf.squeeze(encoder(x))
|
z = tf.squeeze(encoder(x))
|
||||||
zd_result = tf.squeeze(z_discriminator(z))
|
zd_result = tf.squeeze(z_discriminator(z))
|
||||||
zd_fake_loss = k.mean(k.binary_crossentropy(y_fake_z, zd_result), axis=0)
|
zd_fake_loss = binary_crossentropy(y_fake_z, zd_result)
|
||||||
|
|
||||||
_zd_train_loss = zd_real_loss + zd_fake_loss
|
_zd_train_loss = zd_real_loss + zd_fake_loss
|
||||||
|
|
||||||
@ -195,8 +196,8 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
x_decoded = decoder(z)
|
x_decoded = decoder(z)
|
||||||
|
|
||||||
zd_result = tf.squeeze(z_discriminator(tf.squeeze(z)))
|
zd_result = tf.squeeze(z_discriminator(tf.squeeze(z)))
|
||||||
encoder_loss = k.mean(k.binary_crossentropy(y_real_z, zd_result), axis=0) * 2.0
|
encoder_loss = binary_crossentropy(y_real_z, zd_result) * 2.0
|
||||||
recovery_loss = k.mean(k.binary_crossentropy(x, x_decoded))
|
recovery_loss = binary_crossentropy(x, x_decoded)
|
||||||
_enc_dec_train_loss = encoder_loss + recovery_loss
|
_enc_dec_train_loss = encoder_loss + recovery_loss
|
||||||
|
|
||||||
enc_dec_grads = tape.gradient(_enc_dec_train_loss,
|
enc_dec_grads = tape.gradient(_enc_dec_train_loss,
|
||||||
|
|||||||
Reference in New Issue
Block a user