Cómo crear animaciones en Python

Hoy vamos a aprender cómo crear animaciones en Python, para hacer que nuestras visualizaciones sean mucho más impactantes y podamos dar más información de una forma visual e impactante. En este post, aprenderás a crear todo tipo de animaciones en Python desde animaciones sencillas a gráficos animados como barchartraces. ¿Te suena interesante? ¡Pues vamos a ello!

Funcionamiento de las animaciones en Python

Para crear nuestras animaciones usaremos la función FuncAnimation dentro de matplolib. Un aspecto fundamental para poder crear nuestras animaciones es entender que este paquete no crea animaciones, sino que simplemente se limita a crear animaciones a partir de una serie de gráficos que le pasemos.

Esto es algo muy importante, ya que es un enfoque muy diferente al de otros paquetes como el de gganimate de R (si no sabes cómo funciona, aquí tienes el tutorial).

De hecho, para crear animaciones en Python usando FuncAnimation simplemente debes pasar una función que tiene como valor de entrada un número que hace referencia a un frame y devuelve el gráfico correspondiente a ese frame.

Esto hace que para crear animaciones en Python debamos preparar muy bien los datos. Veámos cómo hacerlo.

Estructura de datos para crear animaciones en Python

Aunque se pueden crear gráficos partiendo de datos con una forma muy diferente, en mi opinión, para que sea más sencillo graficar los datos deben estar en formato tidy, es decir:

  1. Cada variable debe estar en una columna.
  2. Cada observación de esa variable debe ser una fila diferente.

Veamos un ejemplo con el dataset gapminder, que es el que usaremos como ejemplo:

gapminder.head()
       country continent  year  lifeExp       pop   gdpPercap
0  Afghanistan      Asia  1952   28.801   8425333  779.445314
1  Afghanistan      Asia  1957   30.332   9240934  820.853030
2  Afghanistan      Asia  1962   31.997  10267083  853.100710
3  Afghanistan      Asia  1967   34.020  11537966  836.197138
4  Afghanistan      Asia  1972   36.088  13079460  739.981106

Ahora que ya sabemos cómo deben estar los datos para crear una animación en Python, vamos a ver cómo crear diferentes animaciones en Python!

Cómo crear animaciones en Python

Para crear animaciones en Python usaremos las funciones animation del módulo matplotlib. Por tanto, crear una animación es muy sencillo y parecido a crear gráficos con matplotlib. Simplemente debemos crear dos cuestiones:

  • fig: es el objeto que utilizaremos para pintar nuestro gráfico.
  • func: es una función que debe de devolver el estado de la animación para cada frame. Básicamente lo que debemos hacer es crear una función que devuelva todos los gráficos. Siguiendo el ejemplo de la animación de line chart comentada anteriormente, la función debe devolver, en la primera iteración un linechart con el primer año, en la segunda interación un linechart con los dos primeros años y así para todos los años.
  • interval: es el delay en milisegundos entre los diferentes frames de la animación.
  • frames: número de imágenes en las que se va a basar el gráfico. Esto dependerá de cuántos “estados” tenga la animación. Si tenemos una animación con datos en 5 estados diferentes (imaginemos, 5 años), el número de frames será 5, mientras que si tenemos datos de 100 años, el número de frames será 100.

Con estos tres argumentos podemos crear todo tipo de animaciones. Ahora bien, esto puede ser algo complejo (sobre todo la parte del update), así que yo siempre recomendaría primero crear el gráfico que nosotros queremos y, a partir de eso, generar la animación.

En cualquier caso, ya contamos con todo lo básico, así que, ¡veámos cómo crear animaciones en Python!

Cómo crear animación de líneas

Como decía, lo más fácil para crear una animación es primero crear un gráfico que se parezca a lo que nosotros queremos animar. En este caso es muy sencillo, simplemente debemos crear un linechart dell Pib per Cápita para los países España, Italia y Estados Unidos.

Gráfico de Linechart gapminder

Ahora que ya tenemos nuestro gráfico, para crear una animación de líneas, simplemente tendremos que crear una función que, para cada iteración, cree la gráfica de línea pero para los datos que tengamos disponibles.

De esta forma, en la primera iteración la gráfica de línea deberá crear solo un punto para el año 1952, en la segunda iteración creará la gráfica con los dos primeros puntos (1952, 1957), y así hasta completar toda la gráfica.

Por suerte, crear esta iteración habiendo creado ya el gráfico es bastante sencillo, ya que simplemente deberemos utilizar los índices para definir los datos que el gráfico debe coger.

from matplotlib import animation

countries_plot = ['Spain', 'Italy', 'United States']
linechart_plot = gapminder.loc[gapminder['country'].isin(countries_plot), :]

