{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "___\n", "\n", " \n", "___\n", "\n", "# Keras Basics\n", "\n", "Welcome to the section on deep learning! We'll be using Keras with a TensorFlow backend to perform our deep learning operations.\n", "\n", "This means we should get familiar with some Keras fundamentals and basics!\n", "\n", "## Imports\n", "\n" ] }, { "cell_type": "code", "execution_count": 145, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset\n", "\n", "We will use the famous Iris Data set.\n", "_____\n", "More info on the data set:\n", "https://en.wikipedia.org/wiki/Iris_flower_data_set\n", "\n", "## Reading in the Data Set\n", "\n", "We've already downloaded the dataset, its in this folder. So let's open it up. " ] }, { "cell_type": "code", "execution_count": 146, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.datasets import load_iris" ] }, { "cell_type": "code", "execution_count": 147, "metadata": { "collapsed": true }, "outputs": [], "source": [ "iris = load_iris()" ] }, { "cell_type": "code", "execution_count": 148, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "sklearn.utils.Bunch" ] }, "execution_count": 148, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(iris)" ] }, { "cell_type": "code", "execution_count": 149, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ".. _iris_dataset:\n", "\n", "Iris plants dataset\n", "--------------------\n", "\n", "**Data Set Characteristics:**\n", "\n", " :Number of Instances: 150 (50 in each of three classes)\n", " :Number of Attributes: 4 numeric, predictive attributes and the class\n", " :Attribute Information:\n", " - sepal length in cm\n", " - sepal width in cm\n", " - petal length in cm\n", " - petal width in cm\n", " - class:\n", " - Iris-Setosa\n", " - Iris-Versicolour\n", " - Iris-Virginica\n", " \n", " :Summary Statistics:\n", "\n", " ============== ==== ==== ======= ===== ====================\n", " Min Max Mean SD Class Correlation\n", " ============== ==== ==== ======= ===== ====================\n", " sepal length: 4.3 7.9 5.84 0.83 0.7826\n", " sepal width: 2.0 4.4 3.05 0.43 -0.4194\n", " petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)\n", " petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)\n", " ============== ==== ==== ======= ===== ====================\n", "\n", " :Missing Attribute Values: None\n", " :Class Distribution: 33.3% for each of 3 classes.\n", " :Creator: R.A. Fisher\n", " :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n", " :Date: July, 1988\n", "\n", "The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken\n", "from Fisher's paper. Note that it's the same as in R, but not as in the UCI\n", "Machine Learning Repository, which has two wrong data points.\n", "\n", "This is perhaps the best known database to be found in the\n", "pattern recognition literature. Fisher's paper is a classic in the field and\n", "is referenced frequently to this day. (See Duda & Hart, for example.) The\n", "data set contains 3 classes of 50 instances each, where each class refers to a\n", "type of iris plant. One class is linearly separable from the other 2; the\n", "latter are NOT linearly separable from each other.\n", "\n", ".. topic:: References\n", "\n", " - Fisher, R.A. \"The use of multiple measurements in taxonomic problems\"\n", " Annual Eugenics, 7, Part II, 179-188 (1936); also in \"Contributions to\n", " Mathematical Statistics\" (John Wiley, NY, 1950).\n", " - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.\n", " (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.\n", " - Dasarathy, B.V. (1980) \"Nosing Around the Neighborhood: A New System\n", " Structure and Classification Rule for Recognition in Partially Exposed\n", " Environments\". IEEE Transactions on Pattern Analysis and Machine\n", " Intelligence, Vol. PAMI-2, No. 1, 67-71.\n", " - Gates, G.W. (1972) \"The Reduced Nearest Neighbor Rule\". IEEE Transactions\n", " on Information Theory, May 1972, 431-433.\n", " - See also: 1988 MLC Proceedings, 54-64. Cheeseman et al\"s AUTOCLASS II\n", " conceptual clustering system finds 3 classes in the data.\n", " - Many, many more ...\n" ] } ], "source": [ "print(iris.DESCR)" ] }, { "cell_type": "code", "execution_count": 150, "metadata": { "collapsed": true }, "outputs": [], "source": [ "X = iris.data" ] }, { "cell_type": "code", "execution_count": 151, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[5.1, 3.5, 1.4, 0.2],\n", " [4.9, 3. , 1.4, 0.2],\n", " [4.7, 3.2, 1.3, 0.2],\n", " [4.6, 3.1, 1.5, 0.2],\n", " [5. , 3.6, 1.4, 0.2],\n", " [5.4, 3.9, 1.7, 0.4],\n", " [4.6, 3.4, 1.4, 0.3],\n", " [5. , 3.4, 1.5, 0.2],\n", " [4.4, 2.9, 1.4, 0.2],\n", " [4.9, 3.1, 1.5, 0.1],\n", " [5.4, 3.7, 1.5, 0.2],\n", " [4.8, 3.4, 1.6, 0.2],\n", " [4.8, 3. , 1.4, 0.1],\n", " [4.3, 3. , 1.1, 0.1],\n", " [5.8, 4. , 1.2, 0.2],\n", " [5.7, 4.4, 1.5, 0.4],\n", " [5.4, 3.9, 1.3, 0.4],\n", " [5.1, 3.5, 1.4, 0.3],\n", " [5.7, 3.8, 1.7, 0.3],\n", " [5.1, 3.8, 1.5, 0.3],\n", " [5.4, 3.4, 1.7, 0.2],\n", " [5.1, 3.7, 1.5, 0.4],\n", " [4.6, 3.6, 1. , 0.2],\n", " [5.1, 3.3, 1.7, 0.5],\n", " [4.8, 3.4, 1.9, 0.2],\n", " [5. , 3. , 1.6, 0.2],\n", " [5. , 3.4, 1.6, 0.4],\n", " [5.2, 3.5, 1.5, 0.2],\n", " [5.2, 3.4, 1.4, 0.2],\n", " [4.7, 3.2, 1.6, 0.2],\n", " [4.8, 3.1, 1.6, 0.2],\n", " [5.4, 3.4, 1.5, 0.4],\n", " [5.2, 4.1, 1.5, 0.1],\n", " [5.5, 4.2, 1.4, 0.2],\n", " [4.9, 3.1, 1.5, 0.2],\n", " [5. , 3.2, 1.2, 0.2],\n", " [5.5, 3.5, 1.3, 0.2],\n", " [4.9, 3.6, 1.4, 0.1],\n", " [4.4, 3. , 1.3, 0.2],\n", " [5.1, 3.4, 1.5, 0.2],\n", " [5. , 3.5, 1.3, 0.3],\n", " [4.5, 2.3, 1.3, 0.3],\n", " [4.4, 3.2, 1.3, 0.2],\n", " [5. , 3.5, 1.6, 0.6],\n", " [5.1, 3.8, 1.9, 0.4],\n", " [4.8, 3. , 1.4, 0.3],\n", " [5.1, 3.8, 1.6, 0.2],\n", " [4.6, 3.2, 1.4, 0.2],\n", " [5.3, 3.7, 1.5, 0.2],\n", " [5. , 3.3, 1.4, 0.2],\n", " [7. , 3.2, 4.7, 1.4],\n", " [6.4, 3.2, 4.5, 1.5],\n", " [6.9, 3.1, 4.9, 1.5],\n", " [5.5, 2.3, 4. , 1.3],\n", " [6.5, 2.8, 4.6, 1.5],\n", " [5.7, 2.8, 4.5, 1.3],\n", " [6.3, 3.3, 4.7, 1.6],\n", " [4.9, 2.4, 3.3, 1. ],\n", " [6.6, 2.9, 4.6, 1.3],\n", " [5.2, 2.7, 3.9, 1.4],\n", " [5. , 2. , 3.5, 1. ],\n", " [5.9, 3. , 4.2, 1.5],\n", " [6. , 2.2, 4. , 1. ],\n", " [6.1, 2.9, 4.7, 1.4],\n", " [5.6, 2.9, 3.6, 1.3],\n", " [6.7, 3.1, 4.4, 1.4],\n", " [5.6, 3. , 4.5, 1.5],\n", " [5.8, 2.7, 4.1, 1. ],\n", " [6.2, 2.2, 4.5, 1.5],\n", " [5.6, 2.5, 3.9, 1.1],\n", " [5.9, 3.2, 4.8, 1.8],\n", " [6.1, 2.8, 4. , 1.3],\n", " [6.3, 2.5, 4.9, 1.5],\n", " [6.1, 2.8, 4.7, 1.2],\n", " [6.4, 2.9, 4.3, 1.3],\n", " [6.6, 3. , 4.4, 1.4],\n", " [6.8, 2.8, 4.8, 1.4],\n", " [6.7, 3. , 5. , 1.7],\n", " [6. , 2.9, 4.5, 1.5],\n", " [5.7, 2.6, 3.5, 1. ],\n", " [5.5, 2.4, 3.8, 1.1],\n", " [5.5, 2.4, 3.7, 1. ],\n", " [5.8, 2.7, 3.9, 1.2],\n", " [6. , 2.7, 5.1, 1.6],\n", " [5.4, 3. , 4.5, 1.5],\n", " [6. , 3.4, 4.5, 1.6],\n", " [6.7, 3.1, 4.7, 1.5],\n", " [6.3, 2.3, 4.4, 1.3],\n", " [5.6, 3. , 4.1, 1.3],\n", " [5.5, 2.5, 4. , 1.3],\n", " [5.5, 2.6, 4.4, 1.2],\n", " [6.1, 3. , 4.6, 1.4],\n", " [5.8, 2.6, 4. , 1.2],\n", " [5. , 2.3, 3.3, 1. ],\n", " [5.6, 2.7, 4.2, 1.3],\n", " [5.7, 3. , 4.2, 1.2],\n", " [5.7, 2.9, 4.2, 1.3],\n", " [6.2, 2.9, 4.3, 1.3],\n", " [5.1, 2.5, 3. , 1.1],\n", " [5.7, 2.8, 4.1, 1.3],\n", " [6.3, 3.3, 6. , 2.5],\n", " [5.8, 2.7, 5.1, 1.9],\n", " [7.1, 3. , 5.9, 2.1],\n", " [6.3, 2.9, 5.6, 1.8],\n", " [6.5, 3. , 5.8, 2.2],\n", " [7.6, 3. , 6.6, 2.1],\n", " [4.9, 2.5, 4.5, 1.7],\n", " [7.3, 2.9, 6.3, 1.8],\n", " [6.7, 2.5, 5.8, 1.8],\n", " [7.2, 3.6, 6.1, 2.5],\n", " [6.5, 3.2, 5.1, 2. ],\n", " [6.4, 2.7, 5.3, 1.9],\n", " [6.8, 3. , 5.5, 2.1],\n", " [5.7, 2.5, 5. , 2. ],\n", " [5.8, 2.8, 5.1, 2.4],\n", " [6.4, 3.2, 5.3, 2.3],\n", " [6.5, 3. , 5.5, 1.8],\n", " [7.7, 3.8, 6.7, 2.2],\n", " [7.7, 2.6, 6.9, 2.3],\n", " [6. , 2.2, 5. , 1.5],\n", " [6.9, 3.2, 5.7, 2.3],\n", " [5.6, 2.8, 4.9, 2. ],\n", " [7.7, 2.8, 6.7, 2. ],\n", " [6.3, 2.7, 4.9, 1.8],\n", " [6.7, 3.3, 5.7, 2.1],\n", " [7.2, 3.2, 6. , 1.8],\n", " [6.2, 2.8, 4.8, 1.8],\n", " [6.1, 3. , 4.9, 1.8],\n", " [6.4, 2.8, 5.6, 2.1],\n", " [7.2, 3. , 5.8, 1.6],\n", " [7.4, 2.8, 6.1, 1.9],\n", " [7.9, 3.8, 6.4, 2. ],\n", " [6.4, 2.8, 5.6, 2.2],\n", " [6.3, 2.8, 5.1, 1.5],\n", " [6.1, 2.6, 5.6, 1.4],\n", " [7.7, 3. , 6.1, 2.3],\n", " [6.3, 3.4, 5.6, 2.4],\n", " [6.4, 3.1, 5.5, 1.8],\n", " [6. , 3. , 4.8, 1.8],\n", " [6.9, 3.1, 5.4, 2.1],\n", " [6.7, 3.1, 5.6, 2.4],\n", " [6.9, 3.1, 5.1, 2.3],\n", " [5.8, 2.7, 5.1, 1.9],\n", " [6.8, 3.2, 5.9, 2.3],\n", " [6.7, 3.3, 5.7, 2.5],\n", " [6.7, 3. , 5.2, 2.3],\n", " [6.3, 2.5, 5. , 1.9],\n", " [6.5, 3. , 5.2, 2. ],\n", " [6.2, 3.4, 5.4, 2.3],\n", " [5.9, 3. , 5.1, 1.8]])" ] }, "execution_count": 151, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X" ] }, { "cell_type": "code", "execution_count": 152, "metadata": { "collapsed": true }, "outputs": [], "source": [ "y = iris.target" ] }, { "cell_type": "code", "execution_count": 153, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])" ] }, "execution_count": 153, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y" ] }, { "cell_type": "code", "execution_count": 154, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from keras.utils import to_categorical" ] }, { "cell_type": "code", "execution_count": 156, "metadata": { "collapsed": true }, "outputs": [], "source": [ "y = to_categorical(y)" ] }, { "cell_type": "code", "execution_count": 157, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(150, 3)" ] }, "execution_count": 157, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y.shape" ] }, { "cell_type": "code", "execution_count": 158, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.]], dtype=float32)" ] }, "execution_count": 158, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Split the Data into Training and Test\n", "\n", "Its time to split the data into a train/test set. Keep in mind, sometimes people like to split 3 ways, train/test/validation. We'll keep things simple for now. **Remember to check out the video explanation as to why we split and what all the parameters mean!**" ] }, { "cell_type": "code", "execution_count": 159, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "code", "execution_count": 160, "metadata": { "collapsed": true }, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)" ] }, { "cell_type": "code", "execution_count": 111, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[5.7, 2.9, 4.2, 1.3],\n", " [7.6, 3. , 6.6, 2.1],\n", " [5.6, 3. , 4.5, 1.5],\n", " [5.1, 3.5, 1.4, 0.2],\n", " [7.7, 2.8, 6.7, 2. ],\n", " [5.8, 2.7, 4.1, 1. ],\n", " [5.2, 3.4, 1.4, 0.2],\n", " [5. , 3.5, 1.3, 0.3],\n", " [5.1, 3.8, 1.9, 0.4],\n", " [5. , 2. , 3.5, 1. ],\n", " [6.3, 2.7, 4.9, 1.8],\n", " [4.8, 3.4, 1.9, 0.2],\n", " [5. , 3. , 1.6, 0.2],\n", " [5.1, 3.3, 1.7, 0.5],\n", " [5.6, 2.7, 4.2, 1.3],\n", " [5.1, 3.4, 1.5, 0.2],\n", " [5.7, 3. , 4.2, 1.2],\n", " [7.7, 3.8, 6.7, 2.2],\n", " [4.6, 3.2, 1.4, 0.2],\n", " [6.2, 2.9, 4.3, 1.3],\n", " [5.7, 2.5, 5. , 2. ],\n", " [5.5, 4.2, 1.4, 0.2],\n", " [6. , 3. , 4.8, 1.8],\n", " [5.8, 2.7, 5.1, 1.9],\n", " [6. , 2.2, 4. , 1. ],\n", " [5.4, 3. , 4.5, 1.5],\n", " [6.2, 3.4, 5.4, 2.3],\n", " [5.5, 2.3, 4. , 1.3],\n", " [5.4, 3.9, 1.7, 0.4],\n", " [5. , 2.3, 3.3, 1. ],\n", " [6.4, 2.7, 5.3, 1.9],\n", " [5. , 3.3, 1.4, 0.2],\n", " [5. , 3.2, 1.2, 0.2],\n", " [5.5, 2.4, 3.8, 1.1],\n", " [6.7, 3. , 5. , 1.7],\n", " [4.9, 3.1, 1.5, 0.2],\n", " [5.8, 2.8, 5.1, 2.4],\n", " [5. , 3.4, 1.5, 0.2],\n", " [5. , 3.5, 1.6, 0.6],\n", " [5.9, 3.2, 4.8, 1.8],\n", " [5.1, 2.5, 3. , 1.1],\n", " [6.9, 3.2, 5.7, 2.3],\n", " [6. , 2.7, 5.1, 1.6],\n", " [6.1, 2.6, 5.6, 1.4],\n", " [7.7, 3. , 6.1, 2.3],\n", " [5.5, 2.5, 4. , 1.3],\n", " [4.4, 2.9, 1.4, 0.2],\n", " [4.3, 3. , 1.1, 0.1],\n", " [6. , 2.2, 5. , 1.5],\n", " [7.2, 3.2, 6. , 1.8],\n", " [4.6, 3.1, 1.5, 0.2],\n", " [5.1, 3.5, 1.4, 0.3],\n", " [4.4, 3. , 1.3, 0.2],\n", " [6.3, 2.5, 4.9, 1.5],\n", " [6.3, 3.4, 5.6, 2.4],\n", " [4.6, 3.4, 1.4, 0.3],\n", " [6.8, 3. , 5.5, 2.1],\n", " [6.3, 3.3, 6. , 2.5],\n", " [4.7, 3.2, 1.3, 0.2],\n", " [6.1, 2.9, 4.7, 1.4],\n", " [6.5, 2.8, 4.6, 1.5],\n", " [6.2, 2.8, 4.8, 1.8],\n", " [7. , 3.2, 4.7, 1.4],\n", " [6.4, 3.2, 5.3, 2.3],\n", " [5.1, 3.8, 1.6, 0.2],\n", " [6.9, 3.1, 5.4, 2.1],\n", " [5.9, 3. , 4.2, 1.5],\n", " [6.5, 3. , 5.2, 2. ],\n", " [5.7, 2.6, 3.5, 1. ],\n", " [5.2, 2.7, 3.9, 1.4],\n", " [6.1, 3. , 4.6, 1.4],\n", " [4.5, 2.3, 1.3, 0.3],\n", " [6.6, 2.9, 4.6, 1.3],\n", " [5.5, 2.6, 4.4, 1.2],\n", " [5.3, 3.7, 1.5, 0.2],\n", " [5.6, 3. , 4.1, 1.3],\n", " [7.3, 2.9, 6.3, 1.8],\n", " [6.7, 3.3, 5.7, 2.1],\n", " [5.1, 3.7, 1.5, 0.4],\n", " [4.9, 2.4, 3.3, 1. ],\n", " [6.7, 3.3, 5.7, 2.5],\n", " [7.2, 3. , 5.8, 1.6],\n", " [4.9, 3.6, 1.4, 0.1],\n", " [6.7, 3.1, 5.6, 2.4],\n", " [4.9, 3. , 1.4, 0.2],\n", " [6.9, 3.1, 4.9, 1.5],\n", " [7.4, 2.8, 6.1, 1.9],\n", " [6.3, 2.9, 5.6, 1.8],\n", " [5.7, 2.8, 4.1, 1.3],\n", " [6.5, 3. , 5.5, 1.8],\n", " [6.3, 2.3, 4.4, 1.3],\n", " [6.4, 2.9, 4.3, 1.3],\n", " [5.6, 2.8, 4.9, 2. ],\n", " [5.9, 3. , 5.1, 1.8],\n", " [5.4, 3.4, 1.7, 0.2],\n", " [6.1, 2.8, 4. , 1.3],\n", " [4.9, 2.5, 4.5, 1.7],\n", " [5.8, 4. , 1.2, 0.2],\n", " [5.8, 2.6, 4. , 1.2],\n", " [7.1, 3. , 5.9, 2.1]])" ] }, "execution_count": 111, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train" ] }, { "cell_type": "code", "execution_count": 112, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[6.1, 2.8, 4.7, 1.2],\n", " [5.7, 3.8, 1.7, 0.3],\n", " [7.7, 2.6, 6.9, 2.3],\n", " [6. , 2.9, 4.5, 1.5],\n", " [6.8, 2.8, 4.8, 1.4],\n", " [5.4, 3.4, 1.5, 0.4],\n", " [5.6, 2.9, 3.6, 1.3],\n", " [6.9, 3.1, 5.1, 2.3],\n", " [6.2, 2.2, 4.5, 1.5],\n", " [5.8, 2.7, 3.9, 1.2],\n", " [6.5, 3.2, 5.1, 2. ],\n", " [4.8, 3. , 1.4, 0.1],\n", " [5.5, 3.5, 1.3, 0.2],\n", " [4.9, 3.1, 1.5, 0.1],\n", " [5.1, 3.8, 1.5, 0.3],\n", " [6.3, 3.3, 4.7, 1.6],\n", " [6.5, 3. , 5.8, 2.2],\n", " [5.6, 2.5, 3.9, 1.1],\n", " [5.7, 2.8, 4.5, 1.3],\n", " [6.4, 2.8, 5.6, 2.2],\n", " [4.7, 3.2, 1.6, 0.2],\n", " [6.1, 3. , 4.9, 1.8],\n", " [5. , 3.4, 1.6, 0.4],\n", " [6.4, 2.8, 5.6, 2.1],\n", " [7.9, 3.8, 6.4, 2. ],\n", " [6.7, 3. , 5.2, 2.3],\n", " [6.7, 2.5, 5.8, 1.8],\n", " [6.8, 3.2, 5.9, 2.3],\n", " [4.8, 3. , 1.4, 0.3],\n", " [4.8, 3.1, 1.6, 0.2],\n", " [4.6, 3.6, 1. , 0.2],\n", " [5.7, 4.4, 1.5, 0.4],\n", " [6.7, 3.1, 4.4, 1.4],\n", " [4.8, 3.4, 1.6, 0.2],\n", " [4.4, 3.2, 1.3, 0.2],\n", " [6.3, 2.5, 5. , 1.9],\n", " [6.4, 3.2, 4.5, 1.5],\n", " [5.2, 3.5, 1.5, 0.2],\n", " [5. , 3.6, 1.4, 0.2],\n", " [5.2, 4.1, 1.5, 0.1],\n", " [5.8, 2.7, 5.1, 1.9],\n", " [6. , 3.4, 4.5, 1.6],\n", " [6.7, 3.1, 4.7, 1.5],\n", " [5.4, 3.9, 1.3, 0.4],\n", " [5.4, 3.7, 1.5, 0.2],\n", " [5.5, 2.4, 3.7, 1. ],\n", " [6.3, 2.8, 5.1, 1.5],\n", " [6.4, 3.1, 5.5, 1.8],\n", " [6.6, 3. , 4.4, 1.4],\n", " [7.2, 3.6, 6.1, 2.5]])" ] }, "execution_count": 112, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_test" ] }, { "cell_type": "code", "execution_count": 113, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0., 1., 0.],\n", " [0., 0., 1.],\n", " [0., 1., 0.],\n", " [1., 0., 0.],\n", " [0., 0., 1.],\n", " [0., 1., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [1., 0., 0.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [0., 1., 0.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [1., 0., 0.],\n", " [0., 0., 1.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 1., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [1., 0., 0.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [1., 0., 0.],\n", " [0., 0., 1.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [1., 0., 0.],\n", " [0., 0., 1.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.]], dtype=float32)" ] }, "execution_count": 113, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train" ] }, { "cell_type": "code", "execution_count": 114, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0., 1., 0.],\n", " [1., 0., 0.],\n", " [0., 0., 1.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [1., 0., 0.],\n", " [0., 0., 1.],\n", " [1., 0., 0.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [0., 0., 1.],\n", " [0., 1., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [0., 0., 1.],\n", " [0., 1., 0.],\n", " [0., 1., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [0., 1., 0.],\n", " [0., 0., 1.],\n", " [0., 0., 1.],\n", " [0., 1., 0.],\n", " [0., 0., 1.]], dtype=float32)" ] }, "execution_count": 114, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_test" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Standardizing the Data\n", "\n", "Usually when using Neural Networks, you will get better performance when you standardize the data. Standardization just means normalizing the values to all fit between a certain range, like 0-1, or -1 to 1.\n", "\n", "The scikit learn library also provides a nice function for this.\n", "\n", "http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MinMaxScaler.html" ] }, { "cell_type": "code", "execution_count": 115, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.preprocessing import MinMaxScaler" ] }, { "cell_type": "code", "execution_count": 116, "metadata": { "collapsed": true }, "outputs": [], "source": [ "scaler_object = MinMaxScaler()" ] }, { "cell_type": "code", "execution_count": 117, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MinMaxScaler(copy=True, feature_range=(0, 1))" ] }, "execution_count": 117, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scaler_object.fit(X_train)" ] }, { "cell_type": "code", "execution_count": 118, "metadata": { "collapsed": true }, "outputs": [], "source": [ "scaled_X_train = scaler_object.transform(X_train)" ] }, { "cell_type": "code", "execution_count": 119, "metadata": { "collapsed": true }, "outputs": [], "source": [ "scaled_X_test = scaler_object.transform(X_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Ok, now we have the data scaled!" ] }, { "cell_type": "code", "execution_count": 120, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "7.7" ] }, "execution_count": 120, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train.max()" ] }, { "cell_type": "code", "execution_count": 121, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.0" ] }, "execution_count": 121, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scaled_X_train.max()" ] }, { "cell_type": "code", "execution_count": 122, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[5.7, 2.9, 4.2, 1.3],\n", " [7.6, 3. , 6.6, 2.1],\n", " [5.6, 3. , 4.5, 1.5],\n", " [5.1, 3.5, 1.4, 0.2],\n", " [7.7, 2.8, 6.7, 2. ],\n", " [5.8, 2.7, 4.1, 1. ],\n", " [5.2, 3.4, 1.4, 0.2],\n", " [5. , 3.5, 1.3, 0.3],\n", " [5.1, 3.8, 1.9, 0.4],\n", " [5. , 2. , 3.5, 1. ],\n", " [6.3, 2.7, 4.9, 1.8],\n", " [4.8, 3.4, 1.9, 0.2],\n", " [5. , 3. , 1.6, 0.2],\n", " [5.1, 3.3, 1.7, 0.5],\n", " [5.6, 2.7, 4.2, 1.3],\n", " [5.1, 3.4, 1.5, 0.2],\n", " [5.7, 3. , 4.2, 1.2],\n", " [7.7, 3.8, 6.7, 2.2],\n", " [4.6, 3.2, 1.4, 0.2],\n", " [6.2, 2.9, 4.3, 1.3],\n", " [5.7, 2.5, 5. , 2. ],\n", " [5.5, 4.2, 1.4, 0.2],\n", " [6. , 3. , 4.8, 1.8],\n", " [5.8, 2.7, 5.1, 1.9],\n", " [6. , 2.2, 4. , 1. ],\n", " [5.4, 3. , 4.5, 1.5],\n", " [6.2, 3.4, 5.4, 2.3],\n", " [5.5, 2.3, 4. , 1.3],\n", " [5.4, 3.9, 1.7, 0.4],\n", " [5. , 2.3, 3.3, 1. ],\n", " [6.4, 2.7, 5.3, 1.9],\n", " [5. , 3.3, 1.4, 0.2],\n", " [5. , 3.2, 1.2, 0.2],\n", " [5.5, 2.4, 3.8, 1.1],\n", " [6.7, 3. , 5. , 1.7],\n", " [4.9, 3.1, 1.5, 0.2],\n", " [5.8, 2.8, 5.1, 2.4],\n", " [5. , 3.4, 1.5, 0.2],\n", " [5. , 3.5, 1.6, 0.6],\n", " [5.9, 3.2, 4.8, 1.8],\n", " [5.1, 2.5, 3. , 1.1],\n", " [6.9, 3.2, 5.7, 2.3],\n", " [6. , 2.7, 5.1, 1.6],\n", " [6.1, 2.6, 5.6, 1.4],\n", " [7.7, 3. , 6.1, 2.3],\n", " [5.5, 2.5, 4. , 1.3],\n", " [4.4, 2.9, 1.4, 0.2],\n", " [4.3, 3. , 1.1, 0.1],\n", " [6. , 2.2, 5. , 1.5],\n", " [7.2, 3.2, 6. , 1.8],\n", " [4.6, 3.1, 1.5, 0.2],\n", " [5.1, 3.5, 1.4, 0.3],\n", " [4.4, 3. , 1.3, 0.2],\n", " [6.3, 2.5, 4.9, 1.5],\n", " [6.3, 3.4, 5.6, 2.4],\n", " [4.6, 3.4, 1.4, 0.3],\n", " [6.8, 3. , 5.5, 2.1],\n", " [6.3, 3.3, 6. , 2.5],\n", " [4.7, 3.2, 1.3, 0.2],\n", " [6.1, 2.9, 4.7, 1.4],\n", " [6.5, 2.8, 4.6, 1.5],\n", " [6.2, 2.8, 4.8, 1.8],\n", " [7. , 3.2, 4.7, 1.4],\n", " [6.4, 3.2, 5.3, 2.3],\n", " [5.1, 3.8, 1.6, 0.2],\n", " [6.9, 3.1, 5.4, 2.1],\n", " [5.9, 3. , 4.2, 1.5],\n", " [6.5, 3. , 5.2, 2. ],\n", " [5.7, 2.6, 3.5, 1. ],\n", " [5.2, 2.7, 3.9, 1.4],\n", " [6.1, 3. , 4.6, 1.4],\n", " [4.5, 2.3, 1.3, 0.3],\n", " [6.6, 2.9, 4.6, 1.3],\n", " [5.5, 2.6, 4.4, 1.2],\n", " [5.3, 3.7, 1.5, 0.2],\n", " [5.6, 3. , 4.1, 1.3],\n", " [7.3, 2.9, 6.3, 1.8],\n", " [6.7, 3.3, 5.7, 2.1],\n", " [5.1, 3.7, 1.5, 0.4],\n", " [4.9, 2.4, 3.3, 1. ],\n", " [6.7, 3.3, 5.7, 2.5],\n", " [7.2, 3. , 5.8, 1.6],\n", " [4.9, 3.6, 1.4, 0.1],\n", " [6.7, 3.1, 5.6, 2.4],\n", " [4.9, 3. , 1.4, 0.2],\n", " [6.9, 3.1, 4.9, 1.5],\n", " [7.4, 2.8, 6.1, 1.9],\n", " [6.3, 2.9, 5.6, 1.8],\n", " [5.7, 2.8, 4.1, 1.3],\n", " [6.5, 3. , 5.5, 1.8],\n", " [6.3, 2.3, 4.4, 1.3],\n", " [6.4, 2.9, 4.3, 1.3],\n", " [5.6, 2.8, 4.9, 2. ],\n", " [5.9, 3. , 5.1, 1.8],\n", " [5.4, 3.4, 1.7, 0.2],\n", " [6.1, 2.8, 4. , 1.3],\n", " [4.9, 2.5, 4.5, 1.7],\n", " [5.8, 4. , 1.2, 0.2],\n", " [5.8, 2.6, 4. , 1.2],\n", " [7.1, 3. , 5.9, 2.1]])" ] }, "execution_count": 122, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train" ] }, { "cell_type": "code", "execution_count": 123, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0.41176471, 0.40909091, 0.55357143, 0.5 ],\n", " [0.97058824, 0.45454545, 0.98214286, 0.83333333],\n", " [0.38235294, 0.45454545, 0.60714286, 0.58333333],\n", " [0.23529412, 0.68181818, 0.05357143, 0.04166667],\n", " [1. , 0.36363636, 1. , 0.79166667],\n", " [0.44117647, 0.31818182, 0.53571429, 0.375 ],\n", " [0.26470588, 0.63636364, 0.05357143, 0.04166667],\n", " [0.20588235, 0.68181818, 0.03571429, 0.08333333],\n", " [0.23529412, 0.81818182, 0.14285714, 0.125 ],\n", " [0.20588235, 0. , 0.42857143, 0.375 ],\n", " [0.58823529, 0.31818182, 0.67857143, 0.70833333],\n", " [0.14705882, 0.63636364, 0.14285714, 0.04166667],\n", " [0.20588235, 0.45454545, 0.08928571, 0.04166667],\n", " [0.23529412, 0.59090909, 0.10714286, 0.16666667],\n", " [0.38235294, 0.31818182, 0.55357143, 0.5 ],\n", " [0.23529412, 0.63636364, 0.07142857, 0.04166667],\n", " [0.41176471, 0.45454545, 0.55357143, 0.45833333],\n", " [1. , 0.81818182, 1. , 0.875 ],\n", " [0.08823529, 0.54545455, 0.05357143, 0.04166667],\n", " [0.55882353, 0.40909091, 0.57142857, 0.5 ],\n", " [0.41176471, 0.22727273, 0.69642857, 0.79166667],\n", " [0.35294118, 1. , 0.05357143, 0.04166667],\n", " [0.5 , 0.45454545, 0.66071429, 0.70833333],\n", " [0.44117647, 0.31818182, 0.71428571, 0.75 ],\n", " [0.5 , 0.09090909, 0.51785714, 0.375 ],\n", " [0.32352941, 0.45454545, 0.60714286, 0.58333333],\n", " [0.55882353, 0.63636364, 0.76785714, 0.91666667],\n", " [0.35294118, 0.13636364, 0.51785714, 0.5 ],\n", " [0.32352941, 0.86363636, 0.10714286, 0.125 ],\n", " [0.20588235, 0.13636364, 0.39285714, 0.375 ],\n", " [0.61764706, 0.31818182, 0.75 , 0.75 ],\n", " [0.20588235, 0.59090909, 0.05357143, 0.04166667],\n", " [0.20588235, 0.54545455, 0.01785714, 0.04166667],\n", " [0.35294118, 0.18181818, 0.48214286, 0.41666667],\n", " [0.70588235, 0.45454545, 0.69642857, 0.66666667],\n", " [0.17647059, 0.5 , 0.07142857, 0.04166667],\n", " [0.44117647, 0.36363636, 0.71428571, 0.95833333],\n", " [0.20588235, 0.63636364, 0.07142857, 0.04166667],\n", " [0.20588235, 0.68181818, 0.08928571, 0.20833333],\n", " [0.47058824, 0.54545455, 0.66071429, 0.70833333],\n", " [0.23529412, 0.22727273, 0.33928571, 0.41666667],\n", " [0.76470588, 0.54545455, 0.82142857, 0.91666667],\n", " [0.5 , 0.31818182, 0.71428571, 0.625 ],\n", " [0.52941176, 0.27272727, 0.80357143, 0.54166667],\n", " [1. , 0.45454545, 0.89285714, 0.91666667],\n", " [0.35294118, 0.22727273, 0.51785714, 0.5 ],\n", " [0.02941176, 0.40909091, 0.05357143, 0.04166667],\n", " [0. , 0.45454545, 0. , 0. ],\n", " [0.5 , 0.09090909, 0.69642857, 0.58333333],\n", " [0.85294118, 0.54545455, 0.875 , 0.70833333],\n", " [0.08823529, 0.5 , 0.07142857, 0.04166667],\n", " [0.23529412, 0.68181818, 0.05357143, 0.08333333],\n", " [0.02941176, 0.45454545, 0.03571429, 0.04166667],\n", " [0.58823529, 0.22727273, 0.67857143, 0.58333333],\n", " [0.58823529, 0.63636364, 0.80357143, 0.95833333],\n", " [0.08823529, 0.63636364, 0.05357143, 0.08333333],\n", " [0.73529412, 0.45454545, 0.78571429, 0.83333333],\n", " [0.58823529, 0.59090909, 0.875 , 1. ],\n", " [0.11764706, 0.54545455, 0.03571429, 0.04166667],\n", " [0.52941176, 0.40909091, 0.64285714, 0.54166667],\n", " [0.64705882, 0.36363636, 0.625 , 0.58333333],\n", " [0.55882353, 0.36363636, 0.66071429, 0.70833333],\n", " [0.79411765, 0.54545455, 0.64285714, 0.54166667],\n", " [0.61764706, 0.54545455, 0.75 , 0.91666667],\n", " [0.23529412, 0.81818182, 0.08928571, 0.04166667],\n", " [0.76470588, 0.5 , 0.76785714, 0.83333333],\n", " [0.47058824, 0.45454545, 0.55357143, 0.58333333],\n", " [0.64705882, 0.45454545, 0.73214286, 0.79166667],\n", " [0.41176471, 0.27272727, 0.42857143, 0.375 ],\n", " [0.26470588, 0.31818182, 0.5 , 0.54166667],\n", " [0.52941176, 0.45454545, 0.625 , 0.54166667],\n", " [0.05882353, 0.13636364, 0.03571429, 0.08333333],\n", " [0.67647059, 0.40909091, 0.625 , 0.5 ],\n", " [0.35294118, 0.27272727, 0.58928571, 0.45833333],\n", " [0.29411765, 0.77272727, 0.07142857, 0.04166667],\n", " [0.38235294, 0.45454545, 0.53571429, 0.5 ],\n", " [0.88235294, 0.40909091, 0.92857143, 0.70833333],\n", " [0.70588235, 0.59090909, 0.82142857, 0.83333333],\n", " [0.23529412, 0.77272727, 0.07142857, 0.125 ],\n", " [0.17647059, 0.18181818, 0.39285714, 0.375 ],\n", " [0.70588235, 0.59090909, 0.82142857, 1. ],\n", " [0.85294118, 0.45454545, 0.83928571, 0.625 ],\n", " [0.17647059, 0.72727273, 0.05357143, 0. ],\n", " [0.70588235, 0.5 , 0.80357143, 0.95833333],\n", " [0.17647059, 0.45454545, 0.05357143, 0.04166667],\n", " [0.76470588, 0.5 , 0.67857143, 0.58333333],\n", " [0.91176471, 0.36363636, 0.89285714, 0.75 ],\n", " [0.58823529, 0.40909091, 0.80357143, 0.70833333],\n", " [0.41176471, 0.36363636, 0.53571429, 0.5 ],\n", " [0.64705882, 0.45454545, 0.78571429, 0.70833333],\n", " [0.58823529, 0.13636364, 0.58928571, 0.5 ],\n", " [0.61764706, 0.40909091, 0.57142857, 0.5 ],\n", " [0.38235294, 0.36363636, 0.67857143, 0.79166667],\n", " [0.47058824, 0.45454545, 0.71428571, 0.70833333],\n", " [0.32352941, 0.63636364, 0.10714286, 0.04166667],\n", " [0.52941176, 0.36363636, 0.51785714, 0.5 ],\n", " [0.17647059, 0.22727273, 0.60714286, 0.66666667],\n", " [0.44117647, 0.90909091, 0.01785714, 0.04166667],\n", " [0.44117647, 0.27272727, 0.51785714, 0.45833333],\n", " [0.82352941, 0.45454545, 0.85714286, 0.83333333]])" ] }, "execution_count": 123, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scaled_X_train" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Building the Network with Keras\n", "\n", "Let's build a simple neural network!" ] }, { "cell_type": "code", "execution_count": 132, "metadata": {}, "outputs": [], "source": [ "from keras.models import Sequential\n", "from keras.layers import Dense" ] }, { "cell_type": "code", "execution_count": 133, "metadata": {}, "outputs": [], "source": [ "model = Sequential()\n", "model.add(Dense(8, input_dim=4, activation='relu'))\n", "model.add(Dense(8, input_dim=4, activation='relu'))\n", "model.add(Dense(3, activation='softmax'))\n", "model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": 134, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "dense_27 (Dense) (None, 8) 40 \n", "_________________________________________________________________\n", "dense_28 (Dense) (None, 8) 72 \n", "_________________________________________________________________\n", "dense_29 (Dense) (None, 3) 27 \n", "=================================================================\n", "Total params: 139\n", "Trainable params: 139\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "model.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fit (Train) the Model" ] }, { "cell_type": "code", "execution_count": 135, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/150\n", " - 0s - loss: 1.0926 - acc: 0.3400\n", "Epoch 2/150\n", " - 0s - loss: 1.0871 - acc: 0.3400\n", "Epoch 3/150\n", " - 0s - loss: 1.0814 - acc: 0.3400\n", "Epoch 4/150\n", " - 0s - loss: 1.0760 - acc: 0.3400\n", "Epoch 5/150\n", " - 0s - loss: 1.0700 - acc: 0.3400\n", "Epoch 6/150\n", " - 0s - loss: 1.0640 - acc: 0.3500\n", "Epoch 7/150\n", " - 0s - loss: 1.0581 - acc: 0.3700\n", "Epoch 8/150\n", " - 0s - loss: 1.0520 - acc: 0.4200\n", "Epoch 9/150\n", " - 0s - loss: 1.0467 - acc: 0.5100\n", "Epoch 10/150\n", " - 0s - loss: 1.0416 - acc: 0.5700\n", "Epoch 11/150\n", " - 0s - loss: 1.0361 - acc: 0.6300\n", "Epoch 12/150\n", " - 0s - loss: 1.0308 - acc: 0.6300\n", "Epoch 13/150\n", " - 0s - loss: 1.0249 - acc: 0.6300\n", "Epoch 14/150\n", " - 0s - loss: 1.0187 - acc: 0.6200\n", "Epoch 15/150\n", " - 0s - loss: 1.0124 - acc: 0.6300\n", "Epoch 16/150\n", " - 0s - loss: 1.0062 - acc: 0.6300\n", "Epoch 17/150\n", " - 0s - loss: 0.9994 - acc: 0.6400\n", "Epoch 18/150\n", " - 0s - loss: 0.9919 - acc: 0.6400\n", "Epoch 19/150\n", " - 0s - loss: 0.9836 - acc: 0.6400\n", "Epoch 20/150\n", " - 0s - loss: 0.9748 - acc: 0.6400\n", "Epoch 21/150\n", " - 0s - loss: 0.9649 - acc: 0.6400\n", "Epoch 22/150\n", " - 0s - loss: 0.9552 - acc: 0.6400\n", "Epoch 23/150\n", " - 0s - loss: 0.9448 - acc: 0.6500\n", "Epoch 24/150\n", " - 0s - loss: 0.9337 - acc: 0.6500\n", "Epoch 25/150\n", " - 0s - loss: 0.9225 - acc: 0.6500\n", "Epoch 26/150\n", " - 0s - loss: 0.9111 - acc: 0.6500\n", "Epoch 27/150\n", " - 0s - loss: 0.8992 - acc: 0.6500\n", "Epoch 28/150\n", " - 0s - loss: 0.8882 - acc: 0.6500\n", "Epoch 29/150\n", " - 0s - loss: 0.8766 - acc: 0.6500\n", "Epoch 30/150\n", " - 0s - loss: 0.8658 - acc: 0.6500\n", "Epoch 31/150\n", " - 0s - loss: 0.8555 - acc: 0.6500\n", "Epoch 32/150\n", " - 0s - loss: 0.8446 - acc: 0.6500\n", "Epoch 33/150\n", " - 0s - loss: 0.8337 - acc: 0.6500\n", "Epoch 34/150\n", " - 0s - loss: 0.8225 - acc: 0.6500\n", "Epoch 35/150\n", " - 0s - loss: 0.8112 - acc: 0.6500\n", "Epoch 36/150\n", " - 0s - loss: 0.7998 - acc: 0.6500\n", "Epoch 37/150\n", " - 0s - loss: 0.7886 - acc: 0.6500\n", "Epoch 38/150\n", " - 0s - loss: 0.7768 - acc: 0.6500\n", "Epoch 39/150\n", " - 0s - loss: 0.7650 - acc: 0.6500\n", "Epoch 40/150\n", " - 0s - loss: 0.7539 - acc: 0.6500\n", "Epoch 41/150\n", " - 0s - loss: 0.7428 - acc: 0.6500\n", "Epoch 42/150\n", " - 0s - loss: 0.7318 - acc: 0.6500\n", "Epoch 43/150\n", " - 0s - loss: 0.7213 - acc: 0.6500\n", "Epoch 44/150\n", " - 0s - loss: 0.7111 - acc: 0.6500\n", "Epoch 45/150\n", " - 0s - loss: 0.7007 - acc: 0.6500\n", "Epoch 46/150\n", " - 0s - loss: 0.6907 - acc: 0.6500\n", "Epoch 47/150\n", " - 0s - loss: 0.6806 - acc: 0.6500\n", "Epoch 48/150\n", " - 0s - loss: 0.6714 - acc: 0.6500\n", "Epoch 49/150\n", " - 0s - loss: 0.6621 - acc: 0.6500\n", "Epoch 50/150\n", " - 0s - loss: 0.6534 - acc: 0.6500\n", "Epoch 51/150\n", " - 0s - loss: 0.6452 - acc: 0.6500\n", "Epoch 52/150\n", " - 0s - loss: 0.6366 - acc: 0.6500\n", "Epoch 53/150\n", " - 0s - loss: 0.6285 - acc: 0.6500\n", "Epoch 54/150\n", " - 0s - loss: 0.6205 - acc: 0.6500\n", "Epoch 55/150\n", " - 0s - loss: 0.6130 - acc: 0.6500\n", "Epoch 56/150\n", " - 0s - loss: 0.6058 - acc: 0.6500\n", "Epoch 57/150\n", " - 0s - loss: 0.5990 - acc: 0.6500\n", "Epoch 58/150\n", " - 0s - loss: 0.5920 - acc: 0.6500\n", "Epoch 59/150\n", " - 0s - loss: 0.5852 - acc: 0.6700\n", "Epoch 60/150\n", " - 0s - loss: 0.5790 - acc: 0.6700\n", "Epoch 61/150\n", " - 0s - loss: 0.5727 - acc: 0.6900\n", "Epoch 62/150\n", " - 0s - loss: 0.5665 - acc: 0.6900\n", "Epoch 63/150\n", " - 0s - loss: 0.5605 - acc: 0.6900\n", "Epoch 64/150\n", " - 0s - loss: 0.5544 - acc: 0.6900\n", "Epoch 65/150\n", " - 0s - loss: 0.5490 - acc: 0.6900\n", "Epoch 66/150\n", " - 0s - loss: 0.5436 - acc: 0.7000\n", "Epoch 67/150\n", " - 0s - loss: 0.5385 - acc: 0.7000\n", "Epoch 68/150\n", " - 0s - loss: 0.5334 - acc: 0.7000\n", "Epoch 69/150\n", " - 0s - loss: 0.5287 - acc: 0.7100\n", "Epoch 70/150\n", " - 0s - loss: 0.5243 - acc: 0.7200\n", "Epoch 71/150\n", " - 0s - loss: 0.5198 - acc: 0.7300\n", "Epoch 72/150\n", " - 0s - loss: 0.5150 - acc: 0.7300\n", "Epoch 73/150\n", " - 0s - loss: 0.5108 - acc: 0.7400\n", "Epoch 74/150\n", " - 0s - loss: 0.5066 - acc: 0.7400\n", "Epoch 75/150\n", " - 0s - loss: 0.5022 - acc: 0.7600\n", "Epoch 76/150\n", " - 0s - loss: 0.4986 - acc: 0.7500\n", "Epoch 77/150\n", " - 0s - loss: 0.4945 - acc: 0.7400\n", "Epoch 78/150\n", " - 0s - loss: 0.4908 - acc: 0.7500\n", "Epoch 79/150\n", " - 0s - loss: 0.4867 - acc: 0.7700\n", "Epoch 80/150\n", " - 0s - loss: 0.4830 - acc: 0.7700\n", "Epoch 81/150\n", " - 0s - loss: 0.4794 - acc: 0.7900\n", "Epoch 82/150\n", " - 0s - loss: 0.4761 - acc: 0.8500\n", "Epoch 83/150\n", " - 0s - loss: 0.4726 - acc: 0.8500\n", "Epoch 84/150\n", " - 0s - loss: 0.4693 - acc: 0.8500\n", "Epoch 85/150\n", " - 0s - loss: 0.4660 - acc: 0.8500\n", "Epoch 86/150\n", " - 0s - loss: 0.4628 - acc: 0.8500\n", "Epoch 87/150\n", " - 0s - loss: 0.4599 - acc: 0.8200\n", "Epoch 88/150\n", " - 0s - loss: 0.4575 - acc: 0.8000\n", "Epoch 89/150\n", " - 0s - loss: 0.4546 - acc: 0.7800\n", "Epoch 90/150\n", " - 0s - loss: 0.4521 - acc: 0.7800\n", "Epoch 91/150\n", " - 0s - loss: 0.4490 - acc: 0.8100\n", "Epoch 92/150\n", " - 0s - loss: 0.4460 - acc: 0.8300\n", "Epoch 93/150\n", " - 0s - loss: 0.4433 - acc: 0.8500\n", "Epoch 94/150\n", " - 0s - loss: 0.4407 - acc: 0.8500\n", "Epoch 95/150\n", " - 0s - loss: 0.4380 - acc: 0.8500\n", "Epoch 96/150\n", " - 0s - loss: 0.4352 - acc: 0.8600\n", "Epoch 97/150\n", " - 0s - loss: 0.4326 - acc: 0.8600\n", "Epoch 98/150\n", " - 0s - loss: 0.4302 - acc: 0.9000\n", "Epoch 99/150\n", " - 0s - loss: 0.4270 - acc: 0.9100\n", "Epoch 100/150\n", " - 0s - loss: 0.4245 - acc: 0.9200\n", "Epoch 101/150\n", " - 0s - loss: 0.4223 - acc: 0.9200\n", "Epoch 102/150\n", " - 0s - loss: 0.4195 - acc: 0.9200\n", "Epoch 103/150\n", " - 0s - loss: 0.4171 - acc: 0.9200\n", "Epoch 104/150\n", " - 0s - loss: 0.4147 - acc: 0.9200\n", "Epoch 105/150\n", " - 0s - loss: 0.4124 - acc: 0.9000\n", "Epoch 106/150\n", " - 0s - loss: 0.4102 - acc: 0.9000\n", "Epoch 107/150\n", " - 0s - loss: 0.4077 - acc: 0.9200\n", "Epoch 108/150\n", " - 0s - loss: 0.4055 - acc: 0.9200\n", "Epoch 109/150\n", " - 0s - loss: 0.4028 - acc: 0.9200\n", "Epoch 110/150\n", " - 0s - loss: 0.4006 - acc: 0.9200\n", "Epoch 111/150\n", " - 0s - loss: 0.3985 - acc: 0.9200\n", "Epoch 112/150\n", " - 0s - loss: 0.3965 - acc: 0.9200\n", "Epoch 113/150\n", " - 0s - loss: 0.3945 - acc: 0.9200\n", "Epoch 114/150\n", " - 0s - loss: 0.3930 - acc: 0.9200\n", "Epoch 115/150\n", " - 0s - loss: 0.3914 - acc: 0.9000\n", "Epoch 116/150\n", " - 0s - loss: 0.3891 - acc: 0.9100\n", "Epoch 117/150\n", " - 0s - loss: 0.3860 - acc: 0.9200\n", "Epoch 118/150\n", " - 0s - loss: 0.3844 - acc: 0.9200\n", "Epoch 119/150\n", " - 0s - loss: 0.3821 - acc: 0.9500\n", "Epoch 120/150\n", " - 0s - loss: 0.3801 - acc: 0.9500\n", "Epoch 121/150\n", " - 0s - loss: 0.3780 - acc: 0.9500\n", "Epoch 122/150\n", " - 0s - loss: 0.3757 - acc: 0.9500\n", "Epoch 123/150\n", " - 0s - loss: 0.3738 - acc: 0.9500\n", "Epoch 124/150\n", " - 0s - loss: 0.3720 - acc: 0.9500\n", "Epoch 125/150\n", " - 0s - loss: 0.3693 - acc: 0.9500\n", "Epoch 126/150\n", " - 0s - loss: 0.3662 - acc: 0.9500\n", "Epoch 127/150\n", " - 0s - loss: 0.3646 - acc: 0.9300\n", "Epoch 128/150\n", " - 0s - loss: 0.3652 - acc: 0.9200\n", "Epoch 129/150\n", " - 0s - loss: 0.3640 - acc: 0.9100\n", "Epoch 130/150\n", " - 0s - loss: 0.3627 - acc: 0.9100\n", "Epoch 131/150\n", " - 0s - loss: 0.3604 - acc: 0.9100\n", "Epoch 132/150\n", " - 0s - loss: 0.3572 - acc: 0.9300\n", "Epoch 133/150\n", " - 0s - loss: 0.3544 - acc: 0.9300\n", "Epoch 134/150\n", " - 0s - loss: 0.3521 - acc: 0.9300\n", "Epoch 135/150\n", " - 0s - loss: 0.3501 - acc: 0.9500\n", "Epoch 136/150\n", " - 0s - loss: 0.3482 - acc: 0.9500\n", "Epoch 137/150\n", " - 0s - loss: 0.3465 - acc: 0.9500\n", "Epoch 138/150\n", " - 0s - loss: 0.3461 - acc: 0.9500\n", "Epoch 139/150\n", " - 0s - loss: 0.3435 - acc: 0.9300\n", "Epoch 140/150\n", " - 0s - loss: 0.3407 - acc: 0.9500\n", "Epoch 141/150\n", " - 0s - loss: 0.3384 - acc: 0.9500\n", "Epoch 142/150\n", " - 0s - loss: 0.3366 - acc: 0.9500\n", "Epoch 143/150\n", " - 0s - loss: 0.3347 - acc: 0.9500\n", "Epoch 144/150\n", " - 0s - loss: 0.3336 - acc: 0.9400\n", "Epoch 145/150\n", " - 0s - loss: 0.3324 - acc: 0.9300\n", "Epoch 146/150\n", " - 0s - loss: 0.3311 - acc: 0.9300\n", "Epoch 147/150\n", " - 0s - loss: 0.3300 - acc: 0.9300\n", "Epoch 148/150\n", " - 0s - loss: 0.3283 - acc: 0.9300\n", "Epoch 149/150\n", " - 0s - loss: 0.3256 - acc: 0.9400\n", "Epoch 150/150\n", " - 0s - loss: 0.3224 - acc: 0.9500\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 135, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Play around with number of epochs as well!\n", "model.fit(scaled_X_train,y_train,epochs=150, verbose=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Predicting New Unseen Data\n", "\n", "Let's see how we did by predicting on **new data**. Remember, our model has **never** seen the test data that we scaled previously! This process is the exact same process you would use on totally brand new data. For example , a brand new bank note that you just analyzed ." ] }, { "cell_type": "code", "execution_count": 136, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 0.52941176, 0.36363636, 0.64285714, 0.45833333],\n", " [ 0.41176471, 0.81818182, 0.10714286, 0.08333333],\n", " [ 1. , 0.27272727, 1.03571429, 0.91666667],\n", " [ 0.5 , 0.40909091, 0.60714286, 0.58333333],\n", " [ 0.73529412, 0.36363636, 0.66071429, 0.54166667],\n", " [ 0.32352941, 0.63636364, 0.07142857, 0.125 ],\n", " [ 0.38235294, 0.40909091, 0.44642857, 0.5 ],\n", " [ 0.76470588, 0.5 , 0.71428571, 0.91666667],\n", " [ 0.55882353, 0.09090909, 0.60714286, 0.58333333],\n", " [ 0.44117647, 0.31818182, 0.5 , 0.45833333],\n", " [ 0.64705882, 0.54545455, 0.71428571, 0.79166667],\n", " [ 0.14705882, 0.45454545, 0.05357143, 0. ],\n", " [ 0.35294118, 0.68181818, 0.03571429, 0.04166667],\n", " [ 0.17647059, 0.5 , 0.07142857, 0. ],\n", " [ 0.23529412, 0.81818182, 0.07142857, 0.08333333],\n", " [ 0.58823529, 0.59090909, 0.64285714, 0.625 ],\n", " [ 0.64705882, 0.45454545, 0.83928571, 0.875 ],\n", " [ 0.38235294, 0.22727273, 0.5 , 0.41666667],\n", " [ 0.41176471, 0.36363636, 0.60714286, 0.5 ],\n", " [ 0.61764706, 0.36363636, 0.80357143, 0.875 ],\n", " [ 0.11764706, 0.54545455, 0.08928571, 0.04166667],\n", " [ 0.52941176, 0.45454545, 0.67857143, 0.70833333],\n", " [ 0.20588235, 0.63636364, 0.08928571, 0.125 ],\n", " [ 0.61764706, 0.36363636, 0.80357143, 0.83333333],\n", " [ 1.05882353, 0.81818182, 0.94642857, 0.79166667],\n", " [ 0.70588235, 0.45454545, 0.73214286, 0.91666667],\n", " [ 0.70588235, 0.22727273, 0.83928571, 0.70833333],\n", " [ 0.73529412, 0.54545455, 0.85714286, 0.91666667],\n", " [ 0.14705882, 0.45454545, 0.05357143, 0.08333333],\n", " [ 0.14705882, 0.5 , 0.08928571, 0.04166667],\n", " [ 0.08823529, 0.72727273, -0.01785714, 0.04166667],\n", " [ 0.41176471, 1.09090909, 0.07142857, 0.125 ],\n", " [ 0.70588235, 0.5 , 0.58928571, 0.54166667],\n", " [ 0.14705882, 0.63636364, 0.08928571, 0.04166667],\n", " [ 0.02941176, 0.54545455, 0.03571429, 0.04166667],\n", " [ 0.58823529, 0.22727273, 0.69642857, 0.75 ],\n", " [ 0.61764706, 0.54545455, 0.60714286, 0.58333333],\n", " [ 0.26470588, 0.68181818, 0.07142857, 0.04166667],\n", " [ 0.20588235, 0.72727273, 0.05357143, 0.04166667],\n", " [ 0.26470588, 0.95454545, 0.07142857, 0. ],\n", " [ 0.44117647, 0.31818182, 0.71428571, 0.75 ],\n", " [ 0.5 , 0.63636364, 0.60714286, 0.625 ],\n", " [ 0.70588235, 0.5 , 0.64285714, 0.58333333],\n", " [ 0.32352941, 0.86363636, 0.03571429, 0.125 ],\n", " [ 0.32352941, 0.77272727, 0.07142857, 0.04166667],\n", " [ 0.35294118, 0.18181818, 0.46428571, 0.375 ],\n", " [ 0.58823529, 0.36363636, 0.71428571, 0.58333333],\n", " [ 0.61764706, 0.5 , 0.78571429, 0.70833333],\n", " [ 0.67647059, 0.45454545, 0.58928571, 0.54166667],\n", " [ 0.85294118, 0.72727273, 0.89285714, 1. ]])" ] }, "execution_count": 136, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scaled_X_test" ] }, { "cell_type": "code", "execution_count": 137, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Spits out probabilities by default.\n", "# model.predict(scaled_X_test)" ] }, { "cell_type": "code", "execution_count": 138, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 0, 2, 1, 2, 0, 1, 2, 2, 1, 2, 0, 0, 0, 0, 1, 2, 1, 1, 2, 0, 2,\n", " 0, 2, 2, 2, 2, 2, 0, 0, 0, 0, 1, 0, 0, 2, 1, 0, 0, 0, 2, 1, 1, 0,\n", " 0, 1, 2, 2, 1, 2], dtype=int64)" ] }, "execution_count": 138, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict_classes(scaled_X_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Evaluating Model Performance\n", "\n", "So how well did we do? How do we actually measure \"well\". Is 95% accuracy good enough? It all depends on the situation. Also we need to take into account things like recall and precision. Make sure to watch the video discussion on classification evaluation before running this code!" ] }, { "cell_type": "code", "execution_count": 139, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['loss', 'acc']" ] }, "execution_count": 139, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.metrics_names" ] }, { "cell_type": "code", "execution_count": 140, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "50/50 [==============================] - 0s 2ms/step\n" ] }, { "data": { "text/plain": [ "[0.2843634402751923, 0.96]" ] }, "execution_count": 140, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.evaluate(x=scaled_X_test,y=y_test)" ] }, { "cell_type": "code", "execution_count": 141, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.metrics import confusion_matrix,classification_report" ] }, { "cell_type": "code", "execution_count": 142, "metadata": { "collapsed": true }, "outputs": [], "source": [ "predictions = model.predict_classes(scaled_X_test)" ] }, { "cell_type": "code", "execution_count": 161, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 0, 2, 1, 2, 0, 1, 2, 2, 1, 2, 0, 0, 0, 0, 1, 2, 1, 1, 2, 0, 2,\n", " 0, 2, 2, 2, 2, 2, 0, 0, 0, 0, 1, 0, 0, 2, 1, 0, 0, 0, 2, 1, 1, 0,\n", " 0, 1, 2, 2, 1, 2], dtype=int64)" ] }, "execution_count": 161, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions" ] }, { "cell_type": "code", "execution_count": 163, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 0, 2, 1, 1, 0, 1, 2, 1, 1, 2, 0, 0, 0, 0, 1, 2, 1, 1, 2, 0, 2,\n", " 0, 2, 2, 2, 2, 2, 0, 0, 0, 0, 1, 0, 0, 2, 1, 0, 0, 0, 2, 1, 1, 0,\n", " 0, 1, 2, 2, 1, 2], dtype=int64)" ] }, "execution_count": 163, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_test.argmax(axis=1)" ] }, { "cell_type": "code", "execution_count": 164, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[19, 0, 0],\n", " [ 0, 13, 2],\n", " [ 0, 0, 16]], dtype=int64)" ] }, "execution_count": 164, "metadata": {}, "output_type": "execute_result" } ], "source": [ "confusion_matrix(y_test.argmax(axis=1),predictions)" ] }, { "cell_type": "code", "execution_count": 165, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 1.00 1.00 1.00 19\n", " 1 1.00 0.87 0.93 15\n", " 2 0.89 1.00 0.94 16\n", "\n", " micro avg 0.96 0.96 0.96 50\n", " macro avg 0.96 0.96 0.96 50\n", "weighted avg 0.96 0.96 0.96 50\n", "\n" ] } ], "source": [ "print(classification_report(y_test.argmax(axis=1),predictions))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Saving and Loading Models\n", "\n", "Now that we have a model trained, let's see how we can save and load it." ] }, { "cell_type": "code", "execution_count": 166, "metadata": { "collapsed": true }, "outputs": [], "source": [ "model.save('myfirstmodel.h5')" ] }, { "cell_type": "code", "execution_count": 167, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from keras.models import load_model" ] }, { "cell_type": "code", "execution_count": 168, "metadata": { "collapsed": true }, "outputs": [], "source": [ "newmodel = load_model('myfirstmodel.h5')" ] }, { "cell_type": "code", "execution_count": 169, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([2, 1, 2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2,\n", " 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 1, 1, 1, 2, 2, 2, 1,\n", " 1, 2, 2, 2, 2, 2], dtype=int64)" ] }, "execution_count": 169, "metadata": {}, "output_type": "execute_result" } ], "source": [ "newmodel.predict_classes(X_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Great job! You now know how to preprocess data, train a neural network, and evaluate its classification performance!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 2 }