From 65638974d3d16a48a706d4fb01ce817e3e2d638f Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Fri, 8 Feb 2019 06:48:13 +0100 Subject: [PATCH] Specified correct datatype for global step variable Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 68ee846..8fa69f2 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -24,6 +24,7 @@ from typing import Callable, Sequence, Tuple import numpy as np import tensorflow as tf +from tensorflow.python.ops import summary_ops_v2 from .model import Decoder, Encoder, XDiscriminator, ZDiscriminator from .util import save_image @@ -104,11 +105,11 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1) z_generator = functools.partial(get_z_variable, batch_size=batch_size, zsize=zsize) - - global_step_decoder = k.variable(0) - global_step_enc_dec = k.variable(0) - global_step_xd = k.variable(0) - global_step_zd = k.variable(0) + + global_step_decoder = k.variable(0, dtype=tf.int64) + global_step_enc_dec = k.variable(0, dtype=tf.int64) + global_step_xd = k.variable(0, dtype=tf.int64) + global_step_zd = k.variable(0, dtype=tf.int64) encoder_loss_history = [] decoder_loss_history = []