What are the `X, Y, Z` arguments of `Axes3D.plot_wireframe`


I am trying to use the plot_wireframe function to display my data

I have an X 2d input vector (x.shape == (2, 500, 500))
and the output of X from a function (f(X)), which let’s call it Z (z.shape == ((500, 500)))

what do i need to put to the arguments X,Y,Z in plot_wireframe in order to plot my X → Z function

Full code for reference:

import numpy as np
import matplotlib.pyplot as plt

def f(x: np.ndarray[np.float64]) -> float:
    """A function."""
    return 1 / (1+np.sum(x, axis=0))

x = np.mgrid[0:5:0.01, 0:5:0.01]

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_wireframe(x, x, f(x), color='black')  # How do i plot here?

Thank you!

i have figured it out

ax.plot_wireframe(x[0, :], x[1, :], f(x), color='black')