
import matplotlib.text as mtext
import matplotlib.axes as maxes
import matplotlib.transforms as mtransforms
import matplotlib.artist as martist
import matplotlib.axis as maxis
import matplotlib.ticker as mticker

import matplotlib
from matplotlib  import rcParams
#from matplotlib.collections import LineCollection
import matplotlib.lines as mlines
import matplotlib.artist as artist
import matplotlib.cbook as cbook

import matplotlib.pyplot as plt
import numpy as np

class SubAxes(maxes.Axes):
    """
    SubAxes : patch, frame, 
    parent : xaxis, yaxis, and other typical artists like lines, patches, etc.
    """
    def __init__(self, *kl, **kw):
        self.parent_container = kw.pop("container")
        self.master_axes = self.parent_container
        self._axes_num = kw.pop("multipane_num")
        maxes.Axes.__init__(self, *kl, **kw)

        self.xaxis = self.master_axes.xaxis
        self.yaxis = self.master_axes.yaxis
        
    def apply_aspect(self, position=None):

        active_bbox = self.parent_container.get_master_position()
        
        figW,figH = self.get_figure().get_size_inches()

        nrows, ncols = self.parent_container.get_geometry()

        pad_inch = self.parent_container.get_axes_pad_inch()
        #active_bbox = self._parent_axes.get_position(False)
        x0, xl = active_bbox.xmin, active_bbox.width
        dx = (active_bbox.width - (ncols-1.)*(pad_inch/figW))/ncols
        dx_padded = dx + (pad_inch/figW)
        y0, yl = active_bbox.ymin, active_bbox.height
        dy = (active_bbox.height - (nrows-1.)*(pad_inch/figH))/nrows
        dy_padded = dy + (pad_inch/figH)
        

        if self.parent_container.pane_direction == "column":
            ix, iy = divmod(self._axes_num, nrows)
        else:
            iy, ix = divmod(self._axes_num, ncols)

        iy = (nrows - 1) - iy
        self.set_position([x0+ix*dx_padded, y0+iy*dy_padded,
                           dx, dy], "active")




    def set_figure(self, fig):
        """
        Set the class:`~matplotlib.axes.Axes` figure

        accepts a class:`~matplotlib.figure.Figure` instance
        """
        martist.Artist.set_figure(self, fig)

        # 
        self.bbox = mtransforms.TransformedBbox(self._position, fig.transFigure)

        # setting bbox to master's bbox does not work okay. Drawing is
        # fine but interactive usage, such as panning, does not work
        
        #self.bbox = self.master_axes.bbox 

        self.dataLim = self.master_axes.dataLim
        self.viewLim = self.master_axes.viewLim
        

        # Setting a transScale is useless as it is set again in _set_lim_transform.
        # Anyhow, transScale seems synced by overriding _update_transScale method().

        #self.transScale = self.master_axes.transScale
        
        self._set_lim_and_transforms()
        
    def _update_transScale(self):
        self.transScale.set(
            mtransforms.blended_transform_factory(
                self.xaxis.get_transform(), self.yaxis.get_transform()))
        if hasattr(self, "lines"):
            for line in self.lines:
                line._transformed_path.invalidate()

    def draw(self, renderer):

        self.apply_aspect(self.get_position(True))
        self.parent_container.set_position(self.get_position(False),
                                       "active")

        images, artists = self.parent_container.get_images_and_artists(axes=self)
        
        orig_images = self.images
        orig_artists = self.artists
        #artists.extend(orig_artists)
        #images.extend(orig_images)
        self.artists = artists + orig_artists
        self.images = images + orig_images


        ## self.xaxis = self._parent_axes.xaxis
        if 0: # test code for hiding overlapping ticklabels.
            xlastticklabel, ylastticklabel = self.get_xticklabels()[-1], \
                                             self.get_yticklabels()[-1]

            xticklabel_vis = xlastticklabel.get_visible()
            yticklabel_vis = ylastticklabel.get_visible()

        if 0:
            xaxis_vis = self.xaxis.get_visible()
            self.xaxis.set_visible(False)
            yaxis_vis = self.yaxis.get_visible()
            self.yaxis.set_visible(False)
            
        maxes.Axes.draw(self, renderer)

        if 0:
            self.xaxis.set_visible(xaxis_vis)
            self.yaxis.set_visible(yaxis_vis)

        if 0:
            xlastticklabel.set_visible(xticklabel_vis)
            ylastticklabel.set_visible(yticklabel_vis)

        self.artists = orig_artists
        self.images = orig_images
        

        

