Save predictions batch-wise to prevent memory issues
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -151,17 +151,8 @@ def predict(dataset: tf.data.Dataset,
|
||||
checkpoint = tf.train.Checkpoint(**checkpointables)
|
||||
checkpoint.restore(latest_checkpoint)
|
||||
|
||||
outputs = _predict_one_epoch(dataset, use_dropout, forward_passes_per_image, **checkpointables)
|
||||
outputs = _predict_one_epoch(dataset, use_dropout, output_path, forward_passes_per_image, **checkpointables)
|
||||
|
||||
# save predictions
|
||||
filename = 'ssd_predictions.npy'
|
||||
if use_dropout:
|
||||
filename = 'dropout-' + filename
|
||||
output_file = os.path.join(output_path, filename)
|
||||
|
||||
with open(output_file, 'wb') as file:
|
||||
np.save(file, outputs['decoded_predictions'], allow_pickle=False, fix_imports=False)
|
||||
|
||||
if verbose:
|
||||
print((
|
||||
f"predict time: {outputs['per_epoch_time']:.2f}, "
|
||||
@ -171,35 +162,44 @@ def predict(dataset: tf.data.Dataset,
|
||||
|
||||
def _predict_one_epoch(dataset: tf.data.Dataset,
|
||||
use_dropout: bool,
|
||||
output_path: str,
|
||||
forward_passes_per_image: int,
|
||||
ssd: tf.keras.Model) -> Dict[str, Union[float, List[List[np.ndarray]]]]:
|
||||
ssd: tf.keras.Model) -> Dict[str, float]:
|
||||
|
||||
with summary_ops_v2.always_record_summaries():
|
||||
epoch_start_time = time.time()
|
||||
decoded_predictions = []
|
||||
|
||||
# go through the data set
|
||||
for inputs, _ in dataset:
|
||||
decoded_predictions_batch = []
|
||||
if use_dropout:
|
||||
for _ in range(forward_passes_per_image):
|
||||
decoded_predictions_pass = ssd(inputs)
|
||||
decoded_predictions_batch.append(decoded_predictions_pass)
|
||||
else:
|
||||
decoded_predictions_batch.append(ssd(inputs))
|
||||
|
||||
decoded_predictions.append(decoded_predictions_batch)
|
||||
|
||||
epoch_end_time = time.time()
|
||||
per_epoch_time = epoch_end_time - epoch_start_time
|
||||
epoch_start_time = time.time()
|
||||
|
||||
# prepare filename
|
||||
filename = 'ssd_predictions'
|
||||
if use_dropout:
|
||||
filename = 'dropout-' + filename
|
||||
output_file = os.path.join(output_path, filename)
|
||||
|
||||
# go through the data set
|
||||
counter = 0
|
||||
for inputs, _ in dataset:
|
||||
decoded_predictions_batch = []
|
||||
if use_dropout:
|
||||
for _ in range(forward_passes_per_image):
|
||||
decoded_predictions_pass = ssd(inputs)
|
||||
decoded_predictions_batch.append(decoded_predictions_pass)
|
||||
else:
|
||||
decoded_predictions_batch.append(ssd(inputs))
|
||||
|
||||
# outputs for epoch
|
||||
outputs = {
|
||||
'per_epoch_time': per_epoch_time,
|
||||
'decoded_predictions': decoded_predictions
|
||||
}
|
||||
# save predictions batch-wise to prevent memory problems
|
||||
with open(output_file + '-' + counter + '.npy', 'wb') as file:
|
||||
np.save(file, decoded_predictions_batch, allow_pickle=False, fix_imports=False)
|
||||
|
||||
counter += 1
|
||||
|
||||
epoch_end_time = time.time()
|
||||
per_epoch_time = epoch_end_time - epoch_start_time
|
||||
|
||||
return outputs
|
||||
# outputs for epoch
|
||||
outputs = {
|
||||
'per_epoch_time': per_epoch_time,
|
||||
}
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def train(dataset: tf.data.Dataset,
|
||||
|
||||
Reference in New Issue
Block a user