Added nn6
This commit is contained in:
3
image-inpainting/.gitignore
vendored
Normal file
3
image-inpainting/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
data/*
|
||||
*.zip
|
||||
*.jpg
|
||||
BIN
image-inpainting/src/__pycache__/architecture.cpython-314.pyc
Normal file
BIN
image-inpainting/src/__pycache__/architecture.cpython-314.pyc
Normal file
Binary file not shown.
BIN
image-inpainting/src/__pycache__/datasets.cpython-314.pyc
Normal file
BIN
image-inpainting/src/__pycache__/datasets.cpython-314.pyc
Normal file
Binary file not shown.
BIN
image-inpainting/src/__pycache__/train.cpython-314.pyc
Normal file
BIN
image-inpainting/src/__pycache__/train.cpython-314.pyc
Normal file
Binary file not shown.
BIN
image-inpainting/src/__pycache__/utils.cpython-314.pyc
Normal file
BIN
image-inpainting/src/__pycache__/utils.cpython-314.pyc
Normal file
Binary file not shown.
11
image-inpainting/src/architecture.py
Normal file
11
image-inpainting/src/architecture.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Author: Your Name
|
||||
HTL-Grieskirchen 5. Jahrgang, Schuljahr 2025/26
|
||||
architecture.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
# TODO: Implement the model architecture.
|
||||
pass
|
||||
43
image-inpainting/src/datasets.py
Normal file
43
image-inpainting/src/datasets.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Author: Your Name
|
||||
HTL-Grieskirchen 5. Jahrgang, Schuljahr 2025/26
|
||||
datasets.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import glob
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
IMAGE_DIMENSION = 100
|
||||
|
||||
|
||||
def create_arrays_from_image(image_array: np.ndarray, offset: tuple, spacing: tuple) -> tuple[np.ndarray, np.ndarray]:
|
||||
image_array, known_array = None, None
|
||||
|
||||
# TODO: Implement the logic to create input and known arrays based on offset and spacing
|
||||
|
||||
return image_array, known_array
|
||||
|
||||
def resize(img: Image):
|
||||
pass
|
||||
def preprocess(input_array: np.ndarray):
|
||||
pass
|
||||
|
||||
class ImageDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Dataset class for loading images from a folder
|
||||
"""
|
||||
|
||||
def __init__(self, datafolder: str):
|
||||
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imagefiles)
|
||||
|
||||
def __getitem__(self, idx:int):
|
||||
pass
|
||||
|
||||
# TODO: Implement the __init__, __len__, and __getitem__ methods
|
||||
49
image-inpainting/src/main.py
Normal file
49
image-inpainting/src/main.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Author: Your Name
|
||||
HTL-Grieskirchen 5. Jahrgang, Schuljahr 2025/26
|
||||
main.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from utils import create_predictions
|
||||
|
||||
|
||||
from train import train
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config_dict = dict()
|
||||
|
||||
config_dict['seed'] = 42
|
||||
config_dict['testset_ratio'] = 0.1
|
||||
config_dict['validset_ratio'] = 0.1
|
||||
config_dict['results_path'] = os.path.join("results")
|
||||
config_dict['data_path'] = os.path.join("data", "dataset")
|
||||
config_dict['device'] = None
|
||||
config_dict['learningrate'] = 1e-3
|
||||
config_dict['weight_decay'] = 1e-5 # default is 0
|
||||
config_dict['n_updates'] = 50000
|
||||
config_dict['batchsize'] = 32
|
||||
config_dict['early_stopping_patience'] = 3
|
||||
config_dict['use_wandb'] = False
|
||||
|
||||
config_dict['print_train_stats_at'] = 10
|
||||
config_dict['print_stats_at'] = 100
|
||||
config_dict['plot_at'] = 100
|
||||
config_dict['validate_at'] = 100
|
||||
|
||||
network_config = {
|
||||
'n_in_channels': 4
|
||||
}
|
||||
|
||||
config_dict['network_config'] = network_config
|
||||
|
||||
train(**config_dict)
|
||||
|
||||
testset_path = os.path.join("data", "challenge_testset.npz")
|
||||
state_dict_path = os.path.join(config_dict['results_path'], "best_model.pt")
|
||||
save_path = os.path.join(config_dict['results_path'], "testset", "my_submission_name.npz")
|
||||
plot_path = os.path.join(config_dict['results_path'], "testset", "plots")
|
||||
|
||||
# Comment out, if predictions are required
|
||||
create_predictions(config_dict['network_config'], state_dict_path, testset_path, None, save_path, plot_path, plot_at=20)
|
||||
166
image-inpainting/src/train.py
Normal file
166
image-inpainting/src/train.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Author: Your Name
|
||||
HTL-Grieskirchen 5. Jahrgang, Schuljahr 2025/26
|
||||
train.py
|
||||
"""
|
||||
|
||||
import datasets
|
||||
from architecture import MyModel
|
||||
from utils import plot, evaluate_model
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Subset
|
||||
|
||||
import wandb
|
||||
|
||||
|
||||
def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_stopping_patience, device, learningrate,
|
||||
weight_decay, n_updates, use_wandb, print_train_stats_at, print_stats_at, plot_at, validate_at, batchsize,
|
||||
network_config: dict):
|
||||
np.random.seed(seed=seed)
|
||||
torch.manual_seed(seed=seed)
|
||||
|
||||
if device is None:
|
||||
device = torch.device(
|
||||
"cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
|
||||
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
if use_wandb:
|
||||
wandb.login()
|
||||
wandb.init(project="image_inpainting", config={
|
||||
"learning_rate": learningrate,
|
||||
"weight_decay": weight_decay,
|
||||
"n_updates": n_updates,
|
||||
"batch_size": batchsize,
|
||||
"validation_ratio": validset_ratio,
|
||||
"testset_ratio": testset_ratio,
|
||||
"early_stopping_patience": early_stopping_patience,
|
||||
})
|
||||
|
||||
# Prepare a path to plot to
|
||||
plotpath = os.path.join(results_path, "plots")
|
||||
os.makedirs(plotpath, exist_ok=True)
|
||||
|
||||
image_dataset = datasets.ImageDataset(datafolder=data_path)
|
||||
|
||||
n_total = len(image_dataset)
|
||||
n_test = int(n_total * testset_ratio)
|
||||
n_valid = int(n_total * validset_ratio)
|
||||
n_train = n_total - n_test - n_valid
|
||||
indices = np.random.permutation(n_total)
|
||||
dataset_train = Subset(image_dataset, indices=indices[0:n_train])
|
||||
dataset_valid = Subset(image_dataset, indices=indices[n_train:n_train + n_valid])
|
||||
dataset_test = Subset(image_dataset, indices=indices[n_train + n_valid:n_total])
|
||||
|
||||
assert len(image_dataset) == len(dataset_train) + len(dataset_test) + len(dataset_valid)
|
||||
|
||||
del image_dataset
|
||||
|
||||
dataloader_train = DataLoader(dataset=dataset_train, batch_size=batchsize,
|
||||
num_workers=0, shuffle=True)
|
||||
dataloader_valid = DataLoader(dataset=dataset_valid, batch_size=1,
|
||||
num_workers=0, shuffle=False)
|
||||
dataloader_test = DataLoader(dataset=dataset_test, batch_size=1,
|
||||
num_workers=0, shuffle=False)
|
||||
|
||||
# initializing the model
|
||||
network = MyModel(**network_config)
|
||||
network.to(device)
|
||||
network.train()
|
||||
|
||||
# defining the loss
|
||||
mse_loss = torch.nn.MSELoss()
|
||||
|
||||
# defining the optimizer
|
||||
optimizer = torch.optim.Adam(network.parameters(), lr=learningrate, weight_decay=weight_decay)
|
||||
|
||||
if use_wandb:
|
||||
wandb.watch(network, mse_loss, log="all", log_freq=10)
|
||||
|
||||
i = 0
|
||||
counter = 0
|
||||
best_validation_loss = np.inf
|
||||
loss_list = []
|
||||
|
||||
saved_model_path = os.path.join(results_path, "best_model.pt")
|
||||
|
||||
print(f"Started training on device {device}")
|
||||
|
||||
while i < n_updates:
|
||||
|
||||
for input, target in dataloader_train:
|
||||
|
||||
input, target = input.to(device), target.to(device)
|
||||
|
||||
if (i + 1) % print_train_stats_at == 0:
|
||||
print(f'Update Step {i + 1} of {n_updates}: Current loss: {loss_list[-1]}')
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
output = network(input)
|
||||
|
||||
loss = mse_loss(output, target)
|
||||
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
loss_list.append(loss.item())
|
||||
|
||||
# writing the stats to wandb
|
||||
if use_wandb and (i+1) % print_stats_at == 0:
|
||||
wandb.log({"training/loss_per_batch": loss.item()}, step=i)
|
||||
|
||||
# plotting
|
||||
if (i + 1) % plot_at == 0:
|
||||
print(f"Plotting images, current update {i + 1}")
|
||||
plot(input.cpu().numpy(), target.detach().cpu().numpy(), output.detach().cpu().numpy(), plotpath, i)
|
||||
|
||||
# evaluating model every validate_at sample
|
||||
if (i + 1) % validate_at == 0:
|
||||
print(f"Evaluation of the model:")
|
||||
val_loss, val_rmse = evaluate_model(network, dataloader_valid, mse_loss, device)
|
||||
print(f"val_loss: {val_loss}")
|
||||
print(f"val_RMSE: {val_rmse}")
|
||||
|
||||
if use_wandb:
|
||||
wandb.log({"validation/loss": val_loss,
|
||||
"validation/RMSE": val_rmse}, step=i)
|
||||
# wandb histogram
|
||||
|
||||
# Save best model for early stopping
|
||||
if val_loss < best_validation_loss:
|
||||
best_validation_loss = val_loss
|
||||
torch.save(network.state_dict(), saved_model_path)
|
||||
print(f"Saved new best model with val_loss: {best_validation_loss}")
|
||||
counter = 0
|
||||
else:
|
||||
counter += 1
|
||||
|
||||
if counter >= early_stopping_patience:
|
||||
print("Stopped training because of early stopping")
|
||||
i = n_updates
|
||||
break
|
||||
|
||||
i += 1
|
||||
if i >= n_updates:
|
||||
print("Finished training because maximum number of updates reached")
|
||||
break
|
||||
|
||||
print("Evaluating the self-defined testset")
|
||||
network.load_state_dict(torch.load(saved_model_path))
|
||||
testset_loss, testset_rmse = evaluate_model(network=network, dataloader=dataloader_test, loss_fn=mse_loss,
|
||||
device=device)
|
||||
|
||||
print(f'testset_loss of model: {testset_loss}, RMSE = {testset_rmse}')
|
||||
|
||||
if use_wandb:
|
||||
wandb.summary["testset/loss"] = testset_loss
|
||||
wandb.summary["testset/RMSE"] = testset_rmse
|
||||
wandb.finish()
|
||||
133
image-inpainting/src/utils.py
Normal file
133
image-inpainting/src/utils.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
Author: Your Name
|
||||
HTL-Grieskirchen 5. Jahrgang, Schuljahr 2025/26
|
||||
utils.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from architecture import MyModel
|
||||
|
||||
|
||||
def plot(inputs, targets, predictions, path, update):
|
||||
"""Plotting the inputs, targets and predictions to file `path`"""
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
fig, axes = plt.subplots(ncols=3, figsize=(15, 5))
|
||||
|
||||
for i in range(len(inputs)):
|
||||
for ax, data, title in zip(axes, [inputs, targets, predictions], ["Input", "Target", "Prediction"]):
|
||||
ax.clear()
|
||||
ax.set_title(title)
|
||||
img = data[i:i + 1:, 0:3, :, :]
|
||||
img = np.squeeze(img)
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
img = np.clip(img, 0, 1)
|
||||
ax.imshow(img)
|
||||
ax.set_axis_off()
|
||||
fig.savefig(os.path.join(path, f"{update + 1:07d}_{i + 1:02d}.jpg"))
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def testset_plot(input_array, output_array, path, index):
|
||||
"""Plotting the inputs, targets and predictions to file `path` for testset (no targets available)"""
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
fig, axes = plt.subplots(ncols=2, figsize=(10, 5))
|
||||
|
||||
for ax, data, title in zip(axes, [input_array, output_array], ["Input", "Prediction"]):
|
||||
ax.clear()
|
||||
ax.set_title(title)
|
||||
img = data[0:3, :, :]
|
||||
img = np.squeeze(img)
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
img = np.clip(img, 0, 1)
|
||||
ax.imshow(img)
|
||||
ax.set_axis_off()
|
||||
fig.savefig(os.path.join(path, f"testset_{index + 1:07d}.jpg"))
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def evaluate_model(network: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn, device: torch.device):
|
||||
"""Returnse MSE and RMSE of the model on the provided dataloader"""
|
||||
network.eval()
|
||||
loss = 0.0
|
||||
with torch.no_grad():
|
||||
for data in dataloader:
|
||||
input_array, target = data
|
||||
input_array = input_array.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
outputs = network(input_array)
|
||||
|
||||
loss += loss_fn(outputs, target).item()
|
||||
|
||||
loss = loss / len(dataloader)
|
||||
|
||||
network.train()
|
||||
|
||||
return loss, 255.0 * np.sqrt(loss)
|
||||
|
||||
|
||||
def read_compressed_file(file_path: str):
|
||||
with np.load(file_path) as data:
|
||||
input_arrays = data['input_arrays']
|
||||
known_arrays = data['known_arrays']
|
||||
return input_arrays, known_arrays
|
||||
|
||||
|
||||
def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20):
|
||||
"""
|
||||
Here, one might needs to adjust the code based on the used preprocessing
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = torch.device(
|
||||
"cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
|
||||
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
model = MyModel(**model_config)
|
||||
model.load_state_dict(torch.load(state_dict_path))
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
input_arrays, known_arrays = read_compressed_file(testset_path)
|
||||
|
||||
known_arrays = known_arrays.astype(np.float32)
|
||||
|
||||
input_arrays = input_arrays.astype(np.float32) / 255.0
|
||||
|
||||
input_arrays = np.concatenate((input_arrays, known_arrays), axis=1)
|
||||
|
||||
predictions = list()
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(len(input_arrays)):
|
||||
print(f"Processing image {i + 1}/{len(input_arrays)}")
|
||||
input_array = torch.from_numpy(input_arrays[i]).to(
|
||||
device)
|
||||
output = model(input_array)
|
||||
output = output.cpu().numpy()
|
||||
predictions.append(output)
|
||||
|
||||
if (i + 1) % plot_at == 0:
|
||||
testset_plot(input_array.cpu().numpy(), output, plot_path, i)
|
||||
|
||||
predictions = np.stack(predictions, axis=0)
|
||||
|
||||
predictions = (np.clip(predictions, 0, 1) * 255.0).astype(np.uint8)
|
||||
|
||||
data = {
|
||||
"predictions": predictions
|
||||
}
|
||||
|
||||
np.savez_compressed(save_path, **data)
|
||||
|
||||
print(f"Predictions saved at {save_path}")
|
||||
Reference in New Issue
Block a user