Dario Radečić, Author at Towards Data Science https://towardsdatascience.com/author/radecicdario/ The world’s leading publication for data science, AI, and ML professionals. Wed, 22 Jan 2025 07:23:32 +0000 en-US hourly 1 https://wordpress.org/?v=6.7.1 https://towardsdatascience.com/wp-content/uploads/2025/02/cropped-Favicon-32x32.png Dario Radečić, Author at Towards Data Science https://towardsdatascience.com/author/radecicdario/ 32 32 How to Convert a Single HEX Color Code into a Monochrome Color Palette with Python https://towardsdatascience.com/how-to-convert-a-single-hex-color-code-into-a-monochrome-color-palette-with-python-0bf7741029de/ Fri, 27 Sep 2024 05:11:20 +0000 https://towardsdatascience.com/how-to-convert-a-single-hex-color-code-into-a-monochrome-color-palette-with-python-0bf7741029de/ Spoiler: It's harder than you think.

The post How to Convert a Single HEX Color Code into a Monochrome Color Palette with Python appeared first on Towards Data Science.

]]>
Colors are hard, especially if you don’t have an eye for design.

Most of us tech professionals don’t. The upside is that Python can do most of the heavy lifting for you. It can generate an entire monochromatic Color Palette that’ll look stunning on any stacked chart.

The downside is that getting there requires a fair bit of coding. You have to write custom functions to convert a color from HEX to HSL and vice versa, and figure out if the starting color is too bright, as well as how much lighter each next color should be.

Adding insult to injury, I haven’t found a single fully working Python library capable of accomplishing this task.

That’s where this article chimes in.

If you’re a subscriber to my Substack, you can skip the reading and download the notebook instead.

HEX, HSL, RGB – What’s With All The Jargon?

Let’s quickly cover three color formats you need to know:

  • HEX – A six-digit code typically used in web development and graphic design. The color code starts with a #, followed by six hexadecimal digits. Pairs of two represent the amount of red, green, and blue. For example, #FF5633FF represents red, 56 represents green, and 33 represents blue. Values can range from 00 to FF in hexadecimal, or 0 to 255 in decimal.
  • RGB – A color model that defines colors based on their red, green, and blue components. Each value ranges from 0 to 255. Different amounts of red, green, and blue are combined to produce a wide range of colors. It can be easily translated into HEX – just convert the RGB amounts into their hexadecimal representation.
  • HSL – A cylindrical coordinate representation of colors. It describes colors based on hue, saturation, and lightness. Hue is the color itself represented as an angle on a color wheel (0 to 360 degrees with 0 degrees being red, 120 degrees being green, and 240 degrees being blue). Saturation represents the vividness of the color expressed in percentages (0–100%). And lightness represents how light or dark the color is (0% being black and 100% being white).

The process of creating a monochrome palette boils down to adjusting the Lightness of color in HSL format while keeping Hue and Saturation unchanged.

For a better idea, try any color wheel tool online. Enter a HEX code, and you’ll see its other color system versions:

Image 1 - Hex color code in different color models (image by author)
Image 1 – Hex color code in different color models (image by author)

But to get there, you’ll have to convert a HEX color to RGB, RGB to HSL, get an array of colors with different lightness percentages, and then convert it back to HEX.

The problem is, some color codes can’t be converted 100% accurately, and will be estimated. That’s fine, as you’ll end up with a visually indistinguishable color with a slightly different HEX code.

Enough with the theory – let’s start coding!

How to Create a Monochrome Color Palette From Scratch

Start by importing the libraries:

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

You can optionally use my custom Matplotlib theme to make the charts stand out without any manual tweaks:

mpl.style.use("custom_light.mplstyle")

Onto the code now.

The MonochromePaletteGenerator class implements the following methods:

  • __init__() – A class constructor that initializes two attributes: MIN_ALLOWED_LIGHTNESS (integer set to 15 to ensure the palette does not start with too light a color) and max_palette_lightness (integer set to 95, it limits the lightness of the lightest color in the palette).
  • __validate_starting_lightness() – Checks if the starting lightness value matches the range set in the constructor.
  • hex_to_hxl() – Given a HEX color code it first converts it into RGB and then calculates lightness, saturation, and hue through a somewhat involved process.
  • hsl_to_hex() – Given the values for hue, saturation, and lightness, the method estimates RGB values through primes and then converts them back to a HEX color code.
  • __create_lightness_range() – A helper method that creates n_colors with linearly spaced lightness values for a monochrome color palette. You’ll want to tweak this method to adjust the spacing between the colors (e.g., don’t use linear spacing).
  • create_hex_code_palette() – Given a starting HEX color code and a number of colors, it translates HEX to HSL, creates a lightness range, and converts a list of HSL colors into HEX format.
  • create_matplotlib_palette() – Uses Matplotlib’s mpl.colors.ListedColormap to create a color palette you can use in Matplotlib.

Here’s what it boils down to in the code:

class MonochromePaletteGenerator:
    def __init__(self, max_palette_lightness: int = 95):
        self.MIN_ALLOWED_LIGHTNESS = 15
        self.max_palette_lightness = max_palette_lightness

    def __validate_starting_lightness(self, starting_lightness: int) -> bool:
        if starting_lightness <= self.MIN_ALLOWED_LIGHTNESS:
            return True
        return False

    @staticmethod
    def hex_to_hsl(hex_color: str) -> tuple:
        # Remove the "#" character if present
        hex_color = hex_color.lstrip('#')

        # Convert hex to RGB
        r = int(hex_color[0:2], 16) / 255.0
        g = int(hex_color[2:4], 16) / 255.0
        b = int(hex_color[4:6], 16) / 255.0

        # Find the maximum and minimum RGB values
        max_val = max(r, g, b)
        min_val = min(r, g, b)
        delta = max_val - min_val

        # Calculate Lightness
        L = (max_val + min_val) / 2

        # Calculate Saturation
        if delta == 0:
            S = 0  # It's a shade of gray, so no saturation
            H = 0  # Hue is undefined for gray, but we can set it to 0
        else:
            if L < 0.5:
                S = delta / (max_val + min_val)
            else:
                S = delta / (2.0 - max_val - min_val)

            # Calculate Hue
            if max_val == r:
                H = ((g - b) / delta) % 6
            elif max_val == g:
                H = (b - r) / delta + 2
            elif max_val == b:
                H = (r - g) / delta + 4

            H *= 60  # Convert hue to degrees

            if H < 0:
                H += 360

        return int(round(H)), int(round(S * 100)), int(round(L * 100))

    @staticmethod
    def hsl_to_hex(h: int, s: int, l: int) -> str:
        # Convert the saturation and lightness percentages to a fraction of 1
        s /= 100
        l /= 100

        # Calculate C, X, and m
        C = (1 - abs(2 * l - 1)) * s  # Chroma
        X = C * (1 - abs((h / 60) % 2 - 1))  # Intermediate value based on hue
        m = l - C / 2  # Lightness adjustment

        # Calculate r', g', b' (primed values)
        if 0 <= h < 60:
            r_prime, g_prime, b_prime = C, X, 0
        elif 60 <= h < 120:
            r_prime, g_prime, b_prime = X, C, 0
        elif 120 <= h < 180:
            r_prime, g_prime, b_prime = 0, C, X
        elif 180 <= h < 240:
            r_prime, g_prime, b_prime = 0, X, C
        elif 240 <= h < 300:
            r_prime, g_prime, b_prime = X, 0, C
        elif 300 <= h < 360:
            r_prime, g_prime, b_prime = C, 0, X
        else:
            raise ValueError("Hue value must be in range [0, 360)")

        # Convert r', g', b' to the range [0, 255]
        r = round((r_prime + m) * 255)
        g = round((g_prime + m) * 255)
        b = round((b_prime + m) * 255)

        # Convert RGB to hex
        return "#{:02X}{:02X}{:02X}".format(r, g, b)

    def __create_lightness_range(self, starting_lightness: int, n_colors: int) -> list:
        # Make `n_colors` number of linearly spaced lightness values between thr provided starting color
        # and max allowed lightness
        spaced_values = np.linspace(starting_lightness, self.max_palette_lightness, n_colors)
        lightness = np.round(spaced_values).astype(int)
        return lightness

    def create_hex_code_palette(self, starting_hex: str, n_colors: int) -> list:
        # Convert HEX to HSL
        starting_hsl = self.hex_to_hsl(hex_color=starting_hex)
        # Check if HSL is valid
        is_valid_starting_hex = self.__validate_starting_lightness(starting_lightness=starting_hsl[2])

        # Proceed only if the starting HSL is valid
        if is_valid_starting_hex:
            palette = []
            lightness_range = self.__create_lightness_range(starting_lightness=starting_hsl[2], n_colors=n_colors)
            for lightness in lightness_range:
                # Keep hue and saturation identical
                h = starting_hsl[0]
                s = starting_hsl[1]
                # Only change the lightness
                curr_hex = self.hsl_to_hex(h=h, s=s, l=lightness)
                palette.append(curr_hex)

            return palette

        raise ValueError("Given starting color is too light to construct a palette. Please choose a darker shade.")

    @staticmethod
    def create_matplotlib_palette(colors: list, palette_name: str) -> mpl.colors.ListedColormap:
        return mpl.colors.ListedColormap(
            name=palette_name,
            colors=colors
        )

Up next, let’s see what goes into using this class.

How to Use the MonochromePaletteGenerator class

You can create an object from the class like you normally would:

palette_generator = MonochromePaletteGenerator()

No parameters are required, as max_palette_lightness has a default value.

To create a blue monochrome palette, provide a dark enough starting hex color code and the number of colors. From there, you can also convert the generated palette into a Matplotlib format:

monochrome_blue = palette_generator.create_hex_code_palette(starting_hex="#051923", n_colors=10)
monochrome_blue_mpl = palette_generator.create_matplotlib_palette(colors=monochrome_blue, palette_name="custom-monochrome-blue")

print(monochrome_blue)
monochrome_blue_mpl
Image 2 - Monochrome blue color palette (image by author)
Image 2 – Monochrome blue color palette (image by author)

It works with other colors as well. For example, you can use the following snippet to create a monochrome red palette:

monochrome_red = palette_generator.create_hex_code_palette(starting_hex="#1c0000", n_colors=7)
monochrome_red_mpl = palette_generator.create_matplotlib_palette(colors=monochrome_red, palette_name="custom-monochrome-red")

print(monochrome_red)
monochrome_red_mpl
Image 3 - Monochrome red color palette (image by author)
Image 3 – Monochrome red color palette (image by author)

The create_hex_code_palette() will fail if the starting color is too bright, as you can see from the following example:

wont_work_green = palette_generator.create_hex_code_palette(starting_hex="#1b663e", n_colors=12)
Image 4 - Exception raised with a too bright starting color (image by author)
Image 4 – Exception raised with a too bright starting color (image by author)

The best way to see your color palettes in action is with charts.

Monochrome Color Palettes in Action

Use the following snippet to create a made-up dataset of employee counts per location and department:

df = pd.DataFrame({
    "HR": [50, 63, 40, 68, 35],
    "Engineering": [77, 85, 62, 89, 58],
    "Marketing": [50, 35, 79, 43, 67],
    "Sales": [59, 62, 33, 77, 72],
    "Customer Service": [31, 34, 61, 70, 39],
    "Distribution": [35, 21, 66, 90, 31],
    "Logistics": [50, 54, 13, 71, 32],
    "Production": [22, 51, 54, 28, 40],
    "Maintenance": [50, 32, 61, 69, 50],
    "Quality Control": [20, 21, 88, 89, 39]
}, index=["New York", "San Francisco", "Los Angeles", "Chicago", "Miami"])

df = df.T
df = df.loc[df.sum(axis=1).sort_values().index]

df
Image 5 - Dummy employee dataset (image by author)
Image 5 – Dummy employee dataset (image by author)

There are 5 office locations and 10 departments – an ideal situation for a stacked horizontal bar chart!

You can pass your custom color palette to the colormap parameter:

ax = df.plot(kind="barh", colormap=monochrome_blue_mpl, width=0.8, edgecolor="#000000", stacked=True)

plt.title("Employee Count Per Location And Department", loc="left", fontdict={"weight": "bold"}, y=1.06)
plt.xlabel("Office Location")
plt.ylabel("Count")

plt.show()
Image 6 - Chart with a blue monochromatic color palette (image by author)
Image 6 – Chart with a blue monochromatic color palette (image by author)

The same code works with the monochrome red palette – just change the colormap value:

ax = df.plot(kind="barh", colormap=monochrome_red_mpl, width=0.8, edgecolor="#000000", stacked=True)

plt.title("Employee Count Per Location And Department", loc="left", fontdict={"weight": "bold"}, y=1.06)
plt.xlabel("Office Location")
plt.ylabel("Count")

plt.show()
Image 7 - Chart with a red monochromatic color palette (image by author)
Image 7 – Chart with a red monochromatic color palette (image by author)

