Specified correct datatype for global step variable
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -24,6 +24,7 @@ from typing import Callable, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.ops import summary_ops_v2
|
||||
|
||||
from .model import Decoder, Encoder, XDiscriminator, ZDiscriminator
|
||||
from .util import save_image
|
||||
@ -105,10 +106,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
|
||||
z_generator = functools.partial(get_z_variable, batch_size=batch_size, zsize=zsize)
|
||||
|
||||
global_step_decoder = k.variable(0)
|
||||
global_step_enc_dec = k.variable(0)
|
||||
global_step_xd = k.variable(0)
|
||||
global_step_zd = k.variable(0)
|
||||
global_step_decoder = k.variable(0, dtype=tf.int64)
|
||||
global_step_enc_dec = k.variable(0, dtype=tf.int64)
|
||||
global_step_xd = k.variable(0, dtype=tf.int64)
|
||||
global_step_zd = k.variable(0, dtype=tf.int64)
|
||||
|
||||
encoder_loss_history = []
|
||||
decoder_loss_history = []
|
||||
|
||||
Reference in New Issue
Block a user