Save predictions batch-wise to prevent memory issues

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-03-21 17:49:34 +01:00
parent ae109fd131
commit db0d275bdd

View File

@ -151,16 +151,7 @@ def predict(dataset: tf.data.Dataset,
checkpoint = tf.train.Checkpoint(**checkpointables)
checkpoint.restore(latest_checkpoint)
outputs = _predict_one_epoch(dataset, use_dropout, forward_passes_per_image, **checkpointables)
# save predictions
filename = 'ssd_predictions.npy'
if use_dropout:
filename = 'dropout-' + filename
output_file = os.path.join(output_path, filename)
with open(output_file, 'wb') as file:
np.save(file, outputs['decoded_predictions'], allow_pickle=False, fix_imports=False)
outputs = _predict_one_epoch(dataset, use_dropout, output_path, forward_passes_per_image, **checkpointables)
if verbose:
print((
@ -171,14 +162,20 @@ def predict(dataset: tf.data.Dataset,
def _predict_one_epoch(dataset: tf.data.Dataset,
use_dropout: bool,
output_path: str,
forward_passes_per_image: int,
ssd: tf.keras.Model) -> Dict[str, Union[float, List[List[np.ndarray]]]]:
ssd: tf.keras.Model) -> Dict[str, float]:
with summary_ops_v2.always_record_summaries():
epoch_start_time = time.time()
decoded_predictions = []
# prepare filename
filename = 'ssd_predictions'
if use_dropout:
filename = 'dropout-' + filename
output_file = os.path.join(output_path, filename)
# go through the data set
counter = 0
for inputs, _ in dataset:
decoded_predictions_batch = []
if use_dropout:
@ -188,7 +185,11 @@ def _predict_one_epoch(dataset: tf.data.Dataset,
else:
decoded_predictions_batch.append(ssd(inputs))
decoded_predictions.append(decoded_predictions_batch)
# save predictions batch-wise to prevent memory problems
with open(output_file + '-' + counter + '.npy', 'wb') as file:
np.save(file, decoded_predictions_batch, allow_pickle=False, fix_imports=False)
counter += 1
epoch_end_time = time.time()
per_epoch_time = epoch_end_time - epoch_start_time
@ -196,7 +197,6 @@ def _predict_one_epoch(dataset: tf.data.Dataset,
# outputs for epoch
outputs = {
'per_epoch_time': per_epoch_time,
'decoded_predictions': decoded_predictions
}
return outputs