diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index 525bd86..0438244 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -151,17 +151,8 @@ def predict(dataset: tf.data.Dataset, checkpoint = tf.train.Checkpoint(**checkpointables) checkpoint.restore(latest_checkpoint) - outputs = _predict_one_epoch(dataset, use_dropout, forward_passes_per_image, **checkpointables) + outputs = _predict_one_epoch(dataset, use_dropout, output_path, forward_passes_per_image, **checkpointables) - # save predictions - filename = 'ssd_predictions.npy' - if use_dropout: - filename = 'dropout-' + filename - output_file = os.path.join(output_path, filename) - - with open(output_file, 'wb') as file: - np.save(file, outputs['decoded_predictions'], allow_pickle=False, fix_imports=False) - if verbose: print(( f"predict time: {outputs['per_epoch_time']:.2f}, " @@ -171,35 +162,44 @@ def predict(dataset: tf.data.Dataset, def _predict_one_epoch(dataset: tf.data.Dataset, use_dropout: bool, + output_path: str, forward_passes_per_image: int, - ssd: tf.keras.Model) -> Dict[str, Union[float, List[List[np.ndarray]]]]: + ssd: tf.keras.Model) -> Dict[str, float]: - with summary_ops_v2.always_record_summaries(): - epoch_start_time = time.time() - decoded_predictions = [] - - # go through the data set - for inputs, _ in dataset: - decoded_predictions_batch = [] - if use_dropout: - for _ in range(forward_passes_per_image): - decoded_predictions_pass = ssd(inputs) - decoded_predictions_batch.append(decoded_predictions_pass) - else: - decoded_predictions_batch.append(ssd(inputs)) - - decoded_predictions.append(decoded_predictions_batch) - - epoch_end_time = time.time() - per_epoch_time = epoch_end_time - epoch_start_time + epoch_start_time = time.time() + + # prepare filename + filename = 'ssd_predictions' + if use_dropout: + filename = 'dropout-' + filename + output_file = os.path.join(output_path, filename) + + # go through the data set + counter = 0 + for inputs, _ in dataset: + decoded_predictions_batch = [] + if use_dropout: + for _ in range(forward_passes_per_image): + decoded_predictions_pass = ssd(inputs) + decoded_predictions_batch.append(decoded_predictions_pass) + else: + decoded_predictions_batch.append(ssd(inputs)) - # outputs for epoch - outputs = { - 'per_epoch_time': per_epoch_time, - 'decoded_predictions': decoded_predictions - } + # save predictions batch-wise to prevent memory problems + with open(output_file + '-' + counter + '.npy', 'wb') as file: + np.save(file, decoded_predictions_batch, allow_pickle=False, fix_imports=False) + + counter += 1 + + epoch_end_time = time.time() + per_epoch_time = epoch_end_time - epoch_start_time - return outputs + # outputs for epoch + outputs = { + 'per_epoch_time': per_epoch_time, + } + + return outputs def train(dataset: tf.data.Dataset,