Realized learning rate decay with variable

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-07 17:55:51 +01:00
parent ebfa67abb0
commit cce0975a13

View File

@ -88,10 +88,11 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
x_discriminator = XDiscriminator() x_discriminator = XDiscriminator()
# define optimizers # define optimizers
decoder_optimizer = AdamOptimizer(learning_rate=lr, beta1=0.5, beta2=0.999) learning_rate_var = k.variable(lr)
enc_dec_optimizer = AdamOptimizer(learning_rate=lr, beta1=0.5, beta2=0.999) decoder_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
z_discriminator_optimizer = AdamOptimizer(learning_rate=lr, beta1=0.5, beta2=0.999) enc_dec_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
x_discriminator_optimizer = AdamOptimizer(learning_rate=lr, beta1=0.5, beta2=0.999) z_discriminator_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
x_discriminator_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
# train # train
y_real = k.ones(batch_size) y_real = k.ones(batch_size)
@ -128,14 +129,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
# update learning rate # update learning rate
if (epoch + 1) % 30 == 0: if (epoch + 1) % 30 == 0:
decoder_optimizer._lr /= 4 learning_rate_var.assign(learning_rate_var.value() / 4)
decoder_optimizer._lr_t = tf.convert_to_tensor(decoder_optimizer._lr, name="learning_rate")
enc_dec_optimizer._lr /= 4
enc_dec_optimizer._lr_t = tf.convert_to_tensor(enc_dec_optimizer._lr, name="learning_rate")
x_discriminator_optimizer._lr /= 4
x_discriminator_optimizer._lr_t = tf.convert_to_tensor(x_discriminator_optimizer._lr, name="learning_rate")
z_discriminator_optimizer._lr /= 4
z_discriminator_optimizer._lr_t = tf.convert_to_tensor(z_discriminator_optimizer._lr, name="learning_rate")
print("learning rate change!") print("learning rate change!")
nr_batches = len(mnist_train_x) // batch_size nr_batches = len(mnist_train_x) // batch_size