Realized learning rate decay with variable
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -88,10 +88,11 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
x_discriminator = XDiscriminator()
|
||||
|
||||
# define optimizers
|
||||
decoder_optimizer = AdamOptimizer(learning_rate=lr, beta1=0.5, beta2=0.999)
|
||||
enc_dec_optimizer = AdamOptimizer(learning_rate=lr, beta1=0.5, beta2=0.999)
|
||||
z_discriminator_optimizer = AdamOptimizer(learning_rate=lr, beta1=0.5, beta2=0.999)
|
||||
x_discriminator_optimizer = AdamOptimizer(learning_rate=lr, beta1=0.5, beta2=0.999)
|
||||
learning_rate_var = k.variable(lr)
|
||||
decoder_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
|
||||
enc_dec_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
|
||||
z_discriminator_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
|
||||
x_discriminator_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
|
||||
|
||||
# train
|
||||
y_real = k.ones(batch_size)
|
||||
@ -128,14 +129,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
|
||||
# update learning rate
|
||||
if (epoch + 1) % 30 == 0:
|
||||
decoder_optimizer._lr /= 4
|
||||
decoder_optimizer._lr_t = tf.convert_to_tensor(decoder_optimizer._lr, name="learning_rate")
|
||||
enc_dec_optimizer._lr /= 4
|
||||
enc_dec_optimizer._lr_t = tf.convert_to_tensor(enc_dec_optimizer._lr, name="learning_rate")
|
||||
x_discriminator_optimizer._lr /= 4
|
||||
x_discriminator_optimizer._lr_t = tf.convert_to_tensor(x_discriminator_optimizer._lr, name="learning_rate")
|
||||
z_discriminator_optimizer._lr /= 4
|
||||
z_discriminator_optimizer._lr_t = tf.convert_to_tensor(z_discriminator_optimizer._lr, name="learning_rate")
|
||||
learning_rate_var.assign(learning_rate_var.value() / 4)
|
||||
print("learning rate change!")
|
||||
|
||||
nr_batches = len(mnist_train_x) // batch_size
|
||||
|
||||
Reference in New Issue
Block a user