# Define colors
colors = ['red', 'green', 'blue']

fig, ax = plt.subplots()

def update_linechart(i):
  for j in range(len(colors)):
    country = countries_plot[j]
    color = colors[j]

    data = linechart_plot.loc[linechart_plot['country'] == country,:]
    ax.plot(data.year[:i], data.gdpPercap[:i], color)

Con esto, ya hemos creado todo lo que necesitamos para nuestra animación. Ahora simplemente la tenemos que llamar usando la función FuncAnimation que he explicado previamente. En este tensido, de cara

num_frames = len(linechart_plot['year'].unique())        
anim = animation.FuncAnimation(fig, update_linechart, frames = num_frames)
anim.save('linechart.gif')
Animación linechart hecha en Python

¡Ya tenemos nuestra animación de linechart creada con Python! Sencillo, ¿verdad? Ahora sigamos viendo cómo crear animaciones de barcharts!

Cómo crear una animación de barchart en Python

Una buena práctica para que crear nuestro barchart (y todas las animaciones más allá de linecharts) sea más sencillo es filtrar los datos dentro de la propia función de iteración. Esto nos facilitará mucho la creación de animaciones y hará que entenderlas sea mucho más fácil.

De todos modos, como siempre que queremos crear una animación, debemos empezar por graficar lo que queremos llegar a conseguir. Así pues, en este caso voy a crear barchart muy simple en el que veamos cómo ha evolucionado el Pib per Cápita de diferentes países.

countries_plot = ['Spain', 'Italy', 'United States','Ireland','China']
barchart_data  = gapminder.loc[gapminder['country'].isin(countries_plot), :]

font = {
    'weight': 'normal',
    'size'  :  40,
    'color': 'lightgray'
}

colors =['#FF0000','#169b62','#008c45','#aa151b','#002868']

data_temp = barchart_data.loc[barchart_data['year'] == 2007, :]

fig, ax = plt.subplots(figsize=(10, 5))
ax.clear()
ax.barh(data_temp.country,data_temp.gdpPercap, color = colors)

ax.text(0.95, 0.2, data_temp['year'].iloc[0],
        horizontalalignment='right',
        verticalalignment='top',
        transform=ax.transAxes,
       fontdict=font)
plt.show()
Gráfico de barras gapminder

En este gráfico veremos como he incluido dos cambios importantes, que pueden tener un impacto muy importante en tu gráfico.

Tip 1. Limpiar el gráfico anterior

Si creamos el gráfico dentro de nuestro objeto ax, los gráficos se nos van a ir “amontonando”, lo que puede hacer que nuestros datos no sean los reales. Lo peor de todo es que, si no usas transparencia en el gráfico, depende del del tipo de gráfico puede que ni te des cuenta de esto.

Para evitar esto, en cada iteración deberemos llamar a ax.clear(), de tal forma que limpie el resultado proveniente del frame anterior.

Tip 2. Filtra tus datos en la función de update

Hacer una animación para un único año es relativamente sencillo. Sin embargo, hacerlo para muchos años parece algo más complejo. Es por eso que para facilitar la creación de animaciones en Python yo recomiendo:

  1. Crear una lista con todos los estados posibles de la animación. En mi caso los estados son los años, así que creo una lista con todos los años que puedo llegar a plotear.
  2. Filtra el dataset completo en función del estado dentro de la propia función de update.

Con estos dos pasos se te hará mucho más fácil crear animaciones.

Así pues, voy a crear la función de update de mi animación de barplot teniendo en cuenta los dos puntos anteriores:

countries_plot = ['Spain', 'Italy', 'United States','Ireland','China']
barchart_data  = gapminder.loc[gapminder['country'].isin(countries_plot), :]

font = {
    'weight': 'normal',
    'size'  :  40,
    'color': 'lightgray'
}

years = barchart_data['year'].unique()
colors =['#FF0000','#169b62','#008c45','#aa151b','#002868']

fig, ax = plt.subplots(figsize=(10, 5))
label = ax.text(0.95, 0.2, years[0],
            horizontalalignment='right',
            verticalalignment='top',
            transform=ax.transAxes,
            fontdict=font)

def update_barchart(i):
  year = years[i]

  data_temp = barchart_data.loc[barchart_data['year'] == year, :]
  ax.clear()
  ax.barh(data_temp.country,data_temp.gdpPercap, color = colors)
  label.set_text(year)

anim = animation.FuncAnimation(fig, update_barchart, frames = len(years))
anim.save('barchart.gif')  
Animación barchart hecha en Python

¡Animación de barplot en Python lista! Como ves, filtrar los datos dentro de la propia función de update hace que todo sea mucho más sencillo.

