868 lines
25 KiB
Plaintext
868 lines
25 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2580d14d",
|
|
"metadata": {},
|
|
"source": [
|
|
"<div style=\"\n",
|
|
" border: 2px solid #4CAF50; \n",
|
|
" padding: 15px; \n",
|
|
" background-color: #f4f4f4; \n",
|
|
" border-radius: 10px; \n",
|
|
" align-items: center;\">\n",
|
|
"\n",
|
|
"<h1 style=\"margin: 0; color: #4CAF50;\">Neural Networks: Ein Beispiel (Klassifikation) (Lösung)</h1>\n",
|
|
"<h2 style=\"margin: 5px 0; color: #555;\">DSAI</h2>\n",
|
|
"<h3 style=\"margin: 5px 0; color: #555;\">Jakob Eggl</h3>\n",
|
|
"\n",
|
|
"<div style=\"flex-shrink: 0;\">\n",
|
|
" <img src=\"https://www.htl-grieskirchen.at/wp/wp-content/uploads/2022/11/logo_bildschirm-1024x503.png\" alt=\"Logo\" style=\"width: 250px; height: auto;\"/>\n",
|
|
"</div>\n",
|
|
"<p1> © 2025/26 Jakob Eggl. Nutzung oder Verbreitung nur mit ausdrücklicher Genehmigung des Autors.</p1>\n",
|
|
"</div>\n",
|
|
"<div style=\"flex: 1;\">\n",
|
|
"</div> "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e1a0eaf8",
|
|
"metadata": {},
|
|
"source": [
|
|
"Wir wollen nun auch ein neuronales Netzwerk für die Klassifizierung bauen. Dabei wollen wir ein sehr bekanntes Dataset verwenden (MNIST). Es gibt es in vielen Variationen (zum Beispiel auch mit Kleidung (Fashion-MNIST)) und ist gratis. \n",
|
|
"\n",
|
|
"Zuerst wollen wir das normale MNIST Dataset verwenden. Es beinhaltet die handgeschriebenen Zahlen von $0$ bis $9$. Ziel ist es die richtige Zahl zu erkennen."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "cc8846a2",
|
|
"metadata": {},
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"(von https://de.wikipedia.org/wiki/MNIST-Datenbank)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "51c9fecc",
|
|
"metadata": {},
|
|
"source": [
|
|
"Insgesamt hat das MNIST Dataset $60\\mathrm k$ Trainingsbilder und $10\\mathrm k$ Testbilder. Die Klassen sind dabei ziemlich gleichverteilt, sprich es gibt in etwa gleich viele Bilder mit Label \"1\", Label \"2\", usw."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e054c9dc",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Lösung"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7b7c7bd4",
|
|
"metadata": {},
|
|
"source": [
|
|
"Zu Beginn wollen wir sicherstellen, dass jede und jeder das MNIST Dataset heruntergeladen hat. Der Pfad der folgenden Methode kann, wenn nötig, angepasst werden."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5d2fdca5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.optim as optim\n",
|
|
"from torch.utils.data import DataLoader, Dataset, random_split\n",
|
|
"from torchvision import datasets, transforms\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import numpy as np\n",
|
|
"from sklearn.metrics import confusion_matrix\n",
|
|
"import seaborn as sns"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "38992064",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data_path = os.path.join(\"..\", \"..\", \"_data\", \"mnist_data\")\n",
|
|
"\n",
|
|
"train_dataset = datasets.MNIST(root=data_path, train=True, download=True, transform=transforms.ToTensor()) # ToTensor makes images [0, 1] instead of {1,2,...,255}\n",
|
|
"test_dataset = datasets.MNIST(root=data_path, train=False, download=True, transform=transforms.ToTensor())\n",
|
|
"\n",
|
|
"test_size = len(test_dataset) // 2\n",
|
|
"valid_size = len(test_dataset) - test_size\n",
|
|
"\n",
|
|
"test_dataset, valid_dataset = random_split(test_dataset, [test_size, valid_size])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7e4570f5",
|
|
"metadata": {},
|
|
"source": [
|
|
"Mit der obigen Methode haben wir direkt ein Torch Dataset erhalten und müssen nur mehr später den Dataloader erstellen."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2fae606c",
|
|
"metadata": {},
|
|
"source": [
|
|
"Kurze **Wiederholung**: *Wie erstellt man sein eigenes Dataset*?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "3504b37b",
|
|
"metadata": {},
|
|
"source": [
|
|
"Zum Beispiel so:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2c5ae338",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class MyDataSetThatIsNeverUsed(Dataset): \n",
|
|
" def __init__(self, transform=None):\n",
|
|
" super().__init__()\n",
|
|
" self.transform = transform\n",
|
|
"\n",
|
|
" def __len__(self):\n",
|
|
" return 0\n",
|
|
"\n",
|
|
" def __getitem__(self, idx):\n",
|
|
" # here is place for the transformation. Returns input and label\n",
|
|
" return torch.tensor([]), torch.tensor(0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1068fe47",
|
|
"metadata": {},
|
|
"source": [
|
|
"Ansonsten starten wir wieder mit dem device (Prinzipiell eine gute Gewohnheit, dies einmalig am Anfang zu definieren)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3e5a91e4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')\n",
|
|
"print(device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b992eade",
|
|
"metadata": {},
|
|
"source": [
|
|
"Nachdem wir die Datasets schon haben, wollen wir nun die Dataloader definieren."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "01f002c0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch_size = 64\n",
|
|
"\n",
|
|
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
|
"test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
|
|
"valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6d262f12",
|
|
"metadata": {},
|
|
"source": [
|
|
"Wir wollen uns nun auch noch ein paar Bilder aus dem Trainingsset ansehen."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "661246a3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"examples = enumerate(train_loader)\n",
|
|
"batch_idx, (example_data, example_targets) = next(examples)\n",
|
|
"\n",
|
|
"plt.figure(figsize=(8, 3))\n",
|
|
"for i in range(6):\n",
|
|
" plt.subplot(1, 6, i+1)\n",
|
|
" plt.tight_layout()\n",
|
|
" plt.imshow(example_data[i][0], cmap='gray', interpolation='none')\n",
|
|
" plt.title(f\"{example_targets[i]}\")\n",
|
|
" plt.xticks([])\n",
|
|
" plt.yticks([])\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c14a7e3c",
|
|
"metadata": {},
|
|
"source": [
|
|
"Als nächstes definieren wir uns das Netzwerk. Auf was müssen wir nun acht geben im Vergleich zur Regression?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e2d9cd26",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class MNISTClassifier(nn.Module):\n",
|
|
" def __init__(self):\n",
|
|
" super().__init__()\n",
|
|
" self.layers = nn.Sequential(\n",
|
|
" nn.Flatten(), # Very important! Why? -> We will see that for CNN's we don't need this flattening!\n",
|
|
" nn.Linear(28*28, 256),\n",
|
|
" nn.ReLU(),\n",
|
|
" nn.Linear(256, 128),\n",
|
|
" nn.ReLU(),\n",
|
|
" nn.Linear(128,64),\n",
|
|
" nn.ReLU(),\n",
|
|
" nn.Linear(64, 10),\n",
|
|
" )\n",
|
|
" def forward(self, x):\n",
|
|
" return self.layers(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4e413ca3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = MNISTClassifier().to(device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "cf9deca1",
|
|
"metadata": {},
|
|
"source": [
|
|
"Welchen Loss wollen wir verwenden? Welchen Optimizer?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fe340898",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"lr = 0.001\n",
|
|
"\n",
|
|
"criterion = nn.CrossEntropyLoss()\n",
|
|
"optimizer = optim.Adam(model.parameters(), lr=lr)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2b93a21c",
|
|
"metadata": {},
|
|
"source": [
|
|
"Kommen wir nun zur Trainingsmethode. Wir machen diese dieses Mal als eigene Methode. Ebenso machen wir das mit der Evaluierungsmethode. (Grund für die umgekehrte Reihenfolge ist, weil die Trainingsmethode eine Evaluierungsmethode beinhaltet.)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4da656f8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def evaluate_model(model, data_loader, criterion):\n",
|
|
" model.eval()\n",
|
|
" loss_total = 0.0\n",
|
|
" correct = 0\n",
|
|
" total = 0\n",
|
|
" \n",
|
|
" with torch.no_grad():\n",
|
|
" for data, target in data_loader:\n",
|
|
" data, target = data.to(device), target.to(device)\n",
|
|
" output = model(data)\n",
|
|
" loss = criterion(output, target)\n",
|
|
" loss_total += loss.item() * data.size(0)\n",
|
|
" \n",
|
|
" _, predicted = torch.max(output.data, 1)\n",
|
|
" total += target.size(0)\n",
|
|
" correct += (predicted == target).sum().item()\n",
|
|
" \n",
|
|
" avg_loss = loss_total / total\n",
|
|
" accuracy = 100.0 * correct / total\n",
|
|
" return avg_loss, accuracy"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "be0815fe",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def train_model(model, train_loader, valid_loader, criterion, optimizer, save_path:str=None,\n",
|
|
" epochs=20, validate_at=1, print_at=100, patience=3):\n",
|
|
" \n",
|
|
" if save_path is None:\n",
|
|
" save_path = os.path.join(\"..\", \"models\", \"nn_8_best_model.pth\")\n",
|
|
"\n",
|
|
" best_loss = float(\"inf\")\n",
|
|
" patience_counter = 0\n",
|
|
"\n",
|
|
" for epoch in range(1, epochs+1):\n",
|
|
" model.train()\n",
|
|
" running_loss = 0.0\n",
|
|
"\n",
|
|
" for batch_idx, (data, target) in enumerate(train_loader):\n",
|
|
" data, target = data.to(device), target.to(device)\n",
|
|
"\n",
|
|
" optimizer.zero_grad()\n",
|
|
" output = model(data)\n",
|
|
" loss = criterion(output, target)\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
" \n",
|
|
" running_loss += loss.item()\n",
|
|
" \n",
|
|
" if (batch_idx+1) % print_at == 0:\n",
|
|
" print(f\"Epoch [{epoch}/{epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}\")\n",
|
|
"\n",
|
|
" if epoch % validate_at == 0:\n",
|
|
" val_loss, val_acc = evaluate_model(model, valid_loader, criterion)\n",
|
|
" print(f\"Epoch [{epoch}/{epochs}] - Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%\")\n",
|
|
"\n",
|
|
" if val_loss < best_loss:\n",
|
|
" best_loss = val_loss\n",
|
|
" patience_counter = 0\n",
|
|
" torch.save(model.state_dict(), save_path)\n",
|
|
" print(f\">>> Found a better model and saved it at '{save_path}'\")\n",
|
|
" else:\n",
|
|
" patience_counter += 1\n",
|
|
" print(f\"No Improvement. Early Stopping Counter: {patience_counter}/{patience}\")\n",
|
|
" if patience_counter >= patience:\n",
|
|
" print(\"Early Stopping triggered.\")\n",
|
|
" break"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "3afc2ae3",
|
|
"metadata": {},
|
|
"source": [
|
|
"Last but not least wollen wir nun das Modell trainieren. Dazu definieren wir uns die Hyperparameter zuerst (manche sind der Form halber jetzt doppelt)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "32230cc3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"### HYPERPARAMETER ###\n",
|
|
"\n",
|
|
"model = MNISTClassifier().to(device)\n",
|
|
"criterion = nn.CrossEntropyLoss()\n",
|
|
"lr = 0.001\n",
|
|
"optimizer = optim.Adam(model.parameters(), lr=lr)\n",
|
|
"epochs = 20\n",
|
|
"validate_at = 1\n",
|
|
"print_at = 200\n",
|
|
"early_stopping_patience = 3\n",
|
|
"save_path = os.path.join(\"..\", \"models\", \"nn_8_best_model_mnist.pth\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "453d2b87",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_model(model, train_loader, valid_loader, criterion, optimizer, epochs=epochs, validate_at=validate_at, print_at=print_at, patience=early_stopping_patience, save_path=save_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "fd829890",
|
|
"metadata": {},
|
|
"source": [
|
|
"Am Schluss evaluieren wir noch das beste Modell:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d288c09d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.load_state_dict(torch.load(save_path))\n",
|
|
"test_loss, test_acc = evaluate_model(model, test_loader, criterion)\n",
|
|
"print(f\"Finaler Test Loss: {test_loss:.4f}\")\n",
|
|
"print(f\"Finale Test Accuracy: {test_acc:.2f}%\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "60aa40ec",
|
|
"metadata": {},
|
|
"source": [
|
|
"___"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "35559661",
|
|
"metadata": {},
|
|
"source": [
|
|
"Sind wir zufrieden? Was könnte man verbessern?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e2a63701",
|
|
"metadata": {},
|
|
"source": [
|
|
"Man könnte eine (andere) Transformation verwenden."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "cc538713",
|
|
"metadata": {},
|
|
"source": [
|
|
"Berechnen wir dazu mal den Mean und die Varianz (bzw. Standardabweichung der Trainingsdaten)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "40ad9243",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"mean = 0.\n",
|
|
"std = 0.\n",
|
|
"for imgs, _ in train_loader:\n",
|
|
" mean += imgs.mean()\n",
|
|
" std += imgs.std()\n",
|
|
"\n",
|
|
"mean /= len(train_loader)\n",
|
|
"std /= len(train_loader)\n",
|
|
"print(mean.item(), std.item())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f1cf0477",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_transform = transforms.Compose([\n",
|
|
" transforms.RandomRotation(10), # small data augmentation\n",
|
|
" transforms.ToTensor(),\n",
|
|
" transforms.Normalize((0.1307,), (0.3081,))\n",
|
|
"])\n",
|
|
"\n",
|
|
"test_transform = transforms.Compose([\n",
|
|
" transforms.ToTensor(),\n",
|
|
" transforms.Normalize((0.1307,), (0.3081,))\n",
|
|
"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e7d0e06f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data_path = os.path.join(\"..\", \"..\", \"_data\", \"mnist_data\")\n",
|
|
"\n",
|
|
"train_dataset = datasets.MNIST(root=data_path, train=True, download=True, transform=train_transform) # ToTensor makes images [0, 1] instead of {1,2,...,255}\n",
|
|
"test_dataset = datasets.MNIST(root=data_path, train=False, download=True, transform=test_transform)\n",
|
|
"\n",
|
|
"test_size = len(test_dataset) // 2\n",
|
|
"valid_size = len(test_dataset) - test_size\n",
|
|
"\n",
|
|
"test_dataset, valid_dataset = random_split(test_dataset, [test_size, valid_size])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2f8f37e5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch_size = 64\n",
|
|
"\n",
|
|
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
|
"test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
|
|
"valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "cf9c9c5c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"### HYPERPARAMETER ###\n",
|
|
"\n",
|
|
"model = MNISTClassifier().to(device)\n",
|
|
"criterion = nn.CrossEntropyLoss()\n",
|
|
"lr = 0.001\n",
|
|
"optimizer = optim.Adam(model.parameters(), lr=lr)\n",
|
|
"epochs = 10\n",
|
|
"validate_at = 1\n",
|
|
"print_at = 200\n",
|
|
"early_stopping_patience = 3\n",
|
|
"save_path = os.path.join(\"..\", \"models\", \"nn_8_best_model_mnist_transform.pth\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "eab17123",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_model(model, train_loader, valid_loader, criterion, optimizer, epochs=epochs, validate_at=validate_at, print_at=print_at, patience=early_stopping_patience, save_path=save_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "247a1257",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.load_state_dict(torch.load(save_path))\n",
|
|
"test_loss, test_acc = evaluate_model(model, test_loader, criterion)\n",
|
|
"print(f\"Finaler Test Loss: {test_loss:.4f}\")\n",
|
|
"print(f\"Finale Test Accuracy: {test_acc:.2f}%\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7e0cfe7b",
|
|
"metadata": {},
|
|
"source": [
|
|
"Man könnte nun natürlich auch noch weitere Epochen, ein noch größeres Netzwerk, andere Learning Rate, anderer Optimierer etc. verwenden. Wir sehen aber davon ab."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2c35fdc1",
|
|
"metadata": {},
|
|
"source": [
|
|
"___"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ddb2a810",
|
|
"metadata": {},
|
|
"source": [
|
|
"Wir verwenden jetzt das **Fashion-MNIST** Dataset und führen alles nochmal aus."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6b8e57aa",
|
|
"metadata": {},
|
|
"source": [
|
|
"Es besteht nun aus Kleidungsstücken und dazugehörig 10 Labels. Wir müssen also unser Modell in erster Linie nicht anpassen. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "06c551b9",
|
|
"metadata": {},
|
|
"source": [
|
|
"Auch hier gibt es $60\\, \\mathrm k$ Trainingsbilder und $10\\, \\mathrm k$ Testbilder."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f80dc124",
|
|
"metadata": {},
|
|
"source": [
|
|
"Wir kopieren nun die wichtigsten Dinge und ändern sie leicht ab."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "06ed22e0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data_path = os.path.join(\"..\", \"..\", \"_data\", \"fashion_mnist_data\")\n",
|
|
"\n",
|
|
"train_dataset = datasets.FashionMNIST(root=data_path, train=True, download=True, transform=transforms.ToTensor()) # ToTensor makes images [0, 1] instead of {1,2,...,255}\n",
|
|
"test_dataset = datasets.FashionMNIST(root=data_path, train=False, download=True, transform=transforms.ToTensor())\n",
|
|
"\n",
|
|
"test_size = len(test_dataset) // 2\n",
|
|
"valid_size = len(test_dataset) - test_size\n",
|
|
"\n",
|
|
"test_dataset, valid_dataset = random_split(test_dataset, [test_size, valid_size])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "11e3c83a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch_size = 64\n",
|
|
"\n",
|
|
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
|
"test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
|
|
"valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "dac37e73",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"examples = enumerate(train_loader)\n",
|
|
"batch_idx, (example_data, example_targets) = next(examples)\n",
|
|
"\n",
|
|
"label_dict = {\n",
|
|
" 0: \"T-Shirt\",\n",
|
|
" 1: \"Trouser\",\n",
|
|
" 2: \"Pullover\",\n",
|
|
" 3: \"Dress\",\n",
|
|
" 4: \"Coat\",\n",
|
|
" 5: \"Sandal\",\n",
|
|
" 6: \"Shirt\",\n",
|
|
" 7: \"Sneaker\",\n",
|
|
" 8: \"Bag\",\n",
|
|
" 9: \"Ankle Boot\"\n",
|
|
"}\n",
|
|
"\n",
|
|
"plt.figure(figsize=(8, 3))\n",
|
|
"for i in range(6):\n",
|
|
" plt.subplot(1, 6, i+1)\n",
|
|
" plt.tight_layout()\n",
|
|
" plt.imshow(example_data[i][0], cmap='gray', interpolation='none')\n",
|
|
" plt.title(f\"{label_dict[example_targets[i].item()]}\")\n",
|
|
" plt.xticks([])\n",
|
|
" plt.yticks([])\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7950b546",
|
|
"metadata": {},
|
|
"source": [
|
|
"Wir verwenden nun das gleiche Modell wie vorher, ändern aber den Klassennamen."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "25451f9e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class FashionMNISTClassifier(nn.Module):\n",
|
|
" def __init__(self):\n",
|
|
" super().__init__()\n",
|
|
" self.layers = nn.Sequential(\n",
|
|
" nn.Flatten(), # Very important! Why?\n",
|
|
" nn.Linear(28*28, 256),\n",
|
|
" nn.ReLU(),\n",
|
|
" nn.Linear(256, 128),\n",
|
|
" nn.ReLU(),\n",
|
|
" nn.Linear(128,64),\n",
|
|
" nn.ReLU(),\n",
|
|
" nn.Linear(64, 10),\n",
|
|
" )\n",
|
|
" def forward(self, x):\n",
|
|
" return self.layers(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e60d3519",
|
|
"metadata": {},
|
|
"source": [
|
|
"Nun trainieren wir das Modell."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "12b4a9d7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"### HYPERPARAMETER ###\n",
|
|
"\n",
|
|
"model = FashionMNISTClassifier().to(device)\n",
|
|
"criterion = nn.CrossEntropyLoss()\n",
|
|
"lr = 0.001\n",
|
|
"optimizer = optim.Adam(model.parameters(), lr=lr)\n",
|
|
"epochs = 20\n",
|
|
"validate_at = 1\n",
|
|
"print_at = 200\n",
|
|
"early_stopping_patience = 3\n",
|
|
"save_path = os.path.join(\"..\", \"models\", \"nn_8_best_model_fashion_mnist.pth\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "eecd887f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_model(model, train_loader, valid_loader, criterion, optimizer, epochs=epochs, validate_at=validate_at, print_at=print_at, patience=early_stopping_patience, save_path=save_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b63239a2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.load_state_dict(torch.load(save_path))\n",
|
|
"test_loss, test_acc = evaluate_model(model, test_loader, criterion)\n",
|
|
"print(f\"Finaler Test Loss: {test_loss:.4f}\")\n",
|
|
"print(f\"Finale Test Accuracy: {test_acc:.2f}%\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "cfb071d1",
|
|
"metadata": {},
|
|
"source": [
|
|
"Diese Performance ist nicht wirklich gut. Für 10 Klassen bedeutet das, dass wir im Mittel 1 von 10 Klassen falsch zuordnen. Wir betrachten noch kurz die Confusion-Matrix."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f4be3ab7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Confusion Matrix of FashionMNIST model\n",
|
|
"\n",
|
|
"test_data = test_loader.dataset.dataset.data[test_loader.dataset.indices]\n",
|
|
"test_targets = test_loader.dataset.dataset.targets[test_loader.dataset.indices]\n",
|
|
"\n",
|
|
"pred = model(test_data.unsqueeze(1).float().to(device))\n",
|
|
"cm = confusion_matrix(test_targets.cpu(), pred.argmax(dim=1).cpu())\n",
|
|
"\n",
|
|
"plt.figure(figsize=(8, 6))\n",
|
|
"sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=label_dict.values(), yticklabels=label_dict.values())\n",
|
|
"plt.title(\"Confusion Matrix for FashionMNIST Classification\")\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2f6df1f8",
|
|
"metadata": {},
|
|
"source": [
|
|
"(zur Erinnerung, $y$-Achse entspricht der Ground-Truth und $x$-Achse entspricht der Vorhersage)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c792fa9a",
|
|
"metadata": {},
|
|
"source": [
|
|
"---"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ef946e2e",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Reicht also immer ein Fully-Connected Neuronal Netzwerk aus?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f438ffcc",
|
|
"metadata": {},
|
|
"source": [
|
|
"**Nein!**\n",
|
|
"\n",
|
|
"Es gibt viele Probleme, die andere Architekturen erwarten. Auch, wenn man in gewissen Situationen vielleicht mit so einer Performance zufrieden ist, werden wir, insbesondere, wenn wir uns später zum Beispiel der **Image-Inpainting** Challenge widmen, sehen, dass wir andere Architekturen brauchen, da diese viel besser funktionieren."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "43a29933",
|
|
"metadata": {},
|
|
"source": [
|
|
"Als nächstes werden wir also eine neue Architektur kennenlernen, welche mit Bildern noch viel besser umgehen kann, als Feed-Forward Neuronal Netzwerke. Die Rede ist von sogenannten ***CNN's*** (*Convolutional Neuronal Networks*)."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "dsai",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.13.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|