apply_ufunc
with Dask arrays
apply_ufunc
handles dask xarray in same manner as it does with numpy xarray. However, it requires an dask
argument to work with dask xarrays.
xr.apply_ufunc(xda, dask='allowed')
dask
parameter
dask=forbidden
gives error when dask xarray is given.dask=allowed
is when given function is able to handle dask array as input.dask=parallelized
is when function is not able to handle dask array and wantxarray.apply_ufunc
to make function work with dask arrays.
Example
data
is a dask xarray with chunked on frequency having 3 chunks with 1 frequency each.
>>> data
<xarray.DataArray 'vis' (time: 2, baselineid: 1, frequency: 3, polarisation: 4)> Size: 192B
dask.array<rechunk-merge, shape=(2, 1, 3, 4), dtype=complex64, chunksize=(2, 1, 1, 4), chunktype=numpy.ndarray>
Coordinates:
* time (time) float64 16B 4.454e+09 4.454e+09
* frequency (frequency) float64 24B 1.281e+08 1.282e+08 1.282e+08
* polarisation (polarisation) <U2 32B 'XX' 'XY' 'YX' 'YY'
* baselineid (baselineid) int64 8B 0
Attributes:
units: Jy
>>> data.chunksizes
>>>
Frozen({'time': (2,), 'baselineid': (1,), 'frequency': (1, 1, 1), 'polarisation': (4,)})
custom_function
>>> def adder(arr, value):
print('------------------- START ---------------------')
print(type(arr))
print(arr.shape)
print(arr)
print('-------------------- END ------------------------')
return arr + value
adder
can actually handle dask arrays.
>>> xr.apply_ufunc(adder, data, 1, dask="allowed")
------------------- START ---------------------
<class 'dask.array.core.Array'>
(2, 1, 3, 4)
dask.array<rechunk-merge, shape=(2, 1, 3, 4), dtype=complex64, chunksize=(2, 1, 1, 4), chunktype=numpy.ndarray>
-------------------- END ------------------------
<xarray.DataArray 'vis' (time: 2, baselineid: 1, frequency: 3, polarisation: 4)> Size: 192B
dask.array<add, shape=(2, 1, 3, 4), dtype=complex64, chunksize=(2, 1, 1, 4), chunktype=numpy.ndarray>
Coordinates:
* time (time) float64 16B 4.454e+09 4.454e+09
* frequency (frequency) float64 24B 1.281e+08 1.282e+08 1.282e+08
* polarisation (polarisation) <U2 32B 'XX' 'XY' 'YX' 'YY'
* baselineid (baselineid) int64 8B 0
So, adder
gets the whole dask array.
Working with non-dask functions
When function doesn’t support dask xarrays, we can use dask=parallelized
argument for xr.apply_ufunc
.
For example,
adder
can also be used as function which does not support dask array. So, we expect xarray to provide numpy array to adder
.
>>> xr.apply_ufunc(adder, data, 1, dask="parallelized").compute()
------------------- START ---------------------
<class 'numpy.ndarray'>
(1, 1, 1, 1)
[[[[1.+0.j]]]]
-------------------- END ------------------------
------------------- START ---------------------
<class 'numpy.ndarray'>
(2, 1, 1, 4)
[[[[0.+0.j 0.+0.j 0.+0.j 0.+0.j]]]
[[[0.+0.j 0.+0.j 0.+0.j 0.+0.j]]]]
-------------------- END ------------------------
------------------- START ---------------------
<class 'numpy.ndarray'>
(2, 1, 1, 4)
[[[[0.+0.j 0.+0.j 0.+0.j 0.+0.j]]]
[[[0.+0.j 0.+0.j 0.+0.j 0.+0.j]]]]
-------------------- END ------------------------
------------------- START ---------------------
<class 'numpy.ndarray'>
(2, 1, 1, 4)
[[[[0.+0.j 0.+0.j 0.+0.j 0.+0.j]]]
[[[0.+0.j 0.+0.j 0.+0.j 0.+0.j]]]]
-------------------- END ------------------------
<xarray.DataArray 'vis' (time: 2, baselineid: 1, frequency: 3, polarisation: 4)> Size: 192B
array([[[[1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j],
[1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j],
[1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j]]],
[[[1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j],
[1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j],
[1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j]]]], dtype=complex64)
Coordinates:
* time (time) float64 16B 4.454e+09 4.454e+09
* frequency (frequency) float64 24B 1.281e+08 1.282e+08 1.282e+08
* polarisation (polarisation) <U2 32B 'XX' 'XY' 'YX' 'YY'
* baselineid (baselineid) int64 8B 0
If we observer the output, adder
function gets the numpy array of dimension (2, 1, 1, 4)
and it ran 4 times where first run is a test run done by xarray. In each run, adder
gets a chunk of numpy array.
How it works?
xarray.apply_ufunc
usesdask.array.apply_gufunc
internally.- underlying dask array is passed to
dask.array.apply_gufunc
. apply_gufunc
then passes each chunk block numpy array to custom function.- dask stitches back all the output chunk array from custom function.
- xarray then populates meta information to the output array.
Warning
If data is chunked on a core dimension the custom function is working on, executing such operations give error.
For example,
ds.air
# <xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB dask.array<open_dataset-air, shape=(2920, 25, 53), dtype=float64, chunksize=(100, 25, 53), chunktype=numpy.ndarray>
xr.apply_ufunc(mean, ds.air, input_core_dims=[["time"]], kwargs={"axis": -1}, dask="parallelized")
# dimension time on 0th function argument to apply_ufunc with dask='parallelized' consists of multiple chunks, but is also a core dimension. To fix, either rechunk into a single array chunk along this dimension, i.e., ``.chunk(dict(time=-1))``, or pass ``allow_rechunk=True`` in ``dask_gufunc_kwargs`` but beware that this may significantly increase memory usage.
Data is chunked across time
dimension and providing time
as core dimension raises the error. The data has to be re-chunked if we don’t want this error. This can be done by providing argument for dask.apply_gufunc
.
def mean(d, **kwargs):
print('------------------- START ---------------------')
print(type(d))
print(d.shape)
print('-------------------- END ------------------------')
return np.mean(d, **kwargs)
xr.apply_ufunc(mean,
ds.air,
input_core_dims=[["time"]],
kwargs={"axis": -1},
dask_gufunc_kwargs={"allow_rechunk":True},
dask="parallelized")
However, this is consumes more memory.
To determine the effect of the function on the input, dask first runs the function with dummy data as shown below.
xr.apply_ufunc(mean,
ds.air,
input_core_dims=[["lon"]],
kwargs={"axis": -1}, dask="parallelized")
# ------------------- START ---------------------
# <class 'numpy.ndarray'>
# (1, 1, 1)
# -------------------- END ------------------------
#<xarray.DataArray 'air' (time: 2920, lat: 25)> Size: 584kB dask.array<transpose, shape=(2920, 25), dtype=float64, chunksize=(100, 25), chunktype=numpy.ndarray>