Ahora que ya tenemos más control sobre las animaciones, vamos a crear una animación algo más compleja pero mucho más impactante: vamos a animar un gráfico de scatter plot en Python. ¡Vamos a ello!

Cómo animar Scatter Plot en Python

Una vez más, para crear la animación del scatter plot lo primero de todo es crear el gráfico para un único año. Para ello vamos a seguir exactamente los mismos casos que para crear la animación de barplot: primero filtramos los datos y después creamos el gráfico.

En este caso, al haber muchos países colorearé los países en base al continente y, además, les daré transparencia:

import numpy as np
import matplotlib

fig, ax = plt.subplots(figsize=(10, 5))

scatter_data = gapminder.copy()

# Create a color depending on
conditions = [
  scatter_data.continent == 'Asia',
  scatter_data.continent == 'Europe',
  scatter_data.continent == 'Africa',
  scatter_data.continent == 'Americas',
  scatter_data.continent == 'Oceania',
]

values = list(range(5))

scatter_data['color'] = np.select(conditions, values)


font = {
    'weight': 'normal',
    'size'  :  40,
    'color': 'lightgray'
}

years = scatter_data['year'].unique()

data_temp = scatter_data.loc[scatter_data['year'] == years[-1], :]
label = ax.text(0.95, 0.25, years[-1],
            horizontalalignment='right',
            verticalalignment='top',
            transform=ax.transAxes,
            fontdict=font)

colors =[f'C{i}' for i in np.arange(1, 6)]
cmap, norm = matplotlib.colors.from_levels_and_colors(np.arange(1, 5+2), colors)

scatter = ax.scatter(data_temp.gdpPercap,
                     data_temp.lifeExp,
                     s=data_temp['pop']/500000, 
                     alpha = 0.5, 
                     c=data_temp.color, 
                     cmap=cmap,
                     norm=norm)
label.set_text(years[-1])
plt.show()
Scatter plot del PIB per Cápita y esperanza de vida en 2007

Ahora que ya tenemos nuestra gráfica montada, ahora debemos convertirlo en función para animarla. En este caso, resulta fundamental que antes de cada frame limpiemos el contenido anterior del gráfico dentro del objeto ax, ya que sino la animación no quedará bien.

Más allá de eso, el procedimiento para crear la animación de scatter plot es el mismo que el explicado previamente para crear otro tipo de animaciones en Python:


fig, ax = plt.subplots(figsize=(10, 5))

years = scatter_data['year'].unique()

colors =[f'C{i}' for i in np.arange(1, 6)]
cmap, norm = matplotlib.colors.from_levels_and_colors(np.arange(1, 5+2), colors)


label = ax.text(0.95, 0.25, years[0],
                horizontalalignment='right',
                verticalalignment='top',
                transform=ax.transAxes,
                fontdict=font)


def update_scatter(i):

    year = years[i]

    data_temp = scatter_data.loc[scatter_data['year'] == year, :]
    ax.clear()
    label = ax.text(0.95, 0.20, years[i],
                horizontalalignment='right',
                verticalalignment='top',
                transform=ax.transAxes,
                fontdict=font)
    ax.scatter(
        data_temp['gdpPercap'],
        data_temp['lifeExp'],
        s=data_temp['pop']/500000, 
        alpha = 0.5, 
        c=data_temp.color, 
        cmap=cmap,
        norm=norm
    )

    label.set_text(year)

anim = animation.FuncAnimation(fig, update_scatter, frames = len(years), interval = 30)
anim.save('scatter.gif')  
Animación scatter hecha en Python

¡Ya tenemos nuestro scatter plot animado! Ahora vayamos a por la última de las animaciones que vamos aprender a crear en Python: una animación de barchart race.

Cómo crear animaciones de barplot race en Python

La animación de barchart race es muy similar a la animación de barplot que hemos hecho anteriormente. La principal diferencia reside en que, en el barplot race los datos están ordenados, de tal forma que veamos cómo ha ido evolucionando el top de X observaciones para una variable (puede ser desde valoración en bolsa a uso de videojuegos o, como en nuestro caso, el PIB per Cápita de los países).

Así pues, para crear nuestro barplot race necesitaremos tener, para cada uno de los años, cuál es el ranking de los países. Para ello, yo recomiendo utilizar el método rank de pandas , ya que podremos obtener los rankings de una forma muy sencilla.

Una vez tenemos el ranking, simplemente deberemos filtrar los datos para quedarnos con el número de observaciones que nos interese, en mi caso 10.

