Specified correct datatype for global step variable

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 06:48:13 +01:00
parent 8aee9817f0
commit 65638974d3

View File

@ -24,6 +24,7 @@ from typing import Callable, Sequence, Tuple
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.ops import summary_ops_v2
from .model import Decoder, Encoder, XDiscriminator, ZDiscriminator from .model import Decoder, Encoder, XDiscriminator, ZDiscriminator
from .util import save_image from .util import save_image
@ -105,10 +106,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
z_generator = functools.partial(get_z_variable, batch_size=batch_size, zsize=zsize) z_generator = functools.partial(get_z_variable, batch_size=batch_size, zsize=zsize)
global_step_decoder = k.variable(0) global_step_decoder = k.variable(0, dtype=tf.int64)
global_step_enc_dec = k.variable(0) global_step_enc_dec = k.variable(0, dtype=tf.int64)
global_step_xd = k.variable(0) global_step_xd = k.variable(0, dtype=tf.int64)
global_step_zd = k.variable(0) global_step_zd = k.variable(0, dtype=tf.int64)
encoder_loss_history = [] encoder_loss_history = []
decoder_loss_history = [] decoder_loss_history = []