Save predictions batch-wise to prevent memory issues

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-03-21 17:49:34 +01:00
parent ae109fd131
commit db0d275bdd

View File

@ -151,17 +151,8 @@ def predict(dataset: tf.data.Dataset,
checkpoint = tf.train.Checkpoint(**checkpointables) checkpoint = tf.train.Checkpoint(**checkpointables)
checkpoint.restore(latest_checkpoint) checkpoint.restore(latest_checkpoint)
outputs = _predict_one_epoch(dataset, use_dropout, forward_passes_per_image, **checkpointables) outputs = _predict_one_epoch(dataset, use_dropout, output_path, forward_passes_per_image, **checkpointables)
# save predictions
filename = 'ssd_predictions.npy'
if use_dropout:
filename = 'dropout-' + filename
output_file = os.path.join(output_path, filename)
with open(output_file, 'wb') as file:
np.save(file, outputs['decoded_predictions'], allow_pickle=False, fix_imports=False)
if verbose: if verbose:
print(( print((
f"predict time: {outputs['per_epoch_time']:.2f}, " f"predict time: {outputs['per_epoch_time']:.2f}, "
@ -171,35 +162,44 @@ def predict(dataset: tf.data.Dataset,
def _predict_one_epoch(dataset: tf.data.Dataset, def _predict_one_epoch(dataset: tf.data.Dataset,
use_dropout: bool, use_dropout: bool,
output_path: str,
forward_passes_per_image: int, forward_passes_per_image: int,
ssd: tf.keras.Model) -> Dict[str, Union[float, List[List[np.ndarray]]]]: ssd: tf.keras.Model) -> Dict[str, float]:
with summary_ops_v2.always_record_summaries(): epoch_start_time = time.time()
epoch_start_time = time.time()
decoded_predictions = [] # prepare filename
filename = 'ssd_predictions'
# go through the data set if use_dropout:
for inputs, _ in dataset: filename = 'dropout-' + filename
decoded_predictions_batch = [] output_file = os.path.join(output_path, filename)
if use_dropout:
for _ in range(forward_passes_per_image): # go through the data set
decoded_predictions_pass = ssd(inputs) counter = 0
decoded_predictions_batch.append(decoded_predictions_pass) for inputs, _ in dataset:
else: decoded_predictions_batch = []
decoded_predictions_batch.append(ssd(inputs)) if use_dropout:
for _ in range(forward_passes_per_image):
decoded_predictions.append(decoded_predictions_batch) decoded_predictions_pass = ssd(inputs)
decoded_predictions_batch.append(decoded_predictions_pass)
epoch_end_time = time.time() else:
per_epoch_time = epoch_end_time - epoch_start_time decoded_predictions_batch.append(ssd(inputs))
# outputs for epoch # save predictions batch-wise to prevent memory problems
outputs = { with open(output_file + '-' + counter + '.npy', 'wb') as file:
'per_epoch_time': per_epoch_time, np.save(file, decoded_predictions_batch, allow_pickle=False, fix_imports=False)
'decoded_predictions': decoded_predictions
} counter += 1
epoch_end_time = time.time()
per_epoch_time = epoch_end_time - epoch_start_time
return outputs # outputs for epoch
outputs = {
'per_epoch_time': per_epoch_time,
}
return outputs
def train(dataset: tf.data.Dataset, def train(dataset: tf.data.Dataset,