Por último, una vez tengamos nuestros datos filtrados, solo habrá que crear un gráfico de barras horizontales donde el eje vertical sea el ranking. Además, para hacer que el gráfico sea más entendible, cambiaremos el nombre del tick por el nombre del país. Esto lo haremos con el parámetro tick_label de la función barh.

Así pues, veámos cómo sería para un único caso:

barchartrace_data  = gapminder.copy()
n_observations = 10

font = {
    'weight': 'normal',
    'size'  :  40,
    'color': 'lightgray'
}

data_temp = barchartrace_data.loc[barchartrace_data['year'] == 1952, :]

# Create rank and get first 10 countries
data_temp['ranking'] = data_temp.gdpPercap.rank(method='max', ascending = False).values

data_temp = data_temp.loc[data_temp['ranking'] <= n_observations]


colors = plt.cm.Dark2(range(6))

fig, ax = plt.subplots(figsize=(10, 5))
ax.barh(y = data_temp['ranking'] ,
        width = data_temp.gdpPercap, 
        tick_label=data_temp['country'],
       color=colors)
ax.text(0.95, 0.2, data_temp['year'].iloc[0],
        horizontalalignment='right',
        verticalalignment='top',
        transform=ax.transAxes,
       fontdict=font)

ax.set_ylim(ax.get_ylim()[::-1]) # Revert axis
plt.show()
Base para crear animaciones barplotrace en Python

¡Perfecto! Ya tenemos nuestra base creada. Ahora solo queda crear nuestra función de animación. En este caso, yo recomiendo que el ranking y selección de los países se haga dentro de la propia función de actualización, ya que facilita mucho el entendimiento y permite aprovechar el código de nuestro gráfico base.

Así pues, la función de update de nuestro barplot race es la siguiente:

barchartrace_data  = gapminder.copy()
n_observations = 10

fig, ax = plt.subplots(figsize=(10, 5))

font = {
    'weight': 'normal',
    'size'  :  40,
    'color': 'lightgray'
}

years = barchartrace_data['year'].unique()

label = ax.text(0.95, 0.20, years[0],
                horizontalalignment='right',
                verticalalignment='top',
                transform=ax.transAxes,
                fontdict=font)


def update_barchart_race(i):

    year = years[i]

    data_temp = barchartrace_data.loc[barchartrace_data['year'] == year, :]

    # Create rank and get first 10 countries
    data_temp['prueba'] = data_temp['gdpPercap'].rank(ascending = False)
    data_temp = data_temp.loc[data_temp['prueba'] <= n_observations]

    colors = plt.cm.Dark2(range(6))

    ax.clear()
    ax.barh(y = data_temp['prueba'] ,
            width = data_temp.gdpPercap, 
            tick_label=data_temp['country'],
           color=colors)

    label = ax.text(0.95, 0.20, year,
                horizontalalignment='right',
                verticalalignment='top',
                transform=ax.transAxes,
                fontdict=font)

    ax.set_ylim(ax.get_ylim()[::-1]) # Revert axis


anim = animation.FuncAnimation(fig, update_barchart_race, frames = len(years))
anim.save('barchart_race.gif')  
Cómo crear animaciones de barchart race en Python

¡Barplot race creado! Ya hemos visto cómo crear diferentes tipos de animaciones en Python. Sin embargo, no parece que sean del todo visuales, ya que simplemente se limitan a poner un gráfico encima del otro. Así pues, veámos cómo podemos mejorar la fluidez de nuestras animaciones en Python. ¡Vamos a ello!

Cómo mejorar la fluidez de las animaciones de Python

Como he explicado más arriba en este post, la función FuncAnimation se limita a crear la animación poniendo las imágenes que genera nuestra función de update. Como nuestras imágenes cambian año a año, nuestra animación dará pequeños “saltos”.

Así pues, para que nuestra animación sea mucho más fluida, deberemos crear más datos entre cada uno de los estados que tenemos. De esta forma conseguiremos:

  1. Tener datos intermedios, por lo que el salto de las animaciones no será tan grande.
  2. Tener muchos más frames, de tal forma que para la misma duración de la animación, tendrá más fps (frames por segundo) haciendo que se vea mucho más fluida.

Para crear este objetivo, vamos a realizar lo siguiente:

  1. Crear más observaciones entre los datos que ya tenemos. Estas observaiones estarán vacías.
  2. Imputar datos a esas nuevas observaciones vacías mediante la interpolación entre los estados.

Suena complejo, pero es más fácil de lo que parece. Veámos cómo hacerlo:

Crear más observaciones entre los datos que ya tenemos

Crear más observaciones de las que ya tenemos entre los estados actuales es muy sencillo. Simplemente debemos partir de un índice que vaya de 0 al número de observaciones que tengamos. Esto lo podemos conseguir con el método reset_index.

