Vectorize Dask Xarray
Vectorization for dask xarray works only when we give dask=parallelized
parameter.
From the note we saw that dask.array.apply_gufunc
provides chunk blocks to custom function. When vectorize is provided, numpy vectorization happens on the custom function.
So the flow of execution is as follows,
xarray.apply_ufunc
|
V
dask.array.apply_gufunc
|
V
np.vectorize
|
V
custom function
For example,
data
# <xarray.DataArray 'air' (time: 2, lat: 5, lon: 5)> Size: 400B dask.array<rechunk-merge, shape=(2, 5, 5), dtype=float64, chunksize=(1, 5, 5), chunktype=numpy.ndarray>
def add_wrapper(d, value):
print('------------------- START ---------------------')
print(type(d))
print(d.shape)
print(d)
print('-------------------- END ------------------------')
return d + value
xr.apply_ufunc(add_wrapper,
data,
1,
input_core_dims=[["lon"], []],
output_core_dims=[["lon"]],
dask="parallelized",
vectorize=True).compute()
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (1,)
# [1.]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [242.1 242.7 243.1 243.39 243.6 ]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [243.6 244.1 244.2 244.1 243.7]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [253.2 252.89 252.1 250.8 249.3 ]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [269.7 269.4 268.6 267.4 266. ]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [272.5 271.5 270.4 269.4 268.5]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [241.2 242.5 243.5 244. 244.1]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [243.8 244.5 244.7 244.2 243.39]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [250. 249.8 248.89 247.5 246. ]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [266.5 267.1 267.1 266.7 265.9]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [274.5 274.29 274.1 274. 273.79]
# -------------------- END ------------------------
# <xarray.DataArray 'air' (time: 2, lat: 5, lon: 5)> Size: 400B
# array([[[242.2 , 243.5 , 244.5 , 245. , 245.1 ],
# [244.8 , 245.5 , 245.7 , 245.2 , 244.39],
# [251. , 250.8 , 249.89, 248.5 , 247. ],
# [267.5 , 268.1 , 268.1 , 267.7 , 266.9 ],
# [275.5 , 275.29, 275.1 , 275. , 274.79]],
# [[243.1 , 243.7 , 244.1 , 244.39, 244.6 ],
# [244.6 , 245.1 , 245.2 , 245.1 , 244.7 ],
# [254.2 , 253.89, 253.1 , 251.8 , 250.3 ],
# [270.7 , 270.4 , 269.6 , 268.4 , 267. ],
# [273.5 , 272.5 , 271.4 , 270.4 , 269.5 ]]])
Vectorized custom functions is called with 1d array of dim lon
10 times because there are two chunks (1, 5, 5)
and (1, 5, 5)
and we have broadcasted on dims time and lat
. So, function is called with 1d array of 5 elements 10 (1, 5) + (1, 5) = 10