Vectorize Dask Xarray
Vectorization for dask xarray works only when we give dask=parallelized
parameter.
From the note we saw that dask.array.apply_gufunc
provides chunk blocks to custom function. When vectorize is provided, numpy vectorization happens on the custom function.
So the flow of execution is as follows,
xarray.apply_ufunc
|
V
dask.array.apply_gufunc
|
V
np.vectorize
|
V
custom function
For example,
>>> data
<xarray.DataArray 'vis' (time: 2, baselineid: 1, frequency: 3, polarisation: 4)> Size: 192B
dask.array<rechunk-merge, shape=(2, 1, 3, 4), dtype=complex64, chunksize=(2, 1, 1, 4), chunktype=numpy.ndarray>
Coordinates:
* time (time) float64 16B 4.454e+09 4.454e+09
* frequency (frequency) float64 24B 1.281e+08 1.282e+08 1.282e+08
* polarisation (polarisation) <U2 32B 'XX' 'XY' 'YX' 'YY'
* baselineid (baselineid) int64 8B 0
Attributes:
units: Jy
>>> data.chunksizes
>>>
Frozen({'time': (2,), 'baselineid': (1,), 'frequency': (1, 1, 1), 'polarisation': (4,)})
where data
is a dask xarray with chunked on frequency having 3 chunks with 1 frequency each.
custom_function
>>> def adder(arr, value):
print('------------------- START ---------------------')
print(type(arr))
print(arr.shape)
print(arr)
print('-------------------- END ------------------------')
return arr + value
We want to make xarray to vectorize adder
function such that it only gets numpy array with time
as single dimension.
We can achieve this using core-dimensions and loop-dimensions.
>>> xr.apply_ufunc(adder, data, 1, input_core_dims=[["time"], []], output_core_dims=[["time"]], dask="parallelized", vectorize=True)
------------------- START ---------------------
<class 'numpy.ndarray'>
(1,)
[1.+0.j]
-------------------- END ------------------------
------------------- START ---------------------
------------------- START ---------------------
------------------- START ---------------------
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
(2,)
(2,)
(2,)
[0.+0.j 0.+0.j]
[0.+0.j 0.+0.j]
-------------------- END ------------------------
-------------------- END ------------------------
[0.+0.j 0.+0.j]
-------------------- END ------------------------
------------------- START ---------------------
<class 'numpy.ndarray'>
------------------- START ---------------------
------------------- START ---------------------
<class 'numpy.ndarray'>
(2,)
<class 'numpy.ndarray'>
(2,)
(2,)
[0.+0.j 0.+0.j]
[0.+0.j 0.+0.j]
-------------------- END ------------------------
-------------------- END ------------------------
[0.+0.j 0.+0.j]
------------------- START ---------------------
------------------- START ---------------------
<class 'numpy.ndarray'>
-------------------- END ------------------------
(2,)
<class 'numpy.ndarray'>
------------------- START ---------------------
[0.+0.j 0.+0.j]
-------------------- END ------------------------
(2,)
<class 'numpy.ndarray'>
[0.+0.j 0.+0.j]
------------------- START ---------------------
<class 'numpy.ndarray'>
-------------------- END ------------------------
(2,)
(2,)
------------------- START ---------------------
[0.+0.j 0.+0.j]
-------------------- END ------------------------
<class 'numpy.ndarray'>
(2,)
[0.+0.j 0.+0.j]
[0.+0.j 0.+0.j]
-------------------- END ------------------------
-------------------- END ------------------------
------------------- START ---------------------
<class 'numpy.ndarray'>
(2,)
[0.+0.j 0.+0.j]
-------------------- END ------------------------
Out[70]:
<xarray.DataArray 'vis' (baselineid: 1, frequency: 3, polarisation: 4, time: 2)> Size: 192B
array([[[[1.+0.j, 1.+0.j],
[1.+0.j, 1.+0.j],
[1.+0.j, 1.+0.j],
[1.+0.j, 1.+0.j]],
[[1.+0.j, 1.+0.j],
[1.+0.j, 1.+0.j],
[1.+0.j, 1.+0.j],
[1.+0.j, 1.+0.j]],
[[1.+0.j, 1.+0.j],
[1.+0.j, 1.+0.j],
[1.+0.j, 1.+0.j],
[1.+0.j, 1.+0.j]]]], dtype=complex64)
Coordinates:
* time (time) float64 16B 4.454e+09 4.454e+09
* frequency (frequency) float64 24B 1.281e+08 1.282e+08 1.282e+08
* polarisation (polarisation) <U2 32B 'XX' 'XY' 'YX' 'YY'
* baselineid (baselineid) int64 8B 0
If we observe the logs, adder
is called 13
times where first time is a test run and rest runs are 12
. This is 12
because we have data dimensions as (time: 2, baselineid: 1, frequency: 3, polarisation: 4)
so, total function calls is baselineid * frequency * polarisation
which is 12
. And adder
gets 1d numpy array with time
dimension having 2
values (as (time: 2)
).