Turn off autoscale and optional subplots

I need to generate in some cases a subplot with two plots one on top of
the other, and in some other cases just one subplot.
So after some attempts I ended up with the ugliest code I've ever
written (maybe not) that you see below.

Is there a better way to do it?
I still didn't fully get how to get "rid" of the implicit global state
in matplotlib and use a more object/functional programming approach.

Anyway the second thing is that I'm not able to disable the autoscale, I
just would like to see the grid as it should be (geometrically).

It should just be matplotlib.pyplot.autoscale but in my version is not
present, isn't there another way?

Thanks a lot,

--8<---------------cut here---------------start------------->8---
    def write_graph(self, output_file, title, show=False):
        def upper_graph():
            def draw_nodes(nlist, color):
                # this returns the list of all the patches, also from axes.patches
                nx.draw_networkx_nodes(self.graph, nodelist=nlist, pos=self.pos,
                                       node_color=color, node_size=700, alpha=0.6)

            draw_nodes(self.lands, 'red')
            draw_nodes(self.mobiles, 'green')
            simple_nodes = set(self.nodes) - set(self.lands + self.mobiles)
            draw_nodes(simple_nodes, 'blue')

            old_axes = plt.axis()
            sizes = old_axes[1] - old_axes[0], old_axes[3] - old_axes[2]
            offset = lambda x: int((float(x) / 10))
            new_axes = []

            for i in range(len(old_axes)):
                new_val = old_axes[i] + (((-1) ** (i + 1)) * offset(sizes[i % 2]))

            nx.draw_networkx_edges(self.graph, self.pos, edge_color='k', alpha=0.5)

            labels = dict((x, str(x)) for x in self.nodes)
            # for the landmarks also add its index
            for land in self.lands:
                labels[land] = "%d (%d)" % (land, self.lands.index(land))

            for recv in self.receivers:
                labels[recv] += " (R)"

            for send in self.senders:
                labels[send] += " (S)"

            nx.draw_networkx_labels(self.graph, self.pos, labels, font_size=8)


        if show:
            legend = self.coordinate_text()
            # remove useless ticks
            plt.text(0.05, 0, "\n".join(legend))
            plt.text(0.5, 0, str(self.event_buffer))

--8<---------------cut here---------------end--------------->8---