Specified correct datatype for global step variable
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -24,6 +24,7 @@ from typing import Callable, Sequence, Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.ops import summary_ops_v2
|
||||||
|
|
||||||
from .model import Decoder, Encoder, XDiscriminator, ZDiscriminator
|
from .model import Decoder, Encoder, XDiscriminator, ZDiscriminator
|
||||||
from .util import save_image
|
from .util import save_image
|
||||||
@ -104,11 +105,11 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1)
|
sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1)
|
||||||
|
|
||||||
z_generator = functools.partial(get_z_variable, batch_size=batch_size, zsize=zsize)
|
z_generator = functools.partial(get_z_variable, batch_size=batch_size, zsize=zsize)
|
||||||
|
|
||||||
global_step_decoder = k.variable(0)
|
global_step_decoder = k.variable(0, dtype=tf.int64)
|
||||||
global_step_enc_dec = k.variable(0)
|
global_step_enc_dec = k.variable(0, dtype=tf.int64)
|
||||||
global_step_xd = k.variable(0)
|
global_step_xd = k.variable(0, dtype=tf.int64)
|
||||||
global_step_zd = k.variable(0)
|
global_step_zd = k.variable(0, dtype=tf.int64)
|
||||||
|
|
||||||
encoder_loss_history = []
|
encoder_loss_history = []
|
||||||
decoder_loss_history = []
|
decoder_loss_history = []
|
||||||
|
|||||||
Reference in New Issue
Block a user