Una vez nuestros datos son así, simplemente podemos cambiar el índice actual de los datos multiplicando cada índice por el número de frames entre estados que queramos crear. Si queremos crear 10 frames, al multiplicar el índice antiguo por 10, la segunda observación (índice 1) pasará a tener el índice 10 y entre medias se habrán creado muchas variables vacías.

En cualquier caso, para que la interpolación funcione bien, deberemos tener los datos en el formato adecuado, que es:

  • Cada fila debe ser un estado, un año en mi caso.
  • Cada columna debe ser la observación que nosotros vayamos a graficar, en mi caso, un país.
  • El valor debe ser la variable que vayamos a graficar. En mi caso seguiré creando el barplot race, por lo que la variable sigue siendo el gdpPercap.

Así pues, esto es lo que debemos realizar:

barchartrace_data  = gapminder.copy()
n_observations = 10
n_frames_between_states = 30

barchartrace_data= barchartrace_data.pivot('year', 'country', 'gdpPercap')
barchartrace_data['year'] = barchartrace_data.index

barchartrace_data.reset_index(drop = True, inplace = True)
barchartrace_data.index = barchartrace_data.index * n_frames_between_states
barchartrace_data =  barchartrace_data.reindex(range(barchartrace_data.index.max()+1))

barchartrace_data.iloc[:15,:5]
country  Afghanistan      Albania      Algeria       Angola    Argentina
0         779.445314  1601.056136  2449.008185  3520.610273  5911.315053
1                NaN          NaN          NaN          NaN          NaN
2                NaN          NaN          NaN          NaN          NaN
3                NaN          NaN          NaN          NaN          NaN
4                NaN          NaN          NaN          NaN          NaN
5                NaN          NaN          NaN          NaN          NaN
6                NaN          NaN          NaN          NaN          NaN
7                NaN          NaN          NaN          NaN          NaN
8                NaN          NaN          NaN          NaN          NaN
9                NaN          NaN          NaN          NaN          NaN
10               NaN          NaN          NaN          NaN          NaN
11               NaN          NaN          NaN          NaN          NaN
12               NaN          NaN          NaN          NaN   

Como ves, cada columna es un país y he creado 30 nuevas observaciones entre los estados que ya tenía. Una vez hecho esto, ya podemos ver cómo imputar esos nuevos datos mediante interpolación.

Imputar datos a esas nuevas observaciones vacías mediante la interpolación entre los estados

Para imputar los datos vacíos, vamos a usar la interpolación. Esto se puede realizar con el método interpolate de pandas. Existen diferentes métodos de interpolación (puedes encontrar los método aquí) y cada uno darán un efecto diferente, como puedes ver en la siguiente animación hecha por Nicholas A Rossi (enlace).

Efecto de la interpolación en las animaciones de Python

En nuestro caso lo vamos a hacer fácil, dejando los valores por defecto del método, esto es, aplicando una interpolación lineal. Aunque sea lo más sencillo, el cambio va a ser importante, solo hay que ver la diferencia entre usar la interpolación lineal y no usar interpolación en la animación.

Así pues, podemos interpolar nuestros datos de la siguiente forma:

barchartrace_data = barchartrace_data.interpolate()
barchartrace_data.iloc[:15,:5]
country  Afghanistan      Albania      Algeria       Angola    Argentina
0         779.445314  1601.056136  2449.008185  3520.610273  5911.315053
1         780.825572  1612.430406  2467.840446  3530.854613  5942.833092
2         782.205829  1623.804677  2486.672708  3541.098952  5974.351130
3         783.586086  1635.178947  2505.504969  3551.343292  6005.869169
4         784.966343  1646.553217  2524.337230  3561.587632  6037.387208
5         786.346600  1657.927487  2543.169491  3571.831972  6068.905246
6         787.726858  1669.301758  2562.001753  3582.076311  6100.423285
7         789.107115  1680.676028  2580.834014  3592.320651  6131.941323
8         790.487372  1692.050298  2599.666275  3602.564991  6163.459362
9         791.867629  1703.424568  2618.498536  3612.809331  6194.977401
10        793.247886  1714.798839  2637.330798  3623.053670  6226.495439
11        794.628143  1726.173109  2656.163059  3633.298010  6258.013478
12        796.008401  1737.547379  2674.995320  3643.542350  6289.531517
13        797.388658  1748.921649  2693.827581  3653.786690  6321.049555
14        798.768915  1760.295920  2712.659843  3664.031029  6352.567594

Por último, ahora que ya tenemos nuestros datos interpolados, vamos a cambiar la forma de nuestro dataframe para que siga manteniendo la forma que tenía antes, es decir, que tanto el año, como el país como el PIB per Cápita sean variables. Esto lo podemos conseguir con el método melt de pandas.

