Save predictions batch-wise to prevent memory issues

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-03-21 17:49:34 +01:00
parent ae109fd131
commit db0d275bdd

View File

@ -151,16 +151,7 @@ def predict(dataset: tf.data.Dataset,
checkpoint = tf.train.Checkpoint(**checkpointables) checkpoint = tf.train.Checkpoint(**checkpointables)
checkpoint.restore(latest_checkpoint) checkpoint.restore(latest_checkpoint)
outputs = _predict_one_epoch(dataset, use_dropout, forward_passes_per_image, **checkpointables) outputs = _predict_one_epoch(dataset, use_dropout, output_path, forward_passes_per_image, **checkpointables)
# save predictions
filename = 'ssd_predictions.npy'
if use_dropout:
filename = 'dropout-' + filename
output_file = os.path.join(output_path, filename)
with open(output_file, 'wb') as file:
np.save(file, outputs['decoded_predictions'], allow_pickle=False, fix_imports=False)
if verbose: if verbose:
print(( print((
@ -171,35 +162,44 @@ def predict(dataset: tf.data.Dataset,
def _predict_one_epoch(dataset: tf.data.Dataset, def _predict_one_epoch(dataset: tf.data.Dataset,
use_dropout: bool, use_dropout: bool,
output_path: str,
forward_passes_per_image: int, forward_passes_per_image: int,
ssd: tf.keras.Model) -> Dict[str, Union[float, List[List[np.ndarray]]]]: ssd: tf.keras.Model) -> Dict[str, float]:
with summary_ops_v2.always_record_summaries(): epoch_start_time = time.time()
epoch_start_time = time.time()
decoded_predictions = []
# go through the data set # prepare filename
for inputs, _ in dataset: filename = 'ssd_predictions'
decoded_predictions_batch = [] if use_dropout:
if use_dropout: filename = 'dropout-' + filename
for _ in range(forward_passes_per_image): output_file = os.path.join(output_path, filename)
decoded_predictions_pass = ssd(inputs)
decoded_predictions_batch.append(decoded_predictions_pass)
else:
decoded_predictions_batch.append(ssd(inputs))
decoded_predictions.append(decoded_predictions_batch) # go through the data set
counter = 0
for inputs, _ in dataset:
decoded_predictions_batch = []
if use_dropout:
for _ in range(forward_passes_per_image):
decoded_predictions_pass = ssd(inputs)
decoded_predictions_batch.append(decoded_predictions_pass)
else:
decoded_predictions_batch.append(ssd(inputs))
epoch_end_time = time.time() # save predictions batch-wise to prevent memory problems
per_epoch_time = epoch_end_time - epoch_start_time with open(output_file + '-' + counter + '.npy', 'wb') as file:
np.save(file, decoded_predictions_batch, allow_pickle=False, fix_imports=False)
# outputs for epoch counter += 1
outputs = {
'per_epoch_time': per_epoch_time,
'decoded_predictions': decoded_predictions
}
return outputs epoch_end_time = time.time()
per_epoch_time = epoch_end_time - epoch_start_time
# outputs for epoch
outputs = {
'per_epoch_time': per_epoch_time,
}
return outputs
def train(dataset: tf.data.Dataset, def train(dataset: tf.data.Dataset,