Long story short – it works like a charm – it just needs a dark enough color to start.

Wrapping up

To conclude, a bit of Python code can go a long way.

Usually, to create a monochrome color palette, you’d search online or start with a color you like and adjust the lightness until you get enough variations. Today, you’ve learned how to automate this process.

You can expand this by adding default starting colors, so you don’t need to choose one each time. But that part is simple and I’ll leave it to you.

Until next time.

-Dario

Looking to level up your data visualization game? Join thousands of like-minded people on my Substack for expert tips and exclusive content.

Data Doodles with Python | Dario Radecic | Substack


Originally published at https://darioradecic.substack.com.

The post How to Convert a Single HEX Color Code into a Monochrome Color Palette with Python appeared first on Towards Data Science.

]]>
How to Create a Custom Matplotlib Theme and Make Your Charts Go from Boring to Amazing https://towardsdatascience.com/how-to-create-a-custom-matplotlib-theme-and-make-your-charts-go-from-boring-to-amazing-bb2998945c8f/ Thu, 05 Sep 2024 18:48:23 +0000 https://towardsdatascience.com/how-to-create-a-custom-matplotlib-theme-and-make-your-charts-go-from-boring-to-amazing-bb2998945c8f/ The best part? You'll only have to do this once.

The post How to Create a Custom Matplotlib Theme and Make Your Charts Go from Boring to Amazing appeared first on Towards Data Science.

]]>
Every Matplotlib chart has the potential to go viral. But not with the default theme.

Let’s be honest: default styles aren’t going to stop anyone in their tracks. If you want to make your reader pay attention, you’ll need more than a resolution bump or a new font. You’ll need a custom theme.

It’s the only way to make your charts look yours. It’s the only way to make your reader stop scrolling. The good news? Matplotlib makes it incredibly easy to write custom style sheets from scratch. Even better – you can make your custom theme available system-wide!

And in today’s article, you’ll learn how to do just that.

Two Functions You’ll Need For Plotting Bar and Line Charts

For the sake of repeatability, I’ll provide you with two functions for plotting a bar and a line chart.

The goal of setting a custom Matplotlib theme isn’t to change the underlying Python code – but rather to leave it as is – and see what the visual differences are.

Here are the libraries you’ll need to follow along:

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

Let’s go over both functions.

Bar Charts

The plot_bar() function creates a stacked bar chart representing the employee count per location and department.

The data is completely made up, and the only "hard coded" bit is the y position of the title. It’s increased ever so slightly to create more room between the title and the plot.

You can tweak the legend_pos value when calling this function. I’ve deliberately made this part tweakable so you can adjust the offset for the legend. It’ll always be different, so it’s nice to have some flexibility.

Everything else is pretty minimal and self-explanatory:

def plot_bar(legend_pos: tuple = (1.25, 1)):
    x = np.array(["New York", "San Francisco", "Los Angeles", "Chicago", "Miami"])
    y1 = np.array([50, 63, 40, 68, 35])
    y2 = np.array([77, 85, 62, 89, 58])
    y3 = np.array([50, 35, 79, 43, 67])
    y4 = np.array([59, 62, 33, 77, 72])

    plt.bar(x, y1, label="HR")
    plt.bar(x, y2, bottom=y1, label="Engineering")
    plt.bar(x, y3, bottom=y1 + y2, label="Marketing")
    plt.bar(x, y4, bottom=y1 + y2 + y3, label="Sales")

    plt.title("Employee Count Per Location And Department", y=1.06)
    plt.xlabel("Office Location")
    plt.ylabel("Count")

    plt.legend(bbox_to_anchor=legend_pos)
    plt.show()

Line Charts

The plot_line() function is pretty similar, but it displays made-up yearly revenues.

In other words, it shows a line plot of revenue generated by each office location from 2018 to 2023:

def plot_line(legend_pos: tuple = (1.25, 1)):
    x = [2018, 2019, 2020, 2021, 2023]
    ny = [100, 124, 154, 133, 167]
    sf = [89, 95, 109, 139, 155]
    la = [107, 144, 179, 161, 175]
    ch = [59, 91, 99, 117, 141]
    mi = [121, 99, 104, 131, 140]

    plt.plot(x, ny, label="New York")
    plt.plot(x, sf, label="San Francisco")
    plt.plot(x, la, label="Los Angeles")
    plt.plot(x, ch, label="Chicago")
    plt.plot(x, mi, label="Miami")

    plt.title("Yearly Revenue", y=1.06)
    plt.xlabel("Year")
    plt.ylabel("Revenue (M)")

    plt.legend(bbox_to_anchor=legend_pos)   
    plt.show()

So, how do these charts look by default?

A Terrible-Looking Example

As the title suggests – terrible. And I do mean terrible.

Take a look for yourself, and feel free to tweak the legend offset if you get some overlap:

plot_bar(legend_pos=(1.32, 1))
plot_line(legend_pos=(1.05, 1))
Image 1 - Ugly bar chart (image by author)
Image 1 – Ugly bar chart (image by author)
Image 2 - Ugly line chart (image by author)
Image 2 – Ugly line chart (image by author)

The font is meh, the colors are boring, and the resolution is unreasonably low.

You don’t want to share this with your coworkers, let alone your boss. A custom theme can fix all of the above and then some.

Matplotlib Stylesheets – How to Create a Custom Matpltolib Theme

Start by creating a file with an .mplstyle extension. I’ve named mine custom.mplstyle.

In this file, you can tweak everything (literally everything) Matplotlib related. I’ll keep things simple and change a handful of things that will make a world of difference:

  • Generic tweaks – Default figure size and resolution, color palette (Tableau colorblind), and plot spines.
  • TextCustom font (must be installed on your system), font size, title location and style, and size of axis labels/ticks.
  • Legend – Location and the number of data points representing a single item.
  • Grid – Always show the grid as a light gray dotted line behind the chart contents.
  • Chart type specific tweaks – Default width of a line in a line chart, and including markers.

Paste the following into your .mplstyle file:

# Generic figure tweaks
figure.figsize: 10, 6  
figure.dpi: 125  
savefig.dpi: 300 
savefig.bbox: tight  
axes.prop_cycle: cycler("color", ["006ba4", "ff800e", "ababab", "595959", "5f9ed1", "c85200", "898989", "a2c8ec", "ffbc79", "cfcfcf"])
axes.spines.top: False
axes.spines.right: False

# Text
font.size: 14
font.family: sans-serif
font.sans-serif: IBM Plex Sans
axes.titlesize: 20
axes.titleweight: bold
axes.titlelocation: left
axes.labelsize: large
xtick.labelsize: medium
ytick.labelsize: medium

# Legend
legend.loc: "upper right"
legend.numpoints: 3 
legend.scatterpoints: 3

# Grid
axes.grid: True
axes.axisbelow: True  
grid.color: "#d3d3d3"
grid.linestyle: : 
grid.linewidth: 1.0

# Line chart only
lines.linewidth: 3
lines.marker: o
lines.markersize: 8

So, how can you reference this file in Python?

A Superb-Looking Example

You only need to add one line of code to use your custom theme:

You only need to add one line of code to use your custom theme:

plt.style.use("custom.mplstyle")

A typical place to do so is right after importing Matplotlib. That’s not a requirement, just a best practice.

This is what your charts will look like now:

plot_bar()
plot_line()
Image 3 - Pretty bar chart (image by author)
Image 3 – Pretty bar chart (image by author)
Image 4 - Pretty line chart (image by author)
Image 4 – Pretty line chart (image by author)

Identical code, but a night and day visual difference, to say at least.

How to Make Your Custom Matplotlib Theme Available Globally

Bringing the mplstyle file to every new project isn’t a huge inconvenience, but you can do better.

What "better" means is making the custom stylesheet available globally, or at least inside a virtual environment.

Start by finding out where Matplotlib is installed:

mpl.__file__
Image 5 - Environment-specific Matplotlib location
Image 5 – Environment-specific Matplotlib location

You need the entire path without the __init__.py part.

Assuming you’re in a notebook environment, run the following command (after replacing the path) to copy the mplstyle file to Matplotlib’s style library directory:

!cp custom.mplstyle /Users/dradecic/miniforge3/envs/py/lib/python3.11/site-packages/matplotlib/mpl-data/stylelib/custom-style.mplstyle

Note that I’ve renamed the file.

If you’re not in a notebook environment, just omit the exclamation at the start and run the command from the console.

Anyhow, after restarting the kernel, you should see your stylesheet listed:

import matplotlib.pyplot as plt

plt.style.available
Image 6 - Available Matplotlib styles (image by author)
Image 6 – Available Matplotlib styles (image by author)

You can now reference it from any script or notebook running in the virtual environment:

plt.style.use("custom-style")

To demonstrate, I’ve created a new chart in a separate notebook. The data it shows is completely arbitrary.

The important thing is that the custom theme works:

import matplotlib.pyplot as plt
plt.style.use("custom-style")

x = [2019, 2020, 2021, 2022, 2023, 2024]

plt.plot(x, [100, 125, 200, 150, 225, 250])
plt.plot(x, [150, 220, 250, 300, 200, 150])
plt.plot(x, [200, 150, 275, 175, 250, 300])
plt.title("Custom Styles Demonstration Chart", y=1.06)

plt.show()
Image 7 - Test chart (image by author)
Image 7 – Test chart (image by author)

And there it is – a completely custom visualization style just one line of code away.

Wrapping up

To conclude, default Data Visualization themes suck.

You should do everything in your power to make your visuals stand out and capture attention. If you fail to do so, the quality of your message won’t matter. No one will look at it for more than a second. No one will share it.

Custom style sheets in Matplotlib are a great way to capture attention.

You only have to write your theme once, and then use it anywhere with one additional line of code. It’s a lot less friction than pasting dozens of rcParams to every Python script or notebook.

Which style tweaks do you typically include in your visualizations? Let me know in the comment section below.


Looking to level up your data visualization game? Join thousands of like-minded people on my Substack for expert tips and exclusive content.

Data Doodles with Python | Dario Radecic | Substack

Originally published at https://darioradecic.substack.com.

The post How to Create a Custom Matplotlib Theme and Make Your Charts Go from Boring to Amazing appeared first on Towards Data Science.

]]>
How to Create Custom Color Palettes in Matplotlib – Discrete vs. Linear Colormaps, Explained https://towardsdatascience.com/how-to-create-custom-color-palettes-in-matplotlib-discrete-vs-linear-colormaps-explained-cfe4c5ba1215/ Thu, 29 Aug 2024 17:18:08 +0000 https://towardsdatascience.com/how-to-create-custom-color-palettes-in-matplotlib-discrete-vs-linear-colormaps-explained-cfe4c5ba1215/ Actionable guide on how to bring custom colors to personalize your charts

The post How to Create Custom Color Palettes in Matplotlib – Discrete vs. Linear Colormaps, Explained appeared first on Towards Data Science.

]]>
If there’s one thing that’ll make a good chart great, it’s the color choice.

You can turn any set of hex color codes into a color palette with Matplotlib, and this article will show you how. You’ll also learn the difference between discrete and linear color palettes, and the reasons why one is better than the other.

If you want to get the same Data Visualization quality I have, follow the steps from this article before proceeding:


How to Create Custom Colormaps in Matplotlib

These are the libraries you’ll need to follow along:

import numpy as np
import pandas as pd
import Matplotlib as mpl
import matplotlib.pyplot as plt

Matplotlib allows you to create two types of color palettes:

  • Discrete – The palette has a finite number of color values. Great for categorical data, but you need to make sure the palette has at least as many colors as you have distinct categories.
  • Linear (continuous) – The palette has an "infinite" number of values. Great for continuous and categorical data. You can specify just two colors, and the palette will automatically include all the values between them (think: gradient in Photoshop).

To create a discrete color palette in Matplotlib, run the following:

Python">cmap_discrete = mpl.colors.ListedColormap(
    name="discrete-monochromatic-blue",
    colors=["#051923", "#003554", "#006494", "#0582ca", "#00a6fb"]
)
cmap_discrete
Image 1 - Discrete colormap (image by author)
Image 1 – Discrete colormap (image by author)

And to create a linear (continuous) color palette, run this code snippet:

cmap_linear = mpl.colors.LinearSegmentedColormap.from_list(
    name="linear-monochromatic-blue", 
    colors=["#051923", "#003554", "#006494", "#0582ca", "#00a6fb"]
)
cmap_linear
Image 2 - Linear colormap (image by author)
Image 2 – Linear colormap (image by author)

You can see how the discrete palette has 5 distinct colors, while the linear palette takes a continuous range of values.

What’s the difference when visualizing data? That’s what you’ll learn next.

Comparison – Discrete vs. Linear Colormap on Continuous Data

In this section, you’ll create a 10×10 matrix of random numbers [0.0, 1.0) and visualize it as an image.

