139 lines
4.2 KiB
Python
139 lines
4.2 KiB
Python
"""
|
|
Author: Your Name
|
|
HTL-Grieskirchen 5. Jahrgang, Schuljahr 2025/26
|
|
utils.py
|
|
"""
|
|
|
|
import torch
|
|
import numpy as np
|
|
import os
|
|
from matplotlib import pyplot as plt
|
|
|
|
from architecture import MyModel
|
|
|
|
|
|
def plot(inputs, targets, predictions, path, update):
|
|
"""Plotting the inputs, targets and predictions to file `path`"""
|
|
|
|
os.makedirs(path, exist_ok=True)
|
|
fig, axes = plt.subplots(ncols=3, figsize=(15, 5))
|
|
|
|
for i in range(5):
|
|
for ax, data, title in zip(axes, [inputs, targets, predictions], ["Input", "Target", "Prediction"]):
|
|
ax.clear()
|
|
ax.set_title(title)
|
|
img = data[i:i + 1:, 0:3, :, :]
|
|
img = np.squeeze(img)
|
|
img = np.transpose(img, (1, 2, 0))
|
|
img = np.clip(img, 0, 1)
|
|
ax.imshow(img)
|
|
ax.set_axis_off()
|
|
fig.savefig(os.path.join(path, f"{update + 1:07d}_{i + 1:02d}.jpg"))
|
|
|
|
plt.close(fig)
|
|
|
|
|
|
def testset_plot(input_array, output_array, path, index):
|
|
"""Plotting the inputs, targets and predictions to file `path` for testset (no targets available)"""
|
|
|
|
os.makedirs(path, exist_ok=True)
|
|
fig, axes = plt.subplots(ncols=2, figsize=(10, 5))
|
|
|
|
for ax, data, title in zip(axes, [input_array, output_array], ["Input", "Prediction"]):
|
|
ax.clear()
|
|
ax.set_title(title)
|
|
img = data[0:3, :, :]
|
|
img = np.squeeze(img)
|
|
img = np.transpose(img, (1, 2, 0))
|
|
img = np.clip(img, 0, 1)
|
|
ax.imshow(img)
|
|
ax.set_axis_off()
|
|
fig.savefig(os.path.join(path, f"testset_{index + 1:07d}.jpg"))
|
|
|
|
plt.close(fig)
|
|
|
|
|
|
def evaluate_model(network: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn, device: torch.device):
|
|
"""Returnse MSE and RMSE of the model on the provided dataloader"""
|
|
network.eval()
|
|
loss = 0.0
|
|
with torch.no_grad():
|
|
for data in dataloader:
|
|
input_array, target = data
|
|
input_array = input_array.to(device)
|
|
target = target.to(device)
|
|
|
|
outputs = network(input_array)
|
|
|
|
loss += loss_fn(outputs, target).item()
|
|
|
|
loss = loss / len(dataloader)
|
|
|
|
network.train()
|
|
|
|
return loss, 255.0 * np.sqrt(loss)
|
|
|
|
|
|
def read_compressed_file(file_path: str):
|
|
with np.load(file_path) as data:
|
|
input_arrays = data['input_arrays']
|
|
known_arrays = data['known_arrays']
|
|
return input_arrays, known_arrays
|
|
|
|
|
|
def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20, rmse_value=None):
|
|
"""
|
|
Here, one might needs to adjust the code based on the used preprocessing
|
|
"""
|
|
|
|
if device is None:
|
|
device = torch.device(
|
|
"cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
|
|
|
|
if isinstance(device, str):
|
|
device = torch.device(device)
|
|
|
|
model = MyModel(**model_config)
|
|
model.load_state_dict(torch.load(state_dict_path))
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
input_arrays, known_arrays = read_compressed_file(testset_path)
|
|
|
|
known_arrays = known_arrays.astype(np.float32)
|
|
|
|
input_arrays = input_arrays.astype(np.float32) / 255.0
|
|
|
|
input_arrays = np.concatenate((input_arrays, known_arrays), axis=1)
|
|
|
|
predictions = list()
|
|
|
|
with torch.no_grad():
|
|
for i in range(len(input_arrays)):
|
|
print(f"Processing image {i + 1}/{len(input_arrays)}")
|
|
input_array = torch.from_numpy(input_arrays[i]).to(
|
|
device)
|
|
output = model(input_array.unsqueeze(0) if hasattr(input_array, 'dim') and input_array.dim() == 3 else input_array)
|
|
output = output.cpu().numpy()
|
|
predictions.append(output)
|
|
|
|
if (i + 1) % plot_at == 0:
|
|
testset_plot(input_array.cpu().numpy(), output, plot_path, i)
|
|
|
|
predictions = np.stack(predictions, axis=0)
|
|
|
|
predictions = (np.clip(predictions, 0, 1) * 255.0).astype(np.uint8)
|
|
|
|
data = {
|
|
"predictions": predictions
|
|
}
|
|
|
|
# Modify save_path to include RMSE value if provided
|
|
if rmse_value is not None:
|
|
base_path = save_path.rsplit('.npz', 1)[0]
|
|
save_path = f"{base_path}-{rmse_value:.4f}.npz"
|
|
|
|
np.savez_compressed(save_path, **data)
|
|
|
|
print(f"Predictions saved at {save_path}")
|