Realized learning rate decay with variable
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -88,10 +88,11 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
x_discriminator = XDiscriminator()
|
x_discriminator = XDiscriminator()
|
||||||
|
|
||||||
# define optimizers
|
# define optimizers
|
||||||
decoder_optimizer = AdamOptimizer(learning_rate=lr, beta1=0.5, beta2=0.999)
|
learning_rate_var = k.variable(lr)
|
||||||
enc_dec_optimizer = AdamOptimizer(learning_rate=lr, beta1=0.5, beta2=0.999)
|
decoder_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
|
||||||
z_discriminator_optimizer = AdamOptimizer(learning_rate=lr, beta1=0.5, beta2=0.999)
|
enc_dec_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
|
||||||
x_discriminator_optimizer = AdamOptimizer(learning_rate=lr, beta1=0.5, beta2=0.999)
|
z_discriminator_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
|
||||||
|
x_discriminator_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
|
||||||
|
|
||||||
# train
|
# train
|
||||||
y_real = k.ones(batch_size)
|
y_real = k.ones(batch_size)
|
||||||
@ -128,14 +129,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
|
|
||||||
# update learning rate
|
# update learning rate
|
||||||
if (epoch + 1) % 30 == 0:
|
if (epoch + 1) % 30 == 0:
|
||||||
decoder_optimizer._lr /= 4
|
learning_rate_var.assign(learning_rate_var.value() / 4)
|
||||||
decoder_optimizer._lr_t = tf.convert_to_tensor(decoder_optimizer._lr, name="learning_rate")
|
|
||||||
enc_dec_optimizer._lr /= 4
|
|
||||||
enc_dec_optimizer._lr_t = tf.convert_to_tensor(enc_dec_optimizer._lr, name="learning_rate")
|
|
||||||
x_discriminator_optimizer._lr /= 4
|
|
||||||
x_discriminator_optimizer._lr_t = tf.convert_to_tensor(x_discriminator_optimizer._lr, name="learning_rate")
|
|
||||||
z_discriminator_optimizer._lr /= 4
|
|
||||||
z_discriminator_optimizer._lr_t = tf.convert_to_tensor(z_discriminator_optimizer._lr, name="learning_rate")
|
|
||||||
print("learning rate change!")
|
print("learning rate change!")
|
||||||
|
|
||||||
nr_batches = len(mnist_train_x) // batch_size
|
nr_batches = len(mnist_train_x) // batch_size
|
||||||
|
|||||||
Reference in New Issue
Block a user