# Hacemos otro pivot para volver a los datos originales
barchartrace_data = barchartrace_data.melt(id_vars='year', var_name ='country', value_name  = 'gdpPercap')

barchartrace_data.iloc[:15,:5]
           year      country   gdpPercap
0   1952.000000  Afghanistan  779.445314
1   1952.166667  Afghanistan  780.825572
2   1952.333333  Afghanistan  782.205829
3   1952.500000  Afghanistan  783.586086
4   1952.666667  Afghanistan  784.966343
5   1952.833333  Afghanistan  786.346600
6   1953.000000  Afghanistan  787.726858
7   1953.166667  Afghanistan  789.107115
8   1953.333333  Afghanistan  790.487372
9   1953.500000  Afghanistan  791.867629
10  1953.666667  Afghanistan  793.247886
11  1953.833333  Afghanistan  794.628143
12  1954.000000  Afghanistan  796.008401
13  1954.166667  Afghanistan  797.388658
14  1954.333333  Afghanistan  798.768915

Si te fijas, tenemos un dataframe exactamente igual que el que teníamos cuando hemos hecho la animación del barchart race previamente, solo que ahora tenemos muchos más datos intermedios. Así pues, simplemente debemos replicar el código de antes para conseguir buenos resultados:

import math

n_observations = 10

fig, ax = plt.subplots(figsize=(10, 5))

font = {
    'weight': 'normal',
    'size'  :  40,
    'color': 'lightgray'
}

years = barchartrace_data['year'].unique()

label = ax.text(0.95, 0.20, years[0],
                horizontalalignment='right',
                verticalalignment='top',
                transform=ax.transAxes,
                fontdict=font)

colors = plt.cm.Dark2(range(200))

def update_barchart_race(i):

    year = years[i]

    data_temp = barchartrace_data.loc[barchartrace_data['year'] == year, :]

    # Create rank and get first 10 countries
    data_temp['ranking'] = data_temp['gdpPercap'].rank(method = 'first',ascending = False)
    data_temp = data_temp.loc[data_temp['ranking'] <= n_observations]

    ax.clear()
    ax.barh(y = data_temp['ranking'] ,
            width = data_temp.gdpPercap, 
            tick_label=data_temp['country'],
           color=colors)

    label = ax.text(0.95, 0.20, math.floor(year),
                horizontalalignment='right',
                verticalalignment='top',
                transform=ax.transAxes,
                fontdict=font)

    ax.set_ylim(ax.get_ylim()[::-1]) # Revert axis


anim = animation.FuncAnimation(fig, update_barchart_race, frames = len(years))
anim.save('barchart_race_cool.gif', fps = 20)  
Cómo mejorar las animaciones en Python

¡Animación mejorada! Ahora queda mucho mejor, ¿verdad? Sin embargo, hay una cosa que quizás se pueda seguir mejorando: los estados intermedio. Y es que, aunque las gráficas se animen, los cambios de posición siguen siendo saltos. Veamos cómo crear ese movimiento horizontal de nuestras animaciones en Python.

Cómo crear movimiento horizontal en barchart race

La razón por la cual las posiciones siguen dando “saltos” es que, aunque hayamos interpolado los datos de la gráfica, no hemos interpolado los datos de las posiciones. Así pues, podemos crear una interpolación de los datos de las posiciones y juntarnos a nuestro dataset anterior.

Para ello, primero tendremos que quedarnos con el ranking de cada país para cada año. Una vez tengamos ese dataframe, el proceso será exactamente el mismo al realizado anteriormente: pivotamos, interpolamos y deshacemos el pivotado con un melt.

Importante: para que este método funcione debemos usar el mismo sistema de interpolación que hemos usado anteriormente.

ranking_data  = gapminder.copy()
n_observations = 10
n_frames_between_states = 30


#barchartrace_data['ranking'] 
ranking = ranking_data.groupby('year')['gdpPercap'].rank(method = 'first', ascending = False)
ranking = ranking.rename('ranking', axis = 1)
ranking_data = ranking_data.join(ranking)
ranking_data = ranking_data.pivot('year', 'country', 'ranking')

ranking_data['year'] =ranking_data.index
ranking_data.reset_index(drop = True, inplace = True)

ranking_data.index = ranking_data.index * n_frames_between_states
ranking_data =  ranking_data.reindex(range(ranking_data.index.max()+1))
ranking_data = ranking_data.interpolate('linear')
ranking_data = ranking_data.melt(id_vars='year', var_name ='country', value_name  = 'ranking')

ranking_data.iloc[:15,:5]
           year      country     ranking
