fixed evaluation error

This commit is contained in:
2026-01-24 15:02:07 +01:00
parent 09d1911feb
commit 3b1d3c0497
7 changed files with 2 additions and 2 deletions

Binary file not shown.

Binary file not shown.

View File

@@ -32,7 +32,7 @@ if __name__ == '__main__':
config_dict['print_train_stats_at'] = 10
config_dict['print_stats_at'] = 100
config_dict['plot_at'] = 10
config_dict['plot_at'] = 100
config_dict['validate_at'] = 100
network_config = {

View File

@@ -113,7 +113,7 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save
print(f"Processing image {i + 1}/{len(input_arrays)}")
input_array = torch.from_numpy(input_arrays[i]).to(
device)
output = model(input_array)
output = model(input_array.unsqueeze(0) if hasattr(input_array, 'dim') and input_array.dim() == 3 else input_array)
output = output.cpu().numpy()
predictions.append(output)