Run the following snippet to create the data:

data = np.random.random(100).reshape(10, 10)
data
Image 3 - Random normal 10x10 matrix (image by author)
Image 3 – Random normal 10×10 matrix (image by author)

As for the visual, you’ll want to create a 1×2 grid and show the identical dataset colored through a discrete colormap on the left, and a linear colormap on the right:

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

cax1 = ax1.imshow(data, cmap=cmap_discrete)
ax1.set_title("Discrete Colormap", loc="left", fontdict={"weight": "bold"}, y=1.04)
fig.colorbar(cax1, ax=ax1)

cax2 = ax2.imshow(data, cmap=cmap_linear)
ax2.set_title("Linear Colormap", loc="left", fontdict={"weight": "bold"}, y=1.04)
fig.colorbar(cax2, ax=ax2)

plt.tight_layout()
plt.show()
Image 4 - Discrete vs. linear colormap (image by author)
Image 4 – Discrete vs. linear colormap (image by author)

The underlying data is the same, but the plot on the left can take up to 5 possible color values. The one on the right has a much wider range.

To conclude – linear color palettes can be used to visualize both continuous and categorical data, while discrete color palettes can only do the latter without information loss.

Real-World Plots – How to Add Custom Colors to Your Charts

You’ll now see how to apply a custom color palette to bar charts in Matplotlib.

To start, copy the following code snippet to create mock employee count data across 10 departments and 5 office locations:

df = pd.DataFrame({
    "HR": [50, 63, 40, 68, 35],
    "Engineering": [77, 85, 62, 89, 58],
    "Marketing": [50, 35, 79, 43, 67],
    "Sales": [59, 62, 33, 77, 72],
    "Customer Service": [31, 34, 61, 70, 39],
    "Distribution": [35, 21, 66, 90, 31],
    "Logistics": [50, 54, 13, 71, 32],
    "Production": [22, 51, 54, 28, 40],
    "Maintenance": [50, 32, 61, 69, 50],
    "Quality Control": [20, 21, 88, 89, 39]
}, index=["New York", "San Francisco", "Los Angeles", "Chicago", "Miami"])

df = df.T
df = df.loc[df.sum(axis=1).sort_values().index]

df
Image 5 - Custom employee dataset (image by author)
Image 5 – Custom employee dataset (image by author)

Since there are 5 office locations, and our discrete color palette has 5 colors, it’s a perfect match for visualization.

The only new parameter you need to know in plot() is colormap. You’ll have to provide your palette variable. I’ve added employee counts to each bar segment, but consider this part optional:

ax = df.plot(kind="barh", colormap=cmap_discrete, width=0.8, edgecolor="#000000", stacked=True)

for container in ax.containers:
    ax.bar_label(container, label_type="center", fontsize=10, color="#FFFFFF", fontweight="bold")

plt.title("Employee Count Per Location And Department", loc="left", fontdict={"weight": "bold"}, y=1.06)
plt.xlabel("Office Location")
plt.ylabel("Count")

plt.show()
Image 6 - Stacked horizontal bar chart with a custom discrete colormap (image by author)
Image 6 – Stacked horizontal bar chart with a custom discrete colormap (image by author)

Looks great, doesn’t it?

You shouldn’t see any difference when comparing discrete and linear color palettes for this visualization. Why? Because the discrete palette has 5 colors, and you have 5 groups in the data.

Still, let’s define a function that will compare the two palette categories:

def plot_employee_count_comparison(df, cmap1, cmap2):
    _, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 7), sharex=True, sharey=True)

    df.plot(kind="barh", colormap=cmap1, width=0.8, edgecolor="#000000", stacked=True, ax=ax1)
    for container in ax1.containers:
        ax1.bar_label(container, label_type="center", fontsize=10, color="#FFFFFF", fontweight="bold")
    ax1.set_title("Employee Count (Discrete Colormap)", loc="left", fontdict={"weight": "bold"}, y=1.06)
    ax1.set_xlabel("Count")
    ax1.set_ylabel("Department")

    df.plot(kind="barh", colormap=cmap2, width=0.8, edgecolor="#000000", stacked=True, ax=ax2)
    for container in ax2.containers:
        ax2.bar_label(container, label_type="center", fontsize=10, color="#FFFFFF", fontweight="bold")
    ax2.set_title("Employee Count (Linear Colormap)", loc="left", fontdict={"weight": "bold"}, y=1.06)
    ax2.set_xlabel("Count")
    ax2.set_ylabel("Department")

    # Adjust layout
    plt.tight_layout()
    plt.show()

plot_employee_count_comparison(df=df, cmap1=cmap_discrete, cmap2=cmap_linear)
Image 7 - Discrete vs. linear colormap (identical results) (image by author)
Image 7 – Discrete vs. linear colormap (identical results) (image by author)

Both plots are identical. But that won’t always be the case.

What Happens When You Have More Categories Then Colors in A Discrete Color Palette?

Good question. Let’s answer it by expanding the Pandas DataFrame to include two additional office locations:

df = pd.DataFrame({
    "HR": [50, 63, 40, 68, 35, 44, 31],
    "Engineering": [77, 85, 62, 89, 58, 56, 59],
    "Marketing": [50, 35, 79, 43, 67, 31, 24],
    "Sales": [59, 62, 33, 77, 72, 55, 66],
    "Customer Service": [31, 34, 61, 70, 39, 49, 81],
    "Distribution": [35, 21, 66, 90, 31, 67, 81],
    "Logistics": [50, 54, 13, 71, 32, 58, 51],
    "Production": [22, 51, 54, 28, 40, 41, 62],
    "Maintenance": [50, 32, 61, 69, 50, 49, 41],
    "Quality Control": [20, 21, 88, 89, 39, 66, 32]
}, index=["New York", "San Francisco", "Los Angeles", "Chicago", "Miami", "Las Vegas", "Boston"])

df = df.T
df = df.loc[df.sum(axis=1).sort_values().index]

df
Image 8 - Custom employee dataset (2) (image by author)
Image 8 – Custom employee dataset (2) (image by author)

In theory, a discrete color palette should fail since it only has 5 color values and the dataset has 7 categories.

Let’s use the plot_employee_count_comparison() function to see the differences:

plot_employee_count_comparison(df=df, cmap1=cmap_discrete, cmap2=cmap_linear)
Image 9 - Discrete vs. linear colormap (different results) (image by author)
Image 9 – Discrete vs. linear colormap (different results) (image by author)

The left chart is unusable.

Just look at the first column for employee count combinations (77, 85) and (56, 59). They use the same color! Matplotlib uses the colors provided in the center of the chart, and the edges take values from the neighboring points.

In short, not what you want.

You don’t get this type of issue with linear color palettes, even if you construct it from two colors only.


Wrapping up

To summarize, a custom color palette might be just what your chart needs to make it publication-worthy.

Also, if you’re working at a company, chances are they already have a set of colors defined. Using them from the get-go is a guaranteed way to remove at least one submission iteration.

Matplotlib makes creating custom discrete and linear color palettes a breeze. Both can be used on categorical data, but only the latter works properly on continuous data. In this use case, a discrete color palette is obsolete, as you can get the same result (and more) with a linear palette.

What are your favorite color combinations for data visualization? Please share in the comment section below.


Looking to level up your data visualization game? Join thousands of like-minded people on my Substack for expert tips and exclusive content.

Data Doodles with Python | Dario Radecic | Substack

Originally published at https://darioradecic.substack.com.

The post How to Create Custom Color Palettes in Matplotlib – Discrete vs. Linear Colormaps, Explained appeared first on Towards Data Science.

]]>
3 Key Tweaks That Will Make Your Matplotlib Charts Publication Ready https://towardsdatascience.com/3-key-tweaks-that-will-make-your-matplotlib-charts-publication-ready-18e41eaac589/ Thu, 08 Aug 2024 18:07:33 +0000 https://towardsdatascience.com/3-key-tweaks-that-will-make-your-matplotlib-charts-publication-ready-18e41eaac589/ Matplotlib charts are an eyesore by default - here's what to do about it.

The post 3 Key Tweaks That Will Make Your Matplotlib Charts Publication Ready appeared first on Towards Data Science.

]]>
Data visualization offers much deeper insights than looking at raw, numerical data.

However, creating appealing Charts takes time and effort. Matplotlib is a de facto standard library for data visualization in Python. It’s simple, has been used for decades, and anything you’re looking for is one web search away.

But it’s not all sunshine and rainbows. Matplotlib visualizations look horrendous by default, and you as a data professional will have to turn many cogs to get something usable. Getting you there is the goal of today’s article.

By the end, you’ll have a code snippet you can stick to any Jupyter Notebook.

What’s Wrong with Matplotlib’s Default Styles?

You won’t need to download any dataset to follow along. You’ll create a synthetic time series dataset with increasing trend and repeatable seasonal patterns:

import numpy as np
import pandas as pd
import Matplotlib.pyplot as plt

# Single season multiplier factors - for seasonality effect
seasonal_multipliers = [1.1, 1.3, 1.2, 1.5, 1.9, 2.3, 2.1, 2.8, 2.0, 1.7, 1.5, 1.2]
# Immitate 10 years of data
xs = np.arange(1, 121)

time_series = []
# Split to 10 chunks - 1 year each
for chunk in np.split(xs, 10):
    for i, val in enumerate(chunk):
        # Multiply value with seasonal scalar
        time_series.append(float(val * seasonal_multipliers[i]))

x = pd.date_range(start="2015-01-01", freq="MS", periods=120)
y = time_series      

print(x[-10:])
print(y[-10:])
Image 1 - Time series data (image by author)
Image 1 – Time series data (image by author)

Since the dataset has dates for an index, and float values for the only attribute, you can plot the entire thing directly via plt.plot():

Python">plt.figure(figsize=(9, 6))
plt.plot(x, y)
plt.show()
Image 2 - Default matplotlib chart (image by author)
Image 2 – Default matplotlib chart (image by author)

Everything about it screams 2002. Low resolution. The surrounding box. The font size.

Nothing a couple of tweaks can’t fix.

Tweak #1 – Adjust rcParams to Set the Overall Theme

Tweaking every chart by hand is a sure way to waste your time.

After all, most charts you make will have an underlying theme. It makes sense to declare it once and reuse it everywhere. That’s the role of rcParams.

The following code snippet changes a whole bunch of them and ensure your charts are rendered as SVGs. This last bit won’t matter if you explicitly save your charts to disk, but it will make a huge difference in a notebook environment:

import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats("svg")

plt.rcParams.update({
    "figure.figsize": (9, 6),
    "axes.spines.top": False,
    "axes.spines.right": False,
    "font.size": 14,
    "figure.titlesize": "xx-large",
    "xtick.labelsize": "medium",
    "ytick.labelsize": "medium",
    "axes.axisbelow": True
})

And now, when you repeat the call to plt.plot(), the chart will look somewhat presentable:

plt.plot(x, y)
plt.title("Sales Over Time")
plt.xlabel("Time Period")
plt.ylabel("Sales in 000")
plt.show()
Image 3 - Adjusting chart theme (image by author)
Image 3 – Adjusting chart theme (image by author)

Not quite there yet, but the idea was just to set an underlying theme. You shouldn’t include chart-specific instructions in rcParams.

Tweak #2 – Bring the Font to the 21st Century

Another thing you can change in rcParams is the font.

You can download any TTF font from the internet, and load it via Matplotlib’s font_manager. I’ll use Roboto Condensed, but feel free to go with anything you like:

import matplotlib.font_manager as font_manager

font_dir = ["/path/to/Roboto_Condensed"]
for font in font_manager.findSystemFonts(font_dir):
    font_manager.fontManager.addfont(font)

plt.rcParams.update({
    "font.family": "Roboto Condensed"
})

To verify the font has changed, simply rerun the plotting snippet from earlier:

plt.plot(x, y)
plt.title("Sales Over Time")
plt.xlabel("Time Period")
plt.ylabel("Sales in 000")
plt.show()
Image 4 - Changing the font (image by author)
Image 4 – Changing the font (image by author)

And that’s about all we’ll discuss regarding overall theme changes. Up next, let’s get specific.

Tweak #3 – Make Micro Adjustment Specific to Your Chart Type

Different chart types will have different go-to approaches when it comes to fine-tuning.

For line charts, you can change the line color and width, and potentially even add a filled area section to make the chart look more dashboardy.

The other changes made in the below snippet are purely cosmetic – title location and y-axis limit:

# 1. Line color and width
plt.plot(x, y, color="#1C3041", linewidth=2)
# 2. Add shaded area below the line
plt.fill_between(x, y, color="#1C3041", alpha=0.3)
# 3. Change title location and font weight
plt.title("Sales Over Time", loc="left", fontdict={"weight": "bold"}, y=1.06)

plt.xlabel("Time Period")
plt.ylabel("Sales in 000")
plt.show()
Image 5 - Finalized chart (image by author)
Image 5 – Finalized chart (image by author)

