added result, 21.3950

This commit is contained in:
2026-01-24 17:26:03 +01:00
parent 16cdc2ed4e
commit 15cfbe315c
10 changed files with 211 additions and 33 deletions

View File

@@ -81,9 +81,42 @@ def read_compressed_file(file_path: str):
return input_arrays, known_arrays
def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20, rmse_value=None):
def apply_tta(model, input_tensor, device):
"""
Here, one might needs to adjust the code based on the used preprocessing
Apply Test-Time Augmentation for better predictions.
Averages predictions from original and augmented versions.
"""
outputs = []
# Original
out = model(input_tensor)
outputs.append(out)
# Horizontal flip
flipped_h = torch.flip(input_tensor, dims=[3])
out_h = model(flipped_h)
out_h = torch.flip(out_h, dims=[3])
outputs.append(out_h)
# Vertical flip
flipped_v = torch.flip(input_tensor, dims=[2])
out_v = model(flipped_v)
out_v = torch.flip(out_v, dims=[2])
outputs.append(out_v)
# Both flips
flipped_hv = torch.flip(input_tensor, dims=[2, 3])
out_hv = model(flipped_hv)
out_hv = torch.flip(out_hv, dims=[2, 3])
outputs.append(out_hv)
# Average all predictions
return torch.stack(outputs, dim=0).mean(dim=0)
def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20, rmse_value=None, use_tta=True):
"""
Create predictions with optional Test-Time Augmentation for improved results.
"""
if device is None:
@@ -94,7 +127,7 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save
device = torch.device(device)
model = MyModel(**model_config)
model.load_state_dict(torch.load(state_dict_path))
model.load_state_dict(torch.load(state_dict_path, weights_only=True))
model.to(device)
model.eval()
@@ -111,9 +144,14 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save
with torch.no_grad():
for i in range(len(input_arrays)):
print(f"Processing image {i + 1}/{len(input_arrays)}")
input_array = torch.from_numpy(input_arrays[i]).to(
device)
output = model(input_array.unsqueeze(0) if hasattr(input_array, 'dim') and input_array.dim() == 3 else input_array)
input_array = torch.from_numpy(input_arrays[i]).to(device)
input_tensor = input_array.unsqueeze(0) if input_array.dim() == 3 else input_array
if use_tta:
output = apply_tta(model, input_tensor, device)
else:
output = model(input_tensor)
output = output.cpu().numpy()
predictions.append(output)