Discrepancy in figure and axes facecolors when saving animation as GIF using Pillow writer

I’m working on an animated simulation using Matplotlib where I have explicitly set different facecolors for the figure and axes using:

self.fig.set_facecolor("#061323")
self.ax1.set_facecolor("#1c2833")
self.ax2.set_facecolor("#1c2833")

The animation renders perfectly with the expected color separation when saved in MP4 format using ffmpeg, i.e., the figure background and axes backgrounds appear as intended.

However, when I save the same animation as a GIF using the pillow writer, the entire frame (figure and axes) appears uniformly filled with the figure’s facecolor, completely ignoring the distinct axes colors. This effectively flattens the visual hierarchy and affects clarity.

Here’s how I’m saving the animation:

self.ani.save(
    filepath,
    writer="pillow",
    fps=self.fps,
    dpi=dpi,
    savefig_kwargs={
        "facecolor": "#17202a",
        "edgecolor": "none",
        "bbox_inches": "tight",
        "pad_inches": 0.1,
    },
)

I have tried explicitly updating the facecolors again within the update() method as well, but the saved GIF still ignores the axes background color. I am attaching snapshots of two animations in GIF and MP4 formats.

I am also including the complete code for the animation.

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import os
import matplotlib.lines as mlines

from scipy.stats import norm
from datetime import datetime