Now this is almost a publication-ready visualization!

A couple of specific made all the difference, but that wouldn’t be possible without setting a strong foundation.

Wrapping Up

Many data professionals disregard Matplotlib entirely because of how it looks by default.

They think "It’s blurry and awful, I can’t send a visual like this to my boss." The reality couldn’t be further away from the truth. You can change a bunch of parameters and end up with a code block you can bring anywhere – to every script, notebook, and environment.

I encourage you to play around with the parameters to further personalize the overall look and feel, and share your preferences in the comment section below.


Looking to level-up your data visualization game? Join thousands of like-minded people on my Substack for expert tips and exclusive content.

Data Doodles with Python | Dario Radecic | Substack

Originally published at https://darioradecic.substack.com.

The post 3 Key Tweaks That Will Make Your Matplotlib Charts Publication Ready appeared first on Towards Data Science.

]]>
5 PCA Visualizations You Must Try On Your Next Data Science Project https://towardsdatascience.com/5-pca-visualizations-you-must-try-on-your-next-data-science-project-148ec3d31e4d/ Fri, 02 Aug 2024 15:01:31 +0000 https://towardsdatascience.com/5-pca-visualizations-you-must-try-on-your-next-data-science-project-148ec3d31e4d/ Which features carry the most weight? How do original features contribute to principal components? These 5 visualization types have the...

The post 5 PCA Visualizations You Must Try On Your Next Data Science Project appeared first on Towards Data Science.

]]>
Which features carry the most weight? How do original features contribute to principal components? These 5 visualization types have the answer.
Photo by Andrew Neel on Unsplash
Photo by Andrew Neel on Unsplash

Principal Component Analysis (PCA) can tell you a lot about your data. In short, it’s a dimensionality reduction technique used to bring high-dimensional datasets into a space that can be visualized.

But I assume you already know that. If not, check my from-scratch guide.

Today, we only care about the visuals. By the end of the article, you’ll know how to create and interpret:

  1. Explained variance plot
  2. Cumulative explained variance plot
  3. 2D/3D component scatter plot
  4. Attribute biplot
  5. Loading score plot

Getting Started – PCA Visualization Prerequisites

I’d love to dive into visualizations right away, but you’ll need data to follow along. This section covers data loading, preprocessing, PCA fitting, and general Matplotlib styling tweaks.

Dataset Info

I’m using the Wine Quality Dataset. It’s available for free on Kaggle and is licensed under the Creative Commons License.

This is what you should see after loading it with Python:

import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

data = pd.read_csv("data/WineQT.csv")
data.drop(["Id"], axis=1, inplace=True)
data.head()
Image 1 - Head of the Wine quality dataset (image by author)
Image 1 – Head of the Wine quality dataset (image by author)

PCA assumes numerical features with no missing values located around 0 with a standard deviation of 1:

X = data.drop("quality", axis=1)
y = data["quality"]

X_scaled = StandardScaler().fit_transform(X)

pca = PCA().fit(X_scaled)
pca_res = pca.transform(X_scaled)

pca_res_df = pd.DataFrame(pca_res, columns=[f"PC{i}" for i in range(1, pca_res.shape[1] + 1)])
pca_res_df.head()
Image 2 - Wine quality dataset after applying PCA (image by author)
Image 2 – Wine quality dataset after applying PCA (image by author)

The dataset is now ready for visualization, but any chart you make will look horrendous at best. Let’s fix that.

Matplotlib Visualization Tweaks

I have a full-length article covering Matplotlib styling tweaks, so don’t expect any depth here.

Download a TTF font of your choice to follow along (mine is Roboto Condensed) and replace the path in font_dir to match one on your operating system:

import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager

import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats("svg")

font_dir = ["Roboto_Condensed"]
for font in font_manager.findSystemFonts(font_dir):
    font_manager.fontManager.addfont(font)

plt.rcParams["figure.figsize"] = 10, 6
plt.rcParams["axes.spines.top"] = False
plt.rcParams["axes.spines.right"] = False
plt.rcParams["font.size"] = 14
plt.rcParams["figure.titlesize"] = "xx-large"
plt.rcParams["xtick.labelsize"] = "medium"
plt.rcParams["ytick.labelsize"] = "medium"
plt.rcParams["axes.axisbelow"] = True
plt.rcParams["font.family"] = "Roboto Condensed"

You’ll now get modern-looking high-resolution charts any time you plot something.

And that’s all you need to start making awesome PCA visualizations. Let’s dive into the first one next.

PCA Plot #1: Explained Variance Plot

Question: How much of the total variance in the data is captured by each principal component?

If you’re wondering the same, the explained variance plot is where it’s at. You’ll typically see the first couple of components covering a decent chunk of the overall variance, but that might depend on the number of features you’re starting with.

The first component of a dataset with 5 features will capture more total variance than the first component of a dataset with 500 features – duh.

plot_y = [val * 100 for val in pca.explained_variance_ratio_]
plot_x = range(1, len(plot_y) + 1)

bars = plt.bar(plot_x, plot_y, align="center", color="#1C3041", edgecolor="#000000", linewidth=1.2)
for bar in bars:
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width() / 2, yval + 0.001, f"{yval:.1f}%", ha="center", va="bottom")

plt.xlabel("Principal Component")
plt.ylabel("Percentage of Explained Variance")
plt.title("Variance Explained per Principal Component", loc="left", fontdict={"weight": "bold"}, y=1.06)
plt.grid(axis="y")
plt.xticks(plot_x)

plt.show()
Image 3 - Explained variance plot (image by author)
Image 3 – Explained variance plot (image by author)

28.7% of total variance (from 11 features) is explained with just one component. In other words, you can plot the entire dataset on a single line (1D) and still show ~ a third of the variability.

PCA Plot #2: Cumulative Explained Variance Plot

Question: I want to reduce the dimensionality of my data, but still want to keep at least 90% of the variance. What should I do?

The simplest answer is to modify the first chart so it shows the cumulative sum of explained variance. In code, just iteratively sum up the values up to the current item.

exp_var = [val * 100 for val in pca.explained_variance_ratio_]
plot_y = [sum(exp_var[:i+1]) for i in range(len(exp_var))]
plot_x = range(1, len(plot_y) + 1)

plt.plot(plot_x, plot_y, marker="o", color="#9B1D20")
for x, y in zip(plot_x, plot_y):
    plt.text(x, y + 1.5, f"{y:.1f}%", ha="center", va="bottom")

plt.xlabel("Principal Component")
plt.ylabel("Cumulative Percentage of Explained Variance")
plt.title("Cumulative Variance Explained per Principal Component", loc="left", fontdict={"weight": "bold"}, y=1.06)

plt.yticks(range(0, 101, 5))
plt.grid(axis="y")
plt.xticks(plot_x)

plt.show()
Image 4 - Cumulative explained variance plot (image by author)
Image 4 – Cumulative explained variance plot (image by author)

So, if you want a narrower dataset that still captures at least 90% of the variance, you’ll want to keep the first 7 principal components.

PCA Plot #3: 2D/3D Component Scatter Plot

Question: How can I visualize the relationship between records of a high-dimensional dataset? We can’t see more than 3 dimensions at a time.

A scatter plot of the first 2 or 3 principal components is what you’re looking for. Ideally, you should color the data points with distinct values of the target variable (assuming a classification dataset).

Let’s break it down.

2D Scatter Plot

The first two components capture ~ 45% of the variance. It’s a decent amount, but a 2-dimensional scatter plot still won’t account for more than half of it. Something to keep in mind.

total_explained_variance = sum(pca.explained_variance_ratio_[:2]) * 100
colors = ["#1C3041", "#9B1D20", "#0B6E4F", "#895884", "#F07605", "#F5E400"]

pca_2d_df = pd.DataFrame(pca_res[:, :2], columns=["PC1", "PC2"])
pca_2d_df["y"] = data["quality"]

fig, ax = plt.subplots()
for i, target in enumerate(sorted(pca_2d_df["y"].unique())):
    subset = pca_2d_df[pca_2d_df["y"] == target]
    ax.scatter(x=subset["PC1"], y=subset["PC2"], s=70, alpha=0.7, c=colors[i], edgecolors="#000000", label=target)

plt.xlabel("Principal Component 1")
plt.ylabel("Principal Component 2")
plt.title(f"Wine Quality Dataset PCA ({total_explained_variance:.2f}% Explained Variance)", loc="left", fontdict={"weight": "bold"}, y=1.06)

ax.legend(title="Wine quality")
plt.show()
Image 5 - Scatter plot of the first two principal components (image by author)
Image 5 – Scatter plot of the first two principal components (image by author)

It’s a mess, so let’s try adding a dimension.

3D Scatter Plot

Adding an extra dimension will increase the explained variance to ~ 60%. Keep in mind that 3D charts are more challenging to look at, and the interpretation can somewhat depend on the chart’s angle.

total_explained_variance = sum(pca.explained_variance_ratio_[:3]) * 100
colors = ["#1C3041", "#9B1D20", "#0B6E4F", "#895884", "#F07605", "#F5E400"]

pca_3d_df = pd.DataFrame(pca_res[:, :3], columns=["PC1", "PC2", "PC3"])
pca_3d_df["y"] = data["quality"]

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(projection="3d")

for i, target in enumerate(sorted(pca_3d_df["y"].unique())):
    subset = pca_3d_df[pca_3d_df["y"] == target]
    ax.scatter(xs=subset["PC1"], ys=subset["PC2"], zs=subset["PC3"], s=70, alpha=0.7, c=colors[i], edgecolors="#000000", label=target)

ax.set_xlabel("Principal Component 1")
ax.set_ylabel("Principal Component 2")
ax.set_zlabel("Principal Component 3")
ax.set_title(f"Wine Quality Dataset PCA ({total_explained_variance:.2f}% Explained Variance)", loc="left", fontdict={"weight": "bold"})

ax.legend(title="Wine quality", loc="lower left")
plt.show()
Image 6 - Scatter plot of the first three principal components (image by author)
Image 6 – Scatter plot of the first three principal components (image by author)

To change the perspective, you can play around with the view_init() function. It allows you to change the elevation and azimuth of the axes in degrees:

ax.view_init(elev=<value>, azim=<value>)
Image 7 - Tweaking 3D scatter plot angle (image by author)
Image 7 – Tweaking 3D scatter plot angle (image by author)

Or, you can use an interactive charting library like Plotly and rotate the chart like a sane person.

PCA Plot #4: Biplot

Question: Can I see how the original variables contribute to and correlate with the principal components?

Yup. Annoyingly, this chart type is much easier to make in R, but what can you do.

labels = X.columns
n = len(labels)
coeff = np.transpose(pca.components_)
pc1 = pca.components_[:, 0]
pc2 = pca.components_[:, 1]

plt.figure(figsize=(8, 8))

for i in range(n):
    plt.arrow(x=0, y=0, dx=coeff[i, 0], dy=coeff[i, 1], color="#000000", width=0.003, head_width=0.03)
    plt.text(x=coeff[i, 0] * 1.15, y=coeff[i, 1] * 1.15, s=labels[i], size=13, color="#000000", ha="center", va="center")

plt.axis("square")
plt.title(f"Wine Quality Dataset PCA Biplot", loc="left", fontdict={"weight": "bold"}, y=1.06)
plt.xlabel("Principal Component 1")
plt.ylabel("Principal Component 2")

plt.xlim(-1, 1)
plt.ylim(-1, 1)
plt.xticks(np.arange(-1, 1.1, 0.2))
plt.yticks(np.arange(-1, 1.1, 0.2))

plt.axhline(y=0, color="black", linestyle="--")
plt.axvline(x=0, color="black", linestyle="--")
circle = plt.Circle((0, 0), 0.99, color="gray", fill=False)
plt.gca().add_artist(circle)

plt.grid()
plt.show()
Image 8 - Variable biplot (image by author)
Image 8 – Variable biplot (image by author)

So, what are you looking at? This is what you should know when interpreting biplots:

  • Arrow direction: Indicates how the corresponding variable is aligned with the principal component. Arrows that point in the same direction are positively correlated. Arrows pointing in the opposite direction are negatively correlated.
  • Arrow length: Shows how much the variable contributes to the principal components. Longer arrows mean stronger contribution (the variable accounts for more explained variance). Shorter arrows mean weaker contribution (the variable accounts for less explained variance).
  • Angle between arrows: 0º indicates a perfect positive correlation. 180º indicates a perfect negative correlation. 90º indicates no correlation.

From the above chart, you can see how the features sulphates, fixed acidity, and citric accid show a high correlation.

PCA Plot #5: Loading Score Plot

Question: Which variables from the original dataset have the most influence on each principal component?

You can visualize something known as a loading score to find out. It’s a value that represents the weight of an original variable to a given principal component. A higher absolute value indicates higher influence.

These are essentially correlations with a principal component and range from -1 to +1. You don’t often care about the direction, just the magnitude.

