from matplotlib.axes import Axes
from matplotlib.lines import Line2D
from matplotlib.collections import LineCollection
from matplotlib.ticker import FixedLocator, AutoLocator, ScalarFormatter, FixedLocator
from matplotlib import transforms
from matplotlib.projections import register_projection

import numpy as np

class SkewXAxes(Axes):
    # The projection must specify a name.  This will be used be the
    # user to select the projection, i.e. ``subplot(111,
    # projection='skewx')``.
    name = 'skewx'

    def set_xlim(self, *args):
        Axes.set_xlim(self, *args)

    def draw(self, *args):
        '''
        draw() is overridden here to allow the data transform to be updated
        before calling the Axes.draw() method.  This allows resizes to be
        properly handled without registering callbacks.  The amount of
        work done here is kept to a minimum.
        '''
        self._update_data_transform()
        Axes.draw(self, *args)

    def _update_data_transform(self):
        '''
        This separates out the creating of the data transform so that
        it alone is updated at draw time.
        '''
        # This transforms x in pixel space to be x + the offset in y from
        # the lower left corner - producing an x-axis sloped 45 degrees
        # down, or x-axis grid lines sloped 45 degrees to the right
        self.transProjection = transforms.Affine2D(
            np.array([[1, 1, -self.bbox.ymin], [0, 1, 0], [0, 0, 1]]))

        # Full data transform
        self.transData.set(self._transDataNonskew + self.transProjection)

    def _set_lim_and_transforms(self):
        """
        This is called once when the plot is created to set up all the
        transforms for the data, text and grids.
        """
        #Get the standard transform setup from the Axes base class
        Axes._set_lim_and_transforms(self)

        #Save the unskewed data transform for our own use when regenerating
        #the data transform. The user might want this as well
        self._transDataNonskew = self.transData

        #Create a wrapper for the data transform, so that any object that
        #grabs this transform will see an updated version when we change it
        self.transData = transforms.TransformWrapper(
            transforms.IdentityTransform())

        #Use the helper method to actually set the skewed data transform
        self._update_data_transform()
        
    def get_xaxis_transform(self):
        """
        Get the transformation used for drawing x-axis labels, ticks
        and gridlines.  The x-direction is in data coordinates and the
        y-direction is in axis coordinates.

        We override here so that the x-axis gridlines get properly
        transformed for the skewed plot.
        """
        return self._xaxis_transform + self.transProjection

# Now register the projection with matplotlib so the user can select
# it.
register_projection(SkewXAxes)

# Now make a simple example using the custom projection.
import matplotlib.pyplot as plt

fig = plt.figure(1, figsize=(6.5875, 6.2125))
ax = fig.add_subplot(111, projection='skewx')

plt.grid(True)

ax.set_yticks(np.linspace(100,1000,10))
ax.yaxis.set_major_formatter(ScalarFormatter())
ax.xaxis.set_major_locator(FixedLocator(np.arange(-80,65,10)))
ax.set_xlim(-40,45)
ax.set_ylim(1050,100)

plt.savefig('skew.png', dpi=100)
plt.show()
