apply_ufunc with dask arrays

apply_ufunc handles dask xarray in same manner as it does with numpy xarray. However, it requires an dask argument` to work with dask xarrays.

xr.apply_ufunc(xda, dask='allowed')

dask parameter

  1. dask=forbidden gives error when dask xarray is given.
  2. dask=allowed is when given function is able to handle dask array as input.
  3. dask=parallelized is when function is not able to handle dask array and want xarray.apply_ufunc to make function work with dask arrays.

Working with non-dask functions

When function doesn’t support dask xarrays, we can use dask=parallelized argument for xr.apply_ufunc.

How it works?

  1. xarray.apply_ufunc uses dask.array.apply_gufunc internally.
  2. underlying dask array is passed to apply_gufunc.
  3. apply_gufunc then passes each chunk block numpy array to custom function.
  4. dask stitches back all the output chunk array from custom function.
  5. xarray then populates meta information to the output array.

Warning

If data is chunked on a core dimension the custom function is working on, executing such operations give error.

For example,

ds.air
# <xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB dask.array<open_dataset-air, shape=(2920, 25, 53), dtype=float64, chunksize=(100, 25, 53), chunktype=numpy.ndarray>
 
 
xr.apply_ufunc(mean, ds.air, input_core_dims=["time"]("time"), kwargs={"axis": -1}, dask="parallelized")
 
# dimension time on 0th function argument to apply_ufunc with dask='parallelized' consists of multiple chunks, but is also a core dimension. To fix, either rechunk into a single array chunk along this dimension, i.e., ``.chunk(dict(time=-1))``, or pass ``allow_rechunk=True`` in ``dask_gufunc_kwargs`` but beware that this may significantly increase memory usage.

Data is chunked across time dimension and providing time as core dimension raises the error. The data has to be re-chunked if we don’t want this error. This can be done by providing argument for dask.apply_gufunc.

  def mean(d, **kwargs):
     print('------------------- START ---------------------')
     print(type(d))
     print(d.shape)
     print('-------------------- END ------------------------')
     return np.mean(d, **kwargs)
xr.apply_ufunc(mean,
               ds.air,
               input_core_dims=["time"]("time"), 
               kwargs={"axis": -1}, 
               dask_gufunc_kwargs={"allow_rechunk":True},
               dask="parallelized")

However, this is consumes more memory.

To determine the effect of the function on the input, dask first runs the function with dummy data as shown below.

xr.apply_ufunc(mean, 
               ds.air, 
               input_core_dims=["lon"]("lon"), 
               kwargs={"axis": -1}, dask="parallelized")
 
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (1, 1, 1)
# -------------------- END ------------------------
#<xarray.DataArray 'air' (time: 2920, lat: 25)> Size: 584kB dask.array<transpose, shape=(2920, 25), dtype=float64, chunksize=(100, 25), chunktype=numpy.ndarray>

References

  1. https://tutorial.xarray.dev/advanced/apply_ufunc/dask_apply_ufunc.html#understanding-what-s-happening