loadings = pd.DataFrame(
    data=pca.components_.T * np.sqrt(pca.explained_variance_), 
    columns=[f"PC{i}" for i in range(1, len(X.columns) + 1)],
    index=X.columns
)

fig, axs = plt.subplots(2, 2, figsize=(14, 10), sharex=True, sharey=True)
colors = ["#1C3041", "#9B1D20", "#0B6E4F", "#895884"]

for i, ax in enumerate(axs.flatten()):
    explained_variance = pca.explained_variance_ratio_[i] * 100
    pc = f"PC{i+1}"
    bars = ax.bar(loadings.index, loadings[pc], color=colors[i], edgecolor="#000000", linewidth=1.2)
    ax.set_title(f"{pc} Loading Scores ({explained_variance:.2f}% Explained Variance)", loc="left", fontdict={"weight": "bold"}, y=1.06)
    ax.set_xlabel("Feature")
    ax.set_ylabel("Loading Score")
    ax.grid(axis="y")
    ax.tick_params(axis="x", rotation=90)
    ax.set_ylim(-1, 1)

    for bar in bars:
        yval = bar.get_height()
        offset = yval + 0.02 if yval > 0 else yval - 0.15
        ax.text(bar.get_x() + bar.get_width() / 2, offset, f"{yval:.2f}", ha="center", va="bottom")

plt.tight_layout()
plt.show()
Image 9 - Loading scores for the first four principal components (image by author)
Image 9 – Loading scores for the first four principal components (image by author)

If the first principal component accounts for ~ 29% of the variability and the fixed accidity feature has a high loading score (0.86), it means it’s an important feature and should be kept for further analysis and predictive modeling.

This way, you can hack your way around loading scores to use them as a feature selection technique.

Summing up Python PCA Visualizations

To conclude, PCA is a must-know technique for reducing dimensionality and gaining insights. We humans suck at interpreting long lists of numbers, so visualization is the way to go.

Thanks to PCA, you can visualize high-dimensional datasets with 2D/3D charts. You’ll lose some variance along the way, so interpret your visualizations carefully.

Which PCA visualizations do you use in Data Science projects? Make sure to let me know in the comment section below.


Join my newsletter to receive exclusive content on Python, Data Science, and Machine Learning.

Dario’s Substack | Dario Radecic | Substack

Originally published at https://darioradecic.substack.com.

The post 5 PCA Visualizations You Must Try On Your Next Data Science Project appeared first on Towards Data Science.

]]>
Python Poetry – The Best Data Science Dependency Management Tool? https://towardsdatascience.com/python-poetry-the-best-data-science-dependency-management-tool-cca260257dd5/ Tue, 30 Jul 2024 14:23:09 +0000 https://towardsdatascience.com/python-poetry-the-best-data-science-dependency-management-tool-cca260257dd5/ Poetry makes deploying machine learning applications a breeze - learn how!

The post Python Poetry – The Best Data Science Dependency Management Tool? appeared first on Towards Data Science.

]]>
If I had a dollar every time I faced a missing Python dependency or a version mismatch, well, I wouldn’t be a millionaire, but you get the point.

Dependency management is a common problem in Data Science with many potential solutions. A virtual environment is always recommended, but that’s only the beginning. It’s usually followed by keeping track of installed packages. But what about their dependencies? And dependencies of their dependencies? It’s a recursive nightmare.

Poetry might be the solution you’re looking for. It aims to be a one-stop-shop for everything dependency management related, and can even be used for publishing Python packages.

Today, you’ll build a simple Machine Learning application locally and then push it to a remote compute instance. If Poetry keeps up to its promise, the remote setup should be as simple as running a single shell command.


How to Get Started with Poetry

A small inconvenience with Poetry is that you can’t start with a pip install command. It needs an additional command line tool.

Install Poetry

And that command line tool is called pipx. The installation instructions for Mac are pretty straightforward (follow the above link for other operating systems):

brew install pipx
pipx ensurepath
sudo pipx ensurepath --global

Once pipx is installed, install Poetry with the following command:

pipx install poetry

Just a slight inconvenience, but it’s over now.

Create a new Poetry Project

The poetry new <name> command initializes a new Python project and creates an appropriately named folder:

poetry new ml-demo
Image 1 - New Poetry project output (image by author)
Image 1 – New Poetry project output (image by author)

These are the files that’ll get created:

Image 2 - Directory structure (image by author)
Image 2 – Directory structure (image by author)

The file to focus your attention now is pyproject.toml. It contains details of your app, external dependencies, and the minimum required Python version.

If you don’t want to publish the package, which you don’t want in this case, add the following line to the first block:

[tool.poetry]
package-mode = false
Image 3 - The pyproject.toml file (image by author)
Image 3 – The pyproject.toml file (image by author)

You can also remove the tests and ml_demo folders if you wish.

We’ll return to this file after installing a couple of dependencies.

Python Poetry in Action – A Sample Machine Learning Application

Your machine learning app will be a simple decision tree model trained on the Iris dataset and made available through a multi-worker FastAPI service.

Any dependency you need is not installed with pip install, but rather with poetry add command:

poetry add numpy pandas scikit-learn fastapi gunicorn
Image 4 - Installing Python packages (image by author)
Image 4 – Installing Python packages (image by author)

This creates a virtual environment if it doesn’t exist, installs the packages and create/update the poetry.lock file:

Image 5 - The poetry.lock file (image by author)
Image 5 – The poetry.lock file (image by author)

In plain English, the lock file handles the recursive nightmare of dependency management.

The pyproject.toml file was also updated, but it shows only the dependencies you’ve explicitly installed:

Image 6— The pyproject.toml file (2) (image by author)
Image 6— The pyproject.toml file (2) (image by author)

So far, so good!

Visual Studio Code Setup

You’ll face an annoying issue when you start writing Python code in an editor such as VSCode – the editor doesn’t know where to find the newly created virtual environment!

And neither do you.

The poetry show -v command lists the path to the virtual environment and shows dependencies it has installed:

poetry show -v
Image 7 - Listing virtual environment details (image by author)
Image 7 – Listing virtual environment details (image by author)

If you want only the environment’s absolute path, opt for this command instead:

poetry env info -p
Image 8 - Printing absolute path to the Python environment (image by author)
Image 8 – Printing absolute path to the Python environment (image by author)

Now in Visual Studio Code (or any other editor), simply change the path to your Python interpreter. Copy the path returned by the above command, and add /bin/python3.<version> to the end. I’m using Python 3.12, so the part to append becomes /bin/python3.12:

/<path-to-environment>/bin/python3.<version>
Image 9 - Setting up VSCode interpreter (image by author)
Image 9 – Setting up VSCode interpreter (image by author)

The import errors will now disappear.

Machine Learning Application

Onto the app now. Create a utils/ml.py file and copy the following code:

import os
import pickle
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

MODEL_PATH = "ml_models/iris.model"

def train_model() -> DecisionTreeClassifier:
    iris = load_iris()
    X = iris.data
    y = iris.target
    model = DecisionTreeClassifier()
    model.fit(X, y)
    return model

def save_model(model: DecisionTreeClassifier) -> bool:
    os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
    try:
        with open(MODEL_PATH, "wb") as f:
            pickle.dump(model, f)
        return True
    except Exception as e:
        print(f"An error occurred while saving the model: {e}")
        return False

def load_model() -> DecisionTreeClassifier:
    try:
        with open(MODEL_PATH, 'rb') as f:
            model = pickle.load(f)
        return model
    except Exception as e:
        print(f"An error occurred while loading the model: {e}")
        raise e

def predict(
        model: DecisionTreeClassifier, 
        sepal_length: float, 
        sepal_width: float, 
        petal_length: float, 
        petal_width: float
    ) -> dict:
    prediction = model.predict([[sepal_length, sepal_width, petal_length, petal_width]])
    prediction_str = ""
    match prediction:
        case 0:
            prediction_str = "setosa"
        case 1:
            prediction_str = "versicolor"
        case 2:
            prediction_str = "virginica"

    prediction_prob = model.predict_proba([[sepal_length, sepal_width, petal_length, petal_width]])
    return {
        "prediction": prediction_str,
        "prediction_probabilities": prediction_prob[0].tolist()
    }

It’s a simple file with a couple of functions to train the model, make a single prediction, and read/write the model file to/from disk.

Also, create main.py file to test the functionality:

from utils.ml import train_model, save_model, load_model, predict

if __name__ == "__main__":
    model = train_model()
    did_save = save_model(model=model)
    if did_save:
        print("Model saved")
    loaded_model = load_model()
    pred = predict(
        model=loaded_model,
        sepal_length=4.2,
        sepal_width=3.1,
        petal_length=3.4,
        petal_width=4.1
    )
    print(pred)

To use the virtual environment created by Poetry, you must prepend Python script run commands with poetry run:

poetry run python main.py
Image 10 - Model prediction result (image by author)
Image 10 – Model prediction result (image by author)

You get the prediction class and probabilities back, so it’s safe to assume everything works as advertised.

FastAPI Application

Finally, replace the contents of main.py will the following code snippet:

import os
from fastapi import FastAPI
from pydantic import BaseModel
from utils.ml import train_model, save_model, load_model, predict

app = FastAPI()
model = None

# Request body for prediction
class IrisFeatures(BaseModel):
    sepal_length: float
    sepal_width: float
    petal_length: float
    petal_width: float

# Ensure the model is loaded on startup
@app.on_event("startup")
def startup_event():
    if not os.path.exists("ml_models/iris.model"):
        _model = train_model()
        did_save_model = save_model(model=_model)
        if did_save_model:
            print("Model trained and saved successfully.")
        else:
            print("Model training and saving failed.")
    global model
    model = load_model()

# The prediction endpoint
@app.post("/predict")
def make_prediction(iris: IrisFeatures):
    return predict(
        model=model,
        sepal_length=iris.sepal_length,
        sepal_width=iris.sepal_width,
        petal_length=iris.petal_length,
        petal_width=iris.petal_width
    )

It will create or load the model when the application starts (depending if the model already exists), and expose the prediction functionality on the /predict endpoint.

For the final bit of testing, run the FastAPI application through Gunicorn and optionally increase the number of workers. It doesn’t make any difference locally, but changing this parameter will allow more users to access your API at the same time:

poetry run gunicorn main:app --workers 4 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:8500
Image 11 - FastAPI startup (image by author)
Image 11 – FastAPI startup (image by author)

Seems like the app is running, so let’s test it:

Image 12 - Model prediction result (2) (image by author)
Image 12 – Model prediction result (2) (image by author)

The correct output is returned! Let’s now transfer the app to a new environment to see what happens.

Test on a New Environment – EC2 Instance

I’ve provisioned a free tier EC2 instance (Ubuntu) on AWS for this section of the article.

Assuming you’ve done the same, and assuming yours also comes with Python 3.12 (specified in pyproject.toml) or later, run the following set of commands to update the system and install Poetry:

sudo apt update &amp;&amp; sudo apt upgrade -y
sudo apt install gunicorn -y

sudo apt install pipx -y
pipx ensurepath
sudo pipx ensurepath --global

pipx install poetry
Image 13 - Linux machine setup (image by author)
Image 13 – Linux machine setup (image by author)

Visual Studio Code has a neat Remote SSH plugin that allows you to connect to remote instances. But really, any SCP tool will do the trick.

I’ve copied our ml-demo folder with Python code and Poetry environment details:

Image 14 - Directory structure (image by author)
Image 14 – Directory structure (image by author)

Poetry promises to make dependency management in new environments a breeze. All you have to do is navigate to the application folder and run the poetry install command:

cd ml-demo
poetry install
Image 15 - Package installation (image by author)
Image 15 – Package installation (image by author)

It seems like a new virtual environment was created and dependencies were installed.

The ultimate test will be running the same FastAPI startup command you ran moments ago in a local environment. If no errors are raised, Poetry has delivered on its promise:

poetry run gunicorn main:app --workers 4 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:8500
Image 16 - FastAPI startup on a remote instance (image by author)
Image 16 – FastAPI startup on a remote instance (image by author)

The app looks to be running, which is great news! Now, assuming you’ve allowed traffic to port 8500, you will be able to make requests to the /predict endpoint:

Image 17 - Model prediction result (3) (image by author)
Image 17 – Model prediction result (3) (image by author)

You should get the same response back as with the local testing. If the request hangs for a while and then dies, your instance likely doesn’t allow traffic on the port the app is running on.

Anyhow, the deployment process sure was a breeze.


Summing up Python Poetry

Tools like Poetry promise to reduce the number of migraines you get during model deployment. They also allow you to focus on what’s important, such as improving model quality and enriching its response, and not waste time on errors that shouldn’t happen in the first place.

For the short time I’ve been using Poetry on my day job, I’ve yet to find an error in which Poetry is to blame. Don’t get me wrong, I’ve seen errors after running the poetry install command, but they’ve all been OS-related. Like forgetting to install Gunicorn or something even dumber.

