{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Keras - 1D Linear Fitting" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are now going to explore how to implement a very simple 1D linear regression with Keras. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start defining the parameters of an ideal linear function which we are going to predict through a linear fitting." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# target parameters of f(x) = m*x + b\n", "m = 3 # slope\n", "b = 0 # intersect" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's generate a set of input data which will slightly deviate from our ideal behaviour using a random noise:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# here we are going to use Keras and numpy\n", "import numpy as np\n", "\n", "# generate training inputs\n", "x = np.arange(-1, 1, 0.1)\n", "y_target = m * x + b # ideal (target) linear function\n", "\n", "noise_amp = 1.0 # noise amplitude\n", "y_train = m * x + b + noise_amp * (np.random.rand(len(x))-0.5) # actual measures from which we want to guess regression parameters" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAAD8CAYAAACW/ATfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAIABJREFUeJzt3Xl4VOX9/vH3EyBACKtAUJaEXZYoGkCxFY0rdUNRWxWtW41W+7Mu/bZA1Lo0Yt1arbWWuragqVoQQQRRg+ICSqyaQFgChH0PWwhZ5/n9MYOONCGTmTOZMzP367rmSubMmWfu68zw4eSZcz7HWGsREZHYlhDpACIiEn4q9iIicUDFXkQkDqjYi4jEARV7EZE4oGIvIhIHVOxFROKAir2ISBxQsRcRiQPNI/GinTt3tmlpaSGPc+DAAdq0aRN6oDBwczZwdz5lC46bs4G780VLtvz8/J3W2i5BDWStbfJbRkaGdUJeXp4j44SDm7NZ6+58yhYcN2ez1t35oiUbsMQGWXc1jSMiEgdU7EVE4oCKvYhIHFCxFxGJAyEXe2NMK2PMF8aYb4wxS40xDzgRTEREnOPEoZeVwBnW2jJjTAvgE2PMu9baRQ6MLSIiDgi52PsOByrz3W3hu+nyVyIiLmKsA5clNMY0A/KBfsBfrbW/q2OdLCALICUlJSM3Nzfk1y0rKyM5OTnkccLBzdnA3fmULThuzgbuzhct2TIzM/OttcODGijYA/TrugEdgDxg6JHW00lVkefmfMoWHDdns9bd+cKWbepUa1NTrTXG+3Pq1EYP4cqTqqy1e3zFfoyT44qIRJ1p0yArC9atA2u9P7OyvMsjwImjcboYYzr4fm8NnA0sD3VcEZGolp0N5eU/XFZe7l0eAU4cjXM08Ipv3j4BeN1aO9uBcUVEotf69Y1bHmZOHI3zLXCCA1lERGJHr17eqZu6lkeAzqAVEQmHnBxISvrhsqQk7/IIULEXEQmH8eNhyhRITQVjvD+nTPEuj4CIXLxERCQujB8fseJ+OO3Zi4iEyYHKGh6ft4LdB6oiHUV79iIi4ZC3fDv3vFXIpj0H6dOlDeNO7BHRPCr2IiIO2r6/ggdnLWP2t1vo1zWZN24ZxYi0TpGOpWIvIuIEj8fy+pINPDyniIpqD3edPYCbT+tDy+bNIh0NULEXEQlZ8fYyJk0v4IuSUk7q3YmHx6XTt4u7Gqup2IuIBKmyppa/LVjNs3mraZ3YjEcvPY7Lh/fAGBPpaP9DxV5EJAhfrC1l4vRvWb3jABcdfwz3XjCYLm1bRjpWvVTsRUQaYW95NY/MLeK1LzbQo2NrXr5+BKcP7BrpWA1SsRcRCYC1lncKtnD/28vYXV5F1ug+3HFWf5ISo6OMRkdKEZEI2ri7nPtmLuXD5dtJ796el68fwdDu7SMdq1FU7EVE6lHrsbz8WQlPvLcCgHsvGMy1o1Jp3iz6mg+o2IuI1KFw014mTi+gYNNezji2Kw+OHUKPjkkNP9GlVOxFRPyUV9Xw5/dX8cIna+mYlMgzV53A+elHu/JwysZQsRcR8VmwwtvPZuPug1w5sicTxgyifVKLSMdyhIq9iMS9vZWW21/7L29/s5m+Xdrw+s2jGNk78v1snKRiLyJxy1pvP5sHPymn2lPBHWf155en93VNPxsnqdiLSFxavcPbz2bx2lIGdEzg2etPpV9Xd/WzcZKKvYjElaoaD899tJpnPiymVYsEHhmXTtcDq2O60IOuVCUibjZtGqSlQUKC9+e0aSENt6SklPOfXsiT81dyzpAU3r/7NK4Y2YuEKD/SJhDasxcRd5o2DbKyoLzce3/dOu99aPR1XfcerOaPc5fz6uL1dO/QmpeuG0Hmse7vZ+Mk7dmLiDtlZ39f6A8pL/cuD5C1ljkFWzjryY/I/WI9v/hxb967c3Tghd7hvywiKeQ9e2NMT+CfQApggSnW2qdCHVdE4tz69Y1bfpjNew5y38xC3i/azpBj2vHitSNI79GIfjYO/mXhBk5M49QAd1trvzLGtAXyjTHzrbXLHBhbROJVr17eAlvX8iOo9Vhe8fWz8VjIPm8Q1/8orfH9bI70l0UUFvuQp3GstVustV/5ft8PFAHdQx1XRGLEoamQ/PzGTYXk5EDSYb1okpK8y+uxdPNexj37KQ/OXsbwtE68d+dobhrdJ7jGZSH+ZeE2xlrr3GDGpAEfA0OttfsOeywLyAJISUnJyM3NDfn1ysrKSE525+FSbs4G7s6nbMFxZbbSUu/eucdDWY8eJG/c6J3/Tk2FTgGcoVpaCps2QVUVJCZC9+51Pq+y1jKzuJq5JdUkt4CrBrXkpG7NAu5nU+e2Kyjwvu7hEhMhPT2gcZ3gny0zMzPfWjs8qIGstY7cgGQgHxjX0LoZGRnWCXl5eY6MEw5uzmatu/MpW3BcmS011VqwFmze449/97tNTXXsJT5asd3++I8f2NTfzba/e/Mbu/tAZaPHqHPbTZ1qbVLS95nBe3/q1NBDB5kNWGKDrNGOHHppjGkB/AeYZq2d7sSYIhIDwjgVsrOskj/MXsZbX2+mT5c25GadzMl9jgp53O8cmpfPzvbm7dXLO4UUhfP14MzROAZ4ASiy1j4ZeiQRiRlBfsl6JNZa3szfSM6cIg5U1nD7mf259fS+tGoRhn4248dHbXE/nBN79j8CrgEKjDFf+5ZNstbOcWBsEYlmOTk/PHwRGvyS9UjW7jzApOkFfL5mFyPSOvLwJen0T2nrUNjYFnKxt9Z+AsT+ucYi0nj+UyHg/WI2iKmQqhoPUz5ezdMfFtOyeQIPX5LOFSN6kpCg0hMotUsQkfA6NBWyYAGUlDT66fnrdjNpegErtu3n/PSj+f2Fg+narpXjMWOdir2IuNK+imoem7uCqYvXcXS7Vrxw7XDOHJQS6VhRS71xRMR15hZu5ewnP2La4nVcf0pv5t91WnCFPoZ624RKe/Yi4hpb9h7kvplLmb9sG4OObseUa4ZzfM8OwQ0WY71tQqViLyIRV+ux/OvzEh6bt4Jaa5n4k2O54ce9aRFMm4NDYqy3TahU7EUkooq27GPC9AK+2bCH0QO6kHPxUHp2Smr4iQ2Jsd42oVKxF5GIqKiu5akPVvGPj9fQvnULnrpiGBcdf0zA/WwaFIYTuqKZir2INLlPVu0k+60C1u0q56fDezDpvEF0SEp09kUcPqEr2qnYi0iT2VVWSc47RUz/7yZ6d27DqzedxCl9O4fnxWKst02oVOxFJOystXyyqZo7P/6Issoabj+jH7dm9gtPPxt/MdTbJlQq9iISViU7D5D9VgGfFleRkdqRyePSGaB+Nk1OxV5EwqK61sOUj9fw9AerSGyWwM8HJ3L/1aPUzyZCVOxFxHFfrd/NxP94+9n8ZGg37r9oCEVfLVKhjyAVexFxzP6Kah6bt4J/LVpHt3at+MfPh3P2YG+bg6IIZ4t3KvYi4oh5S7fy+5lL2ba/gmtHpfGbcweS3FIlxi30TohISLbureD3bxcyb6m3n81z12QwLNh+NhI2KvYiEpRaj2Xa4nU8OncFNR4PE35yLDeG2s9GwkbFXkQabfnWfUycXsB/1+/h1P6dybk4nV5HOdDPRsJGxV5EAlZRXctfPlzF3z9aQ7vWLfjzz4YxdpiD/WwkbFTsRSQgnxbvJHtGASW7yrksowfZ5w2iYxuH+9lI2KjYi8gRlR6oIuedIv7z1UbSjkri1V+cxCn9wtTPRsJG36SIxLogL81nrWX6Vxs584kFzPx6E7dl9mXuHaNV6KOU9uxFYlmQl+Zbt+sA2TMK+aR4Jyf26sDkcccxsJv62UQzFXuRWNbIS/NV13p4fuFa/vz+ShKbJfDQxUMZP7KX2hzEAEeKvTHmReACYLu1dqgTY4qIAxpxab7/rt/NxOkFLN+6nzFDvP1surVvFeaA0lSc2rN/GXgG+KdD44mIEwK4NF9ZZQ2Pz1vBK5+XkNK2FVOuyeCcId2aLqM0CUeKvbX2Y2NMmhNjiYiDGrg03/xl27hvZiFb93n72dx9zgDatmoRobASTpqzF4ll9Vyab9uFl/L7f+Uzd+lWju3WlmfHn8gJvTpGNquElbHWOjOQd89+dn1z9saYLCALICUlJSM3Nzfk1ywrKyM5OTnkccLBzdnA3fmULTiBZPNYy4INNbyxsooaD4zt14IxaS1o3gRfwEb7tosU/2yZmZn51trhQQ1krXXkBqQBhYGsm5GRYZ2Ql5fnyDjh4OZs1ro7n7IFp6FsK7bus+Oe/dSm/m62veofn9u1O8qaJphPNG+7SPLPBiyxQdZoTeOIxLiK6lqe+bCYv3+8muSWzXni8uMZd2J39bOJM46cQWuMeQ34HBhojNlojLnRiXFFJDSfrd7JT55ayDN5xVx43DF8cPfpXJrRo3GFPsgzcMVdnDoa50onxhERZ+w+UEXOnCLezN9Ir05JTL3xJH7cP4g2B0GegSvuo2kckRhirWXm15t5cPYy9h6s5pen9+X2M/rTOrFZcAM28gxccS8Ve5EYsb3cw89f/IKFq3ZyfM8OTBuXzqCj24U2aCPOwBV3U7EXiXLVtR5e+GQtT35ykBbNq3jgoiFcfXIqzZw4nDKAM3AlOqjYi0SxbzbsYcL0Aoq27OOErs149sbTOLp9a+deoIEzcCV6qNiLRKGyyhqeeG8Fr3xWQpe2LXnu6gxa7VzubKGHes/A1Xx99FGxF4ky7/v62WzZV8HVJ6Xyf2MG0q5VCxYsWB6eFxw/XsU9BqjYi0SJ7fsquH/WUuYUbGVASjJvXnUKGanqZyOBUbEXcTmPx/Lal+t55N3lVNZ4+L9zB3LTqX1IbK6rikrgVOxFXGzVtv1MnF7AknW7GdXnKHIuGUqfLu5s2CXupl0DEReqqK7lyfdWcN7TCyneUcZjPcp59eEr6JPSTi0LJCjasxdxmUVrdjFpegFrdh7gkhO6c0/FMo669Sa1LJCQqNiLuMSe8iomz1nOv5dsoGen1vzzhpGMHtAF0i5WywIJmYq9SIRZa5n17RYenLWU3eXV3HxaH+44c8D3/WzUskAcoGIvEkEbSsu5561CPlq5g+N7tOeVG0Yy5Jj2P1xJLQvEASr2IhFQU+vhpU9LeHL+ShIM/P7Cwfx8VFrd/WzUskAcoGIv0sQKNu5lwvRvWbp5H2cN6sqDY4dyTIcjtDlQywJxgIq9SBM5UFnDk/NX8tKna+mc3JK/jT+RMUO7BXbVKLUskBCp2Is0gQ+Xb+Pet5ayac9Bxp/Ui9+OOZb2rVtEOpbEERV7kTDavr+CB2Yt451vt9C/azJv3jKK4WmdIh1L4pCKvUgYeDyWfy/ZwOQ5RVRUe7j77AHcfFpf9bORiFGxF3FY8fb9TJpeyBclpZzUuxMPj0unr/rZSIRpN0PEIZU1tfxp/krOe+oTVmzbz6OXHkdu1sneQj9tmrenTUKCettIRGjPXsQBi9fsYtKMAlbvOMDYYcdw7wWD6Zzc0vvgtGk/PE5evW0kAlTsRUKwt7yaR+YW8doXG+jRsTUvXz+C0wd2/eFK2dnqbSMR50ixN8aMAZ4CmgHPW2sfcWJcEbey1jL72y08MGsZu8uryBrdhzvO6k9SYh3/pNTbRlwg5GJvjGkG/BU4G9gIfGmMedtauyzUsUXcaOPucu59q5C8FTtI796el68fwdDu7et/gnrbiAs48QXtSKDYWrvGWlsF5AJjHRhXxD2mTaOmdx/mLVzD2TlzWbxqO/deMJgZt55y5EIP3tYGSUk/XKbeNtLEnCj23YENfvc3+paJxIZp0yjMnszFp/2a1w6kMKrkG9578TZuXPcZzZsF8E9o/HiYMgVSU8EY788pUzRfL03KWGtDG8CYy4Ax1tpf+O5fA5xkrf3VYetlAVkAKSkpGbm5uSG9LkBZWRnJye48ftnN2cDd+dyUrbLGMmPRZuaVtaetqeWy1BpOLSvBGCAxEdLTIx3xO27abnVxc75oyZaZmZlvrR0e1EDW2pBuwChgnt/9icDEIz0nIyPDOiEvL8+RccLBzdmsdXc+t2T7cPk2e8rkD2zq72bbCefeZve0bGPzHn/cWvDejIl0xB9wy3arj5vzRUs2YIkNslY7cTTOl0B/Y0xvYBNwBXCVA+OKRMSO/ZU8OHsZs77ZTL+uybzx/pOMyP/wf1fUF6wSRUIu9tbaGmPMr4B5eA+9fNFauzTkZCJNzFrL60s2kPOOt5/NnWcN4JbT+9Ay5QbIWqSLh0hUc+Q4e2vtHGCOE2OJRMLqHWVMml7A4rWljOzdiYcvSadfV98crv/FQ8D7BasuHiJRRmfQSlyrrKnluQVr+GteMa1aJPDHS9O5PKMnCYdfHvDQxUMWLICSkkhEFQmJir3ErS9LSpk4vYDi7WVcdLy3n02Xti0jHUskLFTsJe7sPVjNH+cu59XF6+neoTUvXT+CzMP72YjEGBV7iRvWWt4t3Mrv317KrrJKbjq1N3eePaDufjYiMUafcokLm/Yc5L63Cvlg+XaGdm/HS9c10M9GJMao2EtMq/VYXvmshMffW4G1cM/5g7julLTA2hyIxBAVe4lZSzfvZeL0Ar7duJfTB3bhobFD6dkpqeEnisQgFXuJOQeravnz+yt5/pO1dExqwV+uPIELjjsaY0zDTxaJUSr2ElM+WrmDe94qYEPpQa4c2ZMJYwbRPqlFpGOJRJyKvcSEnWWVPDR7GTO/3kyfLm34d9bJnNTnqEjHEnENFXuJatZa3sjfSM47RZRX1fDrM/tza2ZfWjZvFuloIq6iYi9Ra82OMibNKGDRmlJGpHVk8rh0+nVtG+lYIq6kYi9Rp6rGw98/Ws1f8opp2TyByePS+dnwOvrZiMh3dLCxRJX8daWc//RCnpi/krMHp/DBXadx5cheDRf6adMgLQ0SErw/p01rirgirqE9e4kK+yqqeXTucqYuWs8x7VvxwrXDOXNQSmBPnjYNsrK+70e/bp33PqhNscQNFXtxNWstc339bHaWVXLDj3pz9zkDaNOyER/d7OwfXngEvPezs1XsJW5oGkcC18RTIZv3HOSmf+bzy2lf0Tm5JW/d9iPuu3Bw4wo9wPr1jVsuEoO0Zy+BacKpEI+1vPzpWh6bt4Jaa5l03rHc8KPewfez6dXLm7eu5SJxQnv2EpgjTYU4qGjLPv6wqIL7Zy0jI60T8+88jazRfUNrXJaT471mrD9dQ1bijPbsJTBhngo5WFXLUx+s4h8L15DU3PLUFcO46PhjnOln438N2fXrvXv0uoasxBkVewlMGKdCFq7aQfaMQtaXlvPT4T0Y3a6UC4Z1D3ncHzh0DVmROKVpHAlMGKZCdpVVcte/v+aaF76gWYLh1ZtO4tHLjic5USdHiThNe/YSGAenQqy1/OerTeS8s4yyyhpuP6Mft2b2o1UL9bMRCRft2Uvgxo+HkhLweLw/gyj0JTsPMP75xfzmjW/o0yWZd24/lbvOGRhYoddZsCJBC2nP3hhzOXA/MAgYaa1d4kQoiT1VNR7+sXANT32wipbNEvjDxUO5KpA2B4foLFiRkIQ6jVMIjAP+7kAWiVH563YzaXoBK7bt5ydDu3H/RUNIadeqcYPoLFiRkIRU7K21RYAu9yZ12ldRzWNzVzB18Tq6tWvFP34+nLMHB9jP5nA6C1YkJPqCVsLC28+mkO37K7nulDTuPmcgyY1tc+BPZ8GKhMRYa4+8gjHvA93qeCjbWjvTt84C4DdHmrM3xmQBWQApKSkZubm5wWb+TllZGcnJySGPEw5uzgbhy1da4WHqsiq+2l5Lz7YJXD8kkT4dfF++lpbCpk1QVQWJidC9O3TqFFi20lJvsfd4vl+WkACpqXWOES5ufl/dnA3cnS9asmVmZuZba4cHNZC1NuQbsAAYHuj6GRkZ1gl5eXmOjBMObs5mrfP5amo99pXP1toh9821A++ZY/+2oNhW1dR+v8LUqdYmJVkL39+SkrzLA802daq1qanWGuP9Wcdzw83N76ubs1nr7nzRkg1YYoOs05rGkZAt37qPCf8p4OsNezi1f2dyLk6n11GHnYDlxBesOgtWJGihHnp5CfAXoAvwjjHma2vtuY4kE9erqK7l6Q9WMeXjNbRr3YI//2wYY4fV089GX7CKRFSoR+PMAGY4lEWiyKfFO5k0o4B1u8q5LKMH2ecNomObxPqfoC9YRSJKZ9BKo5QeqOLu179h/POLMcCrvziJxy8//siFHtRmWCTCNGcvAbHWMuO/m3ho9jL2V9Twq8x+/OqMRvSzUZthkYjSnn28OdRfJj8/4P4y63Yd4JoXvuCu178hrXMb3rn9VH5zboD9bPw50FtHRIKjPft40sj+MtW1Hp5fuJY/v7+SxGYJPDR2CONPSg28n42IuIaKfTxpxOGP/12/m4nTC1i+dT9jhnj72XRr38h+NiLiGir28SSAwx/3V1TzxHsreeXzElLatmLKNRmcM6SuE6hFJJpozj7ahNLTvb7DHH3L31u6lbOf/JhXPi/h2lFpzL9rtAq9SIzQnn00CbWne07OD58PkJTE1vsf5v5/5TN36VaO7daWv119Iif06uh8fhGJGBX7aBJqywH/wx8BT2oa0+56lEfXtqOqdju/HTOQm07tQ4tm+oNPJNao2EcTJ1oO+PrLbJz1IZf9dipfrd/Dj/t1IOeSoaQe1caZnCLiOir20cSBlgMV1bU882Exf/vsIO1a1/DkT4/nkhO66wI0IjFOxT6a1DPnHmjLgc9W7yR7RiFrdx7glGOa88yNp9OpoTYHIhITVOyjSZAtB3YfqCJnThFv5m8k9agkpt54EjWbClXoReKIin20aURPd2stM7/ezIOzl7HvYDW3nt6X28/sT6sWzViwKcw5RcRVVOxj1Ppd5WS/VcDCVTsZ1rMDk8elM+jodpGOJSIRomIfY6prPbzwibefTfOEBB709bNppn42InFNxT6GfLNhDxOmF1C0ZR/nDE7hgbFDOLp960jHEhEXULGPAWWVNTzx3gpe+ayELm1b8tzVGYwZqjYHIvI9Ffso9/6ybdw3s5At+yq45uRUfnPuQNq1ahHpWCLiMir2UWr7vgrun7WUOQVbGZjSlr9cdSIZqepnIyJ1U7GPMh6P5bUv1/PIu8uprPHwf+cOJGu0+tmIyJGpQjS1EFoUr9y2n8v//jnZMwpJ796eeXeM5rbMfir0ItIg7dk3pSBbFFdU1/JsXjF/+2g1yS2b88TlxzPuRPWzEZHAqdg3pSBaFH++ehfZMwpYs/MA407oTvb5gzgquWUThBWRWBJSsTfGPAZcCFQBq4HrrbV7nAgWkxrRonhPeRUPzyni9SUb6dUpiX/dOJJT+3cJc0ARiVWhTvbOB4Zaa48DVgITQ48Uwxq4LCAc6mezibOe/Ij/fLWJX57el3l3jFahF5GQhLRnb619z+/uIuCy0OLEuAZaFG8oLSf7rUI+XrmD43t24J83pDP4GPWzEZHQOTlnfwPwbwfHiz31tCiuueJKXvx4NU/OX0kzY7j/wsFcMypN/WxExDHGWnvkFYx5H6jr3Ptsa+1M3zrZwHBgnK1nQGNMFpAFkJKSkpGbmxtKbgDKyspITk4OeZxwCDTb2r21vFRYxfr9Hk7o2oyrByVyVOvwH0oZC9suEpQteG7OFy3ZMjMz8621w4MayFob0g24DvgcSAr0ORkZGdYJeXl5jowTDg1lK6uotg+8vdT2njDbjvjDfPtuwWbr8XiaJpyN7m0XScoWPDfni5ZswBIbZK0O9WicMcBvgdOsteUNrS9eHxRt476ZS9m05yBXn9yL3445Vv1sRCSsQp2zfwZoCcz3neCzyFp7S8ipYtT2fRU8MGsZ7xRsoX/XZN68ZRTD0zpFOpaIxIFQj8bp51SQWObxWHK/3MDkd4uorPZw99kDuPm0viQ2V5sDEWkaOoM2zIq372fi9AK+LNnNyX068fAl6fTp4s4vgkQkdqnYh0m1x/Kn+St5dkExSYnNefSy47g8o4f62YhIRKjYh8HiNbu499ODbD2wirHDjuHeCwbTWf1sRCSCVOwdtLe8msnvFpH75QY6tza8csNIThugNgciEnkq9g6w1jL72y08MGsZu8uruHl0H05M3KpCLyKuoWIfoo27y7n3rULyVuwgvXt7Xr5+BEO7t2fBgm2RjiYi8p34PPYvhKtFHVJT6+H5hWs4+8mPWby2lPsuGMxbt/2Iod3bOx5XRCRU8bdnH+TVovwVbtrLhOnfUrhpH2cc25WHLh5K9w6twxRYRCR08Vfsg7ha1HerVdXwp/kreeGTtXRq05K/XnUi56V30+GUIuJ68VfsG3G1KH95K7Zzz4xCNu05yJUjezFhzLG0T1I/GxGJDvFX7Hv18k7d1LW8Djv2V/Lg7GXM+mYz/bom88YtoxihfjYiEmXir9g3cLWoQ6y1vL5kAznvFFFR7eHOswZwy+l9aNm8WRMHFhEJXfwV+3quFuU/X796RxmTpheweG0pI3t3YvK4dPqqn42IRLH4K/bgLex1fBlbWVPLcwvW8Ne8Ylq1SOCPl6ZzeUZPEnR5QBGJcvFZ7OvwZUkpE6cXULy9jIuO9/az6dJW/WxEJDbEfbHfe7CaP85dzquL19O9Q2teun4EmQO7RjqWiIij4rbYW2uZU7CV+2ctZVdZJTed2ps7zx5AUmLcbhIRiWHR2S7hULuD/Pyg2h1s2nOQX7yyhNte/Ypu7Vrx9q9+TPb5g1XoRSRmRV91C6HdQa3H8vJnJTzx3gqshXvOH8R1p6TRvFl0/p8nIhKo6Cv2QbY7KNy0l0kzCvh2414yB3bhoYuH0qNjUpjDioi4Q/Tt0jay3UF5VQ2T5xQx9q+fsnlPBc9cdQIvXjcitELvQNdMEZGmFH179o1od/DRyh1kzyhg4+6DXDmyJxPGDAq9n40DXTNFRJpa9O3Z5+R42xv4O6zdwc6ySn6d+1+uffELWjZP4PWbRzF53HHONC470jSSiIhLRd+evX+7A4DU1O/aHVhreSN/IznvFHGwqpY7zurPL0/v62w/myC7ZoqIRFJIxd4Y8xAwFvAA24HrrLWbnQh2RIfaHSxYACUlAKzZUcakGQUsWlPKyLROPDwunX5dw9DPppFdM0VE3CBEO0GEAAAHj0lEQVTUaZzHrLXHWWuHAbOB+xzI1ChVNR7+8sEqxjy1kGWb9/HIuHRys04OT6GHgKaRRETcJqQ9e2vtPr+7bQAbWpzGWbW7lpynF7JqexkXHHc09104mK5tW4X3RQPomiki4jYhz9kbY3KAnwN7gcyQEwVo8rtF/H1xhbefzXUjyDy2CfvZ1NM1U0TErYy1R94ZN8a8D3Sr46Fsa+1Mv/UmAq2stb+vZ5wsIAsgJSUlIzc3N+jQAHnrq9mwp5KfDm5Dq+bua0FcVlZGcrJ7e+C7OZ+yBcfN2cDd+aIlW2ZmZr61dnhQA1lrHbkBvYDCQNbNyMiwTsjLy3NknHBwczZr3Z1P2YLj5mzWujtftGQDltgga3RIX9AaY/r73R0LLA9lPBERCY9Q5+wfMcYMxHvo5TrgltAjiYiI00I9GudSp4KIiEj4RF+7BBERaTQVexGROKBiLyISB1TsRUTigIq9iEgcaPAM2rC8qDE78B6qGarOwE4HxgkHN2cDd+dTtuC4ORu4O1+0ZEu11nYJZpCIFHunGGOW2GBPHQ4zN2cDd+dTtuC4ORu4O188ZNM0johIHFCxFxGJA9Fe7KdEOsARuDkbuDufsgXHzdnA3fliPltUz9mLiEhgon3PXkREAuD6Ym+MudwYs9QY4zHG1PuNtDFmjDFmhTGm2BgzwW95b2PMYt/yfxtjEh3M1skYM98Ys8r3s2Md62QaY772u1UYYy72PfayMWat32PDnMoWaD7ferV+Gd72Wx7pbTfMGPO57/3/1hjzM7/HHN929X2G/B5v6dsOxb7tkub32ETf8hXGmHNDzRJEtruMMct82+kDY0yq32N1vr9NmO06Y8wOvwy/8HvsWt9nYJUx5lqnswWY709+2VYaY/b4PRa2bWeMedEYs90YU1jP48YY87Qv97fGmBP9Hmv8dgu2EX5T3YBBwEBgATC8nnWaAauBPkAi8A0w2PfY68AVvt+fA37pYLZHgQm+3ycAf2xg/U5AKZDku/8ycFkYt11A+YCyepZHdNsBA4D+vt+PAbYAHcKx7Y70GfJb51bgOd/vVwD/9v0+2Ld+S6C3b5xmTZwt0+9z9ctD2Y70/jZhtuuAZ+p4bidgje9nR9/vHZs632Hr/z/gxSbadqOBE6nnok/AecC7gAFOBhaHst1cv2dvrS2y1q5oYLWRQLG1do21tgrIBcYaYwxwBvCmb71XgIsdjDfWN2agY18GvGutLXcww5E0Nt933LDtrLUrrbWrfL9vBrYDQZ1QEoA6P0NHyPwmcKZvO40Fcq21ldbatUCxb7wmy2atzfP7XC0Cejj4+iFlO4JzgfnW2lJr7W5gPjAmwvmuBF5zOEOdrLUf4935q89Y4J/WaxHQwRhzNEFuN9cX+wB1Bzb43d/oW3YUsMdaW3PYcqekWGu3+H7fCqQ0sP4V/O8HKcf3J9qfjDEtHczWmHytjDFLjDGLDk0x4bJtZ4wZiXfPbLXfYie3XX2foTrX8W2XvXi3UyDPDXc2fzfi3SM8pK73t6mzXep7r940xvRs5HObIh++qa/ewId+i8O57RpSX/agtluoV6pyhAnwouaRcKRs/nestdYYU++hTb7/kdOBeX6LJ+ItdIl4D6/6HfBgBPKlWms3GWP6AB8aYwrwFrKQOLzt/gVca631+BaHvO1ikTHmamA4cJrf4v95f621q+seISxmAa9ZayuNMTfj/evojCZ8/UBdAbxpra31WxbpbecYVxR7a+1ZIQ6xCejpd7+Hb9kuvH/6NPftiR1a7kg2Y8w2Y8zR1totvoK0/QhD/RSYYa2t9hv70J5tpTHmJeA3jcnmVD5r7SbfzzXGmAXACcB/cMG2M8a0A97B+x//Ir+xQ952h6nvM1TXOhuNMc2B9ng/Y4E8N9zZMMachfc/0tOstZWHltfz/jpVsBrMZq3d5Xf3ebzf1xx67umHPXeBQ7kCzufnCuA2/wVh3nYNqS97UNstVqZxvgT6G+/RI4l437S3rffbjDy8c+UA1wJO/qXwtm/MQMb+n7lAX5E7ND9+MVDnt/LhzGeM6XhoCsQY0xn4EbDMDdvO917OwDtv+eZhjzm97er8DB0h82XAh77t9DZwhfEerdMb6A98EWKeRmUzxpwA/B24yFq73W95ne9vE2c72u/uRUCR7/d5wDm+jB2Bc/jhX75Nks+X8Vi8X3Z+7rcs3NuuIW8DP/cdlXMysNe3kxPcdgvXN81O3YBL8M5JVQLbgHm+5ccAc/zWOw9Yifd/3Wy/5X3w/sMrBt4AWjqY7SjgA2AV8D7Qybd8OPC833ppeP83Tjjs+R8CBXgL1VQg2eFt12A+4BRfhm98P290y7YDrgaqga/9bsPCte3q+gzhnRq6yPd7K992KPZtlz5+z832PW8F8JMw/DtoKNv7vn8fh7bT2w29v02YbTKw1JchDzjW77k3+LZnMXC909kCyee7fz/wyGHPC+u2w7vzt8X3Gd+I97uWW4BbfI8b4K++3AX4HY0YzHbTGbQiInEgVqZxRETkCFTsRUTigIq9iEgcULEXEYkDKvYiInFAxV5EJA6o2IuIxAEVexGROPD/ASzwlQBGs/j7AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot training and target dataset\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "plt.plot(x, y_target)\n", "plt.scatter(x, y_train, color='r')\n", "plt.grid(True); plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you remember how a single node of a neural network works, you can easily spot that just a single neuron can make the job. So let's start using a simple Sequential model with just one layer on one neuron only!" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# compose the NN model\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "\n", "from tensorflow.keras.models import Sequential\n", "from tensorflow.keras.layers import Dense, Activation\n", "from tensorflow.keras import backend as K\n", "from tensorflow.keras.utils import get_custom_objects\n", "\n", "model = tf.keras.Sequential()\n", "model.add(Dense(1, input_shape=(1,)))\n", "\n", "# compile the model choosing optimizer, loss and metrics objects\n", "model.compile(optimizer='sgd', loss='mse', metrics=['mse']) # metrics is optional here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "aggiungere grafico per mostrare mappa di 2 parametri con peso neurone e bias\n", "qui usiamo mse, cioè proprio i minimi quadrati che utilizzeremmo per la stima dei parametri per un fit lineare" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "dense_3 (Dense) (None, 1) 2 \n", "=================================================================\n", "Total params: 2\n", "Trainable params: 2\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "# get a summary of our composed model\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are now going to train our model, that is we feed the neuron with the set of training pair x, y_train from which the optimizer will find the best weights to minimize the Mean Square Error loss function (out linear regression function)." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.48%\n" ] } ], "source": [ "# we have to choice the batch_size and epochs\n", "batch_size=100 # defaults to 32\n", "epochs=1000\n", "\n", "model.fit(x, y_train,\n", " batch_size=batch_size,\n", " epochs=epochs,\n", " shuffle=True, # a good idea is to shuffle input before at each epoch\n", " validation_data=(x, y_target), # used to evaluate the loss and any model metrics at each epoch\n", " verbose=0, # 1 get a line per epoch reporting loss and metric data\n", " )\n", "\n", "score = model.evaluate(x, y_target, verbose=0)\n", "print(\"Accuracy: %.2f%%\" % (score[1]*100))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[array([[2.8795085]], dtype=float32), array([-0.00351556], dtype=float32)]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# return weights\n", "model.get_weights()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Keras model.fit available callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The .fit method can also get callback functions which can be used to customize the fitting procedure with special actions.\n", "\n", "Keras provides some predefined callbacks to feed in, among them for example:\n", "- TerminateOnNaN(): that terminates training when a NaN loss is encountered\n", "- ProgbarLogger(): that prints metrics to stdout\n", "- ModelCheckpoint(filepath): that save the model after every epoch\n", "- EarlyStopping: which stop training when a monitored quantity has stopped improving\n", "- LambdaCallback: for creating simple, custom callbacks on-the-fly\n", "\n", "You can select one or more callback and pass them as a list to the callback argument of the fit method." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also create a callback object from scratch, customizing its behaviour overloading the base methods of the Callback Keras class:\n", "- on_epoch_begin and on_epoch_end\n", "- on_batch_begin and on_batch_end\n", "- on_train_begin and on_train_end\n", "\n", "A callback has access to its associated model through the class property self.model, so that you can monitor and access many of the quantities which are in the optimization process." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we are going to construct a callback object to represent how estimated parameters are converging during the training procedure" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from IPython.display import clear_output\n", "\n", "class PlotCurrentEstimate(tf.keras.callbacks.Callback):\n", " def __init__(self, x_valid, y_valid):\n", " \"\"\"Keras Callback which plot current model estimate against reference target\"\"\"\n", " \n", " # convert numpy arrays into lists for plotting purposes\n", " self.x_valid = list(x_valid[:])\n", " self.y_valid = list(y_valid[:])\n", " self.iter=0\n", "\n", " def on_epoch_end(self, epoch, logs={}):\n", " \n", " temp = self.model.predict(self.x_valid, batch_size=None, verbose=False, steps=None)\n", " self.y_curr = list(temp[:]) # convert numpy array into list\n", " \n", " self.iter+=1\n", " if self.iter%10 == 0:\n", " clear_output(wait=True) \n", " self.eplot = plt.subplot(1,1,1)\n", " self.eplot.clear() \n", " self.eplot.scatter(self.x_valid, self.y_curr, color=\"blue\", s=4, marker=\"o\", label=\"estimate\")\n", " self.eplot.scatter(self.x_valid, self.y_valid, color=\"red\", s=4, marker=\"x\", label=\"valid\")\n", " self.eplot.legend()\n", "\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will use also an EarlyStopping callback on the val_loss quantity. This will stop the training process as soon as the val_loss quantity does not improve anymore after an amount of epochs, preventing a long time of wated computation to take over without useful results.\n", "\n", "keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto', baseline=None, restore_best_weights=False)\n", "\n", "Stop training when a monitored quantity has stopped improving.\n", "\n", "Arguments:\n", "\n", "- monitor: quantity to be monitored. \n", "- min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement. \n", "- patience: number of epochs with no improvement after which training will be stopped. \n", "- verbose: verbosity mode. \n", "- mode: one of {auto, min, max}. In min mode, training will stop when the quantity monitored has stopped decreasing; in max mode it will stop when the quantity monitored has stopped increasing; in auto mode, the direction is automatically inferred from the name of the monitored quantity. \n", "- baseline: Baseline value for the monitored quantity to reach. Training will stop if the model doesn't show improvement over the baseline. \n", "- restore_best_weights: whether to restore model weights from the epoch with the best value of the monitored quantity. If False, the model weights obtained at the last step of training are used." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAF3FJREFUeJzt3X9wVfWZx/HPkxCIiq6IKaJIwalbQYIYouNup2oLAq0OyiILHd2VrQ4I0ersrruoM5REHLvbGZnZNqNhrGv9UZVN18K2dPAXHdupv0IFRKhIKUwjrQS0VC0ISZ794xyYW8yve8/JPffkvF8zd+6955z7vQ/fXJ6cfM/3+V5zdwEAsqEs6QAAAMVD0geADCHpA0CGkPQBIENI+gCQISR9AMiQyEnfzCrN7DUz22Rmb5lZfRyBAQDiZ1Hn6ZuZSTrJ3T8yswpJv5B0m7u/EkeAAID4DIragAe/NT4Kn1aENyq+AKAERU76kmRm5ZI2SPqcpEZ3f7Wn408//XQfM2ZMHG8NAJmxYcOGfe5eFaWNWJK+u3dImmRmp0p6xswmuPuW3GPMbIGkBZI0evRotbS0xPHWAJAZZrY7ahuxzt5x9z9KWi9pRhf7Vrp7rbvXVlVF+kUFAChQHLN3qsIzfJnZCZKukPTrqO0CAOIXx/DOSEnfD8f1yyStcvcfx9AuACBmccze2SzpwqjtHDlyRK2trTp06FDUpjKnsrJSo0aNUkVFRdKhAChxsVzIjUNra6tOPvlkjRkzRsHUf/SFu2v//v1qbW3V2LFjkw4HQIkrmWUYDh06pOHDh5Pw82RmGj58OH8hAeiTkkn6kkj4BaLfAPRVSSV9ABjQSuDraUn6BXjkkUe0Z8+eY89vuukmbd26NXK7u3bt0g9+8IPI7QAoQd/8plRWFtwniKRfgOOT/kMPPaTx48dHbpekDwxQ7lJDQ/C4oSHRM36Sfo7HH39cF198sSZNmqSFCxeqo6ND8+fP14QJE1RdXa0VK1aoublZLS0tuu666zRp0iQdPHhQl19++bFlJYYOHao77rhD559/vqZOnarXXntNl19+uc455xytWbNGUpDcv/jFL6qmpkY1NTX65S9/KUlasmSJfv7zn2vSpElasWKFOjo6dMcdd+iiiy7SxIkT1dTUlFjfAIjATPfYUkkK7pO8DufuRb9NnjzZj7d169ZPbSumrVu3+lVXXeWHDx92d/dFixb5smXLfOrUqceO+eCDD9zd/bLLLvPXX3/92Pbc55J87dq17u5+zTXX+BVXXOGHDx/2jRs3+gUXXODu7h9//LEfPHjQ3d23b9/uR/tj/fr1fuWVVx5rt6mpye+55x53dz906JBPnjzZd+7c2W38AErX4sXu5WWdvnhx4W1IavGI+bdk5ukXoq5OamqSFi6UGhujtfXCCy9ow4YNuuiiiyRJBw8e1IwZM7Rz507deuutuvLKKzVt2rRe2xk8eLBmzAiWHqqurtaQIUNUUVGh6upq7dq1S1JQiHbLLbdo48aNKi8v1/bt27ts69lnn9XmzZvV3NwsSTpw4IDeeecd5uMDKdTYKDU2Jj/TLtVJv6lJ6ugI7qMmfXfXDTfcoPvuu+8vtt97771at26dHnzwQa1atUoPP/xwj+1UVFQcm0JZVlamIUOGHHvc3t4uSVqxYoVGjBihTZs2qbOzU5WVld3G9J3vfEfTp0+P9o8DgFCqx/QXLpTKy4P7qKZMmaLm5mbt3btXkvT+++9r9+7d6uzs1OzZs7V8+XL96le/kiSdfPLJ+vDDDwt+rwMHDmjkyJEqKyvTY489po6Oji7bnT59uh544AEdOXJEkrR9+3Z9/PHHBb8vAKT6TD/4cymetsaPH6/ly5dr2rRp6uzsVEVFhe6//37NmjVLnZ2dknTsr4D58+fr5ptv1gknnKCXX3457/davHixZs+erUcffVQzZszQSSedJEmaOHGiysvLdcEFF2j+/Pm67bbbtGvXLtXU1MjdVVVVpR/96Efx/IMBZFLk78gtRG1trR//JSrbtm3TuHHjih7LQEH/AQOfmW1w99oobaR6eAcAiqoEKmqjIukDQF+USEVtVCR9AOhNCVXURkXSB4DemGltbVBRu7Y24YraiFI9ewcAimXmG/Xq0DKVv2FqTzqYCDjTB4A+COqCLJa6oCSR9As0dOhQSdKePXt07bXXdnlM7kJsANKtsVFqb4+vNigpJP2IzjzzzGNr4wBAqSPph5YsWaLGnF/hy5Yt0/LlyzVlyhTV1NSourpaq1ev/tTrdu3apQkTJkgKFmmbN2+exo0bp1mzZungwYNFix8A+iL9F3LdY7mSPnfuXN1+++2qq6uTJK1atUrr1q3TN77xDZ1yyinat2+fLrnkEs2cObPb76R94IEHdOKJJ2rbtm3avHmzampqIscFAHFK95l+jMUSF154ofbu3as9e/Zo06ZNGjZsmM444wzdddddmjhxoqZOnap3331X7733XrdtvPTSS7r++uslBevoTJw4MXJcAGKU4vn1cUlv0u+HYok5c+aoublZTz/9tObOnasnnnhCbW1t2rBhgzZu3KgRI0bo0KFDkd8HQAIGSEVtVJGTvpmdbWbrzWyrmb1lZrfFEVgf3lhaGhRLaGk8xRJz587VU089pebmZs2ZM0cHDhzQZz7zGVVUVGj9+vXavXt3j6+/9NJLj33H7ZYtW7R58+bIMQGIwQCqqI0qjjP9dkn/4u7jJV0iqc7Mon9LeF/U10udncF9DM4//3x9+OGHOuusszRy5Ehdd911amlpUXV1tR599FGdd955Pb5+0aJF+uijjzRu3DgtXbpUkydPjiUuABENoIraqGJfWtnMVkv6rrs/190xLK0cP/oP6NmgQVJHh6u83NSe0pLaklta2czGSLpQ0qtxtgsAUQ2UitqoYkv6ZjZU0g8l3e7uf+pi/wIzazGzlra2trjeFgD6ZKBU1EYVS9I3swoFCf8Jd//fro5x95XuXuvutVVVVV22k8S3eA0E9BuAvopj9o5J+p6kbe5+f6HtVFZWav/+/SSwPLm79u/fr8rKyqRDAZACcVTkfkHSP0h608w2htvucve1+TQyatQotba2iqGf/FVWVmrUqFFJhwEgBSInfXf/haTI858qKio0duzYqM0AAHqQ3opcANnD8G9kJH0A6cAyCrGIvTirL7oqzgKAbrkHCf+ozs5MVtWWXHEWAPQLM91jwTIK91i2l1GIiqQPIBX+sKheg8o69YdF8ay1lVUM7wBASjC8AwDIC0kfADKEpA8AGULSB4AMIekDKB4qahNH0gdQHFTUlgSmbALof1TUxoIpmwDSgYrakkHSB1AUVNSWBoZ3ACAlGN4BAOSFpA8AGULSB4AMIekDQIaQ9AH0HRW1qUfSB9A3VNQOCEzZBNA7KmpLAlM2ARSHmdbWBhW1a2upqE2zQUkHACAdZr5Rrw4tU/kbpvakg0HBYjnTN7OHzWyvmW2Joz0ApWfhQqm83LRwYdKRIIq4hncekTQjprYAlKDGRqm9PbhHesWS9N39JUnvx9EWAKD/cCEXADKkaEnfzBaYWYuZtbS1tRXrbQEAOYqW9N19pbvXunttVVVVsd4WQC4qajOP4R0gK6ioheKbsvmkpJclfd7MWs3sxjjaBRATd6mhIXjc0MAZf4bFNXvna+4+0t0r3H2Uu38vjnYBxISKWoSoyAUygopaSIzpA5lBRS0kVtkEgNRglU0AQF5I+gCQISR9AMgQkj6QJsyvR0QkfSAtqKhFDJi9A6QB31ELMXsHyA4qahETKnKBlKCiFnHgTB9ICSpqEQfG9AEgJRjTBwDkhaQPABlC0geADCHpA8VERS0SRtIHioWKWpQAZu8AxUBFLWLA7B0gLcykpUFFrZZSUYvkkPSBIqnbV69BZZ2q21efdCjIMJI+UCRNTVJHp6mpKelIkGUkfaBIgmUUxDIKSBQXcgEgJbiQCwDISyxJ38xmmNnbZrbDzJbE0SYAIH6Rk76ZlUtqlPQVSeMlfc3MxkdtFyhJVNQi5eI4079Y0g533+nuhyU9JenqGNoFSgsVtRgA4kj6Z0n6Xc7z1nDbXzCzBWbWYmYtbW1tMbwtUETuUkND8LihgTN+pFbRLuS6+0p3r3X32qqqqmK9LRAPKmoxQMSR9N+VdHbO81HhNmBAoaIWA0EcSf91Seea2VgzGyxpnqQ1MbQLlBQqajEQRE767t4u6RZJ6yRtk7TK3d+K2i5QaqioxUBARS4ApAQVuQCAvJD0ASBDSPoAkCEkfWQLRVXIOJI+soNlFABm7yAj+GJyDADM3gH6ykxra4NlFNbWsowCsmtQ0gEAxTLzjXp1aJnK3zC1Jx0MkBDO9JEZQUWtUVGLTGNMHwBSgjF9AEBeSPoAkCEkfQDIEJI+0oWKWiASkj7Sg4paIDJm7yAdqKgFmL2DDKGiFogFFblIDSpqgeg400dqUFELRMeYPgCkBGP6AIC8kPQBIENI+gCQISR9FB9VtUBiSPooLqpqgURFSvpmNsfM3jKzTjOLdEUZGeAuNTQEjxsaOOMHEhD1TH+LpL+T9FIMsWCgM5OWBlW1WkpVLZCESEnf3be5+9txBYOBr25fvQaVdapuX33SoQCZVLQxfTNbYGYtZtbS1tZWrLdFiWlqkjo6TU1NSUcCZFOvSd/MnjezLV3crs7njdx9pbvXunttVVVV4REj1YKlFMRSCkBCel1wzd2nFiMQZENjY3ADkAymbAJAhkSdsjnLzFol/Y2kn5jZunjCAgD0h0jr6bv7M5KeiSkWpIU70y2BlGJ4B/mhohZINdbTR9/xPbVAolhPH8VFRS2QeiR95IWKWiDdSPrICxW1QLqR9JEXKmqBdONCLgCkBBdyAQB5IekDQIaQ9LOIb6wCMouknzVU1AKZxoXcLKGiFkg1LuQiP1TUAplH0s8YKmqBbCPpZwwVtUC2kfQzhopaINu4kAsAKcGFXABAXkj6AJAhJP00oqIWQIFI+mlDRS2ACLiQmyZU1AKZxoXcrKGiFkBEnOmnkTsJH8igxM/0zezbZvZrM9tsZs+Y2alR2kPv6uqkQRWmurqkIwGQRlGHd56TNMHdJ0raLunO6CGhJ01NUkeHWEYBQEEiJX13f9bd28Onr0gaFT0k9IRlFABEEeeF3K9L+mmM7aELjY1Se3twDwD5GtTbAWb2vKQzuth1t7uvDo+5W1K7pCd6aGeBpAWSNHr06IKCBQBE02vSd/epPe03s/mSrpI0xXuYCuTuKyWtlILZO/mFOcAw+wZAQqLO3pkh6d8kzXT3P8cT0gBHRS2ABEWap29mOyQNkbQ/3PSKu9/c2+syO0+filoAESQ+T9/dP+fuZ7v7pPDWa8LPNCpqASSMitwkMKYPoACJn+kjf1TUAkgSSb/IqKgFkCSSfpFRUQsgSYzpA0BKMKYPAMgLSb8QfEctgJQi6eeLiloAKcaYfj6oqAWQIMb0i42KWgApx5l+IaioBZAAzvQTQEUtgDQj6eeJiloAaUbSzxMVtQDSjDF9AEgJxvQBAHkh6QNAhmQz6bOMAoCMyl7SZxkFABmWrQu5LKMAIMW4kJsvllEAkHHZSvqS6vbVa1BZp+r21ScdCgAUXeaSflOT1NFpVNQCyKTMJX0qagFkWbYu5AJAiiV+IdfM7jGzzWa20cyeNbMzo7QHAOhfUYd3vu3uE919kqQfS1oaQ0wAgH4SKem7+59ynp4kqThjRVTUAkBBIl/INbN7zex3kq5TMc70qagFgIL1eiHXzJ6XdEYXu+5299U5x90pqdLdu8zGZrZA0gJJGj169OTdu3fnHy0VtQAyrCgXct19qrtP6OK2+rhDn5A0u4d2Vrp7rbvXVlVVFRYtFbUAEEnU2Tvn5jy9WtKvo4XTB/X1wRl+PRW1AJCvqGP63zKzLWa2WdI0SbfFEFOP+GJyAChc6oqzBg0Kvpi8vFxqb485MAAoYYkXZyWBZRQAoHCpO9MHgKzK5Jk+AKBwJH0AyBCSPgBkCEkfADKEpA8AGULSB4AMIekDQIYkMk/fzNokFbDM5jGnS9oXUzj9gfiiKeX4Sjk2ifiiKuX4Tpd0krsXuGJlIJGkH5WZtUQtUOhPxBdNKcdXyrFJxBdVKccXV2wM7wBAhpD0ASBD0pr0VyYdQC+IL5pSjq+UY5OIL6pSji+W2FI5pg8AKExaz/QBAAUo2aRvZnPM7C0z6zSzbq9Ym9kMM3vbzHaY2ZKc7WPN7NVw+9NmNjjm+E4zs+fM7J3wflgXx3zJzDbm3A6Z2TXhvkfM7Lc5+yYVO77wuI6cGNbkbO+3/utj300ys5fDz8BmM5ubs69f+q67z1LO/iFhX+wI+2ZMzr47w+1vm9n0OOIpIL5/NrOtYX+9YGafzdnX5c+5iLHNN7O2nBhuytl3Q/hZeMfMbog7tj7GtyIntu1m9secff3ad+F7PGxme81sSzf7zcz+K4x/s5nV5OzLr//cvSRvksZJ+rykn0mq7eaYckm/kXSOpMGSNkkaH+5bJWle+PhBSYtiju8/JS0JHy+R9B+9HH+apPclnRg+f0TStf3Yf32KT9JH3Wzvt/7rS2yS/lrSueHjMyX9XtKp/dV3PX2Wco5ZLOnB8PE8SU+Hj8eHxw+RNDZspzyB+L6U8/ladDS+nn7ORYxtvqTvdvHa0yTtDO+HhY+HFTu+446/VdLDxei7nPe4VFKNpC3d7P+qpJ9KMkmXSHq10P4r2TN9d9/m7m/3ctjFkna4+053PyzpKUlXm5lJ+rKk5vC470u6JuYQrw7b7Wv710r6qbv/OeY4upNvfMcUof96jc3dt7v7O+HjPZL2SopUlNKLLj9Lxx2TG3ezpClhX10t6Sl3/8TdfytpR9heUeNz9/U5n69XJI2KOYaCY+vBdEnPufv77v6BpOckzUg4vq9JejLmGHrk7i8pOCnsztWSHvXAK5JONbORKqD/Sjbp99FZkn6X87w13DZc0h/dvf247XEa4e6/Dx//QdKIXo6fp09/kO4N/1RbYWZDEoqv0sxazOyVo0NP6v/+y6vvzOxiBWdov8nZHHffdfdZ6vKYsG8OKOirvry2GPHlulHBmeFRXf2cix3b7PBn1mxmZ+f52mLEp3BIbKykF3M292ff9VV3/4a8+29Q7KHlwcyel3RGF7vudvfVxY7neD3Fl/vE3d3Mup0GFf5Grpa0LmfznQoS3mAFU7H+XVJDAvF91t3fNbNzJL1oZm8qSGaRxNx3j0m6wd07w82R+24gM7PrJdVKuixn86d+zu7+m65b6Bf/J+lJd//EzBYq+Ivpy0V8/76aJ6nZ3TtytiXdd7FKNOm7+9SITbwr6eyc56PCbfsV/PkzKDwjO7o9tvjM7D0zG+nuvw8T094emvp7Sc+4+5Gcto+e6X5iZv8t6V+TiM/d3w3vd5rZzyRdKOmHith/ccRmZqdI+omCk4BXctqO3Hdd6O6z1NUxrWY2SNJfKfis9eW1xYhPZjZVwS/Wy9z9k6Pbu/k5x5W4eo3N3ffnPH1IwXWdo6+9/LjX/iymuPocX455kupyN/Rz3/VVd/+GvPsv7cM7r0s614KZJoMV/MDWeHCFY72CcXRJukFS3H85rAnb7Uv7nxojDJPd0fHzayR1edW+P+Mzs2FHh0bM7HRJX5C0tQj915fYBkt6RsE4ZvNx+/qj77r8LPUQ97WSXgz7ao2keRbM7hkr6VxJr8UQU17xmdmFkpokzXT3vTnbu/w5Fzm2kTlPZ0raFj5eJ2laGOMwSdP0l38RFyW+MMbzFFwMfTlnW3/3XV+tkfSP4SyeSyQdCE9+8u+//r4qXehN0iwF41OfSHpP0rpw+5mS1uYc91VJ2xX85r07Z/s5Cv7j7ZD0P5KGxBzfcEkvSHpH0vOSTgu310p6KOe4MQp+G5cd9/oXJb2pIGE9LmloseOT9LdhDJvC+xuL0X99jO16SUckbcy5TerPvuvqs6Rg2Ghm+Lgy7IsdYd+ck/Pau8PXvS3pK/30f6K3+J4P/68c7a81vf2cixjbfZLeCmNYL+m8nNd+PezTHZL+KYm+C58vk/St417X730Xvs+TCmaoHVGQ926UdLOkm8P9JqkxjP9N5cxozLf/qMgFgAxJ+/AOACAPJH0AyBCSPgBkCEkfADKEpA8AGULSB4AMIekDQIaQ9AEgQ/4fXvTJVr7xMZ8AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "20/20 [==============================] - 0s 14ms/step - loss: 0.1069 - val_loss: 0.0065\n", "Epoch 101/1000\n", "20/20 [==============================] - 0s 376us/step - loss: 0.1069 - val_loss: 0.0065\n" ] }, { "data": { "text/plain": [ "[array([[3.046707]], dtype=float32), array([-0.07363325], dtype=float32)]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "plot_estimate = PlotCurrentEstimate(x, y_target)\n", "\n", "earlystop = tf.keras.callbacks.EarlyStopping(monitor='val_loss',\n", " min_delta=0, patience=100, mode='auto')\n", "\n", "model.fit(x, y_train, batch_size=100, epochs=1000,\n", " validation_data=(x, y_target),\n", " callbacks=[ plot_estimate, earlystop]\n", " )\n", "\n", "model.get_weights()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Exercises\n", " \n", "1. Try to use different Keras optimizers\n", "\n", "1. Try to extend the model to fit a polynomial of order N.\n", " - How many layers do you need? \n", " - Can you make good prediction using a non-linear activation function?\n", " - Can you identify the meaning of weights?\n", "\n", "\n", "1. Try to extend the model with at least two layers and fit a 2D Gaussian distribution or a simple trigonometric 2D function such as f(x,y) = sin(x+y)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 2 }