diff --git a/src/twomartens/masterthesis/cli.py b/src/twomartens/masterthesis/cli.py index 91a8a13..191cde8 100644 --- a/src/twomartens/masterthesis/cli.py +++ b/src/twomartens/masterthesis/cli.py @@ -86,12 +86,12 @@ def _ssd_train(args: argparse.Namespace) -> None: if args.debug: train_data = next(train_generator) - train_image = train_data[0][0] - print(train_image) train_length -= batch_size from PIL import Image - image = Image.fromarray(train_image) - image.save(f"{args.summary_path}/train/{args.network}/{args.iteration}/train_image.png") + train_images = train_data[0] + for i, train_image in enumerate(train_images): + image = Image.fromarray(train_image) + image.save(f"{args.summary_path}/train/{args.network}/{args.iteration}/train_image{i}.png") nr_batches_train = int(math.floor(train_length / batch_size)) nr_batches_val = int(math.floor(val_length / batch_size))