class BuffonLaplaceSimulationTriangular:
    """
    A class to simulate Buffon's Needle experiment within a regular hexagon with triangular tiling.

    The hexagon is divided into six equilateral triangles by its diagonals. Needles are dropped
    within the hexagon, and intersections with the hexagon's sides and diagonals are counted.
    Intersections are categorized into zero, one, or two (or more). A second subplot shows the
    π estimate with a user-defined Wilson score confidence interval.
    """

    def __init__(
        self,
        needle_length=1.0,
        a=2.0,  # Side length of the hexagon (also side length of equilateral triangles)
        ntrial=500,
        confidence_level=95,
        save_animation=False,
        filename=None,
        fps=20,
    ):
        """
        Initialize the Buffon's Needle simulation for a hexagonal region with triangular tiling.

        Args:
            needle_length (float): Length of the needle (L). Must be positive.
            a (float): Side length of the hexagon (and equilateral triangles). Must be positive.
            ntrial (int): Number of trials/needles to drop. Must be positive.
            confidence_level (int): Confidence level as a percentage (e.g., 95 for 95%).
                                    Must be between 0 and 100. Defaults to 95.
            save_animation (bool): Whether to save the animation. Defaults to False.
            filename (str): Name of the output file. If None, a default name is generated.
                            Must end with .gif or .mp4 if provided and save_animation is True.
            fps (int): Frames per second for the animation. Must be positive.
        """
        # --- Input Validation ---
        if not all(
            isinstance(arg, (int, float)) and arg > 0
            for arg in [needle_length, a, ntrial, fps]
        ):
            raise ValueError(
                "needle_length, a, ntrial, and fps must be positive numbers."
            )
        if (
            save_animation
            and filename
            and not filename.lower().endswith((".gif", ".mp4"))
        ):
            raise ValueError(
                "Filename must end with .gif or .mp4 for saving animations."
            )
        if not 0 < confidence_level <= 100:
            raise ValueError("confidence_level must be between 0 and 100.")

        # --- Initialize Attributes ---
        self.needle_length = float(needle_length)
        self.a = float(a)  # Hexagon side length
        self.ntrial = int(ntrial)
        self.confidence_level = float(confidence_level)
        self.save_animation = bool(save_animation)
        self.filename = str(filename) if filename is not None else None
        self.fps = int(fps)

        # Calculate z-score for the confidence level (two-tailed)
        self.z_score = norm.ppf(1 - (1 - self.confidence_level / 100) / 2)

        # Hexagon vertices (center at origin)
        self.h = np.sqrt(3) * self.a / 2  # Height from center to vertex along y-axis
        self.vertices = np.array(
            [
                [self.a, 0],  # Vertex 0
                [self.a / 2, self.h],  # Vertex 1
                [-self.a / 2, self.h],  # Vertex 2
                [-self.a, 0],  # Vertex 3
                [-self.a / 2, -self.h],  # Vertex 4
                [self.a / 2, -self.h],  # Vertex 5
            ]
        )

        # Define lines to check for intersections (6 sides + 3 diagonals)
        self.lines = []
        # Sides
        for i in range(len(self.vertices)):
            j = (i + 1) % len(self.vertices)
            self.lines.append((self.vertices[i], self.vertices[j]))
        # Diagonals (connecting opposite vertices)
        for i in range(3):
            j = (i + 3) % 6
            self.lines.append((self.vertices[i], self.vertices[j]))

        # Simulation state for intersection types
        self.no_intersection_count = 0
        self.one_intersection_count = 0
        self.two_intersections_count = 0
        self.needle_count = 0
        self.needles_artists = []

        # Data for CI plot
        self.trial_numbers = []
        self.pi_estimates_history = []
        self.ci_lower_history = []
        self.ci_upper_history = []

        self._setup_plot()

    def _setup_plot(self):
        """
        Set up the matplotlib plots for the simulation.

        Creates two subplots: one for the needle simulation within the hexagon
        and one for the π estimate with its confidence interval.
        """
        self.fig, (self.ax1, self.ax2) = plt.subplots(
            2, 1, figsize=(12, 10), gridspec_kw={"height_ratios": [3, 1]}
        )
        self.fig.subplots_adjust(
            left=0.05, right=0.95, top=0.88, bottom=0.1, wspace=0.2
        )
        self.fig.suptitle(
            "Buffon-Laplace Needle Simulation (Hexagonal Grid)",
            fontsize=16,
            y=0.98,
            color="white",
        )
        self.fig.set_facecolor("#061323")
        self.fig.patch.set_facecolor("#061323")

        self.ax1.set_facecolor("#1c2833")
        self.ax1.patch.set_facecolor("#1c2833")
        self.ax1.patch.set_alpha(1.0)

        self.ax2.set_facecolor("#1c2833")
        self.ax2.patch.set_facecolor("#1c2833")
        self.ax2.patch.set_alpha(1.0)

        self.ax1.spines[:].set_color("white")
        self.ax2.spines[:].set_color("white")

        # Define plot limits with padding
        padding = self.a * 0.5
        self.ax1.set_xlim(-self.a - padding, self.a + padding)
        self.ax1.set_ylim(-self.h - padding, self.h + padding)
        self.ax1.set_xticks([])
        self.ax1.set_yticks([])
        self.ax1.set_aspect("equal")

        self._draw_hexagon()

        self.ax1.set_title(
            f"Needle Simulation in Hexagon (a={self.a})",
            pad=10,
            fontsize=12,
            color="white",
        )

        # Initialize legend for ax1
        self.legend_handles = {
            "no_intersection": mlines.Line2D(
                [0], [0], color="red", lw=2, label="No Intersection (0)"
            ),
            "one_intersection": mlines.Line2D(
                [0], [0], color="cyan", lw=2, label="One Intersection (0)"
            ),
            "two_intersections": mlines.Line2D(
                [0], [0], color="lime", lw=2, label="Two Intersections (0)"
            ),
        }
        self.legend_ax1 = self.ax1.legend(
            handles=list(self.legend_handles.values()),
            loc="upper right",
            bbox_to_anchor=(1.38, 1.0),
            fontsize=8,
            labelcolor="white",
            facecolor="#17202a",
        )
        self.ax1.add_artist(self.legend_ax1)

        # Second subplot: π estimate and Confidence Interval
        self.ax2.set_xlim(0, self.ntrial)
        self.ax2.set_ylim(max(0, np.pi - 1.5), np.pi + 1.5)
        self.ax2.set_xlabel("Number of Needles Dropped", color="white")
        self.ax2.set_ylabel(r"$\pi$ Estimate", color="white")
        self.ax2.tick_params(axis="both", colors="white")
        self.ax2.grid(True, alpha=0.3, linestyle=":", color="white")
        self.ax2.set_title(
            f"$\\pi$ Estimate with {int(self.confidence_level)}% Wilson Score Interval",
            pad=10,
            fontsize=12,
            color="white",
        )

        (self.true_pi_line,) = self.ax2.plot(
            [0, self.ntrial],
            [np.pi, np.pi],
            color="cyan",
            linestyle="--",
            label=r"True $\pi$",
        )
        (self.pi_line,) = self.ax2.plot(
            [], [], color="red", lw=1.5, label="$\\pi$ Estimate"
        )
        self.ci_fill_plot = self.ax2.fill_between(
            [0],
            [0],
            [0],
            color="red",
            alpha=0.2,
            label=f"{int(self.confidence_level)}% CI",
        )
        self.ax2.legend(
            loc="upper right", fontsize=8, facecolor="#17202a", labelcolor="white"
        )

    def _draw_hexagon(self):
        """Draws the hexagon with its sides and diagonals, including extensions."""
        line_ext = 0.8  # Extension length for dotted lines

        def get_extended_segments(p1, p2, ext_length):
            """Return three parts of a line: pre-extension, main, post-extension."""
            direction = p2 - p1
            unit_dir = direction / np.linalg.norm(direction)
            pre = p1 - ext_length * unit_dir
            post = p2 + ext_length * unit_dir
            return pre, p1, p2, post

        # Draw hexagon sides
        for i in range(len(self.vertices)):
            j = (i + 1) % len(self.vertices)
            pre, p1, p2, post = get_extended_segments(
                self.vertices[i], self.vertices[j], line_ext
            )
            # Dotted extensions
            self.ax1.plot([pre[0], p1[0]], [pre[1], p1[1]], color="gray", ls="--", lw=2)
            self.ax1.plot(
                [p2[0], post[0]], [p2[1], post[1]], color="gray", ls="--", lw=2
            )
            # Main side
            self.ax1.plot([p1[0], p2[0]], [p1[1], p2[1]], color="gray", ls="-", lw=2)

        # Draw diagonals
        for i in range(3):
            j = (i + 3) % 6
            pre, p1, p2, post = get_extended_segments(
                self.vertices[i], self.vertices[j], line_ext
            )
            # Dotted extensions
            self.ax1.plot([pre[0], p1[0]], [pre[1], p1[1]], color="gray", ls="--", lw=2)
            self.ax1.plot(
                [p2[0], post[0]], [p2[1], post[1]], color="gray", ls="--", lw=2
            )
            # Main diagonal
            self.ax1.plot([p1[0], p2[0]], [p1[1], p2[1]], color="gray", ls="-", lw=2)

    def _is_point_inside_hexagon(self, x, y):
        """
        Check if a point (x, y) lies inside the hexagon using the ray-casting algorithm.

        Args:
            x, y (float): Coordinates of the point.

        Returns:
            bool: True if the point is inside the hexagon, False otherwise.
        """
        inside = False
        for i in range(len(self.vertices)):
            j = (i + 1) % len(self.vertices)
            xi, yi = self.vertices[i]
            xj, yj = self.vertices[j]
            # Check if the point (x, y) crosses the edge from vertex i to j
            if ((yi > y) != (yj > y)) and (
                x < (xj - xi) * (y - yi) / (yj - yi + 1e-10) + xi
            ):
                inside = not inside
        return inside

    def _generate_needle_position(self):
        """
        Generates a random position and orientation for a new needle within the hexagon.

        Returns:
            tuple: (x_center, y_center, theta_angle)
        """
        # Use rejection sampling to ensure the needle center is inside the hexagon
        while True:
            x_center = np.random.uniform(-self.a, self.a)
            y_center = np.random.uniform(-self.h, self.h)
            if self._is_point_inside_hexagon(x_center, y_center):
                break
        theta = np.random.uniform(0, np.pi)
        return x_center, y_center, theta

    def _calculate_endpoints(self, x, y, theta):
        """
        Calculates the (x, y) coordinates of both endpoints of the needle.

        Args:
            x, y (float): Coordinates of the needle's center.
            theta (float): Angle of the needle (radians).

        Returns:
            tuple: (x1, y1, x2, y2) coordinates of the endpoints.
        """
        half_L_cos_theta = (self.needle_length / 2) * np.cos(theta)
        half_L_sin_theta = (self.needle_length / 2) * np.sin(theta)
        x1 = x - half_L_cos_theta
        y1 = y - half_L_sin_theta
        x2 = x + half_L_cos_theta
        y2 = y + half_L_sin_theta
        return x1, y1, x2, y2

    def _check_intersection(self, x1, y1, x2, y2):
        """
        Checks how many lines (hexagon sides or diagonals) the needle intersects.

        Args:
            x1, y1, x2, y2 (float): Coordinates of the needle's endpoints.

        Returns:
            int: Number of intersections (0, 1, or 2+).
        """

        def segments_intersect(p1, p2, q1, q2):
            """Check if line segments (p1,p2) and (q1,q2) intersect."""

            def orientation(p, q, r):
                val = (q[1] - p[1]) * (r[0] - p[0]) - (q[0] - p[0]) * (r[1] - p[1])
                if abs(val) < 1e-10:
                    return 0  # Collinear
                return 1 if val > 0 else 2  # Clockwise or counterclockwise

            def on_segment(p, q, r):
                return (
                    q[0] <= max(p[0], r[0])
                    and q[0] >= min(p[0], r[0])
                    and q[1] <= max(p[1], r[1])
                    and q[1] >= min(p[1], r[1])
                )

            o1 = orientation(p1, p2, q1)
            o2 = orientation(p1, p2, q2)
            o3 = orientation(q1, q2, p1)
            o4 = orientation(q1, q2, p2)

            # General case
            if o1 != o2 and o3 != o4:
                return True

            # Special cases (collinear and overlapping)
            if o1 == 0 and on_segment(p1, q1, p2):
                return True
            if o2 == 0 and on_segment(p1, q2, p2):
                return True
            if o3 == 0 and on_segment(q1, p1, q2):
                return True
            if o4 == 0 and on_segment(q1, p2, q2):
                return True

            return False

        intersections = 0
        needle_p1, needle_p2 = np.array([x1, y1]), np.array([x2, y2])
        for line_p1, line_p2 in self.lines:
            if segments_intersect(needle_p1, needle_p2, line_p1, line_p2):
                intersections += 1

        return min(intersections, 2)  # Cap at 2 for coloring purposes

    def _calculate_pi_and_confidence_interval(self):
        """
        Calculates the pi estimate and Wilson Score confidence interval for a triangular grid.

        Returns:
            tuple: (pi_estimate, ci_lower_bound, ci_upper_bound)
        """
        total_hits = self.one_intersection_count + self.two_intersections_count
        if self.needle_count == 0 or total_hits == 0:
            return 0.0, np.nan, np.nan

        p_cross = float(total_hits / self.needle_count)
        L = self.needle_length
        a = self.a
        x = L / a

        # Pi estimate: π ≈ ( (l/a) * sqrt(3) * (4 - l/a) ) / ( P_cross + (2/3) * (l/a)^2 )
        numerator = x * np.sqrt(3) * (4 - x)
        denominator = p_cross + (2 / 3) * (x**2)
        pi_estimate = numerator / denominator if denominator > 0 else 0.0

        # Wilson Score Confidence Interval
        ci_lower, ci_upper = np.nan, np.nan
        if total_hits > 0 and total_hits < self.needle_count:
            n = float(self.needle_count)
            z = float(self.z_score)
            term1 = p_cross + (z**2) / (2 * n)
            term2 = z * np.sqrt((p_cross * (1 - p_cross)) / n + (z**2) / (4 * n**2))
            denominator_ci = 1 + (z**2) / n
            p_lower = max(0.0, (term1 - term2) / denominator_ci)
            p_upper = min(1.0, (term1 + term2) / denominator_ci)

            # Transform to pi CI
            if p_lower > 0:
                ci_upper = numerator / (p_lower + (2 / 3) * (x**2))
            else:
                ci_upper = float("inf")
            if p_upper > 0:
                ci_lower = numerator / (p_upper + (2 / 3) * (x**2))
            else:
                ci_lower = float("inf")

        return pi_estimate, ci_lower, ci_upper

    def update(self, frame):
        """
        Update function for the animation.

        Args:
            frame (int): Current frame number.

        Returns:
            list: Matplotlib artists to redraw.
        """

        # Ensure face colors are maintained during animation
        self.fig.patch.set_facecolor("#061323")
        self.ax1.patch.set_facecolor("#1c2833")
        self.ax2.patch.set_facecolor("#1c2833")

        if self.needle_count >= self.ntrial:
            return self.needles_artists + [
                self.pi_line,
                self.ci_fill_plot,
                self.true_pi_line,
                self.ax1.title,
                self.ax2.title,
                self.legend_ax1,
            ]

        # Drop a needle
        x, y, theta = self._generate_needle_position()
        x1, y1, x2, y2 = self._calculate_endpoints(x, y, theta)
        intersections = self._check_intersection(x1, y1, x2, y2)

        # Color based on number of intersections
        if intersections == 0:
            color = "red"
            self.no_intersection_count += 1
        elif intersections == 1:
            color = "cyan"
            self.one_intersection_count += 1
        else:  # intersections >= 2
            color = "lime"
            self.two_intersections_count += 1

        # Plot needle
        needle = self.ax1.plot([x1, x2], [y1, y2], c=color, lw=2)[0]
        self.needles_artists.append(needle)
        self.needle_count += 1

        # Update legend
        self.legend_handles["no_intersection"].set_label(
            f"No Intersection ({self.no_intersection_count})"
        )
        self.legend_handles["one_intersection"].set_label(
            f"One Intersection ({self.one_intersection_count})"
        )
        self.legend_handles["two_intersections"].set_label(
            f"Two Intersections ({self.two_intersections_count})"
        )
        self.legend_ax1.remove()
        self.legend_ax1 = self.ax1.legend(
            handles=list(self.legend_handles.values()),
            loc="upper right",
            bbox_to_anchor=(1.38, 1.0),
            fontsize=8,
            labelcolor="white",
            facecolor="#17202a",
        )

        # Update pi estimate and CI
        current_pi_estimate, ci_lower, ci_upper = (
            self._calculate_pi_and_confidence_interval()
        )
        self.trial_numbers.append(self.needle_count)
        self.pi_estimates_history.append(current_pi_estimate)
        self.ci_lower_history.append(ci_lower if ci_lower is not None else np.nan)
        self.ci_upper_history.append(ci_upper if ci_upper is not None else np.nan)

        # Update title
        total_hits = self.one_intersection_count + self.two_intersections_count
        if total_hits > 0:
            error = abs((np.pi - current_pi_estimate) * 100 / np.pi)
            pi_str = f"$\\pi$ estimate: {current_pi_estimate:.4f} | Error: {error:.2f}%"
        else:
            pi_str = "$\\pi$ estimate: N/A | Error: N/A"

        self.ax1.set_title(
            f"Needle Simulation in Hexagon (a={self.a})\n"
            f"Needles Dropped: {self.needle_count} | Total Intersections: {total_hits}\n"
            f"{pi_str}",
            pad=10,
            color="white",
            fontsize=12,
        )

        # Update CI plot
        valid_indices = [
            i
            for i, (pi_est, lower, upper) in enumerate(
                zip(
                    self.pi_estimates_history,
                    self.ci_lower_history,
                    self.ci_upper_history,
                )
            )
            if not (
                np.isnan(pi_est)
                or np.isinf(pi_est)
                or np.isnan(lower)
                or np.isinf(lower)
                or np.isnan(upper)
                or np.isinf(upper)
            )
            and total_hits >= 2
        ]

        plot_trials = [self.trial_numbers[i] for i in valid_indices]
        plot_pi_estimates = [self.pi_estimates_history[i] for i in valid_indices]
        plot_ci_lower = [self.ci_lower_history[i] for i in valid_indices]
        plot_ci_upper = [min(self.ci_upper_history[i], 10) for i in valid_indices]

        if plot_pi_estimates:
            self.pi_line.set_data(plot_trials, plot_pi_estimates)
            if self.ci_fill_plot is not None and self.ci_fill_plot.get_paths():
                self.ci_fill_plot.remove()
                self.ci_fill_plot = None
            self.ci_fill_plot = self.ax2.fill_between(
                plot_trials,
                np.array(plot_ci_lower),
                np.array(plot_ci_upper),
                color="red",
                alpha=0.2,
                label=f"{int(self.confidence_level)}% CI",
            )

            all_y_values = plot_pi_estimates + plot_ci_lower + plot_ci_upper + [np.pi]
            all_y_values = [
                v for v in all_y_values if not np.isinf(v) and not np.isnan(v)
            ]
            if all_y_values:
                min_y = min(all_y_values) * 0.95
                max_y = min(max(all_y_values) * 1.05, 10)
                if abs(max_y - min_y) < 0.5:
                    mid_y = (max_y + min_y) / 2
                    min_y = mid_y - 0.25
                    max_y = mid_y + 0.25
                self.ax2.set_ylim(min_y, max_y)
            else:
                self.ax2.set_ylim(max(0, np.pi - 1.5), np.pi + 1.5)
            self.ax2.set_xlim(0, max(self.needle_count, 1))

        return self.needles_artists + [
            self.pi_line,
            self.ci_fill_plot,
            self.true_pi_line,
            self.ax1.title,
            self.ax2.title,
            self.legend_ax1,
        ]

    def run_and_save_animation(self, dpi=150):
        """
        Run the simulation and optionally save the animation.

        Returns:
            matplotlib.animation.FuncAnimation: The animation object.
        """
        print(
            f"Starting Buffon-Laplace Needle Simulation for {self.ntrial} "
            f"trials within a hexagonal grid (a={self.a})..."
        )

        self.ani = FuncAnimation(
            self.fig,
            self.update,
            frames=self.ntrial,
            interval=int(1000 / self.fps),
            blit=False,
            repeat=False,
        )

        if self.save_animation:
            save_dir = "ANIMATIONS/GRIDS"
            os.makedirs(save_dir, exist_ok=True)
            if self.filename is None:
                current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
                self.filename = (
                    f"buffon_laplace_hexagonal_a{self.a}_"
                    f"{self.ntrial}_trials_CI{int(self.confidence_level)}_"
                    f"{current_time}.gif"
                )
            filepath = os.path.join(save_dir, self.filename)
            print(f"Attempting to save animation to {filepath}...")
            writer = "pillow" if self.filename.lower().endswith(".gif") else "ffmpeg"
            try:
                self.ani.save(
                    filepath,
                    writer=writer,
                    fps=self.fps,
                    dpi=dpi,
                    savefig_kwargs={
                        "facecolor": "#17202a",
                        "edgecolor": "none",
                        "bbox_inches": "tight",
                        "pad_inches": 0.1,
                    },
                )
                print(f"Animation saved successfully to {os.path.abspath(filepath)}")
            except Exception as e:
                print(f"Error saving animation: {e}")
                print("Ensure Pillow (for GIFs) or FFmpeg (for MP4s) is installed.")
            finally:
                plt.close(self.fig)
        else:
            plt.show()

        return self.ani