Have you used Poetry professionally? Did you encounter any issues and/or limitations? Make sure to let me know in the comment section below.

Read next:

Python Concurrency – A Brain-Friendly Guide for Data Professionals

The post Python Poetry – The Best Data Science Dependency Management Tool? appeared first on Towards Data Science.

]]>
Python Concurrency – A Brain-Friendly Guide for Data Professionals https://towardsdatascience.com/python-concurrency-a-brain-friendly-guide-for-data-professionals-a6215a3e9e26/ Fri, 26 Jul 2024 13:55:32 +0000 https://towardsdatascience.com/python-concurrency-a-brain-friendly-guide-for-data-professionals-a6215a3e9e26/ Moving data around can be slow. Here's how you can squeeze every bit of performance optimization out of Python.

The post Python Concurrency – A Brain-Friendly Guide for Data Professionals appeared first on Towards Data Science.

]]>
Python is often criticized for being among the slowest programming languages. While that claim does hold some weight, it’s vital to point out that Python is often the first programming language newcomers learn. Hence, most of the code is highly unoptimized.

But Python does have a couple of tricks up its sleeve. Taking advantage of concurrent function execution is stupidly simple to implement, yet it can reduce the runtime of your data pipelines tenfold. Instead of hours, it’ll take minutes. All for free.

Today you’ll see how Concurrency works in Python, and you’ll also learn how to deal with exception handling, custom callbacks, and rate limiting. Let’s dig in!


JSON Placeholder – Your Data Source for the Day

The first order of business is configuring a data source. I wanted to avoid something proprietary and something that would take ages to set up. JSON Placeholder – a free REST API service – is the perfect candidate.

Image 1 - Available endpoints (image by author)
Image 1 – Available endpoints (image by author)

To be more precise, you’ll use the /posts/<post-id> endpoint that returns a simple JSON object:

Image 2 - Sample API response (image by author)
Image 2 – Sample API response (image by author)

I hear you saying – But making an API request will take a split second to finish! That’s true, but there’s always a way to slow it down.

Default Approach – Run Away and Never Look Back

Since a simple API call takes a split second to run, let’s make it slower to simulate a typical task in your data pipeline. The get_post(id: int) function will:

  • Check for a valid post_id— up to 100 is considered valid
  • Sleep for a random number of seconds between 1 and 10
  • Make an API request to get the post, raise an error if it occurs, or return the response otherwise

The sleeping part makes sure the function takes a while to finish:

Python">import time
import random
import requests

def get_post(post_id: int) -> dict:
    # Value check - Posts on the API only go up to ID of 100
    if post_id > 100:
        raise ValueError("Parameter `post_id` must be less than or equal to 100")

    # API URL
    url = f"https://jsonplaceholder.typicode.com/posts/{post_id}"

    # Sleep to imitate a long-running process
    time_to_sleep = random.randint(1, 10)
    time.sleep(time_to_sleep)

    # Fetch the data and return it
    r = requests.get(url)
    r.raise_for_status()
    result = r.json()
    # To indicate how much time fetching took
    result["fetch_time"] = time_to_sleep
    # Remove the longest key-value pair for formatting reasons
    del result["body"]
    return result

if __name__ == "__main__":
    print(get_post(1))
Image 3 - Default approach fetch results (image by author)
Image 3 – Default approach fetch results (image by author)

Time to sleep is completely random, so you’re likely to get a different result.

Things get painfully slow when you have to run this function multiple times. Python newcomers often opt for a loop, similar to the example below:

import time
import random
import requests
from datetime import datetime

def get_post(post_id: int) -> dict:
    if post_id > 100:
        raise ValueError("Parameter `post_id` must be less than or equal to 100")

    url = f"https://jsonplaceholder.typicode.com/posts/{post_id}"

    time_to_sleep = random.randint(1, 10)
    time.sleep(time_to_sleep)

    r = requests.get(url)
    r.raise_for_status()
    result = r.json()
    result["fetch_time"] = time_to_sleep
    del result["body"]
    return result

if __name__ == "__main__":
    # Measure the time
    time_start = datetime.now()
    print("Starting to fetch posts...n")

    # Simple iteration
    for post_id in range(1, 11):
        post = get_post(post_id)
        print(post)

    # Print total duration
    time_end = datetime.now()
    print(f"nAll posts fetched! Took: {(time_end - time_start).seconds} seconds.")
Image 4 - Default approach fetch results (2) (image by author)
Image 4 – Default approach fetch results (2) (image by author)

53 seconds! That’s the problem with sequential execution. Luckily for you, concurrency can help.

Say Hello to Your New Best Friend – ThreadPoolExecutor

Instead of a traditional loop, you can use the ThreadPoolExecutor class, which is a high-level interface for running functions asynchronously using threads. If your machine has 8 cores, this class will run the function on 12 threads by default (number of cores + 4).

The ThreadPoolExecutor class is easy to use, and it manages a pool of worker threads for you. No manual intervention is needed.

Let’s see how it works in practice.

The Simplest Way to Run Tasks Concurrently in Python

You’ll want to import ThreadPoolExecutor and the as_completed() function. The custom get_post() function remains unchanged.

The real magic happens below.

Essentially, you’re creating a new ThreadPoolExecutor through the context manager syntax (the most common approach). Inside, you’re using the submit() method to add tasks to the executor. The first parameter of this method is your function name, followed by its parameter values. You can dynamically iterate over a range of values for post_id using Python’s list comprehension.

A Future object is returned by the submit() function.

The as_completed() function will extract and print the result as individual threads finish with execution:

import time
import random
import requests
from datetime import datetime
# New imports
from concurrent.futures import ThreadPoolExecutor, as_completed

def get_post(post_id: int) -> dict:
    if post_id > 100:
        raise ValueError("Parameter `post_id` must be less than or equal to 100")

    url = f"https://jsonplaceholder.typicode.com/posts/{post_id}"

    time_to_sleep = random.randint(1, 10)
    time.sleep(time_to_sleep)

    r = requests.get(url)
    r.raise_for_status()
    result = r.json()
    result["fetch_time"] = time_to_sleep
    del result["body"]
    return result

if __name__ == "__main__":
    time_start = datetime.now()
    print("Starting to fetch posts...n")

    # Run post fetching concurrently
    with ThreadPoolExecutor() as tpe:
        # Submit tasks and get future objects
        futures = [tpe.submit(get_post, post_id) for post_id in range(1, 11)]
        # Process task results
        for future in as_completed(futures):
            # Get and display the result
            result = future.result()
            print(result)

    time_end = datetime.now()
    print(f"nAll posts fetched! Took: {(time_end - time_start).seconds} seconds.")
Image 5 - Python concurrent execution (image by author)
Image 5 – Python concurrent execution (image by author)

We’re down to 10 seconds from 53.

Why 10? Because the longest running call to get_post() decided to sleep for 10 seconds. My machine has 12 CPU cores (16 ThreadPoolExecutor workers), meaning all tasks were submitted to the executor at the same time.

Scaling Things Up

Let’s now see what happens if you decide to concurrently run more tasks than you have workers available. Only a single line of code was changed, as indicated by the comment above:

import time
import random
import requests
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed

def get_post(post_id: int) -> dict:
    if post_id > 100:
        raise ValueError("Parameter `post_id` must be less than or equal to 100")

    url = f"https://jsonplaceholder.typicode.com/posts/{post_id}"

    time_to_sleep = random.randint(1, 10)
    time.sleep(time_to_sleep)

    r = requests.get(url)
    r.raise_for_status()
    result = r.json()
    result["fetch_time"] = time_to_sleep
    del result["body"]
    return result

if __name__ == "__main__":
    time_start = datetime.now()
    print("Starting to fetch posts...n")

    with ThreadPoolExecutor() as tpe:
        # Submit tasks and get future objects - NOW 100 POSTS IN TOTAL
        futures = [tpe.submit(get_post, post_id) for post_id in range(1, 101)]
        for future in as_completed(futures):
            result = future.result()
            print(result)

    time_end = datetime.now()
    print(f"nAll posts fetched! Took: {(time_end - time_start).seconds} seconds.")
Image 6 - Python concurrent execution (2) (image by author)
Image 6 – Python concurrent execution (2) (image by author)

Long story short, ThreadPoolExecutor will start by running (num cores + 4) tasks at the time, and proceed with others as workers become available.

How to Handle Failure

Sometimes, the function you want to run concurrently will fail. Inside the for future in as_completed(futures): block, you can add a try/except block to implement exception handling.

To demonstrate, try submitting futures for post_ids up to 150 – as the get_post() function will raise an error for any post_id above 100:

import time
import random
import requests
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed

def get_post(post_id: int) -> dict:
    if post_id > 100:
        raise ValueError("Parameter `post_id` must be less than or equal to 100")

    url = f"https://jsonplaceholder.typicode.com/posts/{post_id}"

    time_to_sleep = random.randint(1, 10)
    time.sleep(time_to_sleep)

    r = requests.get(url)
    r.raise_for_status()
    result = r.json()
    result["fetch_time"] = time_to_sleep
    del result["body"]
    return result

if __name__ == "__main__":
    time_start = datetime.now()
    print("Starting to fetch posts...n")

    with ThreadPoolExecutor() as tpe:
        # Submit tasks and get future objects - NOW 150 POSTS IN TOTAL - 50 WILL FAIL
        futures = [tpe.submit(get_post, post_id) for post_id in range(1, 151)]
        # Process task results
        for future in as_completed(futures):
            # Your typical try/except block
            try:
                result = future.result()
                print(result)
            except Exception as e:
                print(f"Exception raised: {str(e)}")

    time_end = datetime.now()
    print(f"nAll posts fetched! Took: {(time_end - time_start).seconds} seconds.")
Image 7— Python concurrent execution (3) (image by author)
Image 7— Python concurrent execution (3) (image by author)

You can see that exception handling works like a charm. The order of execution is random, but that’s irrelevant.

Want to Do Custom Stuff? Add Callbacks

In case you don’t want to cram a bunch of code into the for future in as_completed(futures): block, you can call add_done_callback() to call your custom Python function. This function will have access to the Future object.

The following code snippet calls future_callback_fn() when execution on an individual thread finishes:

import time
import random
import requests
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed

def get_post(post_id: int) -> dict:
    if post_id > 100:
        raise ValueError("Parameter `post_id` must be less than or equal to 100")

    url = f"https://jsonplaceholder.typicode.com/posts/{post_id}"

    time_to_sleep = random.randint(1, 10)
    time.sleep(time_to_sleep)

    r = requests.get(url)
    r.raise_for_status()
    result = r.json()
    result["fetch_time"] = time_to_sleep
    del result["body"]
    return result

def future_callback_fn(future):
    print(f"[{datetime.now()}] Custom future callback function!")
    # You have access to the future object
    print(future.result())

if __name__ == "__main__":
    time_start = datetime.now()
    print("Starting to fetch posts...n")

    with ThreadPoolExecutor() as tpe:
        futures = [tpe.submit(get_post, post_id) for post_id in range(1, 11)]
        for future in as_completed(futures):
            # Custom callback
            future.add_done_callback(future_callback_fn)

    # Print total duration
    time_end = datetime.now()
    print(f"nAll posts fetched! Took: {(time_end - time_start).seconds} seconds.")
Image 8 - Python concurrent execution (4) (image by author)
Image 8 – Python concurrent execution (4) (image by author)

This is a great way to keep the code inside __main__ short and sweet.

Rate Limiting – How to Get Around That Pesky HTTP 429 Error

Every data professional works with REST APIs. While calling their endpoints concurrently is a good way to reduce the overall runtime, it can result in aHTTP 429 Too Many Request status.

The reason is simple – the API owner doesn’t want you making thousands of requests every second for the sake of performance. Or maybe, they’re restricting the traffic volume based on your subscription tear.

Whatever the case might be, an easy way around it is to install the requests-ratelimiter library and limit how many requests can be made per day, hour, minute, or second.

The example below demonstrates how to set a limit to 2 requests per second:

import time
import random
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
# New import
from requests_ratelimiter import LimiterSession

# Limit to max 2 calls per second
request_session = LimiterSession(per_second=2)

def get_post(post_id: int) -> dict:
    if post_id > 100:
        raise ValueError("Parameter `post_id` must be less than or equal to 100")

    url = f"https://jsonplaceholder.typicode.com/posts/{post_id}"

    time_to_sleep = random.randint(1, 10)
    time.sleep(time_to_sleep)

    # Use the request_session now
    r = request_session.get(url)
    r.raise_for_status()
    result = r.json()
    result["fetch_time"] = time_to_sleep
    del result["body"]
    return result