0   1952.000000  Afghanistan  113.000000
1   1952.166667  Afghanistan  113.066667
2   1952.333333  Afghanistan  113.133333
3   1952.500000  Afghanistan  113.200000
4   1952.666667  Afghanistan  113.266667
5   1952.833333  Afghanistan  113.333333
6   1953.000000  Afghanistan  113.400000
7   1953.166667  Afghanistan  113.466667
8   1953.333333  Afghanistan  113.533333
9   1953.500000  Afghanistan  113.600000
10  1953.666667  Afghanistan  113.666667
11  1953.833333  Afghanistan  113.733333
12  1954.000000  Afghanistan  113.800000
13  1954.166667  Afghanistan  113.866667
14  1954.333333  Afghanistan  113.933333

Ahora que tenemos esta información, deberemos unirla con la información que teníamos en el dataset previo:

barchartrace_data = ranking_data.merge(barchartrace_data, 
                                       left_on = ['country','year'], 
                                       right_on = ['country','year'])

barchartrace_data.head()
          year      country     ranking   gdpPercap
0  1952.000000  Afghanistan  113.000000  779.445314
1  1952.166667  Afghanistan  113.066667  780.825572
2  1952.333333  Afghanistan  113.133333  782.205829
3  1952.500000  Afghanistan  113.200000  783.586086
4  1952.666667  Afghanistan  113.266667  784.966343

Como veis, el ranking no son números redondos sino que se van modificando ligeramente en cada estado. Así pues, ya podemos crear de nuevo la animación. En este caso, como ya tenemos el ranking calculado, no hará falta que lo volvamos a calcular:

import math

n_observations = 10

fig, ax = plt.subplots(figsize=(10, 5))

font = {
    'weight': 'normal',
    'size'  :  40,
    'color': 'lightgray'
}

years = barchartrace_data['year'].unique()

label = ax.text(0.95, 0.20, years[0],
                horizontalalignment='right',
                verticalalignment='top',
                transform=ax.transAxes,
                fontdict=font)

# Create colors
# 1. Get continent
continent = gapminder[['country','continent']].drop_duplicates().reset_index(drop = True)

# 2. Add continent info
barchartrace_data = barchartrace_data.merge(continent,left_on = 'country', right_on = 'country')

# 3. Use continent to get color
conditions = [
  barchartrace_data['continent'] == 'Asia',
  barchartrace_data['continent'] == 'Europe',
  barchartrace_data['continent'] == 'Africa',
  barchartrace_data['continent'] == 'Americas',
  barchartrace_data['continent'] == 'Oceania',
]

values = ['#0275d8', '#5cb85c', '#5bc0de', '#f0ad4e', '#d9534f']

barchartrace_data['color'] = np.select(conditions, values)


def update_barchart_race(i):

    year = years[i]

    data_temp = barchartrace_data.loc[barchartrace_data['year'] == year, :]

    # Create rank and get first 10 countries
    data_temp = data_temp.loc[data_temp['ranking'] <= n_observations]

    ax.clear()
    ax.barh(y = data_temp['ranking'] ,
            width = data_temp.gdpPercap, 
            tick_label=data_temp['country'],
           color=data_temp['color'])

    label = ax.text(0.95, 0.20, math.floor(year),
                horizontalalignment='right',
                verticalalignment='top',
                transform=ax.transAxes,
                fontdict=font)

    ax.set_ylim(ax.get_ylim()[::-1]) # Revert axis


anim = animation.FuncAnimation(fig, update_barchart_race, frames = len(years), )
anim.save('barchart_race_cool2.gif', fps=30)  
Cómo mejorar las animaciones en Python

¡Ya tenemos nuestra animación creada en Python! Como ves, la interpolación de ha hecho que nuestra animación sea mucho más atractiva y parezca muchísimo más fluida de lo que era al principio. ¡Y todo, con solo unas pocas líneas de código! ¿No es fantástico? Pero eso no es todo, veamos otro truquito más que nos permitirá mejorar mucho nuestras animaciones en Python.

Evitar saltos en las animaciones

Un problema típico en las animaciones, como ha ocurrido en la animación del scatter plot, es que haya saltos entre frames. Esto se debe a que los ejes del gráfico cambian, haciendo que el contenido del mismo parezca diferente cuando, en realidad, no lo es.

Arreglar esto es bastante sencillo. Simplemente para cada frame se debe fijar el valor máximo de los ejes X e Y. Ese valor será el máximo que alcanzará el gráfico en toda la serie. De esta forma, conseguiremos evitar esos tirones.

Esto sobre todo es importante aplicarlo cuando animamos gráficos como el scatter plot. Sin embargo, a la hora de animar el linechart no es recomendable aplicarlo ya que reduce mucho el impacto visual de la animación.