if __name__ == "__main__":
    print(
        "\n--- Running Buffon-Laplace Hexagonal Grid Simulation (500 trials, 95% CI) ---"
    )
    sim_hexagonal = BuffonLaplaceSimulationTriangular(
        needle_length=0.5,  # L < a * sqrt(3)/2 ≈ 1.732 for best results with a=2
        a=2.0,  # Hexagon side length
        ntrial=100,
        confidence_level=95,
        save_animation=True,
        fps=25,
    )
    sim_hexagonal.run_and_save_animation()

My Questions:

  1. Is this a known limitation with the pillow writer in Matplotlib?
  2. Is there any workaround to preserve different figure and axes background colors in GIF output?
  3. Are there specific savefig_kwargs or writer settings I should be using for GIFs to make it behave like the MP4 output?

I suspect the issues from a combination of GIF paletting (GIF - Wikipedia) which limits the file to 256 distinct colors, and antialiasing, which means that e.g. the edges of the needles are drawn in various intermediate colors between red/green/cyan and grey (in different weightings), which ends up “taking up” the available color space. The color quantization algorithm could then decide to merge together the two “background grays”, depending on the specifics of the algorithm.

It would be interesting if there was a way to specify to PIL “these two specific colors must not be merged upon quantization”, but I don’t think such an API is available…

The quantization in Pillow is a bit opaque and prone to producing unwanted results. On top of that, the GIF format has a very restricted colour palette. IMO, I would just not attempt to create a GIF. Websites like GIPHY already use MP4 automatically for high-quality 'GIF’s because they look and compress better.