if __name__ == "__main__":
    time_start = datetime.now()
    print("Starting to fetch posts...n")

    # Everything here stays the same
    with ThreadPoolExecutor() as tpe:
        futures = [tpe.submit(get_post, post_id) for post_id in range(1, 16)]
        for future in as_completed(futures):
            result = future.result()
            print(result)

    time_end = datetime.now()
    print(f"nAll posts fetched! Took: {(time_end - time_start).seconds} seconds.")
Image 9— Python concurrent execution (5) (image by author)
Image 9— Python concurrent execution (5) (image by author)

The code still works flawlessly, but might take more time to finish. Still, it will be miles ahead of any sequential implementation.


Summing up Python Concurrency

Concurrency in Python requires a couple of additional lines of code, but is worth it 9/10 times. If your data pipeline fetches data from multiple sources, and the result of one fetch isn’t used as an input to the other, concurrency is a guaranteed way to speed things up.

Just imagine if you were building a web scraper that parses multiple pages with identical structures. Doing so concurrently would have a massive positive impact on the overall runtime.

Or if you’re downloading data from cloud object storage. Simply write a function to download all files in a folder, and then run it concurrently on the root folder contents.

The applications are endless, but it does require a bit of thinking ahead.

Read next:

Python One Billion Row Challenge – From 10 Minutes to 4 Seconds

The post Python Concurrency – A Brain-Friendly Guide for Data Professionals appeared first on Towards Data Science.

]]>
Python One Billion Row Challenge – From 10 Minutes to 4 Seconds https://towardsdatascience.com/python-one-billion-row-challenge-from-10-minutes-to-4-seconds-0718662b303e/ Wed, 08 May 2024 17:12:43 +0000 https://towardsdatascience.com/python-one-billion-row-challenge-from-10-minutes-to-4-seconds-0718662b303e/ The one billion row challenge is exploding in popularity. How well does Python stack up?

The post Python One Billion Row Challenge – From 10 Minutes to 4 Seconds appeared first on Towards Data Science.

]]>
Photo by Alina Grubnyak on Unsplash
Photo by Alina Grubnyak on Unsplash

The question of how fast a Programming language can go through and aggregate 1 billion rows of data has been gaining traction lately. Python, not being the most performant language out there, naturally doesn’t stand a chance – especially since the currently top-performing Java implementation takes only 1.535 seconds!

The fundamental rule of the challenge is that no external libraries are allowed. My goal for today is to start by obeying the rules, and then see what happens if you use external libraries and better-suited file formats.

I’ve run all the scripts 5 times and averaged the results.

As for the hardware, I’m using a 16" M3 Pro Macbook Pro with 12 CPU cores and 36 GB of RAM. Your results may vary if you decide to run the code, but hopefully, you should see similar percentage differences between implementations.

Code


What is the 1 Billion Row Challenge?

The idea behind the 1 Billion Row Challenge (1BRC) is simple – go through a .txt file that contains arbitrary temperature measurements and calculate summary statistics for each station (min, mean, and max). The only issues are that you’re working with 1 billion rows and that the data is stored in an uncompressed .txt format (13.8 GB).

The dataset is generated by the data/createMeasurements.py script on my GitHub repo. I’ve copied the script from the source, just to have everything in the same place.

Once you generate the dataset, you’ll end up with a 13.8 GB semicolon-separated text file with two columns – station name and measurement.

And at the end of the aggregation process, you should end up with something like this:

Image 1 - Sample results (image by author)
Image 1 – Sample results (image by author)

The actual output format is somewhat different from one implementation to the other, but this is the one I found proposed by the official Python repo.

You now know what the challenge is, so let’s dive into the implementation next!

1 Billion Row Challenge – Pure Python Implementation

This is the only section in which I plan to obey the challenge rules. The reason is simple – Python doesn’t stand a chance with its standard library, and everyone in the industry relies heavily on third-party packages.

Single-Core Implementation

By far the easiest one to implement. You go through the text file and keep track of the measurements in a dictionary. It’s simple for min and max calculations, but mean requires keeping track of the count and then dividing the results.

# https://github.com/ifnesi/1brc#submitting
# Modified the multiprocessing version

def process_file(file_name: str):
    result = dict()

    with open(file_name, "rb") as f:
        for line in f:
            location, measurement = line.split(b";")
            measurement = float(measurement)
            if location not in result:
                result[location] = [
                    measurement,
                    measurement,
                    measurement,
                    1,
                ]
            else:
                _result = result[location]
                if measurement < _result[0]:
                    _result[0] = measurement
                if measurement > _result[1]:
                    _result[1] = measurement
                _result[2] += measurement
                _result[3] += 1

    print("{", end="")
    for location, measurements in sorted(result.items()):
        print(
            f"{location.decode('utf8')}={measurements[0]:.1f}/{(measurements[2] / measurements[3]) if measurements[3] !=0 else 0:.1f}/{measurements[1]:.1f}",
            end=", ",
        )
    print("bb} ")

if __name__ == "__main__":
    process_file("data/measurements.txt")

Multi-core implementation

The same idea as before, but now you need to split the text file into equally sized chunks and process them in parallel. Here, you compute the statistics for each chunk and then combine the results.

# Code credits: https://github.com/ifnesi/1brc#submitting

import os
import multiprocessing as mp

def get_file_chunks(
    file_name: str,
    max_cpu: int = 8,
) -> list:
    """Split flie into chunks"""
    cpu_count = min(max_cpu, mp.cpu_count())

    file_size = os.path.getsize(file_name)
    chunk_size = file_size // cpu_count

    start_end = list()
    with open(file_name, "r+b") as f:

        def is_new_line(position):
            if position == 0:
                return True
            else:
                f.seek(position - 1)
                return f.read(1) == b"n"

        def next_line(position):
            f.seek(position)
            f.readline()
            return f.tell()

        chunk_start = 0
        while chunk_start < file_size:
            chunk_end = min(file_size, chunk_start + chunk_size)

            while not is_new_line(chunk_end):
                chunk_end -= 1

            if chunk_start == chunk_end:
                chunk_end = next_line(chunk_end)

            start_end.append(
                (
                    file_name,
                    chunk_start,
                    chunk_end,
                )
            )

            chunk_start = chunk_end

    return (
        cpu_count,
        start_end,
    )

def _process_file_chunk(
    file_name: str,
    chunk_start: int,
    chunk_end: int,
) -> dict:
    """Process each file chunk in a different process"""
    result = dict()
    with open(file_name, "rb") as f:
        f.seek(chunk_start)
        for line in f:
            chunk_start += len(line)
            if chunk_start > chunk_end:
                break
            location, measurement = line.split(b";")
            measurement = float(measurement)
            if location not in result:
                result[location] = [
                    measurement,
                    measurement,
                    measurement,
                    1,
                ]  # min, max, sum, count
            else:
                _result = result[location]
                if measurement < _result[0]:
                    _result[0] = measurement
                if measurement > _result[1]:
                    _result[1] = measurement
                _result[2] += measurement
                _result[3] += 1
    return result

def process_file(
    cpu_count: int,
    start_end: list,
) -> dict:
    """Process data file"""
    with mp.Pool(cpu_count) as p:
        # Run chunks in parallel
        chunk_results = p.starmap(
            _process_file_chunk,
            start_end,
        )

    # Combine all results from all chunks
    result = dict()
    for chunk_result in chunk_results:
        for location, measurements in chunk_result.items():
            if location not in result:
                result[location] = measurements
            else:
                _result = result[location]
                if measurements[0] < _result[0]:
                    _result[0] = measurements[0]
                if measurements[1] > _result[1]:
                    _result[1] = measurements[1]
                _result[2] += measurements[2]
                _result[3] += measurements[3]

    # Print final results
    print("{", end="")
    for location, measurements in sorted(result.items()):
        print(
            f"{location.decode('utf8')}={measurements[0]:.1f}/{(measurements[2] / measurements[3]) if measurements[3] !=0 else 0:.1f}/{measurements[1]:.1f}",
            end=", ",
        )
    print("bb} ")

if __name__ == "__main__":
    cpu_count, *start_end = get_file_chunks("data/measurements.txt", max_cpu=12)
    process_file(cpu_count=cpu_count, start_end=start_end[0])

PyPy implementation

Leverages the multiprocessing implementation to a large extent, but uses PyPy instead of CPython. This allows you to use a just-in-time compiler for your Python code. Great in all cases that don’t depend on CPython extensions.

# Code credits: https://github.com/ifnesi/1brc#submitting

import os
import multiprocessing as mp

def get_file_chunks(
    file_name: str,
    max_cpu: int = 8,
) -> list:
    """Split flie into chunks"""
    cpu_count = min(max_cpu, mp.cpu_count())

    file_size = os.path.getsize(file_name)
    chunk_size = file_size // cpu_count

    start_end = list()
    with open(file_name, "r+b") as f:

        def is_new_line(position):
            if position == 0:
                return True
            else:
                f.seek(position - 1)
                return f.read(1) == b"n"

        def next_line(position):
            f.seek(position)
            f.readline()
            return f.tell()

        chunk_start = 0
        while chunk_start < file_size:
            chunk_end = min(file_size, chunk_start + chunk_size)

            while not is_new_line(chunk_end):
                chunk_end -= 1

            if chunk_start == chunk_end:
                chunk_end = next_line(chunk_end)

            start_end.append(
                (
                    file_name,
                    chunk_start,
                    chunk_end,
                )
            )

            chunk_start = chunk_end

    return (
        cpu_count,
        start_end,
    )

def _process_file_chunk(
    file_name: str,
    chunk_start: int,
    chunk_end: int,
    blocksize: int = 1024 * 1024,
) -> dict:
    """Process each file chunk in a different process"""
    result = dict()

    with open(file_name, "r+b") as fh:
        fh.seek(chunk_start)

        tail = b""
        location = None
        byte_count = chunk_end - chunk_start

        while byte_count > 0:
            if blocksize > byte_count:
                blocksize = byte_count
            byte_count -= blocksize

            index = 0
            data = tail + fh.read(blocksize)
            while data:
                if location is None:
                    try:
                        semicolon = data.index(b";", index)
                    except ValueError:
                        tail = data[index:]
                        break

                    location = data[index:semicolon]
                    index = semicolon + 1

                try:
                    newline = data.index(b"n", index)
                except ValueError:
                    tail = data[index:]
                    break

                value = float(data[index:newline])
                index = newline + 1

                if location not in result:
                    result[location] = [
                        value,
                        value,
                        value,
                        1,
                    ]  # min, max, sum, count
                else:
                    _result = result[location]
                    if value < _result[0]:
                        _result[0] = value
                    if value > _result[1]:
                        _result[1] = value
                    _result[2] += value
                    _result[3] += 1

                location = None

    return result

def process_file(
    cpu_count: int,
    start_end: list,
) -> dict:
    """Process data file"""
    with mp.Pool(cpu_count) as p:
        # Run chunks in parallel
        chunk_results = p.starmap(
            _process_file_chunk,
            start_end,
        )

    # Combine all results from all chunks
    result = dict()
    for chunk_result in chunk_results:
        for location, measurements in chunk_result.items():
            if location not in result:
                result[location] = measurements
            else:
                _result = result[location]
                if measurements[0] < _result[0]:
                    _result[0] = measurements[0]
                if measurements[1] > _result[1]:
                    _result[1] = measurements[1]
                _result[2] += measurements[2]
                _result[3] += measurements[3]

    # Print final results
    print("{", end="")
    for location, measurements in sorted(result.items()):
        print(
            f"{location.decode('utf-8')}={measurements[0]:.1f}/{(measurements[2] / measurements[3]) if measurements[3] !=0 else 0:.1f}/{measurements[1]:.1f}",
            end=", ",
        )
    print("bb} ")

if __name__ == "__main__":
    cpu_count, *start_end = get_file_chunks("data/measurements.txt", max_cpu=12)
    process_file(cpu_count, start_end[0])

1BRC Pure Python Results

As for the results, well, take a look for yourself:

Image 2 - Pure Python implementation results (image by author)
Image 2 – Pure Python implementation results (image by author)

That’s pretty much all you can squeeze out of Python’s standard library. Even the PyPy implementation is over 11 times slower than the fastest Java implementation. So no, Python won’t win any speed contest any time soon.

Speeding Things Up – Using 3rd Party Python Libraries

But what if you rely on third-party libraries? I said it before and I’ll say it again – it’s against the rules of the competition – but I’m beyond it at this point. I just want to make it run faster. No one’s going to restrict me to Python’s standard library on my day job anyway.

Pandas

A must-know data analysis library for any Python data professional. Not nearly the fastest one, but has a far superior ecosystem. I’ve used Pandas 2.2.2 with the PyArrow engine when reading the text file.

import pandas as pd

df = (
    pd.read_csv("data/measurements.txt", sep=";", header=None, names=["station_name", "measurement"], engine="pyarrow")
        .groupby("station_name")
        .agg(["min", "mean", "max"])
)
df.columns = df.columns.get_level_values(level=1)
df = df.reset_index()
df.columns = ["station_name", "min_measurement", "mean_measurement", "max_measurement"]
df = df.sort_values("station_name")

