Vectorize Dask Xarray

Vectorization for dask xarray works only when we give dask=parallelized parameter.

From the note we saw that dask.array.apply_gufunc provides chunk blocks to custom function. When vectorize is provided, numpy vectorization happens on the custom function.

So the flow of execution is as follows,

xarray.apply_ufunc
        |
        V
dask.array.apply_gufunc
        |
        V
np.vectorize
        |
        V
custom function

For example,

data
# <xarray.DataArray 'air' (time: 2, lat: 5, lon: 5)> Size: 400B dask.array<rechunk-merge, shape=(2, 5, 5), dtype=float64, chunksize=(1, 5, 5), chunktype=numpy.ndarray>
 
 
def add_wrapper(d, value):
     print('------------------- START ---------------------')
     print(type(d))
     print(d.shape)
     print(d)
     print('-------------------- END ------------------------')
     return d + value
 
 
xr.apply_ufunc(add_wrapper, 
               data, 
               1, 
               input_core_dims=[["lon"], []], 
               output_core_dims=[["lon"]], 
               dask="parallelized", 
               vectorize=True).compute()
 
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (1,)
# [1.]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [242.1  242.7  243.1  243.39 243.6 ]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [243.6 244.1 244.2 244.1 243.7]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [253.2  252.89 252.1  250.8  249.3 ]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [269.7 269.4 268.6 267.4 266. ]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [272.5 271.5 270.4 269.4 268.5]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [241.2 242.5 243.5 244.  244.1]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [243.8  244.5  244.7  244.2  243.39]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [250.   249.8  248.89 247.5  246.  ]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [266.5 267.1 267.1 266.7 265.9]
# -------------------- END ------------------------
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (5,)
# [274.5  274.29 274.1  274.   273.79]
# -------------------- END ------------------------
# <xarray.DataArray 'air' (time: 2, lat: 5, lon: 5)> Size: 400B
# array([[[242.2 , 243.5 , 244.5 , 245.  , 245.1 ],
#         [244.8 , 245.5 , 245.7 , 245.2 , 244.39],
#         [251.  , 250.8 , 249.89, 248.5 , 247.  ],
#         [267.5 , 268.1 , 268.1 , 267.7 , 266.9 ],
#         [275.5 , 275.29, 275.1 , 275.  , 274.79]],
 
#        [[243.1 , 243.7 , 244.1 , 244.39, 244.6 ],
#         [244.6 , 245.1 , 245.2 , 245.1 , 244.7 ],
#         [254.2 , 253.89, 253.1 , 251.8 , 250.3 ],
#         [270.7 , 270.4 , 269.6 , 268.4 , 267.  ],
#         [273.5 , 272.5 , 271.4 , 270.4 , 269.5 ]]])

Vectorized custom functions is called with 1d array of dim lon 10 times because there are two chunks (1, 5, 5) and (1, 5, 5) and we have broadcasted on dims time and lat. So, function is called with 1d array of 5 elements 10 (1, 5) + (1, 5) = 10

References

  1. https://tutorial.xarray.dev/advanced/apply_ufunc/dask_apply_ufunc.html#automatic-vectorizing