3d plotting

Hi,
I’m using surface_plot to view the results of solving the 2d wave equation. It works fine (code is below) except I would like to add a color bar and fix the limits on the vertical axis. When I add the color bar a new one is added in every iteration instead of overwriting the previous one, anyone know how I can prevent this?
Also is it possible to specify a view point when plotting?
Thanks
D

import matplotlib.pyplot as plt
import numpy as np
import pylab as py
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

pi = np.pi

#Set up grid.

fig = py.figure()
ax = Axes3D(fig)

nx = 50
nz = 50

X = np.arange(0, nx, 1)
Y = np.arange(0, nz, 1)
X, Y = np.meshgrid(X, Y)

nsteps = 100

Constants for equation.

c = 4000
dt = 1e-4
h = 1

Set up source.

xs = 0
zs = 0

#fig2 = py.figure()
ts = np.arange(dt,nstepsdt,dt)
s = 0.5
np.sin(2pi100*ts)
#py.plot(ts,s)
#py.show()

Homogeneous pressure field.

p = np.zeros([nx, nz, nsteps])

Solve equation.

for t in range(0,nsteps-1):

for z in range(0,nz-1):

    for x in range(0,nx-1):

        p[xs,zs,t] = s[t]

        k = (c*dt/h)**2

        p[x,z,t] = 2*p[x,z,t-1] - p[x,z,t-2] + k*(p[x+1,z,t-1]-4*p[x,z,t-1]+p[x-1,z,t-1]+p[x,z+1,t-1]+p[x,z-1,t-1])

snap = p[:,:,t]
surf = ax.plot_surface(X,Y,snap, rstride=1, cstride=1, cmap=cm.jet, linewidth=0)
#fig.colorbar(surf, shrink=0.5, aspect=5)
#py.draw()  
py.savefig('/home/davcra/Desktop/plots/2Dwave/'+str(t))