4.6. Matplotlib Linear Regression Animation in Jupyter Notebook#

This is a notebook for the medium article Matplotlib Linear Regression Animation in Jupyter Notebook

Please check out article for instructions

License: BSD 2-Clause

%%capture
!conda install -c conda-forge -y ffmpeg pillow
!pip install vega_datasets ipympl scikit-learn ffmpeg-python
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LinearRegression
from vega_datasets import data

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

4.6.1. Data Preprocessing#

df = data.cars()
df.head()
Name Miles_per_Gallon Cylinders Displacement Horsepower Weight_in_lbs Acceleration Year Origin
0 chevrolet chevelle malibu 18.0 8 307.0 130.0 3504 12.0 1970-01-01 USA
1 buick skylark 320 15.0 8 350.0 165.0 3693 11.5 1970-01-01 USA
2 plymouth satellite 18.0 8 318.0 150.0 3436 11.0 1970-01-01 USA
3 amc rebel sst 16.0 8 304.0 150.0 3433 12.0 1970-01-01 USA
4 ford torino 17.0 8 302.0 140.0 3449 10.5 1970-01-01 USA
# Drop rows with NaN
df.dropna(subset=['Horsepower', 'Miles_per_Gallon'], inplace=True)
# Transform data
x = df['Horsepower'].to_numpy().reshape(-1, 1)
y = df['Miles_per_Gallon'].to_numpy().reshape(-1, 1)
plt.scatter(x, y, c='g', label='Horsepower vs. Miles_per_Gallon')
plt.legend()
plt.show()
_images/matplotlib-linear-regression-animation_7_0.svg

4.6.2. 1. Animations with Interactive Plot#

# Enable interactive plot
%matplotlib notebook
Warning: Cannot change to a different GUI toolkit: notebook. Using widget instead.
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from sklearn.linear_model import LinearRegression
x_data = []
y_data = []

fig, ax = plt.subplots()     # A tuple unpacking to unpack the only fig
ax.set_xlim(30, 250)
ax.set_ylim(5, 50)
# Plotting 
scatter, = ax.plot([], [], 'go', label='Horsepower vs. Miles_per_Gallon')
line, = ax.plot([], [], 'r', label='Linear Regression')
ax.legend()

reg = LinearRegression()

def animate(frame_num):
    # Adding data
    x_data.append(x[frame_num])
    y_data.append(y[frame_num])
    # Convert data to numpy array
    x_train = np.array(x_data).reshape(-1, 1)
    y_train = np.array(y_data).reshape(-1, 1)
    # Fit values to a linear regression
    reg.fit(x_train, y_train)

    # update data for scatter plot
    scatter.set_data((x_data, y_data))
    # Predict value and update data for line plot
    line.set_data((list(range(250)), reg.predict(np.array([entry for entry in range(250)]).reshape(-1, 1))))

anim = FuncAnimation(fig, animate, frames=len(x), interval=20)
plt.show()

4.6.3. 3. Animations with embedded HTML5 video#

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython import display

# Turn off matplotlib plot in Notebook
plt.ioff()
# Pass the ffmpeg path
plt.rcParams['animation.ffmpeg_path'] = '/usr/local/bin/ffmpeg'
x_data = []
y_data = []

fig, ax = plt.subplots()
ax.set_xlim(30, 250)
ax.set_ylim(5, 50)
scatter, = ax.plot([], [], 'go', label='Horsepower vs. Miles_per_Gallon')
line, = ax.plot([], [], 'r', label='Linear Regression')
ax.legend()

reg = LinearRegression()

def animate(frame_num):
    # Adding data
    x_data.append(x[frame_num])
    y_data.append(y[frame_num])
    # Convert data to numpy array
    x_train = np.array(x_data).reshape(-1, 1)
    y_train = np.array(y_data).reshape(-1, 1)
    reg.fit(x_train, y_train)
    
    # update data for scatter plot
    scatter.set_data((x_data, y_data))
    # Predict value and update data for line plot
    line.set_data((list(range(250)), reg.predict(np.array([entry for entry in range(250)]).reshape(-1, 1))))

anim = FuncAnimation(fig, animate, frames=len(x), interval=20)

video = anim.to_html5_video()
html = display.HTML(video)
display.display(html)
plt.close()

# Note Github only render static HTML and the embeded HTML5 video won't be displayed, 
# The embedded video should be working if you host the Notebook or open it locally. 
/Users/wiggles/Library/Python/3.9/lib/python/site-packages/matplotlib/animation.py:879: UserWarning: Animation was deleted without rendering anything. This is most likely not intended. To prevent deletion, assign the Animation to a variable, e.g. `anim`, that exists until you output the Animation using `plt.show()` or `anim.save()`.
  warnings.warn(

4.6.4. Thanks for reading#

This is a notebook for the medium article Matplotlib Linear Regression Animation in Jupyter Notebook

License: BSD 2-Clause