{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Matplotlib Linear Regression Animation in Jupyter Notebook\n", "\n", "This is a notebook for the medium article [Matplotlib Linear Regression Animation in Jupyter Notebook](https://bindichen.medium.com/matplotlib-linear-regression-animation-in-jupyter-notebook-2435b711bea2)\n", "\n", "Please check out article for instructions\n", "\n", "**License**: [BSD 2-Clause](https://opensource.org/licenses/BSD-2-Clause)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "!conda install -c conda-forge -y ffmpeg pillow\n", "!pip install vega_datasets ipympl scikit-learn ffmpeg-python" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%matplotlib widget\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from sklearn.linear_model import LinearRegression\n", "from vega_datasets import data\n", "\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = 'svg'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Preprocessing" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
NameMiles_per_GallonCylindersDisplacementHorsepowerWeight_in_lbsAccelerationYearOrigin
0chevrolet chevelle malibu18.08307.0130.0350412.01970-01-01USA
1buick skylark 32015.08350.0165.0369311.51970-01-01USA
2plymouth satellite18.08318.0150.0343611.01970-01-01USA
3amc rebel sst16.08304.0150.0343312.01970-01-01USA
4ford torino17.08302.0140.0344910.51970-01-01USA
\n", "
" ], "text/plain": [ " Name Miles_per_Gallon Cylinders Displacement \\\n", "0 chevrolet chevelle malibu 18.0 8 307.0 \n", "1 buick skylark 320 15.0 8 350.0 \n", "2 plymouth satellite 18.0 8 318.0 \n", "3 amc rebel sst 16.0 8 304.0 \n", "4 ford torino 17.0 8 302.0 \n", "\n", " Horsepower Weight_in_lbs Acceleration Year Origin \n", "0 130.0 3504 12.0 1970-01-01 USA \n", "1 165.0 3693 11.5 1970-01-01 USA \n", "2 150.0 3436 11.0 1970-01-01 USA \n", "3 150.0 3433 12.0 1970-01-01 USA \n", "4 140.0 3449 10.5 1970-01-01 USA " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = data.cars()\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Drop rows with NaN\n", "df.dropna(subset=['Horsepower', 'Miles_per_Gallon'], inplace=True)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Transform data\n", "x = df['Horsepower'].to_numpy().reshape(-1, 1)\n", "y = df['Miles_per_Gallon'].to_numpy().reshape(-1, 1)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": "\n\n\n \n \n \n \n 2022-10-18T14:12:19.457955\n image/svg+xml\n \n \n Matplotlib v3.6.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.scatter(x, y, c='g', label='Horsepower vs. Miles_per_Gallon')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Animations with Interactive Plot" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Warning: Cannot change to a different GUI toolkit: notebook. Using widget instead.\n", "Warning: Cannot change to a different GUI toolkit: notebook. Using widget instead.\n" ] } ], "source": [ "# Enable interactive plot\n", "%matplotlib notebook\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from matplotlib.animation import FuncAnimation\n", "from sklearn.linear_model import LinearRegression" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0ed500966b80441180764cbcd5d94cbd", "version_major": 2, "version_minor": 0 }, "image/png": "", "text/html": [ "\n", "
\n", "
\n", " Figure\n", "
\n", " \n", "
\n", " " ], "text/plain": [ "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "x_data = []\n", "y_data = []\n", "\n", "fig, ax = plt.subplots() # A tuple unpacking to unpack the only fig\n", "ax.set_xlim(30, 250)\n", "ax.set_ylim(5, 50)\n", "# Plotting \n", "scatter, = ax.plot([], [], 'go', label='Horsepower vs. Miles_per_Gallon')\n", "line, = ax.plot([], [], 'r', label='Linear Regression')\n", "ax.legend()\n", "\n", "reg = LinearRegression()\n", "\n", "def animate(frame_num):\n", " # Adding data\n", " x_data.append(x[frame_num])\n", " y_data.append(y[frame_num])\n", " # Convert data to numpy array\n", " x_train = np.array(x_data).reshape(-1, 1)\n", " y_train = np.array(y_data).reshape(-1, 1)\n", " # Fit values to a linear regression\n", " reg.fit(x_train, y_train)\n", "\n", " # update data for scatter plot\n", " scatter.set_data((x_data, y_data))\n", " # Predict value and update data for line plot\n", " line.set_data((list(range(250)), reg.predict(np.array([entry for entry in range(250)]).reshape(-1, 1))))\n", "\n", "anim = FuncAnimation(fig, animate, frames=len(x), interval=20)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Animations with embedded HTML5 video" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from matplotlib.animation import FuncAnimation\n", "from IPython import display\n", "\n", "# Turn off matplotlib plot in Notebook\n", "plt.ioff()\n", "# Pass the ffmpeg path\n", "plt.rcParams['animation.ffmpeg_path'] = '/usr/local/bin/ffmpeg'" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/matplotlib/animation.py:879: UserWarning: Animation was deleted without rendering anything. This is most likely not intended. To prevent deletion, assign the Animation to a variable, e.g. `anim`, that exists until you output the Animation using `plt.show()` or `anim.save()`.\n", " warnings.warn(\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "x_data = []\n", "y_data = []\n", "\n", "fig, ax = plt.subplots()\n", "ax.set_xlim(30, 250)\n", "ax.set_ylim(5, 50)\n", "scatter, = ax.plot([], [], 'go', label='Horsepower vs. Miles_per_Gallon')\n", "line, = ax.plot([], [], 'r', label='Linear Regression')\n", "ax.legend()\n", "\n", "reg = LinearRegression()\n", "\n", "def animate(frame_num):\n", " # Adding data\n", " x_data.append(x[frame_num])\n", " y_data.append(y[frame_num])\n", " # Convert data to numpy array\n", " x_train = np.array(x_data).reshape(-1, 1)\n", " y_train = np.array(y_data).reshape(-1, 1)\n", " reg.fit(x_train, y_train)\n", " \n", " # update data for scatter plot\n", " scatter.set_data((x_data, y_data))\n", " # Predict value and update data for line plot\n", " line.set_data((list(range(250)), reg.predict(np.array([entry for entry in range(250)]).reshape(-1, 1))))\n", "\n", "anim = FuncAnimation(fig, animate, frames=len(x), interval=20)\n", "\n", "video = anim.to_html5_video()\n", "html = display.HTML(video)\n", "display.display(html)\n", "plt.close()\n", "\n", "# Note Github only render static HTML and the embeded HTML5 video won't be displayed, \n", "# The embedded video should be working if you host the Notebook or open it locally. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Thanks for reading\n", "\n", "This is a notebook for the medium article [Matplotlib Linear Regression Animation in Jupyter Notebook](https://bindichen.medium.com/matplotlib-linear-regression-animation-in-jupyter-notebook-2435b711bea2)\n", "\n", "**License**: [BSD 2-Clause](https://opensource.org/licenses/BSD-2-Clause)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 }