@ -26,7 +26,7 @@ import tensorflow as tf
|
|||||||
from tensorflow.python.ops import summary_ops_v2
|
from tensorflow.python.ops import summary_ops_v2
|
||||||
|
|
||||||
from .model import Decoder, Encoder, XDiscriminator, ZDiscriminator
|
from .model import Decoder, Encoder, XDiscriminator, ZDiscriminator
|
||||||
from .util import save_image
|
from .util import prepare_image
|
||||||
|
|
||||||
# shortcuts for tensorflow sub packages and classes
|
# shortcuts for tensorflow sub packages and classes
|
||||||
k = tf.keras.backend
|
k = tf.keras.backend
|
||||||
@ -191,8 +191,13 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
if not os.path.exists(directory):
|
if not os.path.exists(directory):
|
||||||
os.makedirs(directory)
|
os.makedirs(directory)
|
||||||
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
|
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
|
||||||
save_image(comparison.cpu(),
|
grid = prepare_image(comparison.cpu(), nrow=64)
|
||||||
'results' + str(inlier_classes[0]) + '/reconstruction_' + str(epoch) + '.png', nrow=64)
|
summary_ops_v2.image(name='reconstruction_' + str(epoch), tensor=grid, max_images=1)
|
||||||
|
from PIL import Image
|
||||||
|
filename = 'results' + str(inlier_classes[0]) + '/reconstruction_' + str(epoch) + '.png'
|
||||||
|
ndarr = grid.cpu().numpy()
|
||||||
|
im = Image.fromarray(ndarr)
|
||||||
|
im.save(filename)
|
||||||
|
|
||||||
batch_iteration.assign_add(1)
|
batch_iteration.assign_add(1)
|
||||||
|
|
||||||
@ -214,8 +219,14 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
resultsample = decoder(sample).cpu()
|
resultsample = decoder(sample).cpu()
|
||||||
directory = 'results' + str(inlier_classes[0])
|
directory = 'results' + str(inlier_classes[0])
|
||||||
os.makedirs(directory, exist_ok=True)
|
os.makedirs(directory, exist_ok=True)
|
||||||
save_image(resultsample,
|
grid = prepare_image(resultsample)
|
||||||
'results' + str(inlier_classes[0]) + '/sample_' + str(epoch) + '.png')
|
summary_ops_v2.image(name='sample_' + str(epoch), tensor=grid, max_images=1)
|
||||||
|
from PIL import Image
|
||||||
|
filename = 'results' + str(inlier_classes[0]) + '/sample_' + str(epoch) + '.png'
|
||||||
|
ndarr = grid.cpu().numpy()
|
||||||
|
im = Image.fromarray(ndarr)
|
||||||
|
im.save(filename)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print("Training finish!... save training results")
|
print("Training finish!... save training results")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user