Extend Dask Array¶
As discussed in the array design document to create a dask Array object you need the following:
- A dask graph
- A name specifying a set of keys within that graph
- A chunks tuple giving chunk shape information
- A dtype
Often dask.array functions take other Array objects as inputs along with parameters, add tasks to a new dask dictionary, create a new chunks tuple, and then construct and return a new Array object. The hard parts are invariably creating the right tasks and creating a new chunks tuple. Careful review of the array design document is suggested.
Example eye¶
Consider this simple example with the eye function.
from dask.array.core import tokens
def eye(n, blocksize):
chunks = ((blocksize,) * n // blocksize,
(blocksize,) * n // blocksize)
name = 'eye' + next(tokens) # unique identifier
dsk = {(name, i, j): (np.eye, blocksize)
if i == j else
(np.zeros, (blocksize, blocksize))
for i in range(n // blocksize)
for j in range(n // blocksize)}
dtype = np.eye(0).dtype # take dtype default from numpy
return Array(dsk, name, chunks, dtype)
This example is particularly simple because it doesn’t take any Array objects as input.
Example diag¶
Consider the function diag that takes a 1d vector and produces a 2d matrix with the values of the vector along the diagonal. Consider the case where the input is a 1d array with chunk sizes (2, 3, 4) in the first dimension like this:
[x_0, x_1], [x_2, x_3, x_4], [x_5, x_6, x_7, x_8]
We need to create a 2d matrix with chunks equal to ((2, 3, 4), (2, 3, 4)) where the ith block along the diagonal of the output is the result of calling np.diag on the ith block of the input and all other blocks are zero.
def diag(v):
"""Construct a diagonal array, with ``v`` on the diagonal."""
chunks = (v.chunks[0], v.chunks[0]) # repeat chunks twice
name = 'diag' + next(tokens) # unique identifier
dsk = {(name, i, j): (np.diag, (v.name, i))
if i == j else
(np.zeros, (v.chunks[0][i], v.chunks[0][j]))
for i in range(len(v.chunks[0]))
for j in range(len(v.chunks[0]))}
dsk.update(v.dask) # include dask graph of the input
dtype = v.dtype # output has the same dtype as the input
return Array(dsk, name, chunks, dtype)
>>> x = da.arange(9, chunks=((2, 3, 4),))
>>> x
dask.array<arange-1, shape=(9,), chunks=((2, 3, 4)), dtype=int64>
>>> M = diag(x)
>>> M
dask.array<diag-2, shape=(9, 9), chunks=((2, 3, 4), (2, 3, 4)), dtype=int64>
>>> M.compute()
array([[0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 2, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 3, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 4, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 5, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 6, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 7, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 8]])