Así pues, vamos a rehacer la animación del scatter plot, pero esta vez aplicando una mayor fluidez mediante la interpolación y evitando los saltos mediante la fijación de las escalas.

En este caso, deberemos aplicar la interpolación a las tres variables que se utilizan en la animación. Así pues, para facilitar el proceso crearemos una función que nos haga la interpolación.

scatter_data = gapminder.copy()
n_frames_between_states = 30
continent = gapminder[['country','continent']].drop_duplicates().reset_index(drop = True)


def interpolate_data(data,frame,obs,variable, n_new_frames, interpolation = 'linear'):
    data= data.pivot(frame, obs, variable)
    data[frame] = data.index
    data.reset_index(drop = True, inplace = True)
    data.index = data.index * n_new_frames
    data =  data.reindex(range(data.index.max()+1))
    data = data.interpolate(interpolation)
    data = data.melt(id_vars= frame, var_name = obs, value_name  = variable)
    return data

# Interpolate data
scatter_data_pop = interpolate_data(scatter_data, 'year', 'country','pop',30)
scatter_data_gdpPerCap = interpolate_data(scatter_data, 'year', 'country','gdpPercap',30)
scatter_data_lifeExp = interpolate_data(scatter_data, 'year', 'country','lifeExp',30)

# Merge the datasets
scatter_data = scatter_data_gdpPerCap.merge(scatter_data_pop,
                   left_on = ['country','year'], 
                   right_on = ['country','year'])

scatter_data = scatter_data.merge(scatter_data_lifeExp,
                   left_on = ['country','year'], 
                   right_on = ['country','year']).merge(continent)

scatter_data.head()
	year	        country	        gdpPercap	pop	        lifeExp	       continent
0	1952.000000	Afghanistan	779.445314	8425333.0	28.801000	Asia
1	1952.166667	Afghanistan	780.825572	8452519.7	28.852033	Asia
2	1952.333333	Afghanistan	782.205829	8479706.4	28.903067	Asia
3	1952.500000	Afghanistan	783.586086	8506893.1	28.954100	Asia
4	1952.666667	Afghanistan	784.966343	8534079.8	29.005133	Asia

Ahora que ya tenemos el dataset, podemos crear la animación. Para fijar los límites de la animación simplemente debemos utilizar el método set_xlim.

fig, ax = plt.subplots(figsize=(10, 5))

years = scatter_data['year'].unique()

conditions = [
  scatter_data.continent == "Asia",
  scatter_data.continent == "Europe",
  scatter_data.continent == "Africa",
  scatter_data.continent == "Americas",
  scatter_data.continent == "Oceania",
]

values = list(range(5))

scatter_data['color'] = np.select(conditions, values)
colors =[f"C{i}" for i in np.arange(1, 6)]
cmap, norm = matplotlib.colors.from_levels_and_colors(np.arange(1, 5+2), colors)

# Get maximum values
x_max = scatter_data['gdpPercap'].max()
y_max = scatter_data['lifeExp'].max()


  
def update_scatter(i):
    
    year = years[i]
    
    data_temp = scatter_data.loc[scatter_data['year'] == year, :]
    ax.clear()
    
    
    label = ax.text(0.95, 0.25, years[0],
                    horizontalalignment='right',
                    verticalalignment='top',
                    transform=ax.transAxes,
                    fontdict=font)
    
    # Set limits
    ax.set_xlim((0,x_max))
    ax.set_ylim((0,y_max))
    
    ax.scatter(
        data_temp['gdpPercap'],
        data_temp['lifeExp'],
        s=data_temp['pop']/500000, 
        alpha = 0.5, 
        c=data_temp.color, 
        cmap=cmap,
        norm=norm
    )
    
    label.set_text(math.floor(year))
    
anim = animation.FuncAnimation(fig, update_scatter, frames = len(years))
anim.save('scatter2.gif', fps = 20)
Cómo mejorar nuestras animaciones en Python

Conclusión

Sin duda alguna crear animaciones en Python es algo que va a permitirte crear gráficos muy visuales que generen mucho más impacto. Esto es algo básico que te va a permitir desde generar mucho más impactantes a poder explicar procesos de una forma más sencilla, como hice con el algoritmo k-Mean en este post.

Además, si estás acostumbrado a trabajar con pandas y matplotlib y entiendes el funcionamiento detrás de las funciones de animación, es algo muy sencillo de poder hacer.

Espero que este post te haya gustado. Si es así, te animo a suscribirte para estar al día de todos los posts que voy subiendo. En cualquier caso, ¡nos vemos en el próximo!