1735 lines
65 KiB
Plaintext
1735 lines
65 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"<div style=\"\n",
|
||
" border: 2px solid #4CAF50; \n",
|
||
" padding: 15px; \n",
|
||
" background-color: #f4f4f4; \n",
|
||
" border-radius: 10px; \n",
|
||
" align-items: center;\">\n",
|
||
"\n",
|
||
"<h1 style=\"margin: 0; color: #4CAF50;\">Supervised ML Modelle: Decision Trees und Random Forest Modelle</h1>\n",
|
||
"<h2 style=\"margin: 5px 0; color: #555;\">DSAI</h2>\n",
|
||
"<h3 style=\"margin: 5px 0; color: #555;\">Jakob Eggl</h3>\n",
|
||
"\n",
|
||
"<div style=\"flex-shrink: 0;\">\n",
|
||
" <img src=\"https://www.htl-grieskirchen.at/wp/wp-content/uploads/2022/11/logo_bildschirm-1024x503.png\" alt=\"Logo\" style=\"width: 250px; height: auto;\"/>\n",
|
||
"</div>\n",
|
||
"<p1> © 2024/25 Jakob Eggl. Nutzung oder Verbreitung nur mit ausdrücklicher Genehmigung des Autors.</p1>\n",
|
||
"</div>\n",
|
||
"<div style=\"flex: 1;\">\n",
|
||
"</div> "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Decision Trees (=Entscheidungsbäume)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Als letztes Modell der Supervised Machine Learning Modelle, die wir uns im 4. Jahrgang ansehen, wollen wir uns den sogenannten **Decision Tree (=Entscheidungsbaum)** und später die daraus resultierenden **Random Forest** Klassifizierer ansehen. Decision Trees sind in ihrer Idee sehr fundamental, beweisen aber trotzdem immer wieder ihre Daseinsberechtigung mit guter Performance bei vielen Anwendungen."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Der Decision Tree (*DT*) $\\ldots$\n",
|
||
"* wird bei uns verwendet zur Klassifizierung\n",
|
||
"* kann jeden noch so komplizierten Zusammenhang lernen (Overfitting?)\n",
|
||
"* basiert auf einer baumartigen Struktur, in der jeder Konten eine Entscheidung (Bedingung) darstellt\n",
|
||
"* Es gibt:\n",
|
||
" * **Wurzelknoten (Root node):** Der Ausgangspunkt des Entscheidungsprozess, an dem die erste Splitting Regel angewendet wird\n",
|
||
" * **Innere Knoten (Internal nodes, decision nodes):** Jeder dieser Knoten repräsentiert eine Entscheidung, basierend auf einer bestimmten Eigenschaft der Daten\n",
|
||
" * **Blätter (Leaf Nodes):** Diese repräsentieren die endgültigen Klassfikiationen oder Vorhersagen"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"\n",
|
||
"\n",
|
||
"(von https://blog.gopenai.com/decision-trees-a-deep-dive-into-making-smarter-decisions-b0be706513af)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Ein Beispiel Dataset"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"| Day | Outlook | Temp. | Humidity | Windy? | Class |\n",
|
||
"|-----|---------|--------|----------|--------|------------|\n",
|
||
"| 1 | sunny | hot | high | false | Don’t Play |\n",
|
||
"| 2 | sunny | hot | high | true | Don’t Play |\n",
|
||
"| 3 | overcast| hot | high | false | Play |\n",
|
||
"| 4 | rain | mild | high | false | Play |\n",
|
||
"| 5 | rain | cool | normal | false | Play |\n",
|
||
"| 6 | rain | cool | normal | true | Don’t Play |\n",
|
||
"| 7 | overcast| cool | normal | true | Play |\n",
|
||
"| 8 | sunny | mild | high | false | Don’t Play |\n",
|
||
"| 9 | sunny | cool | normal | false | Play |\n",
|
||
"| 10 | rain | mild | normal | false | Play |\n",
|
||
"| 11 | sunny | mild | normal | true | Play |\n",
|
||
"| 12 | overcast| mild | high | true | Play |\n",
|
||
"| 13 | overcast| hot | normal | false | Play |\n",
|
||
"| 14 | rain | mild | high | true | Don’t Play |\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"\n",
|
||
"\n",
|
||
"(von *Artificial Intelligence: A Modern Approach*)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Beschreibung des Algorithmus\n",
|
||
"\n",
|
||
"Wir beschreiben nun den Algorithmus des Decision Trees. Dies wollen wir unterteilen in Trainingsprozess und Inferenzprozess."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Training:**\n",
|
||
"\n",
|
||
"* Iterativer Prozess\n",
|
||
"* Daten werden basierend auf einem **Splitting Kriterium** (siehe später) aufgeteilt\n",
|
||
"* Ziel ist es, ein Kriterium zu finden, welches die Daten optimal trennt, sodass eine möglichst homogene Unterteilung an den Leaf Nodes entsteht.\n",
|
||
"* Der Trainingsprozess wird entweder beendet, sobald die Daten aller *Leaf Nodes* jeweils nur eine Klasse haben, oder\n",
|
||
"* Wenn eine maximale Tiefe (Hyperparameter) erreicht worden ist"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Inferenz:**\n",
|
||
"\n",
|
||
"* Der Baum wird nun festgehalten und verändert sich nicht mehr\n",
|
||
"* Ein neuer Datenpunkt wird im Anschluss durch diesen Entscheidungsbaum geschickt und landet in einem Leaf Node (Durchlauf wie verschachtelte If-Else Anweisungen).\n",
|
||
"* Der neue Datenpunkt bekommt nun die Klasse, welche im Leaf Node am meisten vorkommt."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Splitting-Kriterien\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"*Welcher dieser Decision Trees wäre hier der beste?*"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"\n",
|
||
"\n",
|
||
"(von https://www.analyticsvidhya.com/blog/2021/02/how-to-split-decision-tree-the-pursuit-to-achieve-purity/)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Das wichtigste Element beim Aufbau eines Entscheidungsbaums ist das Kriterium, nach dem an jedem Knoten eine Aufteilung stattfindet, wobei wir uns im Training bei jedem Nicht-Leaf-Knoten entscheiden müssen, welches Feature wir als Nächstes zum Splitten der Daten verwenden. Es gibt mehrere häufig verwendete Methoden dafür und sind ein *wichtiger* **Hyperparameter** für Decision Trees. Wir betrachten die wichtigsten 2 Möglichkeiten."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Splitting Criterion: Gini-Impurity:\n",
|
||
"\n",
|
||
"Die Gini-Impurity misst, wie \"gemischt\" die Klassen in einer Gruppe sind. Ein perfekter Split würde Gruppen erzeugen, die nur eine Klasse enthalten (also eine Reinheit von 100 %). Je niedriger die Gini-Impurity, desto besser ist der Split.\n",
|
||
"\n",
|
||
"**Beispiel:**\n",
|
||
"\n",
|
||
"Angenommen, wir haben eine Datenmenge mit 10 Einträgen, von denen 6 \"Ja\" und 4 \"Nein\" sind. Wenn wir nach einem bestimmten Merkmal aufteilen, sodass in einer Gruppe 5 \"Ja\" und 0 \"Nein\" sind und in der anderen 1 \"Ja\" und 4 \"Nein\", dann ist dieser Split besser als ein Split, bei dem beide Gruppen gemischte Werte enthalten."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Splitting Criterion: Entropy (Information Gain):\n",
|
||
"\n",
|
||
"Die Entropie misst die \"Unordnung\" oder den Informationsgehalt in einer Gruppe. Ein Split, der Gruppen erzeugt, die eine Klasse enthalten (geringe Entropie), ist besser als einer, bei dem die Klassen gemischt sind (hohe Entropie). Information Gain beschreibt die Reduzierung der Entropie durch den Split.\n",
|
||
"\n",
|
||
"**Beispiel:**\n",
|
||
"\n",
|
||
"Stellen wir uns vor, wir entscheiden, ob eine Person einen Kredit erhält (\"Ja\" oder \"Nein\"). Ein Attribut ist das Einkommen der Person (hoch/mittel/niedrig). Falls ein Split nach \"hoch\" und \"mittel/niedrig\" dazu führt, dass in einer Gruppe fast nur \"Ja\" und in der anderen fast nur \"Nein\" vorkommen, dann ist die Entropie stark gesunken, und der Split ist gut."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Splitting von numerischen Features\n",
|
||
"\n",
|
||
"Das Splitting eines numerischen Features in einem Decision Tree unterscheidet sich von kategorischen Features, da Zahlen eine kontinuierliche Skala haben. Daher muss ein optimaler Schwellenwert (**Threshold**) gefunden werden, um die Daten in zwei Gruppen zu teilen. Der Vorgang ist sehr ähnlich zu den kategorischen Daten und kann folgendermaßen gemacht werden.\n",
|
||
"\n",
|
||
"1. Mögliche Schwellenwerte identifizieren:\n",
|
||
" * Für jedes numerische Feature werden mögliche Schwellenwerte getestet.\n",
|
||
" * Typischerweise werden diese Schwellenwerte zwischen zwei benachbarten Werten in den Daten gewählt.\n",
|
||
"\n",
|
||
"2. Daten in zwei Gruppen teilen:\n",
|
||
" * Ein Split erfolgt in der Regel in die Form:\n",
|
||
" * Linke Gruppe: Werte $\\leq$ Schwellenwert\n",
|
||
" * Rechte Gruppe: Werte $>$ Schwellenwert\n",
|
||
"\n",
|
||
"3. Qualität des Splits bewerten:\n",
|
||
" * Das gewählte Splitting-Kriterium (z. B. Gini-Impurity, Entropy oder Variance Reduction) wird auf den entstehenden Gruppen berechnet.\n",
|
||
" * Ziel ist es, den besten Split zu finden (abhängig von Splitting Kriterium).\n",
|
||
"\n",
|
||
"4. Besten Split wählen:\n",
|
||
" * Der Schwellenwert mit der höchsten Reduktion von Unreinheit (z. B. niedrigste Gini-Impurity) wird als optimaler Split für diesen Knoten verwendet."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Hyperparameter"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Pruning:\n",
|
||
"\n",
|
||
"* Decision Trees neigen stark zu Overfitting, wenn der Baum groß genug werden darf\n",
|
||
"* Bäume müssen daher \"geschnitten\" (*=pruning*) werden\n",
|
||
"* Beim Pruning wird zum Beispiel:\n",
|
||
" * Die maximale Tiefe festgelegt\n",
|
||
" * Max. Anzahl der Blätter festgelegt\n",
|
||
"\n",
|
||
"Außerdem können manche Features aus dem Dataset entfernt werden, falls diese nicht benötigt werden. Dies beugt auch Overfitting vor!\n",
|
||
"\n",
|
||
"Eine weitere Möglichkeit ist es, mehrere Trees zu verwenden (siehe später **Random Forests**)!"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Vorteile und Nachteile"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Vorteile**:\n",
|
||
"* Sehr einfach zu interpretieren\n",
|
||
"* Wenig Datenvorverarbeitung notwendig (Normalisieren?)\n",
|
||
"* Robust gegenüber irrelevante Features\n",
|
||
"\n",
|
||
"**Nachteile:**\n",
|
||
"* Overfitting gefährdet, wenn Baum zu groß (tief) wird\n",
|
||
"* Sehr instabil (kleine Änderungen der Daten können große Änderungen des Baumes verursachen)\n",
|
||
"* Baum kann sehr einseitig werden, wenn die Klassen sehr unbalanciert sind (Verteilung der Labels ist nicht gleich)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Decision Trees in Python"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 60,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.datasets import load_iris\n",
|
||
"from sklearn.tree import DecisionTreeClassifier\n",
|
||
"from sklearn.model_selection import train_test_split"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 61,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Load dataset\n",
|
||
"X, y = load_iris(as_frame=False, return_X_y=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 62,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Split into training and test sets\n",
|
||
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 63,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"array([[4.6, 3.6, 1. , 0.2],\n",
|
||
" [5.7, 4.4, 1.5, 0.4],\n",
|
||
" [6.7, 3.1, 4.4, 1.4],\n",
|
||
" [4.8, 3.4, 1.6, 0.2],\n",
|
||
" [4.4, 3.2, 1.3, 0.2],\n",
|
||
" [6.3, 2.5, 5. , 1.9],\n",
|
||
" [6.4, 3.2, 4.5, 1.5],\n",
|
||
" [5.2, 3.5, 1.5, 0.2],\n",
|
||
" [5. , 3.6, 1.4, 0.2],\n",
|
||
" [5.2, 4.1, 1.5, 0.1],\n",
|
||
" [5.8, 2.7, 5.1, 1.9],\n",
|
||
" [6. , 3.4, 4.5, 1.6],\n",
|
||
" [6.7, 3.1, 4.7, 1.5],\n",
|
||
" [5.4, 3.9, 1.3, 0.4],\n",
|
||
" [5.4, 3.7, 1.5, 0.2],\n",
|
||
" [5.5, 2.4, 3.7, 1. ],\n",
|
||
" [6.3, 2.8, 5.1, 1.5],\n",
|
||
" [6.4, 3.1, 5.5, 1.8],\n",
|
||
" [6.6, 3. , 4.4, 1.4],\n",
|
||
" [7.2, 3.6, 6.1, 2.5],\n",
|
||
" [5.7, 2.9, 4.2, 1.3],\n",
|
||
" [7.6, 3. , 6.6, 2.1],\n",
|
||
" [5.6, 3. , 4.5, 1.5],\n",
|
||
" [5.1, 3.5, 1.4, 0.2],\n",
|
||
" [7.7, 2.8, 6.7, 2. ],\n",
|
||
" [5.8, 2.7, 4.1, 1. ],\n",
|
||
" [5.2, 3.4, 1.4, 0.2],\n",
|
||
" [5. , 3.5, 1.3, 0.3],\n",
|
||
" [5.1, 3.8, 1.9, 0.4],\n",
|
||
" [5. , 2. , 3.5, 1. ],\n",
|
||
" [6.3, 2.7, 4.9, 1.8],\n",
|
||
" [4.8, 3.4, 1.9, 0.2],\n",
|
||
" [5. , 3. , 1.6, 0.2],\n",
|
||
" [5.1, 3.3, 1.7, 0.5],\n",
|
||
" [5.6, 2.7, 4.2, 1.3],\n",
|
||
" [5.1, 3.4, 1.5, 0.2],\n",
|
||
" [5.7, 3. , 4.2, 1.2],\n",
|
||
" [7.7, 3.8, 6.7, 2.2],\n",
|
||
" [4.6, 3.2, 1.4, 0.2],\n",
|
||
" [6.2, 2.9, 4.3, 1.3],\n",
|
||
" [5.7, 2.5, 5. , 2. ],\n",
|
||
" [5.5, 4.2, 1.4, 0.2],\n",
|
||
" [6. , 3. , 4.8, 1.8],\n",
|
||
" [5.8, 2.7, 5.1, 1.9],\n",
|
||
" [6. , 2.2, 4. , 1. ],\n",
|
||
" [5.4, 3. , 4.5, 1.5],\n",
|
||
" [6.2, 3.4, 5.4, 2.3],\n",
|
||
" [5.5, 2.3, 4. , 1.3],\n",
|
||
" [5.4, 3.9, 1.7, 0.4],\n",
|
||
" [5. , 2.3, 3.3, 1. ],\n",
|
||
" [6.4, 2.7, 5.3, 1.9],\n",
|
||
" [5. , 3.3, 1.4, 0.2],\n",
|
||
" [5. , 3.2, 1.2, 0.2],\n",
|
||
" [5.5, 2.4, 3.8, 1.1],\n",
|
||
" [6.7, 3. , 5. , 1.7],\n",
|
||
" [4.9, 3.1, 1.5, 0.2],\n",
|
||
" [5.8, 2.8, 5.1, 2.4],\n",
|
||
" [5. , 3.4, 1.5, 0.2],\n",
|
||
" [5. , 3.5, 1.6, 0.6],\n",
|
||
" [5.9, 3.2, 4.8, 1.8],\n",
|
||
" [5.1, 2.5, 3. , 1.1],\n",
|
||
" [6.9, 3.2, 5.7, 2.3],\n",
|
||
" [6. , 2.7, 5.1, 1.6],\n",
|
||
" [6.1, 2.6, 5.6, 1.4],\n",
|
||
" [7.7, 3. , 6.1, 2.3],\n",
|
||
" [5.5, 2.5, 4. , 1.3],\n",
|
||
" [4.4, 2.9, 1.4, 0.2],\n",
|
||
" [4.3, 3. , 1.1, 0.1],\n",
|
||
" [6. , 2.2, 5. , 1.5],\n",
|
||
" [7.2, 3.2, 6. , 1.8],\n",
|
||
" [4.6, 3.1, 1.5, 0.2],\n",
|
||
" [5.1, 3.5, 1.4, 0.3],\n",
|
||
" [4.4, 3. , 1.3, 0.2],\n",
|
||
" [6.3, 2.5, 4.9, 1.5],\n",
|
||
" [6.3, 3.4, 5.6, 2.4],\n",
|
||
" [4.6, 3.4, 1.4, 0.3],\n",
|
||
" [6.8, 3. , 5.5, 2.1],\n",
|
||
" [6.3, 3.3, 6. , 2.5],\n",
|
||
" [4.7, 3.2, 1.3, 0.2],\n",
|
||
" [6.1, 2.9, 4.7, 1.4],\n",
|
||
" [6.5, 2.8, 4.6, 1.5],\n",
|
||
" [6.2, 2.8, 4.8, 1.8],\n",
|
||
" [7. , 3.2, 4.7, 1.4],\n",
|
||
" [6.4, 3.2, 5.3, 2.3],\n",
|
||
" [5.1, 3.8, 1.6, 0.2],\n",
|
||
" [6.9, 3.1, 5.4, 2.1],\n",
|
||
" [5.9, 3. , 4.2, 1.5],\n",
|
||
" [6.5, 3. , 5.2, 2. ],\n",
|
||
" [5.7, 2.6, 3.5, 1. ],\n",
|
||
" [5.2, 2.7, 3.9, 1.4],\n",
|
||
" [6.1, 3. , 4.6, 1.4],\n",
|
||
" [4.5, 2.3, 1.3, 0.3],\n",
|
||
" [6.6, 2.9, 4.6, 1.3],\n",
|
||
" [5.5, 2.6, 4.4, 1.2],\n",
|
||
" [5.3, 3.7, 1.5, 0.2],\n",
|
||
" [5.6, 3. , 4.1, 1.3],\n",
|
||
" [7.3, 2.9, 6.3, 1.8],\n",
|
||
" [6.7, 3.3, 5.7, 2.1],\n",
|
||
" [5.1, 3.7, 1.5, 0.4],\n",
|
||
" [4.9, 2.4, 3.3, 1. ],\n",
|
||
" [6.7, 3.3, 5.7, 2.5],\n",
|
||
" [7.2, 3. , 5.8, 1.6],\n",
|
||
" [4.9, 3.6, 1.4, 0.1],\n",
|
||
" [6.7, 3.1, 5.6, 2.4],\n",
|
||
" [4.9, 3. , 1.4, 0.2],\n",
|
||
" [6.9, 3.1, 4.9, 1.5],\n",
|
||
" [7.4, 2.8, 6.1, 1.9],\n",
|
||
" [6.3, 2.9, 5.6, 1.8],\n",
|
||
" [5.7, 2.8, 4.1, 1.3],\n",
|
||
" [6.5, 3. , 5.5, 1.8],\n",
|
||
" [6.3, 2.3, 4.4, 1.3],\n",
|
||
" [6.4, 2.9, 4.3, 1.3],\n",
|
||
" [5.6, 2.8, 4.9, 2. ],\n",
|
||
" [5.9, 3. , 5.1, 1.8],\n",
|
||
" [5.4, 3.4, 1.7, 0.2],\n",
|
||
" [6.1, 2.8, 4. , 1.3],\n",
|
||
" [4.9, 2.5, 4.5, 1.7],\n",
|
||
" [5.8, 4. , 1.2, 0.2],\n",
|
||
" [5.8, 2.6, 4. , 1.2],\n",
|
||
" [7.1, 3. , 5.9, 2.1]])"
|
||
]
|
||
},
|
||
"execution_count": 63,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"X_train"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 64,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(120, 4)"
|
||
]
|
||
},
|
||
"execution_count": 64,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"X_train.shape"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 65,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style>#sk-container-id-7 {\n",
|
||
" /* Definition of color scheme common for light and dark mode */\n",
|
||
" --sklearn-color-text: black;\n",
|
||
" --sklearn-color-line: gray;\n",
|
||
" /* Definition of color scheme for unfitted estimators */\n",
|
||
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
|
||
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
|
||
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
|
||
" --sklearn-color-unfitted-level-3: chocolate;\n",
|
||
" /* Definition of color scheme for fitted estimators */\n",
|
||
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
|
||
" --sklearn-color-fitted-level-1: #d4ebff;\n",
|
||
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
|
||
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
|
||
"\n",
|
||
" /* Specific color for light theme */\n",
|
||
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
|
||
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
|
||
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
|
||
" --sklearn-color-icon: #696969;\n",
|
||
"\n",
|
||
" @media (prefers-color-scheme: dark) {\n",
|
||
" /* Redefinition of color scheme for dark theme */\n",
|
||
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
|
||
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
|
||
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
|
||
" --sklearn-color-icon: #878787;\n",
|
||
" }\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 pre {\n",
|
||
" padding: 0;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 input.sk-hidden--visually {\n",
|
||
" border: 0;\n",
|
||
" clip: rect(1px 1px 1px 1px);\n",
|
||
" clip: rect(1px, 1px, 1px, 1px);\n",
|
||
" height: 1px;\n",
|
||
" margin: -1px;\n",
|
||
" overflow: hidden;\n",
|
||
" padding: 0;\n",
|
||
" position: absolute;\n",
|
||
" width: 1px;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-dashed-wrapped {\n",
|
||
" border: 1px dashed var(--sklearn-color-line);\n",
|
||
" margin: 0 0.4em 0.5em 0.4em;\n",
|
||
" box-sizing: border-box;\n",
|
||
" padding-bottom: 0.4em;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-container {\n",
|
||
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
|
||
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
|
||
" so we also need the `!important` here to be able to override the\n",
|
||
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
|
||
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
|
||
" display: inline-block !important;\n",
|
||
" position: relative;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-text-repr-fallback {\n",
|
||
" display: none;\n",
|
||
"}\n",
|
||
"\n",
|
||
"div.sk-parallel-item,\n",
|
||
"div.sk-serial,\n",
|
||
"div.sk-item {\n",
|
||
" /* draw centered vertical line to link estimators */\n",
|
||
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
|
||
" background-size: 2px 100%;\n",
|
||
" background-repeat: no-repeat;\n",
|
||
" background-position: center center;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Parallel-specific style estimator block */\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-parallel-item::after {\n",
|
||
" content: \"\";\n",
|
||
" width: 100%;\n",
|
||
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
|
||
" flex-grow: 1;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-parallel {\n",
|
||
" display: flex;\n",
|
||
" align-items: stretch;\n",
|
||
" justify-content: center;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
" position: relative;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-parallel-item {\n",
|
||
" display: flex;\n",
|
||
" flex-direction: column;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-parallel-item:first-child::after {\n",
|
||
" align-self: flex-end;\n",
|
||
" width: 50%;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-parallel-item:last-child::after {\n",
|
||
" align-self: flex-start;\n",
|
||
" width: 50%;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-parallel-item:only-child::after {\n",
|
||
" width: 0;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Serial-specific style estimator block */\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-serial {\n",
|
||
" display: flex;\n",
|
||
" flex-direction: column;\n",
|
||
" align-items: center;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
" padding-right: 1em;\n",
|
||
" padding-left: 1em;\n",
|
||
"}\n",
|
||
"\n",
|
||
"\n",
|
||
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
|
||
"clickable and can be expanded/collapsed.\n",
|
||
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
|
||
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
|
||
"*/\n",
|
||
"\n",
|
||
"/* Pipeline and ColumnTransformer style (default) */\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-toggleable {\n",
|
||
" /* Default theme specific background. It is overwritten whether we have a\n",
|
||
" specific estimator or a Pipeline/ColumnTransformer */\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Toggleable label */\n",
|
||
"#sk-container-id-7 label.sk-toggleable__label {\n",
|
||
" cursor: pointer;\n",
|
||
" display: block;\n",
|
||
" width: 100%;\n",
|
||
" margin-bottom: 0;\n",
|
||
" padding: 0.5em;\n",
|
||
" box-sizing: border-box;\n",
|
||
" text-align: center;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 label.sk-toggleable__label-arrow:before {\n",
|
||
" /* Arrow on the left of the label */\n",
|
||
" content: \"▸\";\n",
|
||
" float: left;\n",
|
||
" margin-right: 0.25em;\n",
|
||
" color: var(--sklearn-color-icon);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 label.sk-toggleable__label-arrow:hover:before {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Toggleable content - dropdown */\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-toggleable__content {\n",
|
||
" max-height: 0;\n",
|
||
" max-width: 0;\n",
|
||
" overflow: hidden;\n",
|
||
" text-align: left;\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-toggleable__content.fitted {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-toggleable__content pre {\n",
|
||
" margin: 0.2em;\n",
|
||
" border-radius: 0.25em;\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-toggleable__content.fitted pre {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
|
||
" /* Expand drop-down */\n",
|
||
" max-height: 200px;\n",
|
||
" max-width: 100%;\n",
|
||
" overflow: auto;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
|
||
" content: \"▾\";\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Pipeline/ColumnTransformer-specific style */\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Estimator-specific style */\n",
|
||
"\n",
|
||
"/* Colorize estimator box */\n",
|
||
"#sk-container-id-7 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-label label.sk-toggleable__label,\n",
|
||
"#sk-container-id-7 div.sk-label label {\n",
|
||
" /* The background is the default theme color */\n",
|
||
" color: var(--sklearn-color-text-on-default-background);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* On hover, darken the color of the background */\n",
|
||
"#sk-container-id-7 div.sk-label:hover label.sk-toggleable__label {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Label box, darken color on hover, fitted */\n",
|
||
"#sk-container-id-7 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Estimator label */\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-label label {\n",
|
||
" font-family: monospace;\n",
|
||
" font-weight: bold;\n",
|
||
" display: inline-block;\n",
|
||
" line-height: 1.2em;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-label-container {\n",
|
||
" text-align: center;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Estimator-specific */\n",
|
||
"#sk-container-id-7 div.sk-estimator {\n",
|
||
" font-family: monospace;\n",
|
||
" border: 1px dotted var(--sklearn-color-border-box);\n",
|
||
" border-radius: 0.25em;\n",
|
||
" box-sizing: border-box;\n",
|
||
" margin-bottom: 0.5em;\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-estimator.fitted {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* on hover */\n",
|
||
"#sk-container-id-7 div.sk-estimator:hover {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 div.sk-estimator.fitted:hover {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
|
||
"\n",
|
||
"/* Common style for \"i\" and \"?\" */\n",
|
||
"\n",
|
||
".sk-estimator-doc-link,\n",
|
||
"a:link.sk-estimator-doc-link,\n",
|
||
"a:visited.sk-estimator-doc-link {\n",
|
||
" float: right;\n",
|
||
" font-size: smaller;\n",
|
||
" line-height: 1em;\n",
|
||
" font-family: monospace;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
" border-radius: 1em;\n",
|
||
" height: 1em;\n",
|
||
" width: 1em;\n",
|
||
" text-decoration: none !important;\n",
|
||
" margin-left: 1ex;\n",
|
||
" /* unfitted */\n",
|
||
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
|
||
" color: var(--sklearn-color-unfitted-level-1);\n",
|
||
"}\n",
|
||
"\n",
|
||
".sk-estimator-doc-link.fitted,\n",
|
||
"a:link.sk-estimator-doc-link.fitted,\n",
|
||
"a:visited.sk-estimator-doc-link.fitted {\n",
|
||
" /* fitted */\n",
|
||
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
|
||
" color: var(--sklearn-color-fitted-level-1);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* On hover */\n",
|
||
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
|
||
".sk-estimator-doc-link:hover,\n",
|
||
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
|
||
".sk-estimator-doc-link:hover {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-3);\n",
|
||
" color: var(--sklearn-color-background);\n",
|
||
" text-decoration: none;\n",
|
||
"}\n",
|
||
"\n",
|
||
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
|
||
".sk-estimator-doc-link.fitted:hover,\n",
|
||
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
|
||
".sk-estimator-doc-link.fitted:hover {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-3);\n",
|
||
" color: var(--sklearn-color-background);\n",
|
||
" text-decoration: none;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Span, style for the box shown on hovering the info icon */\n",
|
||
".sk-estimator-doc-link span {\n",
|
||
" display: none;\n",
|
||
" z-index: 9999;\n",
|
||
" position: relative;\n",
|
||
" font-weight: normal;\n",
|
||
" right: .2ex;\n",
|
||
" padding: .5ex;\n",
|
||
" margin: .5ex;\n",
|
||
" width: min-content;\n",
|
||
" min-width: 20ex;\n",
|
||
" max-width: 50ex;\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" box-shadow: 2pt 2pt 4pt #999;\n",
|
||
" /* unfitted */\n",
|
||
" background: var(--sklearn-color-unfitted-level-0);\n",
|
||
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
|
||
"}\n",
|
||
"\n",
|
||
".sk-estimator-doc-link.fitted span {\n",
|
||
" /* fitted */\n",
|
||
" background: var(--sklearn-color-fitted-level-0);\n",
|
||
" border: var(--sklearn-color-fitted-level-3);\n",
|
||
"}\n",
|
||
"\n",
|
||
".sk-estimator-doc-link:hover span {\n",
|
||
" display: block;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
|
||
"\n",
|
||
"#sk-container-id-7 a.estimator_doc_link {\n",
|
||
" float: right;\n",
|
||
" font-size: 1rem;\n",
|
||
" line-height: 1em;\n",
|
||
" font-family: monospace;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
" border-radius: 1rem;\n",
|
||
" height: 1rem;\n",
|
||
" width: 1rem;\n",
|
||
" text-decoration: none;\n",
|
||
" /* unfitted */\n",
|
||
" color: var(--sklearn-color-unfitted-level-1);\n",
|
||
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 a.estimator_doc_link.fitted {\n",
|
||
" /* fitted */\n",
|
||
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
|
||
" color: var(--sklearn-color-fitted-level-1);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* On hover */\n",
|
||
"#sk-container-id-7 a.estimator_doc_link:hover {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-3);\n",
|
||
" color: var(--sklearn-color-background);\n",
|
||
" text-decoration: none;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-7 a.estimator_doc_link.fitted:hover {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-3);\n",
|
||
"}\n",
|
||
"</style><div id=\"sk-container-id-7\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>DecisionTreeClassifier(random_state=42)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-7\" type=\"checkbox\" checked><label for=\"sk-estimator-id-7\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> DecisionTreeClassifier<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.tree.DecisionTreeClassifier.html\">?<span>Documentation for DecisionTreeClassifier</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>DecisionTreeClassifier(random_state=42)</pre></div> </div></div></div></div>"
|
||
],
|
||
"text/plain": [
|
||
"DecisionTreeClassifier(random_state=42)"
|
||
]
|
||
},
|
||
"execution_count": 65,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# Train a decision tree with pruning (ccp_alpha controls pruning)\n",
|
||
"tree = DecisionTreeClassifier(random_state=42)\n",
|
||
"tree.fit(X_train, y_train)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 66,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train accuracy: 1.0\n",
|
||
"Test accuracy: 1.0\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Print accuracy\n",
|
||
"print(\"Train accuracy:\", tree.score(X_train, y_train))\n",
|
||
"print(\"Test accuracy:\", tree.score(X_test, y_test))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 67,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train accuracy: 0.9583333333333334\n",
|
||
"Test accuracy: 1.0\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# We can also use pruning\n",
|
||
"pruned_tree = DecisionTreeClassifier(random_state=42, max_depth=3)\n",
|
||
"pruned_tree.fit(X_train, y_train)\n",
|
||
"\n",
|
||
"# Print accuracy\n",
|
||
"print(\"Train accuracy:\", pruned_tree.score(X_train, y_train))\n",
|
||
"print(\"Test accuracy:\", pruned_tree.score(X_test, y_test))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 68,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Feature importance: [0. 0.01667014 0.90614339 0.07718647]\n",
|
||
"Feature importance: [0. 0. 0.93462632 0.06537368]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Ein weiterer interessanter Aspekt ist es, die Feature Importance zu berechnen\n",
|
||
"# Diese gibt an, wie wichtig die einzelnen Features für die Klassifikation sind\n",
|
||
"\n",
|
||
"# Print feature importance\n",
|
||
"print(\"Feature importance:\", tree.feature_importances_)\n",
|
||
"print(\"Feature importance:\", pruned_tree.feature_importances_)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Somit können wir auch mit Decision Trees entscheiden, welche Features ggf. weggelassen werden können in einem Dataset. Dies ist aber eher kostspielig!"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Random Forests"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Eine weitere Möglichkeit Overfitting zu vermeiden ist es, sogenannte **Ensemble Methoden** zu verwenden. Diese kombinieren mehrere Decision Trees und helfen so, die Generalisierung für das Dataset zu erhöhen. Eine spezielle Variante davon sind die **Random Forest** Classifier."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"\n",
|
||
"\n",
|
||
"(von https://de.wikipedia.org/wiki/Random_Forest)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Eigenschaften von Random Forests\n",
|
||
"\n",
|
||
"* Wir trainieren $N$ kleine Decision Trees, genannt Teilbäume.\n",
|
||
"* Jeder Teilbaum wird auf einer Teilmenge der gesamten Trainingsdaten trainiert\n",
|
||
"* Jeder Teilbaum overfitted diesen Teil der Daten ziemlich sicher\n",
|
||
"* Jeder Teilbaum ist somit ein Experte auf seinem Gebiet\n",
|
||
"* Bei der Vorhersage für neue Datenpunkte werden dann die Entscheidungen der einzelnen Teilbäume kombiniert. Dies passiert zum Beispiel mit einem Mehrheitsbeschluss.\n",
|
||
"* Durch die Mehrheitsbildung wird das Overfitting wieder reduziert."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Hpyerparameter von Random Forest Models\n",
|
||
"\n",
|
||
"* (Fast) gleiche Hyperparameter wie bei einzelnen Decision Trees\n",
|
||
"* Random State für reproduzierbaren Zufall beim Auswählen der zufälligen Teildatasets\n",
|
||
"* Anzahl der Bäume die verwendet wird"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Random Forests in Python"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 69,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Import necessary libraries\n",
|
||
"from sklearn.datasets import load_iris\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"from sklearn.ensemble import RandomForestClassifier\n",
|
||
"from sklearn.metrics import accuracy_score"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 70,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Load the Iris dataset\n",
|
||
"iris = load_iris()\n",
|
||
"X = iris.data\n",
|
||
"y = iris.target\n",
|
||
"\n",
|
||
"# Split the dataset into training and testing sets\n",
|
||
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 71,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style>#sk-container-id-8 {\n",
|
||
" /* Definition of color scheme common for light and dark mode */\n",
|
||
" --sklearn-color-text: black;\n",
|
||
" --sklearn-color-line: gray;\n",
|
||
" /* Definition of color scheme for unfitted estimators */\n",
|
||
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
|
||
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
|
||
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
|
||
" --sklearn-color-unfitted-level-3: chocolate;\n",
|
||
" /* Definition of color scheme for fitted estimators */\n",
|
||
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
|
||
" --sklearn-color-fitted-level-1: #d4ebff;\n",
|
||
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
|
||
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
|
||
"\n",
|
||
" /* Specific color for light theme */\n",
|
||
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
|
||
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
|
||
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
|
||
" --sklearn-color-icon: #696969;\n",
|
||
"\n",
|
||
" @media (prefers-color-scheme: dark) {\n",
|
||
" /* Redefinition of color scheme for dark theme */\n",
|
||
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
|
||
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
|
||
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
|
||
" --sklearn-color-icon: #878787;\n",
|
||
" }\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 pre {\n",
|
||
" padding: 0;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 input.sk-hidden--visually {\n",
|
||
" border: 0;\n",
|
||
" clip: rect(1px 1px 1px 1px);\n",
|
||
" clip: rect(1px, 1px, 1px, 1px);\n",
|
||
" height: 1px;\n",
|
||
" margin: -1px;\n",
|
||
" overflow: hidden;\n",
|
||
" padding: 0;\n",
|
||
" position: absolute;\n",
|
||
" width: 1px;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-dashed-wrapped {\n",
|
||
" border: 1px dashed var(--sklearn-color-line);\n",
|
||
" margin: 0 0.4em 0.5em 0.4em;\n",
|
||
" box-sizing: border-box;\n",
|
||
" padding-bottom: 0.4em;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-container {\n",
|
||
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
|
||
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
|
||
" so we also need the `!important` here to be able to override the\n",
|
||
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
|
||
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
|
||
" display: inline-block !important;\n",
|
||
" position: relative;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-text-repr-fallback {\n",
|
||
" display: none;\n",
|
||
"}\n",
|
||
"\n",
|
||
"div.sk-parallel-item,\n",
|
||
"div.sk-serial,\n",
|
||
"div.sk-item {\n",
|
||
" /* draw centered vertical line to link estimators */\n",
|
||
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
|
||
" background-size: 2px 100%;\n",
|
||
" background-repeat: no-repeat;\n",
|
||
" background-position: center center;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Parallel-specific style estimator block */\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-parallel-item::after {\n",
|
||
" content: \"\";\n",
|
||
" width: 100%;\n",
|
||
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
|
||
" flex-grow: 1;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-parallel {\n",
|
||
" display: flex;\n",
|
||
" align-items: stretch;\n",
|
||
" justify-content: center;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
" position: relative;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-parallel-item {\n",
|
||
" display: flex;\n",
|
||
" flex-direction: column;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-parallel-item:first-child::after {\n",
|
||
" align-self: flex-end;\n",
|
||
" width: 50%;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-parallel-item:last-child::after {\n",
|
||
" align-self: flex-start;\n",
|
||
" width: 50%;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-parallel-item:only-child::after {\n",
|
||
" width: 0;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Serial-specific style estimator block */\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-serial {\n",
|
||
" display: flex;\n",
|
||
" flex-direction: column;\n",
|
||
" align-items: center;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
" padding-right: 1em;\n",
|
||
" padding-left: 1em;\n",
|
||
"}\n",
|
||
"\n",
|
||
"\n",
|
||
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
|
||
"clickable and can be expanded/collapsed.\n",
|
||
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
|
||
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
|
||
"*/\n",
|
||
"\n",
|
||
"/* Pipeline and ColumnTransformer style (default) */\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-toggleable {\n",
|
||
" /* Default theme specific background. It is overwritten whether we have a\n",
|
||
" specific estimator or a Pipeline/ColumnTransformer */\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Toggleable label */\n",
|
||
"#sk-container-id-8 label.sk-toggleable__label {\n",
|
||
" cursor: pointer;\n",
|
||
" display: block;\n",
|
||
" width: 100%;\n",
|
||
" margin-bottom: 0;\n",
|
||
" padding: 0.5em;\n",
|
||
" box-sizing: border-box;\n",
|
||
" text-align: center;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 label.sk-toggleable__label-arrow:before {\n",
|
||
" /* Arrow on the left of the label */\n",
|
||
" content: \"▸\";\n",
|
||
" float: left;\n",
|
||
" margin-right: 0.25em;\n",
|
||
" color: var(--sklearn-color-icon);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 label.sk-toggleable__label-arrow:hover:before {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Toggleable content - dropdown */\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-toggleable__content {\n",
|
||
" max-height: 0;\n",
|
||
" max-width: 0;\n",
|
||
" overflow: hidden;\n",
|
||
" text-align: left;\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-toggleable__content.fitted {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-toggleable__content pre {\n",
|
||
" margin: 0.2em;\n",
|
||
" border-radius: 0.25em;\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-toggleable__content.fitted pre {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
|
||
" /* Expand drop-down */\n",
|
||
" max-height: 200px;\n",
|
||
" max-width: 100%;\n",
|
||
" overflow: auto;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
|
||
" content: \"▾\";\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Pipeline/ColumnTransformer-specific style */\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Estimator-specific style */\n",
|
||
"\n",
|
||
"/* Colorize estimator box */\n",
|
||
"#sk-container-id-8 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-label label.sk-toggleable__label,\n",
|
||
"#sk-container-id-8 div.sk-label label {\n",
|
||
" /* The background is the default theme color */\n",
|
||
" color: var(--sklearn-color-text-on-default-background);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* On hover, darken the color of the background */\n",
|
||
"#sk-container-id-8 div.sk-label:hover label.sk-toggleable__label {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Label box, darken color on hover, fitted */\n",
|
||
"#sk-container-id-8 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Estimator label */\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-label label {\n",
|
||
" font-family: monospace;\n",
|
||
" font-weight: bold;\n",
|
||
" display: inline-block;\n",
|
||
" line-height: 1.2em;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-label-container {\n",
|
||
" text-align: center;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Estimator-specific */\n",
|
||
"#sk-container-id-8 div.sk-estimator {\n",
|
||
" font-family: monospace;\n",
|
||
" border: 1px dotted var(--sklearn-color-border-box);\n",
|
||
" border-radius: 0.25em;\n",
|
||
" box-sizing: border-box;\n",
|
||
" margin-bottom: 0.5em;\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-estimator.fitted {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* on hover */\n",
|
||
"#sk-container-id-8 div.sk-estimator:hover {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 div.sk-estimator.fitted:hover {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
|
||
"\n",
|
||
"/* Common style for \"i\" and \"?\" */\n",
|
||
"\n",
|
||
".sk-estimator-doc-link,\n",
|
||
"a:link.sk-estimator-doc-link,\n",
|
||
"a:visited.sk-estimator-doc-link {\n",
|
||
" float: right;\n",
|
||
" font-size: smaller;\n",
|
||
" line-height: 1em;\n",
|
||
" font-family: monospace;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
" border-radius: 1em;\n",
|
||
" height: 1em;\n",
|
||
" width: 1em;\n",
|
||
" text-decoration: none !important;\n",
|
||
" margin-left: 1ex;\n",
|
||
" /* unfitted */\n",
|
||
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
|
||
" color: var(--sklearn-color-unfitted-level-1);\n",
|
||
"}\n",
|
||
"\n",
|
||
".sk-estimator-doc-link.fitted,\n",
|
||
"a:link.sk-estimator-doc-link.fitted,\n",
|
||
"a:visited.sk-estimator-doc-link.fitted {\n",
|
||
" /* fitted */\n",
|
||
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
|
||
" color: var(--sklearn-color-fitted-level-1);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* On hover */\n",
|
||
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
|
||
".sk-estimator-doc-link:hover,\n",
|
||
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
|
||
".sk-estimator-doc-link:hover {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-3);\n",
|
||
" color: var(--sklearn-color-background);\n",
|
||
" text-decoration: none;\n",
|
||
"}\n",
|
||
"\n",
|
||
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
|
||
".sk-estimator-doc-link.fitted:hover,\n",
|
||
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
|
||
".sk-estimator-doc-link.fitted:hover {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-3);\n",
|
||
" color: var(--sklearn-color-background);\n",
|
||
" text-decoration: none;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Span, style for the box shown on hovering the info icon */\n",
|
||
".sk-estimator-doc-link span {\n",
|
||
" display: none;\n",
|
||
" z-index: 9999;\n",
|
||
" position: relative;\n",
|
||
" font-weight: normal;\n",
|
||
" right: .2ex;\n",
|
||
" padding: .5ex;\n",
|
||
" margin: .5ex;\n",
|
||
" width: min-content;\n",
|
||
" min-width: 20ex;\n",
|
||
" max-width: 50ex;\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" box-shadow: 2pt 2pt 4pt #999;\n",
|
||
" /* unfitted */\n",
|
||
" background: var(--sklearn-color-unfitted-level-0);\n",
|
||
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
|
||
"}\n",
|
||
"\n",
|
||
".sk-estimator-doc-link.fitted span {\n",
|
||
" /* fitted */\n",
|
||
" background: var(--sklearn-color-fitted-level-0);\n",
|
||
" border: var(--sklearn-color-fitted-level-3);\n",
|
||
"}\n",
|
||
"\n",
|
||
".sk-estimator-doc-link:hover span {\n",
|
||
" display: block;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
|
||
"\n",
|
||
"#sk-container-id-8 a.estimator_doc_link {\n",
|
||
" float: right;\n",
|
||
" font-size: 1rem;\n",
|
||
" line-height: 1em;\n",
|
||
" font-family: monospace;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
" border-radius: 1rem;\n",
|
||
" height: 1rem;\n",
|
||
" width: 1rem;\n",
|
||
" text-decoration: none;\n",
|
||
" /* unfitted */\n",
|
||
" color: var(--sklearn-color-unfitted-level-1);\n",
|
||
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 a.estimator_doc_link.fitted {\n",
|
||
" /* fitted */\n",
|
||
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
|
||
" color: var(--sklearn-color-fitted-level-1);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* On hover */\n",
|
||
"#sk-container-id-8 a.estimator_doc_link:hover {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-3);\n",
|
||
" color: var(--sklearn-color-background);\n",
|
||
" text-decoration: none;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-8 a.estimator_doc_link.fitted:hover {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-3);\n",
|
||
"}\n",
|
||
"</style><div id=\"sk-container-id-8\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>RandomForestClassifier(random_state=42)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-8\" type=\"checkbox\" checked><label for=\"sk-estimator-id-8\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> RandomForestClassifier<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.ensemble.RandomForestClassifier.html\">?<span>Documentation for RandomForestClassifier</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>RandomForestClassifier(random_state=42)</pre></div> </div></div></div></div>"
|
||
],
|
||
"text/plain": [
|
||
"RandomForestClassifier(random_state=42)"
|
||
]
|
||
},
|
||
"execution_count": 71,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# Create a Random Forest Classifier\n",
|
||
"rf_classifier = RandomForestClassifier(n_estimators=100, random_state=42)\n",
|
||
"\n",
|
||
"# Train the classifier\n",
|
||
"rf_classifier.fit(X_train, y_train)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 72,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Accuracy: 1.00\n",
|
||
"sepal length (cm): 0.1041\n",
|
||
"sepal width (cm): 0.0446\n",
|
||
"petal length (cm): 0.4173\n",
|
||
"petal width (cm): 0.4340\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Make predictions on the test set\n",
|
||
"y_pred = rf_classifier.predict(X_test)\n",
|
||
"\n",
|
||
"# Evaluate the accuracy of the classifier\n",
|
||
"accuracy = accuracy_score(y_test, y_pred)\n",
|
||
"print(f'Accuracy: {accuracy:.2f}')\n",
|
||
"\n",
|
||
"# Print feature importances\n",
|
||
"importances = rf_classifier.feature_importances_\n",
|
||
"feature_names = iris.feature_names\n",
|
||
"for feature_name, importance in zip(feature_names, importances):\n",
|
||
" print(f'{feature_name}: {importance:.4f}')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Aufgabe"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"* Verwende nun die folgenden Datasets und versuche die bestmögliche Performance zu erreichen.\n",
|
||
"* Lade dazu das Dataset mit den bekannten Methoden (Laden mit Hilfe von `pd.read_csv`)\n",
|
||
"* Überlege, wie du bei schlechter Performance diese verbessern kannst. Zum Beipsiel: Normalisieren, Ausreißer entfernen etc.\n",
|
||
"* Müssen wir ggf. Features entfernen?\n",
|
||
"* Gehören ggf. Features mit einem Ordinal-Encoder oder mit einem Onehot-Encoder encodiert?\n",
|
||
"* Verwende für jedes Dataset eigene Code-Zellen und dokumentiere für die verschiedenen Durchläufe die Ergebnisse (zBsp. Accuracy, Confusion Matrix oder den MSE)\n",
|
||
"\n",
|
||
"**Datasets:**\n",
|
||
"* Breast Cancer `breast_cancer.csv` (verwendet von https://archive.ics.uci.edu/dataset/17/breast+cancer+wisconsin+diagnostic)\n",
|
||
"* Diabetes `diabetes.csv` (verwendet von https://www.kaggle.com/uciml/pima-indians-diabetes-database)\n",
|
||
"* Stroke Prediction `stroke.csv` (verwendet von https://www.kaggle.com/datasets/fedesoriano/stroke-prediction-dataset)\n",
|
||
"\n",
|
||
"*Hinweis:* Überlege dir stets, welchen Problemtyp du verwendest und verwende dementsprechend das richtige Model! "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 73,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"from sklearn import datasets\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"from sklearn.svm import SVC\n",
|
||
"from sklearn.metrics import accuracy_score, classification_report"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 74,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"#Breast Cancer\n",
|
||
"ds = pd.read_csv(\"../../_data/breast_cancer.csv\") # ggf. etwas anders als unser bisheriges Dataset\n",
|
||
"ds = ds.dropna()\n",
|
||
"y = ds.Diagnosis\n",
|
||
"X = ds.drop('Diagnosis', axis=1)\n",
|
||
"X = X.select_dtypes(include=[np.number])\n",
|
||
"# Aufteilen der Daten in Trainings- und Testset (80% Training, 20% Test)\n",
|
||
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Accuracy: 0.9561\n",
|
||
"sepal length (cm): 0.0707\n",
|
||
"sepal width (cm): 0.0066\n",
|
||
"petal length (cm): 0.0587\n",
|
||
"petal width (cm): 0.0140\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Create a Random Forest Classifier\n",
|
||
"rf_classifier = RandomForestClassifier(n_estimators=10\n",
|
||
"\n",
|
||
", random_state=42)\n",
|
||
"\n",
|
||
"# Train the classifier\n",
|
||
"rf_classifier.fit(X_train, y_train)\n",
|
||
"# Make predictions on the test set\n",
|
||
"y_pred = rf_classifier.predict(X_test)\n",
|
||
"\n",
|
||
"# Evaluate the accuracy of the classifier\n",
|
||
"accuracy = accuracy_score(y_test, y_pred)\n",
|
||
"print(f'Accuracy: {accuracy:.4f}')\n",
|
||
"\n",
|
||
"# Print feature importances\n",
|
||
"importances = rf_classifier.feature_importances_\n",
|
||
"feature_names = iris.feature_names\n",
|
||
"for feature_name, importance in zip(feature_names, importances):\n",
|
||
" print(f'{feature_name}: {importance:.4f}')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 76,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"#Diabetes\n",
|
||
"ds = pd.read_csv(\"../../_data/diabetes.csv\") # ggf. etwas anders als unser bisheriges Dataset\n",
|
||
"ds = ds.dropna()\n",
|
||
"y = ds.Outcome\n",
|
||
"X = ds.drop('Outcome', axis=1)\n",
|
||
"X = X.select_dtypes(include=[np.number])\n",
|
||
"\n",
|
||
"# Aufteilen der Daten in Trainings- und Testset (80% Training, 20% Test)\n",
|
||
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 77,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Accuracy: 0.7662\n",
|
||
"sepal length (cm): 0.0733\n",
|
||
"sepal width (cm): 0.2249\n",
|
||
"petal length (cm): 0.0998\n",
|
||
"petal width (cm): 0.0591\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Create a Random Forest Classifier\n",
|
||
"rf_classifier = RandomForestClassifier(n_estimators=10, random_state=42)\n",
|
||
"\n",
|
||
"# Train the classifier\n",
|
||
"rf_classifier.fit(X_train, y_train)\n",
|
||
"# Make predictions on the test set\n",
|
||
"y_pred = rf_classifier.predict(X_test)\n",
|
||
"\n",
|
||
"# Evaluate the accuracy of the classifier\n",
|
||
"accuracy = accuracy_score(y_test, y_pred)\n",
|
||
"print(f'Accuracy: {accuracy:.4f}')\n",
|
||
"\n",
|
||
"# Print feature importances\n",
|
||
"importances = rf_classifier.feature_importances_\n",
|
||
"feature_names = iris.feature_names\n",
|
||
"for feature_name, importance in zip(feature_names, importances):\n",
|
||
" print(f'{feature_name}: {importance:.4f}')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 78,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Stroke\n",
|
||
"ds = pd.read_csv(\"../../_data/stroke.csv\") # ggf. etwas anders als unser bisheriges Dataset\n",
|
||
"ds = ds.dropna()\n",
|
||
"y = ds.stroke\n",
|
||
"X = ds.drop('stroke', axis=1)\n",
|
||
"X = X.select_dtypes(include=[np.number])\n",
|
||
"# Aufteilen der Daten in Trainings- und Testset (80% Training, 20% Test)\n",
|
||
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 79,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Accuracy: 0.9460\n",
|
||
"sepal length (cm): 0.2555\n",
|
||
"sepal width (cm): 0.1966\n",
|
||
"petal length (cm): 0.0195\n",
|
||
"petal width (cm): 0.0245\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Create a Random Forest Classifier\n",
|
||
"rf_classifier = RandomForestClassifier(n_estimators=10, random_state=42)\n",
|
||
"\n",
|
||
"# Train the classifier\n",
|
||
"rf_classifier.fit(X_train, y_train)\n",
|
||
"# Make predictions on the test set\n",
|
||
"y_pred = rf_classifier.predict(X_test)\n",
|
||
"\n",
|
||
"# Evaluate the accuracy of the classifier\n",
|
||
"accuracy = accuracy_score(y_test, y_pred)\n",
|
||
"print(f'Accuracy: {accuracy:.4f}')\n",
|
||
"\n",
|
||
"# Print feature importances\n",
|
||
"importances = rf_classifier.feature_importances_\n",
|
||
"feature_names = iris.feature_names\n",
|
||
"for feature_name, importance in zip(feature_names, importances):\n",
|
||
" print(f'{feature_name}: {importance:.4f}')"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "dsai",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.9.20"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|