Save predictions batch-wise to prevent memory issues

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-03-21 17:49:34 +01:00
parent ae109fd131
commit db0d275bdd

View File

@ -151,16 +151,7 @@ def predict(dataset: tf.data.Dataset,
checkpoint = tf.train.Checkpoint(**checkpointables)
checkpoint.restore(latest_checkpoint)
outputs = _predict_one_epoch(dataset, use_dropout, forward_passes_per_image, **checkpointables)
# save predictions
filename = 'ssd_predictions.npy'
if use_dropout:
filename = 'dropout-' + filename
output_file = os.path.join(output_path, filename)
with open(output_file, 'wb') as file:
np.save(file, outputs['decoded_predictions'], allow_pickle=False, fix_imports=False)
outputs = _predict_one_epoch(dataset, use_dropout, output_path, forward_passes_per_image, **checkpointables)
if verbose:
print((
@ -171,35 +162,44 @@ def predict(dataset: tf.data.Dataset,
def _predict_one_epoch(dataset: tf.data.Dataset,
use_dropout: bool,
output_path: str,
forward_passes_per_image: int,
ssd: tf.keras.Model) -> Dict[str, Union[float, List[List[np.ndarray]]]]:
ssd: tf.keras.Model) -> Dict[str, float]:
with summary_ops_v2.always_record_summaries():
epoch_start_time = time.time()
decoded_predictions = []
epoch_start_time = time.time()
# go through the data set
for inputs, _ in dataset:
decoded_predictions_batch = []
if use_dropout:
for _ in range(forward_passes_per_image):
decoded_predictions_pass = ssd(inputs)
decoded_predictions_batch.append(decoded_predictions_pass)
else:
decoded_predictions_batch.append(ssd(inputs))
# prepare filename
filename = 'ssd_predictions'
if use_dropout:
filename = 'dropout-' + filename
output_file = os.path.join(output_path, filename)
decoded_predictions.append(decoded_predictions_batch)
# go through the data set
counter = 0
for inputs, _ in dataset:
decoded_predictions_batch = []
if use_dropout:
for _ in range(forward_passes_per_image):
decoded_predictions_pass = ssd(inputs)
decoded_predictions_batch.append(decoded_predictions_pass)
else:
decoded_predictions_batch.append(ssd(inputs))
epoch_end_time = time.time()
per_epoch_time = epoch_end_time - epoch_start_time
# save predictions batch-wise to prevent memory problems
with open(output_file + '-' + counter + '.npy', 'wb') as file:
np.save(file, decoded_predictions_batch, allow_pickle=False, fix_imports=False)
# outputs for epoch
outputs = {
'per_epoch_time': per_epoch_time,
'decoded_predictions': decoded_predictions
}
counter += 1
return outputs
epoch_end_time = time.time()
per_epoch_time = epoch_end_time - epoch_start_time
# outputs for epoch
outputs = {
'per_epoch_time': per_epoch_time,
}
return outputs
def train(dataset: tf.data.Dataset,