fixed evaluation error
This commit is contained in:
@@ -113,7 +113,7 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save
|
||||
print(f"Processing image {i + 1}/{len(input_arrays)}")
|
||||
input_array = torch.from_numpy(input_arrays[i]).to(
|
||||
device)
|
||||
output = model(input_array)
|
||||
output = model(input_array.unsqueeze(0) if hasattr(input_array, 'dim') and input_array.dim() == 3 else input_array)
|
||||
output = output.cpu().numpy()
|
||||
predictions.append(output)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user