fixed evaluation error

This commit is contained in:
2026-01-24 15:02:07 +01:00
parent 09d1911feb
commit 3b1d3c0497
7 changed files with 2 additions and 2 deletions

View File

@@ -113,7 +113,7 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save
print(f"Processing image {i + 1}/{len(input_arrays)}")
input_array = torch.from_numpy(input_arrays[i]).to(
device)
output = model(input_array)
output = model(input_array.unsqueeze(0) if hasattr(input_array, 'dim') and input_array.dim() == 3 else input_array)
output = output.cpu().numpy()
predictions.append(output)