diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index 71a11cd..525bd86 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -154,12 +154,13 @@ def predict(dataset: tf.data.Dataset, outputs = _predict_one_epoch(dataset, use_dropout, forward_passes_per_image, **checkpointables) # save predictions - filename = 'ssd_predictions.bin' + filename = 'ssd_predictions.npy' if use_dropout: filename = 'dropout-' + filename - - with open(output_path + filename, 'wb') as file: - pickle.dump(outputs, file) + output_file = os.path.join(output_path, filename) + + with open(output_file, 'wb') as file: + np.save(file, outputs['decoded_predictions'], allow_pickle=False, fix_imports=False) if verbose: print((