{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"___\n",
"\n",
"
\n",
"___\n",
"\n",
"# Keras Basics\n",
"\n",
"Welcome to the section on deep learning! We'll be using Keras with a TensorFlow backend to perform our deep learning operations.\n",
"\n",
"This means we should get familiar with some Keras fundamentals and basics!\n",
"\n",
"## Imports\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 145,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset\n",
"\n",
"We will use the famous Iris Data set.\n",
"_____\n",
"More info on the data set:\n",
"https://en.wikipedia.org/wiki/Iris_flower_data_set\n",
"\n",
"## Reading in the Data Set\n",
"\n",
"We've already downloaded the dataset, its in this folder. So let's open it up. "
]
},
{
"cell_type": "code",
"execution_count": 146,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.datasets import load_iris"
]
},
{
"cell_type": "code",
"execution_count": 147,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"iris = load_iris()"
]
},
{
"cell_type": "code",
"execution_count": 148,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"sklearn.utils.Bunch"
]
},
"execution_count": 148,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(iris)"
]
},
{
"cell_type": "code",
"execution_count": 149,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
".. _iris_dataset:\n",
"\n",
"Iris plants dataset\n",
"--------------------\n",
"\n",
"**Data Set Characteristics:**\n",
"\n",
" :Number of Instances: 150 (50 in each of three classes)\n",
" :Number of Attributes: 4 numeric, predictive attributes and the class\n",
" :Attribute Information:\n",
" - sepal length in cm\n",
" - sepal width in cm\n",
" - petal length in cm\n",
" - petal width in cm\n",
" - class:\n",
" - Iris-Setosa\n",
" - Iris-Versicolour\n",
" - Iris-Virginica\n",
" \n",
" :Summary Statistics:\n",
"\n",
" ============== ==== ==== ======= ===== ====================\n",
" Min Max Mean SD Class Correlation\n",
" ============== ==== ==== ======= ===== ====================\n",
" sepal length: 4.3 7.9 5.84 0.83 0.7826\n",
" sepal width: 2.0 4.4 3.05 0.43 -0.4194\n",
" petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)\n",
" petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)\n",
" ============== ==== ==== ======= ===== ====================\n",
"\n",
" :Missing Attribute Values: None\n",
" :Class Distribution: 33.3% for each of 3 classes.\n",
" :Creator: R.A. Fisher\n",
" :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n",
" :Date: July, 1988\n",
"\n",
"The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken\n",
"from Fisher's paper. Note that it's the same as in R, but not as in the UCI\n",
"Machine Learning Repository, which has two wrong data points.\n",
"\n",
"This is perhaps the best known database to be found in the\n",
"pattern recognition literature. Fisher's paper is a classic in the field and\n",
"is referenced frequently to this day. (See Duda & Hart, for example.) The\n",
"data set contains 3 classes of 50 instances each, where each class refers to a\n",
"type of iris plant. One class is linearly separable from the other 2; the\n",
"latter are NOT linearly separable from each other.\n",
"\n",
".. topic:: References\n",
"\n",
" - Fisher, R.A. \"The use of multiple measurements in taxonomic problems\"\n",
" Annual Eugenics, 7, Part II, 179-188 (1936); also in \"Contributions to\n",
" Mathematical Statistics\" (John Wiley, NY, 1950).\n",
" - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.\n",
" (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.\n",
" - Dasarathy, B.V. (1980) \"Nosing Around the Neighborhood: A New System\n",
" Structure and Classification Rule for Recognition in Partially Exposed\n",
" Environments\". IEEE Transactions on Pattern Analysis and Machine\n",
" Intelligence, Vol. PAMI-2, No. 1, 67-71.\n",
" - Gates, G.W. (1972) \"The Reduced Nearest Neighbor Rule\". IEEE Transactions\n",
" on Information Theory, May 1972, 431-433.\n",
" - See also: 1988 MLC Proceedings, 54-64. Cheeseman et al\"s AUTOCLASS II\n",
" conceptual clustering system finds 3 classes in the data.\n",
" - Many, many more ...\n"
]
}
],
"source": [
"print(iris.DESCR)"
]
},
{
"cell_type": "code",
"execution_count": 150,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"X = iris.data"
]
},
{
"cell_type": "code",
"execution_count": 151,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[5.1, 3.5, 1.4, 0.2],\n",
" [4.9, 3. , 1.4, 0.2],\n",
" [4.7, 3.2, 1.3, 0.2],\n",
" [4.6, 3.1, 1.5, 0.2],\n",
" [5. , 3.6, 1.4, 0.2],\n",
" [5.4, 3.9, 1.7, 0.4],\n",
" [4.6, 3.4, 1.4, 0.3],\n",
" [5. , 3.4, 1.5, 0.2],\n",
" [4.4, 2.9, 1.4, 0.2],\n",
" [4.9, 3.1, 1.5, 0.1],\n",
" [5.4, 3.7, 1.5, 0.2],\n",
" [4.8, 3.4, 1.6, 0.2],\n",
" [4.8, 3. , 1.4, 0.1],\n",
" [4.3, 3. , 1.1, 0.1],\n",
" [5.8, 4. , 1.2, 0.2],\n",
" [5.7, 4.4, 1.5, 0.4],\n",
" [5.4, 3.9, 1.3, 0.4],\n",
" [5.1, 3.5, 1.4, 0.3],\n",
" [5.7, 3.8, 1.7, 0.3],\n",
" [5.1, 3.8, 1.5, 0.3],\n",
" [5.4, 3.4, 1.7, 0.2],\n",
" [5.1, 3.7, 1.5, 0.4],\n",
" [4.6, 3.6, 1. , 0.2],\n",
" [5.1, 3.3, 1.7, 0.5],\n",
" [4.8, 3.4, 1.9, 0.2],\n",
" [5. , 3. , 1.6, 0.2],\n",
" [5. , 3.4, 1.6, 0.4],\n",
" [5.2, 3.5, 1.5, 0.2],\n",
" [5.2, 3.4, 1.4, 0.2],\n",
" [4.7, 3.2, 1.6, 0.2],\n",
" [4.8, 3.1, 1.6, 0.2],\n",
" [5.4, 3.4, 1.5, 0.4],\n",
" [5.2, 4.1, 1.5, 0.1],\n",
" [5.5, 4.2, 1.4, 0.2],\n",
" [4.9, 3.1, 1.5, 0.2],\n",
" [5. , 3.2, 1.2, 0.2],\n",
" [5.5, 3.5, 1.3, 0.2],\n",
" [4.9, 3.6, 1.4, 0.1],\n",
" [4.4, 3. , 1.3, 0.2],\n",
" [5.1, 3.4, 1.5, 0.2],\n",
" [5. , 3.5, 1.3, 0.3],\n",
" [4.5, 2.3, 1.3, 0.3],\n",
" [4.4, 3.2, 1.3, 0.2],\n",
" [5. , 3.5, 1.6, 0.6],\n",
" [5.1, 3.8, 1.9, 0.4],\n",
" [4.8, 3. , 1.4, 0.3],\n",
" [5.1, 3.8, 1.6, 0.2],\n",
" [4.6, 3.2, 1.4, 0.2],\n",
" [5.3, 3.7, 1.5, 0.2],\n",
" [5. , 3.3, 1.4, 0.2],\n",
" [7. , 3.2, 4.7, 1.4],\n",
" [6.4, 3.2, 4.5, 1.5],\n",
" [6.9, 3.1, 4.9, 1.5],\n",
" [5.5, 2.3, 4. , 1.3],\n",
" [6.5, 2.8, 4.6, 1.5],\n",
" [5.7, 2.8, 4.5, 1.3],\n",
" [6.3, 3.3, 4.7, 1.6],\n",
" [4.9, 2.4, 3.3, 1. ],\n",
" [6.6, 2.9, 4.6, 1.3],\n",
" [5.2, 2.7, 3.9, 1.4],\n",
" [5. , 2. , 3.5, 1. ],\n",
" [5.9, 3. , 4.2, 1.5],\n",
" [6. , 2.2, 4. , 1. ],\n",
" [6.1, 2.9, 4.7, 1.4],\n",
" [5.6, 2.9, 3.6, 1.3],\n",
" [6.7, 3.1, 4.4, 1.4],\n",
" [5.6, 3. , 4.5, 1.5],\n",
" [5.8, 2.7, 4.1, 1. ],\n",
" [6.2, 2.2, 4.5, 1.5],\n",
" [5.6, 2.5, 3.9, 1.1],\n",
" [5.9, 3.2, 4.8, 1.8],\n",
" [6.1, 2.8, 4. , 1.3],\n",
" [6.3, 2.5, 4.9, 1.5],\n",
" [6.1, 2.8, 4.7, 1.2],\n",
" [6.4, 2.9, 4.3, 1.3],\n",
" [6.6, 3. , 4.4, 1.4],\n",
" [6.8, 2.8, 4.8, 1.4],\n",
" [6.7, 3. , 5. , 1.7],\n",
" [6. , 2.9, 4.5, 1.5],\n",
" [5.7, 2.6, 3.5, 1. ],\n",
" [5.5, 2.4, 3.8, 1.1],\n",
" [5.5, 2.4, 3.7, 1. ],\n",
" [5.8, 2.7, 3.9, 1.2],\n",
" [6. , 2.7, 5.1, 1.6],\n",
" [5.4, 3. , 4.5, 1.5],\n",
" [6. , 3.4, 4.5, 1.6],\n",
" [6.7, 3.1, 4.7, 1.5],\n",
" [6.3, 2.3, 4.4, 1.3],\n",
" [5.6, 3. , 4.1, 1.3],\n",
" [5.5, 2.5, 4. , 1.3],\n",
" [5.5, 2.6, 4.4, 1.2],\n",
" [6.1, 3. , 4.6, 1.4],\n",
" [5.8, 2.6, 4. , 1.2],\n",
" [5. , 2.3, 3.3, 1. ],\n",
" [5.6, 2.7, 4.2, 1.3],\n",
" [5.7, 3. , 4.2, 1.2],\n",
" [5.7, 2.9, 4.2, 1.3],\n",
" [6.2, 2.9, 4.3, 1.3],\n",
" [5.1, 2.5, 3. , 1.1],\n",
" [5.7, 2.8, 4.1, 1.3],\n",
" [6.3, 3.3, 6. , 2.5],\n",
" [5.8, 2.7, 5.1, 1.9],\n",
" [7.1, 3. , 5.9, 2.1],\n",
" [6.3, 2.9, 5.6, 1.8],\n",
" [6.5, 3. , 5.8, 2.2],\n",
" [7.6, 3. , 6.6, 2.1],\n",
" [4.9, 2.5, 4.5, 1.7],\n",
" [7.3, 2.9, 6.3, 1.8],\n",
" [6.7, 2.5, 5.8, 1.8],\n",
" [7.2, 3.6, 6.1, 2.5],\n",
" [6.5, 3.2, 5.1, 2. ],\n",
" [6.4, 2.7, 5.3, 1.9],\n",
" [6.8, 3. , 5.5, 2.1],\n",
" [5.7, 2.5, 5. , 2. ],\n",
" [5.8, 2.8, 5.1, 2.4],\n",
" [6.4, 3.2, 5.3, 2.3],\n",
" [6.5, 3. , 5.5, 1.8],\n",
" [7.7, 3.8, 6.7, 2.2],\n",
" [7.7, 2.6, 6.9, 2.3],\n",
" [6. , 2.2, 5. , 1.5],\n",
" [6.9, 3.2, 5.7, 2.3],\n",
" [5.6, 2.8, 4.9, 2. ],\n",
" [7.7, 2.8, 6.7, 2. ],\n",
" [6.3, 2.7, 4.9, 1.8],\n",
" [6.7, 3.3, 5.7, 2.1],\n",
" [7.2, 3.2, 6. , 1.8],\n",
" [6.2, 2.8, 4.8, 1.8],\n",
" [6.1, 3. , 4.9, 1.8],\n",
" [6.4, 2.8, 5.6, 2.1],\n",
" [7.2, 3. , 5.8, 1.6],\n",
" [7.4, 2.8, 6.1, 1.9],\n",
" [7.9, 3.8, 6.4, 2. ],\n",
" [6.4, 2.8, 5.6, 2.2],\n",
" [6.3, 2.8, 5.1, 1.5],\n",
" [6.1, 2.6, 5.6, 1.4],\n",
" [7.7, 3. , 6.1, 2.3],\n",
" [6.3, 3.4, 5.6, 2.4],\n",
" [6.4, 3.1, 5.5, 1.8],\n",
" [6. , 3. , 4.8, 1.8],\n",
" [6.9, 3.1, 5.4, 2.1],\n",
" [6.7, 3.1, 5.6, 2.4],\n",
" [6.9, 3.1, 5.1, 2.3],\n",
" [5.8, 2.7, 5.1, 1.9],\n",
" [6.8, 3.2, 5.9, 2.3],\n",
" [6.7, 3.3, 5.7, 2.5],\n",
" [6.7, 3. , 5.2, 2.3],\n",
" [6.3, 2.5, 5. , 1.9],\n",
" [6.5, 3. , 5.2, 2. ],\n",
" [6.2, 3.4, 5.4, 2.3],\n",
" [5.9, 3. , 5.1, 1.8]])"
]
},
"execution_count": 151,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X"
]
},
{
"cell_type": "code",
"execution_count": 152,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"y = iris.target"
]
},
{
"cell_type": "code",
"execution_count": 153,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
" 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
" 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])"
]
},
"execution_count": 153,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y"
]
},
{
"cell_type": "code",
"execution_count": 154,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from keras.utils import to_categorical"
]
},
{
"cell_type": "code",
"execution_count": 156,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"y = to_categorical(y)"
]
},
{
"cell_type": "code",
"execution_count": 157,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(150, 3)"
]
},
"execution_count": 157,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y.shape"
]
},
{
"cell_type": "code",
"execution_count": 158,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.]], dtype=float32)"
]
},
"execution_count": 158,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Split the Data into Training and Test\n",
"\n",
"Its time to split the data into a train/test set. Keep in mind, sometimes people like to split 3 ways, train/test/validation. We'll keep things simple for now. **Remember to check out the video explanation as to why we split and what all the parameters mean!**"
]
},
{
"cell_type": "code",
"execution_count": 159,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "code",
"execution_count": 160,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 111,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[5.7, 2.9, 4.2, 1.3],\n",
" [7.6, 3. , 6.6, 2.1],\n",
" [5.6, 3. , 4.5, 1.5],\n",
" [5.1, 3.5, 1.4, 0.2],\n",
" [7.7, 2.8, 6.7, 2. ],\n",
" [5.8, 2.7, 4.1, 1. ],\n",
" [5.2, 3.4, 1.4, 0.2],\n",
" [5. , 3.5, 1.3, 0.3],\n",
" [5.1, 3.8, 1.9, 0.4],\n",
" [5. , 2. , 3.5, 1. ],\n",
" [6.3, 2.7, 4.9, 1.8],\n",
" [4.8, 3.4, 1.9, 0.2],\n",
" [5. , 3. , 1.6, 0.2],\n",
" [5.1, 3.3, 1.7, 0.5],\n",
" [5.6, 2.7, 4.2, 1.3],\n",
" [5.1, 3.4, 1.5, 0.2],\n",
" [5.7, 3. , 4.2, 1.2],\n",
" [7.7, 3.8, 6.7, 2.2],\n",
" [4.6, 3.2, 1.4, 0.2],\n",
" [6.2, 2.9, 4.3, 1.3],\n",
" [5.7, 2.5, 5. , 2. ],\n",
" [5.5, 4.2, 1.4, 0.2],\n",
" [6. , 3. , 4.8, 1.8],\n",
" [5.8, 2.7, 5.1, 1.9],\n",
" [6. , 2.2, 4. , 1. ],\n",
" [5.4, 3. , 4.5, 1.5],\n",
" [6.2, 3.4, 5.4, 2.3],\n",
" [5.5, 2.3, 4. , 1.3],\n",
" [5.4, 3.9, 1.7, 0.4],\n",
" [5. , 2.3, 3.3, 1. ],\n",
" [6.4, 2.7, 5.3, 1.9],\n",
" [5. , 3.3, 1.4, 0.2],\n",
" [5. , 3.2, 1.2, 0.2],\n",
" [5.5, 2.4, 3.8, 1.1],\n",
" [6.7, 3. , 5. , 1.7],\n",
" [4.9, 3.1, 1.5, 0.2],\n",
" [5.8, 2.8, 5.1, 2.4],\n",
" [5. , 3.4, 1.5, 0.2],\n",
" [5. , 3.5, 1.6, 0.6],\n",
" [5.9, 3.2, 4.8, 1.8],\n",
" [5.1, 2.5, 3. , 1.1],\n",
" [6.9, 3.2, 5.7, 2.3],\n",
" [6. , 2.7, 5.1, 1.6],\n",
" [6.1, 2.6, 5.6, 1.4],\n",
" [7.7, 3. , 6.1, 2.3],\n",
" [5.5, 2.5, 4. , 1.3],\n",
" [4.4, 2.9, 1.4, 0.2],\n",
" [4.3, 3. , 1.1, 0.1],\n",
" [6. , 2.2, 5. , 1.5],\n",
" [7.2, 3.2, 6. , 1.8],\n",
" [4.6, 3.1, 1.5, 0.2],\n",
" [5.1, 3.5, 1.4, 0.3],\n",
" [4.4, 3. , 1.3, 0.2],\n",
" [6.3, 2.5, 4.9, 1.5],\n",
" [6.3, 3.4, 5.6, 2.4],\n",
" [4.6, 3.4, 1.4, 0.3],\n",
" [6.8, 3. , 5.5, 2.1],\n",
" [6.3, 3.3, 6. , 2.5],\n",
" [4.7, 3.2, 1.3, 0.2],\n",
" [6.1, 2.9, 4.7, 1.4],\n",
" [6.5, 2.8, 4.6, 1.5],\n",
" [6.2, 2.8, 4.8, 1.8],\n",
" [7. , 3.2, 4.7, 1.4],\n",
" [6.4, 3.2, 5.3, 2.3],\n",
" [5.1, 3.8, 1.6, 0.2],\n",
" [6.9, 3.1, 5.4, 2.1],\n",
" [5.9, 3. , 4.2, 1.5],\n",
" [6.5, 3. , 5.2, 2. ],\n",
" [5.7, 2.6, 3.5, 1. ],\n",
" [5.2, 2.7, 3.9, 1.4],\n",
" [6.1, 3. , 4.6, 1.4],\n",
" [4.5, 2.3, 1.3, 0.3],\n",
" [6.6, 2.9, 4.6, 1.3],\n",
" [5.5, 2.6, 4.4, 1.2],\n",
" [5.3, 3.7, 1.5, 0.2],\n",
" [5.6, 3. , 4.1, 1.3],\n",
" [7.3, 2.9, 6.3, 1.8],\n",
" [6.7, 3.3, 5.7, 2.1],\n",
" [5.1, 3.7, 1.5, 0.4],\n",
" [4.9, 2.4, 3.3, 1. ],\n",
" [6.7, 3.3, 5.7, 2.5],\n",
" [7.2, 3. , 5.8, 1.6],\n",
" [4.9, 3.6, 1.4, 0.1],\n",
" [6.7, 3.1, 5.6, 2.4],\n",
" [4.9, 3. , 1.4, 0.2],\n",
" [6.9, 3.1, 4.9, 1.5],\n",
" [7.4, 2.8, 6.1, 1.9],\n",
" [6.3, 2.9, 5.6, 1.8],\n",
" [5.7, 2.8, 4.1, 1.3],\n",
" [6.5, 3. , 5.5, 1.8],\n",
" [6.3, 2.3, 4.4, 1.3],\n",
" [6.4, 2.9, 4.3, 1.3],\n",
" [5.6, 2.8, 4.9, 2. ],\n",
" [5.9, 3. , 5.1, 1.8],\n",
" [5.4, 3.4, 1.7, 0.2],\n",
" [6.1, 2.8, 4. , 1.3],\n",
" [4.9, 2.5, 4.5, 1.7],\n",
" [5.8, 4. , 1.2, 0.2],\n",
" [5.8, 2.6, 4. , 1.2],\n",
" [7.1, 3. , 5.9, 2.1]])"
]
},
"execution_count": 111,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train"
]
},
{
"cell_type": "code",
"execution_count": 112,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[6.1, 2.8, 4.7, 1.2],\n",
" [5.7, 3.8, 1.7, 0.3],\n",
" [7.7, 2.6, 6.9, 2.3],\n",
" [6. , 2.9, 4.5, 1.5],\n",
" [6.8, 2.8, 4.8, 1.4],\n",
" [5.4, 3.4, 1.5, 0.4],\n",
" [5.6, 2.9, 3.6, 1.3],\n",
" [6.9, 3.1, 5.1, 2.3],\n",
" [6.2, 2.2, 4.5, 1.5],\n",
" [5.8, 2.7, 3.9, 1.2],\n",
" [6.5, 3.2, 5.1, 2. ],\n",
" [4.8, 3. , 1.4, 0.1],\n",
" [5.5, 3.5, 1.3, 0.2],\n",
" [4.9, 3.1, 1.5, 0.1],\n",
" [5.1, 3.8, 1.5, 0.3],\n",
" [6.3, 3.3, 4.7, 1.6],\n",
" [6.5, 3. , 5.8, 2.2],\n",
" [5.6, 2.5, 3.9, 1.1],\n",
" [5.7, 2.8, 4.5, 1.3],\n",
" [6.4, 2.8, 5.6, 2.2],\n",
" [4.7, 3.2, 1.6, 0.2],\n",
" [6.1, 3. , 4.9, 1.8],\n",
" [5. , 3.4, 1.6, 0.4],\n",
" [6.4, 2.8, 5.6, 2.1],\n",
" [7.9, 3.8, 6.4, 2. ],\n",
" [6.7, 3. , 5.2, 2.3],\n",
" [6.7, 2.5, 5.8, 1.8],\n",
" [6.8, 3.2, 5.9, 2.3],\n",
" [4.8, 3. , 1.4, 0.3],\n",
" [4.8, 3.1, 1.6, 0.2],\n",
" [4.6, 3.6, 1. , 0.2],\n",
" [5.7, 4.4, 1.5, 0.4],\n",
" [6.7, 3.1, 4.4, 1.4],\n",
" [4.8, 3.4, 1.6, 0.2],\n",
" [4.4, 3.2, 1.3, 0.2],\n",
" [6.3, 2.5, 5. , 1.9],\n",
" [6.4, 3.2, 4.5, 1.5],\n",
" [5.2, 3.5, 1.5, 0.2],\n",
" [5. , 3.6, 1.4, 0.2],\n",
" [5.2, 4.1, 1.5, 0.1],\n",
" [5.8, 2.7, 5.1, 1.9],\n",
" [6. , 3.4, 4.5, 1.6],\n",
" [6.7, 3.1, 4.7, 1.5],\n",
" [5.4, 3.9, 1.3, 0.4],\n",
" [5.4, 3.7, 1.5, 0.2],\n",
" [5.5, 2.4, 3.7, 1. ],\n",
" [6.3, 2.8, 5.1, 1.5],\n",
" [6.4, 3.1, 5.5, 1.8],\n",
" [6.6, 3. , 4.4, 1.4],\n",
" [7.2, 3.6, 6.1, 2.5]])"
]
},
"execution_count": 112,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_test"
]
},
{
"cell_type": "code",
"execution_count": 113,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [1., 0., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [1., 0., 0.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.]], dtype=float32)"
]
},
"execution_count": 113,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_train"
]
},
{
"cell_type": "code",
"execution_count": 114,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0., 1., 0.],\n",
" [1., 0., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.]], dtype=float32)"
]
},
"execution_count": 114,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_test"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Standardizing the Data\n",
"\n",
"Usually when using Neural Networks, you will get better performance when you standardize the data. Standardization just means normalizing the values to all fit between a certain range, like 0-1, or -1 to 1.\n",
"\n",
"The scikit learn library also provides a nice function for this.\n",
"\n",
"http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MinMaxScaler.html"
]
},
{
"cell_type": "code",
"execution_count": 115,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.preprocessing import MinMaxScaler"
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"scaler_object = MinMaxScaler()"
]
},
{
"cell_type": "code",
"execution_count": 117,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MinMaxScaler(copy=True, feature_range=(0, 1))"
]
},
"execution_count": 117,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"scaler_object.fit(X_train)"
]
},
{
"cell_type": "code",
"execution_count": 118,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"scaled_X_train = scaler_object.transform(X_train)"
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"scaled_X_test = scaler_object.transform(X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, now we have the data scaled!"
]
},
{
"cell_type": "code",
"execution_count": 120,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"7.7"
]
},
"execution_count": 120,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train.max()"
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 121,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"scaled_X_train.max()"
]
},
{
"cell_type": "code",
"execution_count": 122,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[5.7, 2.9, 4.2, 1.3],\n",
" [7.6, 3. , 6.6, 2.1],\n",
" [5.6, 3. , 4.5, 1.5],\n",
" [5.1, 3.5, 1.4, 0.2],\n",
" [7.7, 2.8, 6.7, 2. ],\n",
" [5.8, 2.7, 4.1, 1. ],\n",
" [5.2, 3.4, 1.4, 0.2],\n",
" [5. , 3.5, 1.3, 0.3],\n",
" [5.1, 3.8, 1.9, 0.4],\n",
" [5. , 2. , 3.5, 1. ],\n",
" [6.3, 2.7, 4.9, 1.8],\n",
" [4.8, 3.4, 1.9, 0.2],\n",
" [5. , 3. , 1.6, 0.2],\n",
" [5.1, 3.3, 1.7, 0.5],\n",
" [5.6, 2.7, 4.2, 1.3],\n",
" [5.1, 3.4, 1.5, 0.2],\n",
" [5.7, 3. , 4.2, 1.2],\n",
" [7.7, 3.8, 6.7, 2.2],\n",
" [4.6, 3.2, 1.4, 0.2],\n",
" [6.2, 2.9, 4.3, 1.3],\n",
" [5.7, 2.5, 5. , 2. ],\n",
" [5.5, 4.2, 1.4, 0.2],\n",
" [6. , 3. , 4.8, 1.8],\n",
" [5.8, 2.7, 5.1, 1.9],\n",
" [6. , 2.2, 4. , 1. ],\n",
" [5.4, 3. , 4.5, 1.5],\n",
" [6.2, 3.4, 5.4, 2.3],\n",
" [5.5, 2.3, 4. , 1.3],\n",
" [5.4, 3.9, 1.7, 0.4],\n",
" [5. , 2.3, 3.3, 1. ],\n",
" [6.4, 2.7, 5.3, 1.9],\n",
" [5. , 3.3, 1.4, 0.2],\n",
" [5. , 3.2, 1.2, 0.2],\n",
" [5.5, 2.4, 3.8, 1.1],\n",
" [6.7, 3. , 5. , 1.7],\n",
" [4.9, 3.1, 1.5, 0.2],\n",
" [5.8, 2.8, 5.1, 2.4],\n",
" [5. , 3.4, 1.5, 0.2],\n",
" [5. , 3.5, 1.6, 0.6],\n",
" [5.9, 3.2, 4.8, 1.8],\n",
" [5.1, 2.5, 3. , 1.1],\n",
" [6.9, 3.2, 5.7, 2.3],\n",
" [6. , 2.7, 5.1, 1.6],\n",
" [6.1, 2.6, 5.6, 1.4],\n",
" [7.7, 3. , 6.1, 2.3],\n",
" [5.5, 2.5, 4. , 1.3],\n",
" [4.4, 2.9, 1.4, 0.2],\n",
" [4.3, 3. , 1.1, 0.1],\n",
" [6. , 2.2, 5. , 1.5],\n",
" [7.2, 3.2, 6. , 1.8],\n",
" [4.6, 3.1, 1.5, 0.2],\n",
" [5.1, 3.5, 1.4, 0.3],\n",
" [4.4, 3. , 1.3, 0.2],\n",
" [6.3, 2.5, 4.9, 1.5],\n",
" [6.3, 3.4, 5.6, 2.4],\n",
" [4.6, 3.4, 1.4, 0.3],\n",
" [6.8, 3. , 5.5, 2.1],\n",
" [6.3, 3.3, 6. , 2.5],\n",
" [4.7, 3.2, 1.3, 0.2],\n",
" [6.1, 2.9, 4.7, 1.4],\n",
" [6.5, 2.8, 4.6, 1.5],\n",
" [6.2, 2.8, 4.8, 1.8],\n",
" [7. , 3.2, 4.7, 1.4],\n",
" [6.4, 3.2, 5.3, 2.3],\n",
" [5.1, 3.8, 1.6, 0.2],\n",
" [6.9, 3.1, 5.4, 2.1],\n",
" [5.9, 3. , 4.2, 1.5],\n",
" [6.5, 3. , 5.2, 2. ],\n",
" [5.7, 2.6, 3.5, 1. ],\n",
" [5.2, 2.7, 3.9, 1.4],\n",
" [6.1, 3. , 4.6, 1.4],\n",
" [4.5, 2.3, 1.3, 0.3],\n",
" [6.6, 2.9, 4.6, 1.3],\n",
" [5.5, 2.6, 4.4, 1.2],\n",
" [5.3, 3.7, 1.5, 0.2],\n",
" [5.6, 3. , 4.1, 1.3],\n",
" [7.3, 2.9, 6.3, 1.8],\n",
" [6.7, 3.3, 5.7, 2.1],\n",
" [5.1, 3.7, 1.5, 0.4],\n",
" [4.9, 2.4, 3.3, 1. ],\n",
" [6.7, 3.3, 5.7, 2.5],\n",
" [7.2, 3. , 5.8, 1.6],\n",
" [4.9, 3.6, 1.4, 0.1],\n",
" [6.7, 3.1, 5.6, 2.4],\n",
" [4.9, 3. , 1.4, 0.2],\n",
" [6.9, 3.1, 4.9, 1.5],\n",
" [7.4, 2.8, 6.1, 1.9],\n",
" [6.3, 2.9, 5.6, 1.8],\n",
" [5.7, 2.8, 4.1, 1.3],\n",
" [6.5, 3. , 5.5, 1.8],\n",
" [6.3, 2.3, 4.4, 1.3],\n",
" [6.4, 2.9, 4.3, 1.3],\n",
" [5.6, 2.8, 4.9, 2. ],\n",
" [5.9, 3. , 5.1, 1.8],\n",
" [5.4, 3.4, 1.7, 0.2],\n",
" [6.1, 2.8, 4. , 1.3],\n",
" [4.9, 2.5, 4.5, 1.7],\n",
" [5.8, 4. , 1.2, 0.2],\n",
" [5.8, 2.6, 4. , 1.2],\n",
" [7.1, 3. , 5.9, 2.1]])"
]
},
"execution_count": 122,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train"
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.41176471, 0.40909091, 0.55357143, 0.5 ],\n",
" [0.97058824, 0.45454545, 0.98214286, 0.83333333],\n",
" [0.38235294, 0.45454545, 0.60714286, 0.58333333],\n",
" [0.23529412, 0.68181818, 0.05357143, 0.04166667],\n",
" [1. , 0.36363636, 1. , 0.79166667],\n",
" [0.44117647, 0.31818182, 0.53571429, 0.375 ],\n",
" [0.26470588, 0.63636364, 0.05357143, 0.04166667],\n",
" [0.20588235, 0.68181818, 0.03571429, 0.08333333],\n",
" [0.23529412, 0.81818182, 0.14285714, 0.125 ],\n",
" [0.20588235, 0. , 0.42857143, 0.375 ],\n",
" [0.58823529, 0.31818182, 0.67857143, 0.70833333],\n",
" [0.14705882, 0.63636364, 0.14285714, 0.04166667],\n",
" [0.20588235, 0.45454545, 0.08928571, 0.04166667],\n",
" [0.23529412, 0.59090909, 0.10714286, 0.16666667],\n",
" [0.38235294, 0.31818182, 0.55357143, 0.5 ],\n",
" [0.23529412, 0.63636364, 0.07142857, 0.04166667],\n",
" [0.41176471, 0.45454545, 0.55357143, 0.45833333],\n",
" [1. , 0.81818182, 1. , 0.875 ],\n",
" [0.08823529, 0.54545455, 0.05357143, 0.04166667],\n",
" [0.55882353, 0.40909091, 0.57142857, 0.5 ],\n",
" [0.41176471, 0.22727273, 0.69642857, 0.79166667],\n",
" [0.35294118, 1. , 0.05357143, 0.04166667],\n",
" [0.5 , 0.45454545, 0.66071429, 0.70833333],\n",
" [0.44117647, 0.31818182, 0.71428571, 0.75 ],\n",
" [0.5 , 0.09090909, 0.51785714, 0.375 ],\n",
" [0.32352941, 0.45454545, 0.60714286, 0.58333333],\n",
" [0.55882353, 0.63636364, 0.76785714, 0.91666667],\n",
" [0.35294118, 0.13636364, 0.51785714, 0.5 ],\n",
" [0.32352941, 0.86363636, 0.10714286, 0.125 ],\n",
" [0.20588235, 0.13636364, 0.39285714, 0.375 ],\n",
" [0.61764706, 0.31818182, 0.75 , 0.75 ],\n",
" [0.20588235, 0.59090909, 0.05357143, 0.04166667],\n",
" [0.20588235, 0.54545455, 0.01785714, 0.04166667],\n",
" [0.35294118, 0.18181818, 0.48214286, 0.41666667],\n",
" [0.70588235, 0.45454545, 0.69642857, 0.66666667],\n",
" [0.17647059, 0.5 , 0.07142857, 0.04166667],\n",
" [0.44117647, 0.36363636, 0.71428571, 0.95833333],\n",
" [0.20588235, 0.63636364, 0.07142857, 0.04166667],\n",
" [0.20588235, 0.68181818, 0.08928571, 0.20833333],\n",
" [0.47058824, 0.54545455, 0.66071429, 0.70833333],\n",
" [0.23529412, 0.22727273, 0.33928571, 0.41666667],\n",
" [0.76470588, 0.54545455, 0.82142857, 0.91666667],\n",
" [0.5 , 0.31818182, 0.71428571, 0.625 ],\n",
" [0.52941176, 0.27272727, 0.80357143, 0.54166667],\n",
" [1. , 0.45454545, 0.89285714, 0.91666667],\n",
" [0.35294118, 0.22727273, 0.51785714, 0.5 ],\n",
" [0.02941176, 0.40909091, 0.05357143, 0.04166667],\n",
" [0. , 0.45454545, 0. , 0. ],\n",
" [0.5 , 0.09090909, 0.69642857, 0.58333333],\n",
" [0.85294118, 0.54545455, 0.875 , 0.70833333],\n",
" [0.08823529, 0.5 , 0.07142857, 0.04166667],\n",
" [0.23529412, 0.68181818, 0.05357143, 0.08333333],\n",
" [0.02941176, 0.45454545, 0.03571429, 0.04166667],\n",
" [0.58823529, 0.22727273, 0.67857143, 0.58333333],\n",
" [0.58823529, 0.63636364, 0.80357143, 0.95833333],\n",
" [0.08823529, 0.63636364, 0.05357143, 0.08333333],\n",
" [0.73529412, 0.45454545, 0.78571429, 0.83333333],\n",
" [0.58823529, 0.59090909, 0.875 , 1. ],\n",
" [0.11764706, 0.54545455, 0.03571429, 0.04166667],\n",
" [0.52941176, 0.40909091, 0.64285714, 0.54166667],\n",
" [0.64705882, 0.36363636, 0.625 , 0.58333333],\n",
" [0.55882353, 0.36363636, 0.66071429, 0.70833333],\n",
" [0.79411765, 0.54545455, 0.64285714, 0.54166667],\n",
" [0.61764706, 0.54545455, 0.75 , 0.91666667],\n",
" [0.23529412, 0.81818182, 0.08928571, 0.04166667],\n",
" [0.76470588, 0.5 , 0.76785714, 0.83333333],\n",
" [0.47058824, 0.45454545, 0.55357143, 0.58333333],\n",
" [0.64705882, 0.45454545, 0.73214286, 0.79166667],\n",
" [0.41176471, 0.27272727, 0.42857143, 0.375 ],\n",
" [0.26470588, 0.31818182, 0.5 , 0.54166667],\n",
" [0.52941176, 0.45454545, 0.625 , 0.54166667],\n",
" [0.05882353, 0.13636364, 0.03571429, 0.08333333],\n",
" [0.67647059, 0.40909091, 0.625 , 0.5 ],\n",
" [0.35294118, 0.27272727, 0.58928571, 0.45833333],\n",
" [0.29411765, 0.77272727, 0.07142857, 0.04166667],\n",
" [0.38235294, 0.45454545, 0.53571429, 0.5 ],\n",
" [0.88235294, 0.40909091, 0.92857143, 0.70833333],\n",
" [0.70588235, 0.59090909, 0.82142857, 0.83333333],\n",
" [0.23529412, 0.77272727, 0.07142857, 0.125 ],\n",
" [0.17647059, 0.18181818, 0.39285714, 0.375 ],\n",
" [0.70588235, 0.59090909, 0.82142857, 1. ],\n",
" [0.85294118, 0.45454545, 0.83928571, 0.625 ],\n",
" [0.17647059, 0.72727273, 0.05357143, 0. ],\n",
" [0.70588235, 0.5 , 0.80357143, 0.95833333],\n",
" [0.17647059, 0.45454545, 0.05357143, 0.04166667],\n",
" [0.76470588, 0.5 , 0.67857143, 0.58333333],\n",
" [0.91176471, 0.36363636, 0.89285714, 0.75 ],\n",
" [0.58823529, 0.40909091, 0.80357143, 0.70833333],\n",
" [0.41176471, 0.36363636, 0.53571429, 0.5 ],\n",
" [0.64705882, 0.45454545, 0.78571429, 0.70833333],\n",
" [0.58823529, 0.13636364, 0.58928571, 0.5 ],\n",
" [0.61764706, 0.40909091, 0.57142857, 0.5 ],\n",
" [0.38235294, 0.36363636, 0.67857143, 0.79166667],\n",
" [0.47058824, 0.45454545, 0.71428571, 0.70833333],\n",
" [0.32352941, 0.63636364, 0.10714286, 0.04166667],\n",
" [0.52941176, 0.36363636, 0.51785714, 0.5 ],\n",
" [0.17647059, 0.22727273, 0.60714286, 0.66666667],\n",
" [0.44117647, 0.90909091, 0.01785714, 0.04166667],\n",
" [0.44117647, 0.27272727, 0.51785714, 0.45833333],\n",
" [0.82352941, 0.45454545, 0.85714286, 0.83333333]])"
]
},
"execution_count": 123,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"scaled_X_train"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Building the Network with Keras\n",
"\n",
"Let's build a simple neural network!"
]
},
{
"cell_type": "code",
"execution_count": 132,
"metadata": {},
"outputs": [],
"source": [
"from keras.models import Sequential\n",
"from keras.layers import Dense"
]
},
{
"cell_type": "code",
"execution_count": 133,
"metadata": {},
"outputs": [],
"source": [
"model = Sequential()\n",
"model.add(Dense(8, input_dim=4, activation='relu'))\n",
"model.add(Dense(8, input_dim=4, activation='relu'))\n",
"model.add(Dense(3, activation='softmax'))\n",
"model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])"
]
},
{
"cell_type": "code",
"execution_count": 134,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"dense_27 (Dense) (None, 8) 40 \n",
"_________________________________________________________________\n",
"dense_28 (Dense) (None, 8) 72 \n",
"_________________________________________________________________\n",
"dense_29 (Dense) (None, 3) 27 \n",
"=================================================================\n",
"Total params: 139\n",
"Trainable params: 139\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fit (Train) the Model"
]
},
{
"cell_type": "code",
"execution_count": 135,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/150\n",
" - 0s - loss: 1.0926 - acc: 0.3400\n",
"Epoch 2/150\n",
" - 0s - loss: 1.0871 - acc: 0.3400\n",
"Epoch 3/150\n",
" - 0s - loss: 1.0814 - acc: 0.3400\n",
"Epoch 4/150\n",
" - 0s - loss: 1.0760 - acc: 0.3400\n",
"Epoch 5/150\n",
" - 0s - loss: 1.0700 - acc: 0.3400\n",
"Epoch 6/150\n",
" - 0s - loss: 1.0640 - acc: 0.3500\n",
"Epoch 7/150\n",
" - 0s - loss: 1.0581 - acc: 0.3700\n",
"Epoch 8/150\n",
" - 0s - loss: 1.0520 - acc: 0.4200\n",
"Epoch 9/150\n",
" - 0s - loss: 1.0467 - acc: 0.5100\n",
"Epoch 10/150\n",
" - 0s - loss: 1.0416 - acc: 0.5700\n",
"Epoch 11/150\n",
" - 0s - loss: 1.0361 - acc: 0.6300\n",
"Epoch 12/150\n",
" - 0s - loss: 1.0308 - acc: 0.6300\n",
"Epoch 13/150\n",
" - 0s - loss: 1.0249 - acc: 0.6300\n",
"Epoch 14/150\n",
" - 0s - loss: 1.0187 - acc: 0.6200\n",
"Epoch 15/150\n",
" - 0s - loss: 1.0124 - acc: 0.6300\n",
"Epoch 16/150\n",
" - 0s - loss: 1.0062 - acc: 0.6300\n",
"Epoch 17/150\n",
" - 0s - loss: 0.9994 - acc: 0.6400\n",
"Epoch 18/150\n",
" - 0s - loss: 0.9919 - acc: 0.6400\n",
"Epoch 19/150\n",
" - 0s - loss: 0.9836 - acc: 0.6400\n",
"Epoch 20/150\n",
" - 0s - loss: 0.9748 - acc: 0.6400\n",
"Epoch 21/150\n",
" - 0s - loss: 0.9649 - acc: 0.6400\n",
"Epoch 22/150\n",
" - 0s - loss: 0.9552 - acc: 0.6400\n",
"Epoch 23/150\n",
" - 0s - loss: 0.9448 - acc: 0.6500\n",
"Epoch 24/150\n",
" - 0s - loss: 0.9337 - acc: 0.6500\n",
"Epoch 25/150\n",
" - 0s - loss: 0.9225 - acc: 0.6500\n",
"Epoch 26/150\n",
" - 0s - loss: 0.9111 - acc: 0.6500\n",
"Epoch 27/150\n",
" - 0s - loss: 0.8992 - acc: 0.6500\n",
"Epoch 28/150\n",
" - 0s - loss: 0.8882 - acc: 0.6500\n",
"Epoch 29/150\n",
" - 0s - loss: 0.8766 - acc: 0.6500\n",
"Epoch 30/150\n",
" - 0s - loss: 0.8658 - acc: 0.6500\n",
"Epoch 31/150\n",
" - 0s - loss: 0.8555 - acc: 0.6500\n",
"Epoch 32/150\n",
" - 0s - loss: 0.8446 - acc: 0.6500\n",
"Epoch 33/150\n",
" - 0s - loss: 0.8337 - acc: 0.6500\n",
"Epoch 34/150\n",
" - 0s - loss: 0.8225 - acc: 0.6500\n",
"Epoch 35/150\n",
" - 0s - loss: 0.8112 - acc: 0.6500\n",
"Epoch 36/150\n",
" - 0s - loss: 0.7998 - acc: 0.6500\n",
"Epoch 37/150\n",
" - 0s - loss: 0.7886 - acc: 0.6500\n",
"Epoch 38/150\n",
" - 0s - loss: 0.7768 - acc: 0.6500\n",
"Epoch 39/150\n",
" - 0s - loss: 0.7650 - acc: 0.6500\n",
"Epoch 40/150\n",
" - 0s - loss: 0.7539 - acc: 0.6500\n",
"Epoch 41/150\n",
" - 0s - loss: 0.7428 - acc: 0.6500\n",
"Epoch 42/150\n",
" - 0s - loss: 0.7318 - acc: 0.6500\n",
"Epoch 43/150\n",
" - 0s - loss: 0.7213 - acc: 0.6500\n",
"Epoch 44/150\n",
" - 0s - loss: 0.7111 - acc: 0.6500\n",
"Epoch 45/150\n",
" - 0s - loss: 0.7007 - acc: 0.6500\n",
"Epoch 46/150\n",
" - 0s - loss: 0.6907 - acc: 0.6500\n",
"Epoch 47/150\n",
" - 0s - loss: 0.6806 - acc: 0.6500\n",
"Epoch 48/150\n",
" - 0s - loss: 0.6714 - acc: 0.6500\n",
"Epoch 49/150\n",
" - 0s - loss: 0.6621 - acc: 0.6500\n",
"Epoch 50/150\n",
" - 0s - loss: 0.6534 - acc: 0.6500\n",
"Epoch 51/150\n",
" - 0s - loss: 0.6452 - acc: 0.6500\n",
"Epoch 52/150\n",
" - 0s - loss: 0.6366 - acc: 0.6500\n",
"Epoch 53/150\n",
" - 0s - loss: 0.6285 - acc: 0.6500\n",
"Epoch 54/150\n",
" - 0s - loss: 0.6205 - acc: 0.6500\n",
"Epoch 55/150\n",
" - 0s - loss: 0.6130 - acc: 0.6500\n",
"Epoch 56/150\n",
" - 0s - loss: 0.6058 - acc: 0.6500\n",
"Epoch 57/150\n",
" - 0s - loss: 0.5990 - acc: 0.6500\n",
"Epoch 58/150\n",
" - 0s - loss: 0.5920 - acc: 0.6500\n",
"Epoch 59/150\n",
" - 0s - loss: 0.5852 - acc: 0.6700\n",
"Epoch 60/150\n",
" - 0s - loss: 0.5790 - acc: 0.6700\n",
"Epoch 61/150\n",
" - 0s - loss: 0.5727 - acc: 0.6900\n",
"Epoch 62/150\n",
" - 0s - loss: 0.5665 - acc: 0.6900\n",
"Epoch 63/150\n",
" - 0s - loss: 0.5605 - acc: 0.6900\n",
"Epoch 64/150\n",
" - 0s - loss: 0.5544 - acc: 0.6900\n",
"Epoch 65/150\n",
" - 0s - loss: 0.5490 - acc: 0.6900\n",
"Epoch 66/150\n",
" - 0s - loss: 0.5436 - acc: 0.7000\n",
"Epoch 67/150\n",
" - 0s - loss: 0.5385 - acc: 0.7000\n",
"Epoch 68/150\n",
" - 0s - loss: 0.5334 - acc: 0.7000\n",
"Epoch 69/150\n",
" - 0s - loss: 0.5287 - acc: 0.7100\n",
"Epoch 70/150\n",
" - 0s - loss: 0.5243 - acc: 0.7200\n",
"Epoch 71/150\n",
" - 0s - loss: 0.5198 - acc: 0.7300\n",
"Epoch 72/150\n",
" - 0s - loss: 0.5150 - acc: 0.7300\n",
"Epoch 73/150\n",
" - 0s - loss: 0.5108 - acc: 0.7400\n",
"Epoch 74/150\n",
" - 0s - loss: 0.5066 - acc: 0.7400\n",
"Epoch 75/150\n",
" - 0s - loss: 0.5022 - acc: 0.7600\n",
"Epoch 76/150\n",
" - 0s - loss: 0.4986 - acc: 0.7500\n",
"Epoch 77/150\n",
" - 0s - loss: 0.4945 - acc: 0.7400\n",
"Epoch 78/150\n",
" - 0s - loss: 0.4908 - acc: 0.7500\n",
"Epoch 79/150\n",
" - 0s - loss: 0.4867 - acc: 0.7700\n",
"Epoch 80/150\n",
" - 0s - loss: 0.4830 - acc: 0.7700\n",
"Epoch 81/150\n",
" - 0s - loss: 0.4794 - acc: 0.7900\n",
"Epoch 82/150\n",
" - 0s - loss: 0.4761 - acc: 0.8500\n",
"Epoch 83/150\n",
" - 0s - loss: 0.4726 - acc: 0.8500\n",
"Epoch 84/150\n",
" - 0s - loss: 0.4693 - acc: 0.8500\n",
"Epoch 85/150\n",
" - 0s - loss: 0.4660 - acc: 0.8500\n",
"Epoch 86/150\n",
" - 0s - loss: 0.4628 - acc: 0.8500\n",
"Epoch 87/150\n",
" - 0s - loss: 0.4599 - acc: 0.8200\n",
"Epoch 88/150\n",
" - 0s - loss: 0.4575 - acc: 0.8000\n",
"Epoch 89/150\n",
" - 0s - loss: 0.4546 - acc: 0.7800\n",
"Epoch 90/150\n",
" - 0s - loss: 0.4521 - acc: 0.7800\n",
"Epoch 91/150\n",
" - 0s - loss: 0.4490 - acc: 0.8100\n",
"Epoch 92/150\n",
" - 0s - loss: 0.4460 - acc: 0.8300\n",
"Epoch 93/150\n",
" - 0s - loss: 0.4433 - acc: 0.8500\n",
"Epoch 94/150\n",
" - 0s - loss: 0.4407 - acc: 0.8500\n",
"Epoch 95/150\n",
" - 0s - loss: 0.4380 - acc: 0.8500\n",
"Epoch 96/150\n",
" - 0s - loss: 0.4352 - acc: 0.8600\n",
"Epoch 97/150\n",
" - 0s - loss: 0.4326 - acc: 0.8600\n",
"Epoch 98/150\n",
" - 0s - loss: 0.4302 - acc: 0.9000\n",
"Epoch 99/150\n",
" - 0s - loss: 0.4270 - acc: 0.9100\n",
"Epoch 100/150\n",
" - 0s - loss: 0.4245 - acc: 0.9200\n",
"Epoch 101/150\n",
" - 0s - loss: 0.4223 - acc: 0.9200\n",
"Epoch 102/150\n",
" - 0s - loss: 0.4195 - acc: 0.9200\n",
"Epoch 103/150\n",
" - 0s - loss: 0.4171 - acc: 0.9200\n",
"Epoch 104/150\n",
" - 0s - loss: 0.4147 - acc: 0.9200\n",
"Epoch 105/150\n",
" - 0s - loss: 0.4124 - acc: 0.9000\n",
"Epoch 106/150\n",
" - 0s - loss: 0.4102 - acc: 0.9000\n",
"Epoch 107/150\n",
" - 0s - loss: 0.4077 - acc: 0.9200\n",
"Epoch 108/150\n",
" - 0s - loss: 0.4055 - acc: 0.9200\n",
"Epoch 109/150\n",
" - 0s - loss: 0.4028 - acc: 0.9200\n",
"Epoch 110/150\n",
" - 0s - loss: 0.4006 - acc: 0.9200\n",
"Epoch 111/150\n",
" - 0s - loss: 0.3985 - acc: 0.9200\n",
"Epoch 112/150\n",
" - 0s - loss: 0.3965 - acc: 0.9200\n",
"Epoch 113/150\n",
" - 0s - loss: 0.3945 - acc: 0.9200\n",
"Epoch 114/150\n",
" - 0s - loss: 0.3930 - acc: 0.9200\n",
"Epoch 115/150\n",
" - 0s - loss: 0.3914 - acc: 0.9000\n",
"Epoch 116/150\n",
" - 0s - loss: 0.3891 - acc: 0.9100\n",
"Epoch 117/150\n",
" - 0s - loss: 0.3860 - acc: 0.9200\n",
"Epoch 118/150\n",
" - 0s - loss: 0.3844 - acc: 0.9200\n",
"Epoch 119/150\n",
" - 0s - loss: 0.3821 - acc: 0.9500\n",
"Epoch 120/150\n",
" - 0s - loss: 0.3801 - acc: 0.9500\n",
"Epoch 121/150\n",
" - 0s - loss: 0.3780 - acc: 0.9500\n",
"Epoch 122/150\n",
" - 0s - loss: 0.3757 - acc: 0.9500\n",
"Epoch 123/150\n",
" - 0s - loss: 0.3738 - acc: 0.9500\n",
"Epoch 124/150\n",
" - 0s - loss: 0.3720 - acc: 0.9500\n",
"Epoch 125/150\n",
" - 0s - loss: 0.3693 - acc: 0.9500\n",
"Epoch 126/150\n",
" - 0s - loss: 0.3662 - acc: 0.9500\n",
"Epoch 127/150\n",
" - 0s - loss: 0.3646 - acc: 0.9300\n",
"Epoch 128/150\n",
" - 0s - loss: 0.3652 - acc: 0.9200\n",
"Epoch 129/150\n",
" - 0s - loss: 0.3640 - acc: 0.9100\n",
"Epoch 130/150\n",
" - 0s - loss: 0.3627 - acc: 0.9100\n",
"Epoch 131/150\n",
" - 0s - loss: 0.3604 - acc: 0.9100\n",
"Epoch 132/150\n",
" - 0s - loss: 0.3572 - acc: 0.9300\n",
"Epoch 133/150\n",
" - 0s - loss: 0.3544 - acc: 0.9300\n",
"Epoch 134/150\n",
" - 0s - loss: 0.3521 - acc: 0.9300\n",
"Epoch 135/150\n",
" - 0s - loss: 0.3501 - acc: 0.9500\n",
"Epoch 136/150\n",
" - 0s - loss: 0.3482 - acc: 0.9500\n",
"Epoch 137/150\n",
" - 0s - loss: 0.3465 - acc: 0.9500\n",
"Epoch 138/150\n",
" - 0s - loss: 0.3461 - acc: 0.9500\n",
"Epoch 139/150\n",
" - 0s - loss: 0.3435 - acc: 0.9300\n",
"Epoch 140/150\n",
" - 0s - loss: 0.3407 - acc: 0.9500\n",
"Epoch 141/150\n",
" - 0s - loss: 0.3384 - acc: 0.9500\n",
"Epoch 142/150\n",
" - 0s - loss: 0.3366 - acc: 0.9500\n",
"Epoch 143/150\n",
" - 0s - loss: 0.3347 - acc: 0.9500\n",
"Epoch 144/150\n",
" - 0s - loss: 0.3336 - acc: 0.9400\n",
"Epoch 145/150\n",
" - 0s - loss: 0.3324 - acc: 0.9300\n",
"Epoch 146/150\n",
" - 0s - loss: 0.3311 - acc: 0.9300\n",
"Epoch 147/150\n",
" - 0s - loss: 0.3300 - acc: 0.9300\n",
"Epoch 148/150\n",
" - 0s - loss: 0.3283 - acc: 0.9300\n",
"Epoch 149/150\n",
" - 0s - loss: 0.3256 - acc: 0.9400\n",
"Epoch 150/150\n",
" - 0s - loss: 0.3224 - acc: 0.9500\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 135,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Play around with number of epochs as well!\n",
"model.fit(scaled_X_train,y_train,epochs=150, verbose=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Predicting New Unseen Data\n",
"\n",
"Let's see how we did by predicting on **new data**. Remember, our model has **never** seen the test data that we scaled previously! This process is the exact same process you would use on totally brand new data. For example , a brand new bank note that you just analyzed ."
]
},
{
"cell_type": "code",
"execution_count": 136,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.52941176, 0.36363636, 0.64285714, 0.45833333],\n",
" [ 0.41176471, 0.81818182, 0.10714286, 0.08333333],\n",
" [ 1. , 0.27272727, 1.03571429, 0.91666667],\n",
" [ 0.5 , 0.40909091, 0.60714286, 0.58333333],\n",
" [ 0.73529412, 0.36363636, 0.66071429, 0.54166667],\n",
" [ 0.32352941, 0.63636364, 0.07142857, 0.125 ],\n",
" [ 0.38235294, 0.40909091, 0.44642857, 0.5 ],\n",
" [ 0.76470588, 0.5 , 0.71428571, 0.91666667],\n",
" [ 0.55882353, 0.09090909, 0.60714286, 0.58333333],\n",
" [ 0.44117647, 0.31818182, 0.5 , 0.45833333],\n",
" [ 0.64705882, 0.54545455, 0.71428571, 0.79166667],\n",
" [ 0.14705882, 0.45454545, 0.05357143, 0. ],\n",
" [ 0.35294118, 0.68181818, 0.03571429, 0.04166667],\n",
" [ 0.17647059, 0.5 , 0.07142857, 0. ],\n",
" [ 0.23529412, 0.81818182, 0.07142857, 0.08333333],\n",
" [ 0.58823529, 0.59090909, 0.64285714, 0.625 ],\n",
" [ 0.64705882, 0.45454545, 0.83928571, 0.875 ],\n",
" [ 0.38235294, 0.22727273, 0.5 , 0.41666667],\n",
" [ 0.41176471, 0.36363636, 0.60714286, 0.5 ],\n",
" [ 0.61764706, 0.36363636, 0.80357143, 0.875 ],\n",
" [ 0.11764706, 0.54545455, 0.08928571, 0.04166667],\n",
" [ 0.52941176, 0.45454545, 0.67857143, 0.70833333],\n",
" [ 0.20588235, 0.63636364, 0.08928571, 0.125 ],\n",
" [ 0.61764706, 0.36363636, 0.80357143, 0.83333333],\n",
" [ 1.05882353, 0.81818182, 0.94642857, 0.79166667],\n",
" [ 0.70588235, 0.45454545, 0.73214286, 0.91666667],\n",
" [ 0.70588235, 0.22727273, 0.83928571, 0.70833333],\n",
" [ 0.73529412, 0.54545455, 0.85714286, 0.91666667],\n",
" [ 0.14705882, 0.45454545, 0.05357143, 0.08333333],\n",
" [ 0.14705882, 0.5 , 0.08928571, 0.04166667],\n",
" [ 0.08823529, 0.72727273, -0.01785714, 0.04166667],\n",
" [ 0.41176471, 1.09090909, 0.07142857, 0.125 ],\n",
" [ 0.70588235, 0.5 , 0.58928571, 0.54166667],\n",
" [ 0.14705882, 0.63636364, 0.08928571, 0.04166667],\n",
" [ 0.02941176, 0.54545455, 0.03571429, 0.04166667],\n",
" [ 0.58823529, 0.22727273, 0.69642857, 0.75 ],\n",
" [ 0.61764706, 0.54545455, 0.60714286, 0.58333333],\n",
" [ 0.26470588, 0.68181818, 0.07142857, 0.04166667],\n",
" [ 0.20588235, 0.72727273, 0.05357143, 0.04166667],\n",
" [ 0.26470588, 0.95454545, 0.07142857, 0. ],\n",
" [ 0.44117647, 0.31818182, 0.71428571, 0.75 ],\n",
" [ 0.5 , 0.63636364, 0.60714286, 0.625 ],\n",
" [ 0.70588235, 0.5 , 0.64285714, 0.58333333],\n",
" [ 0.32352941, 0.86363636, 0.03571429, 0.125 ],\n",
" [ 0.32352941, 0.77272727, 0.07142857, 0.04166667],\n",
" [ 0.35294118, 0.18181818, 0.46428571, 0.375 ],\n",
" [ 0.58823529, 0.36363636, 0.71428571, 0.58333333],\n",
" [ 0.61764706, 0.5 , 0.78571429, 0.70833333],\n",
" [ 0.67647059, 0.45454545, 0.58928571, 0.54166667],\n",
" [ 0.85294118, 0.72727273, 0.89285714, 1. ]])"
]
},
"execution_count": 136,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"scaled_X_test"
]
},
{
"cell_type": "code",
"execution_count": 137,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Spits out probabilities by default.\n",
"# model.predict(scaled_X_test)"
]
},
{
"cell_type": "code",
"execution_count": 138,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1, 0, 2, 1, 2, 0, 1, 2, 2, 1, 2, 0, 0, 0, 0, 1, 2, 1, 1, 2, 0, 2,\n",
" 0, 2, 2, 2, 2, 2, 0, 0, 0, 0, 1, 0, 0, 2, 1, 0, 0, 0, 2, 1, 1, 0,\n",
" 0, 1, 2, 2, 1, 2], dtype=int64)"
]
},
"execution_count": 138,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.predict_classes(scaled_X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Evaluating Model Performance\n",
"\n",
"So how well did we do? How do we actually measure \"well\". Is 95% accuracy good enough? It all depends on the situation. Also we need to take into account things like recall and precision. Make sure to watch the video discussion on classification evaluation before running this code!"
]
},
{
"cell_type": "code",
"execution_count": 139,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['loss', 'acc']"
]
},
"execution_count": 139,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.metrics_names"
]
},
{
"cell_type": "code",
"execution_count": 140,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"50/50 [==============================] - 0s 2ms/step\n"
]
},
{
"data": {
"text/plain": [
"[0.2843634402751923, 0.96]"
]
},
"execution_count": 140,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.evaluate(x=scaled_X_test,y=y_test)"
]
},
{
"cell_type": "code",
"execution_count": 141,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.metrics import confusion_matrix,classification_report"
]
},
{
"cell_type": "code",
"execution_count": 142,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"predictions = model.predict_classes(scaled_X_test)"
]
},
{
"cell_type": "code",
"execution_count": 161,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1, 0, 2, 1, 2, 0, 1, 2, 2, 1, 2, 0, 0, 0, 0, 1, 2, 1, 1, 2, 0, 2,\n",
" 0, 2, 2, 2, 2, 2, 0, 0, 0, 0, 1, 0, 0, 2, 1, 0, 0, 0, 2, 1, 1, 0,\n",
" 0, 1, 2, 2, 1, 2], dtype=int64)"
]
},
"execution_count": 161,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions"
]
},
{
"cell_type": "code",
"execution_count": 163,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1, 0, 2, 1, 1, 0, 1, 2, 1, 1, 2, 0, 0, 0, 0, 1, 2, 1, 1, 2, 0, 2,\n",
" 0, 2, 2, 2, 2, 2, 0, 0, 0, 0, 1, 0, 0, 2, 1, 0, 0, 0, 2, 1, 1, 0,\n",
" 0, 1, 2, 2, 1, 2], dtype=int64)"
]
},
"execution_count": 163,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_test.argmax(axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 164,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[19, 0, 0],\n",
" [ 0, 13, 2],\n",
" [ 0, 0, 16]], dtype=int64)"
]
},
"execution_count": 164,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"confusion_matrix(y_test.argmax(axis=1),predictions)"
]
},
{
"cell_type": "code",
"execution_count": 165,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 1.00 1.00 1.00 19\n",
" 1 1.00 0.87 0.93 15\n",
" 2 0.89 1.00 0.94 16\n",
"\n",
" micro avg 0.96 0.96 0.96 50\n",
" macro avg 0.96 0.96 0.96 50\n",
"weighted avg 0.96 0.96 0.96 50\n",
"\n"
]
}
],
"source": [
"print(classification_report(y_test.argmax(axis=1),predictions))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Saving and Loading Models\n",
"\n",
"Now that we have a model trained, let's see how we can save and load it."
]
},
{
"cell_type": "code",
"execution_count": 166,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"model.save('myfirstmodel.h5')"
]
},
{
"cell_type": "code",
"execution_count": 167,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from keras.models import load_model"
]
},
{
"cell_type": "code",
"execution_count": 168,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"newmodel = load_model('myfirstmodel.h5')"
]
},
{
"cell_type": "code",
"execution_count": 169,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([2, 1, 2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2,\n",
" 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 1, 1, 1, 2, 2, 2, 1,\n",
" 1, 2, 2, 2, 2, 2], dtype=int64)"
]
},
"execution_count": 169,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"newmodel.predict_classes(X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Great job! You now know how to preprocess data, train a neural network, and evaluate its classification performance!"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}