Vectorize Dask Xarray

Vectorization for dask xarray works only when we give dask=parallelized parameter.

From the note we saw that dask.array.apply_gufunc provides chunk blocks to custom function. When vectorize is provided, numpy vectorization happens on the custom function.

So the flow of execution is as follows,

xarray.apply_ufunc
        |
        V
dask.array.apply_gufunc
        |
        V
np.vectorize
        |
        V
custom function

For example,

>>> data
<xarray.DataArray 'vis' (time: 2, baselineid: 1, frequency: 3, polarisation: 4)> Size: 192B
dask.array<rechunk-merge, shape=(2, 1, 3, 4), dtype=complex64, chunksize=(2, 1, 1, 4), chunktype=numpy.ndarray>
Coordinates:
  * time          (time) float64 16B 4.454e+09 4.454e+09
  * frequency     (frequency) float64 24B 1.281e+08 1.282e+08 1.282e+08
  * polarisation  (polarisation) <U2 32B 'XX' 'XY' 'YX' 'YY'
  * baselineid    (baselineid) int64 8B 0
Attributes:
    units:    Jy
 
 
>>> data.chunksizes
>>> 
Frozen({'time': (2,), 'baselineid': (1,), 'frequency': (1, 1, 1), 'polarisation': (4,)})

where data is a dask xarray with chunked on frequency having 3 chunks with 1 frequency each.

custom_function

>>> def adder(arr, value):
          print('------------------- START ---------------------')
          print(type(arr))
          print(arr.shape)
          print(arr)
          print('-------------------- END ------------------------')
          return arr + value

We want to make xarray to vectorize adder function such that it only gets numpy array with time as single dimension.

We can achieve this using core-dimensions and loop-dimensions.

>>> xr.apply_ufunc(adder, data, 1, input_core_dims=[["time"], []], output_core_dims=[["time"]], dask="parallelized", vectorize=True)
 
------------------- START ---------------------
<class 'numpy.ndarray'>
(1,)
[1.+0.j]
-------------------- END ------------------------
------------------- START ---------------------
------------------- START ---------------------
------------------- START ---------------------
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
(2,)
(2,)
(2,)
[0.+0.j 0.+0.j]
[0.+0.j 0.+0.j]
-------------------- END ------------------------
-------------------- END ------------------------
[0.+0.j 0.+0.j]
-------------------- END ------------------------
------------------- START ---------------------
<class 'numpy.ndarray'>
------------------- START ---------------------
------------------- START ---------------------
<class 'numpy.ndarray'>
(2,)
<class 'numpy.ndarray'>
(2,)
(2,)
[0.+0.j 0.+0.j]
[0.+0.j 0.+0.j]
-------------------- END ------------------------
-------------------- END ------------------------
[0.+0.j 0.+0.j]
------------------- START ---------------------
------------------- START ---------------------
<class 'numpy.ndarray'>
-------------------- END ------------------------
(2,)
<class 'numpy.ndarray'>
------------------- START ---------------------
[0.+0.j 0.+0.j]
-------------------- END ------------------------
(2,)
<class 'numpy.ndarray'>
[0.+0.j 0.+0.j]
------------------- START ---------------------
<class 'numpy.ndarray'>
-------------------- END ------------------------
(2,)
(2,)
------------------- START ---------------------
[0.+0.j 0.+0.j]
-------------------- END ------------------------
<class 'numpy.ndarray'>
(2,)
[0.+0.j 0.+0.j]
[0.+0.j 0.+0.j]
-------------------- END ------------------------
-------------------- END ------------------------
------------------- START ---------------------
<class 'numpy.ndarray'>
(2,)
[0.+0.j 0.+0.j]
-------------------- END ------------------------
Out[70]: 
<xarray.DataArray 'vis' (baselineid: 1, frequency: 3, polarisation: 4, time: 2)> Size: 192B
array([[[[1.+0.j, 1.+0.j],
         [1.+0.j, 1.+0.j],
         [1.+0.j, 1.+0.j],
         [1.+0.j, 1.+0.j]],
 
        [[1.+0.j, 1.+0.j],
         [1.+0.j, 1.+0.j],
         [1.+0.j, 1.+0.j],
         [1.+0.j, 1.+0.j]],
 
        [[1.+0.j, 1.+0.j],
         [1.+0.j, 1.+0.j],
         [1.+0.j, 1.+0.j],
         [1.+0.j, 1.+0.j]]]], dtype=complex64)
Coordinates:
  * time          (time) float64 16B 4.454e+09 4.454e+09
  * frequency     (frequency) float64 24B 1.281e+08 1.282e+08 1.282e+08
  * polarisation  (polarisation) <U2 32B 'XX' 'XY' 'YX' 'YY'
  * baselineid    (baselineid) int64 8B 0

If we observe the logs, adder is called 13times where first time is a test run and rest runs are 12. This is 12 because we have data dimensions as (time: 2, baselineid: 1, frequency: 3, polarisation: 4) so, total function calls is baselineid * frequency * polarisation which is 12. And adder gets 1d numpy array with time dimension having 2 values (as (time: 2)).

References

  1. https://tutorial.xarray.dev/advanced/apply_ufunc/dask_apply_ufunc.html#automatic-vectorizing