print("{", end="")
for row in df.itertuples(index=False):
    print(
        f"{row.station_name}={row.min_measurement:.1f}/{row.mean_measurement:.1f}/{row.max_measurement:.1f}",
        end=", "
    )
print("bb} ")

Dask

Almost identical API to Pandas, but is lazy evaluated. You can use Dask to scale Pandas code across CPU cores locally or across machines on a cluster.

import dask.dataframe as dd

df = (
    dd.read_csv("data/measurements.txt", sep=";", header=None, names=["station_name", "measurement"], engine="pyarrow")
        .groupby("station_name")
        .agg(["min", "mean", "max"])
        .compute()
)

df.columns = df.columns.get_level_values(level=1)
df = df.reset_index()
df.columns = ["station_name", "min_measurement", "mean_measurement", "max_measurement"]
df = df.sort_values("station_name")

print("{", end="")
for row in df.itertuples(index=False):
    print(
        f"{row.station_name}={row.min_measurement:.1f}/{row.mean_measurement:.1f}/{row.max_measurement:.1f}",
        end=", "
    )
print("bb} ")

Polars

Similar to Pandas, but has a multi-threaded query engine written in Rust and offers order of operation optimization. I’ve written about it previously.

# Code credits: https://github.com/ifnesi/1brc#submitting

import polars as pl

df = (
    pl.scan_csv("data/measurements.txt", separator=";", has_header=False, with_column_names=lambda cols: ["station_name", "measurement"])
        .group_by("station_name")
        .agg(
            pl.min("measurement").alias("min_measurement"),
            pl.mean("measurement").alias("mean_measurement"),
            pl.max("measurement").alias("max_measurement")
        )
        .sort("station_name")
        .collect(streaming=True)
)

print("{", end="")
for row in df.iter_rows():
    print(
        f"{row[0]}={row[1]:.1f}/{row[2]:.1f}/{row[3]:.1f}", 
        end=", "
    )
print("bb} ")

DuckDB

Open-source, embedded, in-process, relational OLAP DBMS that is typically orders of magnitude faster than Pandas. You can use it from the shell or over 10 different programming languages. In most of them, you can choose between a traditional analytical interface and a SQL interface. I’ve written about it previously.


import duckdb

with duckdb.connect() as conn:
    data = conn.sql("""
        select
            station_name,
            min(measurement) as min_measurement,
            cast(avg(measurement) as decimal(8, 1)) as mean_measurement,
            max(measurement) as max_measurement
        from read_csv(
            'data/measurements.txt',
            header=false,
            columns={'station_name': 'varchar', 'measurement': 'decimal(8, 1)'},
            delim=';',
            parallel=true
        )
        group by station_name
        order by station_name
    """)

    print("{", end="")
    for row in sorted(data.fetchall()):
        print(
            f"{row[0]}={row[1]}/{row[2]}/{row[3]}",
            end=", ",
        )
    print("bb} ")

1BRC Third-Party Library Results

The results are interesting, to say at least:

Image 3 - Python data analysis libraries runtime results (image by author)
Image 3 – Python data analysis libraries runtime results (image by author)

Pandas is slow – no surprises here. Dask offers pretty much the same performance as multi-core Python implementation, but with around 100 lines of code less. Polars and DuckDB reduce the runtime to below 10 seconds, which is impressive!

Going One Step Further – Ditching .txt for .parquet

There’s still one bit of performance gain we can squeeze out, and that’s changing the data file format. The data/convertToParquet.py file in the repo will do just that.

The idea is to go from uncompressed and unoptimized 13.8 GB of text data to a compressed and columnar-oriented 2.51 GB Parquet file.

The libraries remain the same, so it doesn’t make sense to explain them again. I’ll just provide the source code:

Pandas

import pandas as pd

df = (
    pd.read_parquet("data/measurements.parquet", engine="pyarrow")
        .groupby("station_name")
        .agg(["min", "mean", "max"])
)
df.columns = df.columns.get_level_values(level=1)
df = df.reset_index()
df.columns = ["station_name", "min_measurement", "mean_measurement", "max_measurement"]
df = df.sort_values("station_name")

print("{", end="")
for row in df.itertuples(index=False):
    print(
        f"{row.station_name}={row.min_measurement:.1f}/{row.mean_measurement:.1f}/{row.max_measurement:.1f}",
        end=", "
    )
print("bb} ")

Dask

import dask.dataframe as dd

df = (
    dd.read_parquet("data/measurements.parquet")
        .groupby("station_name")
        .agg(["min", "mean", "max"])
        .compute()
)

df.columns = df.columns.get_level_values(level=1)
df = df.reset_index()
df.columns = ["station_name", "min_measurement", "mean_measurement", "max_measurement"]
df = df.sort_values("station_name")

print("{", end="")
for row in df.itertuples(index=False):
    print(
        f"{row.station_name}={row.min_measurement:.1f}/{row.mean_measurement:.1f}/{row.max_measurement:.1f}",
        end=", "
    )
print("bb} ")

Polars

import polars as pl

df = (
    pl.scan_parquet("data/measurements.parquet")
        .group_by("station_name")
        .agg(
            pl.min("measurement").alias("min_measurement"),
            pl.mean("measurement").alias("mean_measurement"),
            pl.max("measurement").alias("max_measurement")
        )
        .sort("station_name")
        .collect(streaming=True)
)

print("{", end="")
for row in df.iter_rows():
    print(
        f"{row[0]}={row[1]:.1f}/{row[2]:.1f}/{row[3]:.1f}", 
        end=", "
    )
print("bb} ")

DuckDB

import duckdb

with duckdb.connect() as conn:
    data = conn.sql("""
        select
            station_name,
            min(measurement) as min_measurement,
            cast(avg(measurement) as decimal(8, 1)) as mean_measurement,
            max(measurement) as max_measurement
        from parquet_scan('data/measurements.parquet')
        group by station_name
        order by station_name
    """)

    print("{", end="")
    for row in sorted(data.fetchall()):
        print(
            f"{row[0]}={row[1]}/{row[2]}/{row[3]}",
            end=", ",
        )
    print("bb} ")

1BRC with Parquet Results

It looks like we have a clear winner:

Image 4- Data analysis libraries on Parquet format runtime results (image by author)
Image 4- Data analysis libraries on Parquet format runtime results (image by author)

The DuckDB implementation on the Parquet file format reduced the runtime to below 4 seconds! It’s still about 2.5x times slower than the fastest Java implementation (on .txt), but it’s something to be happy with.


Conclusion

If there’s one visualization to remember from this article, it has to be the following one:

Image 5 - Average runtime results for all approaches (image by author)
Image 5 – Average runtime results for all approaches (image by author)

Sure, only the first three columns obey the official competition rules, but I don’t care. Speed is speed. All is fair in love and Python performance optimization.

Python will never be as fast as Java or any other compiled language – that’s the fact. The question you have to answer is how fast is fast enough. For me, less than 4 seconds for 1 billion rows of data is well below that margin.

What are your thoughts on the 1 Billion Row Challenge? Did you manage to implement a faster solution? Let me know in the comment section below.

Read next:

DuckDB and AWS – How to Aggregate 100 Million Rows in 1 Minute

The post Python One Billion Row Challenge – From 10 Minutes to 4 Seconds appeared first on Towards Data Science.

]]>
DuckDB and AWS – How to Aggregate 100 Million Rows in 1 Minute https://towardsdatascience.com/duckdb-and-aws-how-to-aggregate-100-million-rows-in-1-minute-3634eef06b79/ Thu, 25 Apr 2024 19:59:37 +0000 https://towardsdatascience.com/duckdb-and-aws-how-to-aggregate-100-million-rows-in-1-minute-3634eef06b79/ Process huge volumes of data with Python and DuckDB - An AWS S3 example.

The post DuckDB and AWS – How to Aggregate 100 Million Rows in 1 Minute appeared first on Towards Data Science.

]]>
When companies need a secure, performant, and scalable storage solution, they tend to gravitate toward the cloud. One of the most popular platforms in the game is Aws S3 – and for a good reason – it’s an industry-leading object storage solution that can serve as a data lake.

The question is – Can you aggregate S3 bucket data without downloading it? And can you do it fast?

The answer is yes to both questions. Duckdb allows you to connect to your S3 bucket directly via the httpfs extension. You’ll learn how to use it today by aggregating around 111 million rows split between 37 Parquet files.

Spoiler alert: It will take you around a minute.

Note: I wrote this post because I was searching for a more performant Pandas alternative. My goal was to perform analysis on large datasets locally instead of opting for cloud solutions. I have no affiliations with DuckDB or AWS.


AWS S3 Setup

First things first, you’ll need an AWS account and an S3 bucket. You’ll also want to create an IAM user for which you can generate an access key.

As for the data, I’ve downloaded Yellow Taxi data Parquet files from January 2021 to January 2024 from the following link:

This is what it looks like when loaded into an S3 bucket:

Image 1 - Parquet files in an S3 bucket (image by author)
Image 1 – Parquet files in an S3 bucket (image by author)

The bucket now contains 37 Parquet files taking 1.79 GB of space and containing over 111 million rows.

DuckDB AWS S3 Setup

Setup on the Python end requires the duckdb library and the httpfs extension for DuckDB. Assuming you have the library installed (simple pip installation), import it, and create a new connection:

import duckdb

conn = duckdb.connect()

DuckDB httpfs Extension

The httpfs extension, among other things, allows you to write/read files to and from a given AWS S3 bucket.

Install it and load it with the following Python command (run the installation only once):

conn.execute("""
    INSTALL httpfs;
    LOAD httpfs;
""").df()

You should see a success message like this:

Image 2 - Installing and loading httpfs extension (image by author)
Image 2 – Installing and loading httpfs extension (image by author)

DuckDB S3 Configuration

As for the S3 configuration, provide the region, access key, and secret access key to DuckDB:

conn.execute("""
    SET s3_region = '<your-region>';
    SET s3_access_key_id = '<your-access-key>';
    SET s3_secret_access_key = '<your-secret-key>';
""").df()

You should see a success message once again:

Image 3 - DuckDB S3 configuration (image by author)
Image 3 – DuckDB S3 configuration (image by author)

And that’s it! You can now query S3 data directly from DuckDB.

Python and DuckDB – How to Get Data From AWS

This section will show how long it takes to run two queries – simple count and aggregation – from 37 Parquet files stored on S3.

Query #1 – Simple Count

To read Parquet data from an S3 bucket, use the parquet_scan() function and provide a glob path to all Parquet files stored in the root path. Just remember to change the bucket name:

res_count = conn.execute("""
    select count(*)
    from parquet_scan('s3://<your-bucket-name>/*.parquet');
""").df()

res_count

Getting a count of over 111 million takes only 7 seconds:

Image 4 - DuckDB count results (image by author)
Image 4 – DuckDB count results (image by author)

Query #2 – Monthly Summary Statistics

And now let’s calculate summary statistics over all Parquet files. The goal is to get counts, sums, and averages for certain columns grouped on a monthly level:

res_agg = conn.execute("""
    select 
        period,
        count(*) as num_rides,
        round(avg(trip_duration), 2) as avg_trip_duration,
        round(avg(trip_distance), 2) as avg_trip_distance,
        round(sum(trip_distance), 2) as total_trip_distance,
        round(avg(total_amount), 2) as avg_trip_price,
        round(sum(total_amount), 2) as total_trip_price,
        round(avg(tip_amount), 2) as avg_tip_amount
    from (
        select
            date_part('year', tpep_pickup_datetime) as trip_year,
            strftime(tpep_pickup_datetime, '%Y-%m') as period,
            epoch(tpep_dropoff_datetime - tpep_pickup_datetime) as trip_duration,
            trip_distance,
            total_amount,
            tip_amount
        from parquet_scan('s3://duckdb-bucket-20240422/*.parquet')
        where trip_year >= 2021 and trip_year <= 2024
    )
    group by period
    order by period
""").df()

res_agg

It takes just over a minute to perform this aggregation:

Image 5 - DuckDB aggregation results (image by author)
Image 5 – DuckDB aggregation results (image by author)

As a frame of reference, the same operation took under 2 seconds when files were saved on a local disk. Pandas took over 11 minutes for the same task.


Summing up DuckDB and AWS in Python

Overall, if you have huge volumes of data stored on S3, DuckDB is your friend. You can analyze and aggregate data in no time without file downloads – which might not be allowed at all due to privacy and security concerns.

You can also use DuckDB as an analysis and aggregation layer between two AWS S3 storage layers. Raw data goes in, aggregated data goes out – hopefully, into another bucket or subdirectory.

Read next:

How to Train a Decision Tree Classifier… In SQL

The post DuckDB and AWS – How to Aggregate 100 Million Rows in 1 Minute appeared first on Towards Data Science.

]]>