class MultiPaneContainerBase(maxes.SubplotBase):
    def get_col_row(self, n):
        if self.pane_direction == "column":
            col, row = divmod(n, self._nrows)
        else:
            row, col = divmod(n, self._ncols)

        return col, row
            

    def __init__(self,
                 fig, subplot_pos=(1, 1, 1),
                 nrows_ncols = (1, 1), #axes_class=MasterAxes,
                 n_pane = None,
                 pane_direction="row",
                 **kwargs):

        axes_pad_inch = kwargs.pop("axes_pad_inch", 0.02)

        assert len(nrows_ncols) == 2

        self._nrows, self._ncols = nrows_ncols

        if n_pane is None:
            n_pane = self._nrows * self._ncols
        else:
            assert n_pane <= self._nrows * self._ncols
            assert n_pane > 0
            
        self.n_pane = n_pane
        
        self.__axes_pad_inch = axes_pad_inch

        assert pane_direction in ["column", "row"]
        self.pane_direction = pane_direction
            
        #self._axes_class = maxes.Axes
        
        #list.__init__(self)
        maxes.SubplotBase.__init__(self, fig, *subplot_pos, **kwargs)
        #self.master_axes = axes_class(fig, self.figbox, self, **kwargs)

        #self.master_axes.set_aspect(1.)
        self.set_aspect(1.)

        #fig.add_axes(self.master_axes)

        self.axes_all = []
        self.axes_column = [[] for i in range(self._ncols)]
        self.axes_row = [[] for i in range(self._nrows)]
        
        for i in range(self.n_pane):
            ax = SubAxes(fig,
                         #self.master_axes.get_position(),
                         self.get_position(),
                         container=self,
                         multipane_num=i,
                         )

            fig.add_axes(ax)

            self.axes_all.append(ax)
            col, row = self.get_col_row(i)
            self.axes_column[col].append(ax)
            self.axes_row[row].append(ax)

            martist.setp(ax.get_xticklabels(), visible=False)
            martist.setp(ax.get_yticklabels(), visible=False)

        self.axes_llc = self.axes_column[0][-1]

        for axeses in self.axes_column:
            if axeses:
                martist.setp(axeses[-1].get_xticklabels(), visible=False)
        for axeses in self.axes_row:
            if axeses:
                martist.setp(axeses[0].get_yticklabels(), visible=False)
            

    def __getitem__(self, i):
        return self.axes_all[i]


    def get_geometry(self):
        return self._nrows, self._ncols


    def get_aspect_applied_position(self, position=None):
        '''
        Use :meth:`_aspect` and :meth:`_adjustable` to modify the
        axes box or the view limits.
        '''
        if position is None:
            position = self.get_position(True)

        aspect = self.get_aspect()
        if aspect == 'auto':
            return position

        if aspect == 'equal':
            A = 1
        else:
            A = aspect

        #Ensure at drawing time that any Axes involved in axis-sharing
        # does not have its position changed.
        if self in self._shared_x_axes or self in self._shared_y_axes:
            self._adjustable = 'datalim'

        figW,figH = self.get_figure().get_size_inches()
        fig_aspect = figH/figW
        if self._adjustable == 'box':
            box_aspect = A * self.get_data_ratio()
            pb = position.frozen()
            pb1 = pb.shrunk_to_aspect(box_aspect, pb, fig_aspect)
            #self.set_position(pb1.anchored(self.get_anchor(), pb), 'active')
            #return
            return pb1.anchored(self.get_anchor(), pb)

        raise Exception("")
    

    def get_master_position(self):

        ax = self #.master_axes
        figW,figH = ax.get_figure().get_size_inches()
        #maxes.Axes.apply_aspect(ax, ax.get_position(True))
        # return ax.get_position(False)
        return self.get_aspect_applied_position(ax.get_position(True))
    
        
    def set_axes_pad_inch(self, axes_pad_inch):
        self.__axes_pad_inch = axes_pad_inch


    def get_axes_pad_inch(self):
        return self.__axes_pad_inch
    

    def get_data_ratio(self):
         
        figW,figH = self.get_figure().get_size_inches()
        
        ysize_xsize = maxes.Axes.get_data_ratio(self)

        #dy_dx = maxes.Axes.get_data_ratio(self)

        bb = self.get_position(True)

        #nrows, ncols = self.parent_container.get_geometry()
        #pad_inch = self.parent_container.get_axes_pad_inch()
        nrows, ncols = self.get_geometry()
        pad_inch = self.get_axes_pad_inch()


        
        k_width = (figW*bb.width - (ncols-1.)*pad_inch) / ncols
        k_height = (figH*bb.height - (nrows-1.)*pad_inch) / ysize_xsize / nrows

        k = min(k_width, k_height)

        new_data_ratio = (k*ysize_xsize*nrows + (nrows-1.)*pad_inch) \
                         / (k*ncols + (ncols-1.)*pad_inch)

        return new_data_ratio

    #def set_axes_pad(self, axes_pad_inch):
    #    self.master_axes.set_axes_pad(axes_pad_inch)


    #def get_axes_pad(self):
    #    return self.master_axes.get_axes_pad()


    def get_images_and_artists(self, axes):

        if hasattr(self._axes_class, "get_images_and_artists"):
            images, artists = self._axes_class.get_images_and_artists(self)
        else:
            artists = []
            #artists.extend([self.xaxis, self.yaxis])
            artists.extend(self.lines)
            artists.extend(self.patches)
            artists.extend(self.texts)
            artists.extend(self.tables)
            artists.extend(self.artists)
            #artists.extend(self.images) # images???
            # no legend. use parent's legend instead
            artists.extend(self.collections)
        
            images = self.images

        def set_labels_visible(axis, visible):
            martist.setp(axis.get_ticklabels(), visible=visible)
            martist.setp(axis.label, visible=visible)
            
        if axes is self.axes_llc:
            set_labels_visible(self.xaxis, True)
            set_labels_visible(self.yaxis, True)
        elif axes in self.axes_column[0]:
            set_labels_visible(self.xaxis, False)
            set_labels_visible(self.yaxis, True)
        elif axes in [row[-1] for row in self.axes_column]:
            set_labels_visible(self.xaxis, True)
            set_labels_visible(self.yaxis, False)
        else:
            set_labels_visible(self.xaxis, False)
            set_labels_visible(self.yaxis, False)

        #artists.extend([self.xaxis, self.yaxis])
            
        return images, artists

    
    def draw(self, renderer):
        pass

    def axes_iter(self):
        pass

        



import new

_subplot_classes = {}
def multipane_container_class_factory(axes_class=None):
    # This makes a new class that inherits from SubclassBase and the
    # given axes_class (which is assumed to be a subclass of Axes).
    # This is perhaps a little bit roundabout to make a new class on
    # the fly like this, but it means that a new Subplot class does
    # not have to be created for every type of Axes.
    if axes_class is None:
        axes_class = maxes.Axes

    new_class = _subplot_classes.get(axes_class)
    if new_class is None:
        new_class = new.classobj("%sSubplot" % (axes_class.__name__),
                                 (MultiPaneContainerBase, axes_class),
                                 {'_axes_class': axes_class})
        _subplot_classes[axes_class] = new_class

    return new_class

# This is provided for backward compatibility
MultiPane_Subplot = multipane_container_class_factory()





#Subplot = maxes.subplot_class_factory(Axes)


def test():
    pass

if __name__ == "__main__":
    F = plt.figure(1)
    F.clf()

    mp = MultiPane_Subplot(F, subplot_pos=(1, 1, 1),
                           nrows_ncols = (3, 2),
                           #n_pane = 4,
                           pane_direction="row", # or "column"
                           axes_pad_inch=0.0,
                           ) 

    F.add_subplot(mp)

    def make_test_image(n,m, cx, cy):
        iy, ix = np.indices((m, n), dtype="d")
        return (ix-cx)**2 + (iy-cy)**2

    im = make_test_image(30,20, 15, 10)
    mp.contour(im, origin="lower")
    # any plot command in "mp" is drawn in all its subaxes.

    im = make_test_image(30,20, 0, 0)
    # mp[i] is the i-th subaxes. The direction of the subaxes is
    # controled by the pane_drection.
    mp[0].imshow(im,    
                     interpolation="nearest",
                     origin="lower") 

    
    for ax in mp.axes_all[1:]:
        cx, cy = np.random.rand(2)
        im = make_test_image(30,20, 30*cx, 20*cy)
        ax.imshow(im,    # mp[0] is the first axes
                  interpolation="nearest",
                  origin="lower") 
        
    
    mp.set_xlim(-3, 33)
    mp.set_ylim(-3, 23)

    mp.axes_llc.set_xlabel("x-axis")
    mp.axes_llc.set_ylabel("y-axis")

    plt.draw()
    
    #plt.show()
    plt.savefig("t")
