from __future__ import absolute_import
from .block import Block
from .coordinate import Coordinate
from .roi import Roi
import numpy as np
from itertools import product
import logging
import collections
from typing import List, Optional
logger = logging.getLogger(__name__)
class BlockwiseDependencyGraph:
"""Create a dependency graph as a list with elements::
(block, [upstream_blocks])
per block, where ``block`` is the block with exclusive write access to its
``write_roi`` (and save read access to its ``read_roi``), and
``upstream_blocks`` is a list of blocks that need to finish before this
block can start (``[]`` if there are no upstream dependencies``).
Args:
total_roi (`class:daisy.Roi`):
The region of interest (ROI) of the complete volume to process.
block_read_roi (`class:daisy.Roi`):
The ROI every block needs to read data from. Will be shifted over
the ``total_roi``.
block_write_roi (`class:daisy.Roi`):
The ROI every block writes data from. Will be shifted over the
``total_roi`` in synchrony with ``block_read_roi``.
read_write_conflict (``bool``, optional):
Whether the read and write ROIs are conflicting, i.e., accessing
the same resource. If set to ``False``, all blocks can run at the
same time in parallel. In this case, providing a ``read_roi`` is
simply a means of convenience to ensure no out-of-bound accesses
and to avoid re-computation of it in each block.
fit (``string``, optional):
How to handle cases where shifting blocks by the size of
``block_write_roi`` does not tile the ``total_roi``. Possible
options are:
"valid": Skip blocks that would lie outside of ``total_roi``. This
is the default::
|---------------------------| total ROI
|rrrr|wwwwww|rrrr| block 1
|rrrr|wwwwww|rrrr| block 2
no further block
"overhang": Add all blocks that overlap with ``total_roi``, even if
they leave it. Client code has to take care of save access beyond
``total_roi`` in this case.::
|---------------------------| total ROI
|rrrr|wwwwww|rrrr| block 1
|rrrr|wwwwww|rrrr| block 2
|rrrr|wwwwww|rrrr| block 3 (overhanging)
"shrink": Like "overhang", but shrink the boundary blocks' read and
write ROIs such that they are guaranteed to lie within
``total_roi``. The shrinking will preserve the context, i.e., the
difference between the read ROI and write ROI stays the same.::
|---------------------------| total ROI
|rrrr|wwwwww|rrrr| block 1
|rrrr|wwwwww|rrrr| block 2
|rrrr|www|rrrr| block 3 (shrunk)
"""
def __init__(
self,
task_id: str,
block_read_roi: Roi,
block_write_roi: Roi,
read_write_conflict: bool,
fit: str,
total_read_roi: Optional[Roi] = None,
total_write_roi: Optional[Roi] = None,
):
self.block_read_roi = block_read_roi
self.block_write_roi = block_write_roi
self.read_write_context = (
block_write_roi.begin - block_read_roi.begin,
block_read_roi.end - block_write_roi.end,
)
if total_read_roi is not None and total_write_roi is not None:
total_context = (
total_write_roi.begin - total_read_roi.begin,
total_read_roi.end - total_write_roi.end,
)
assert total_context == self.read_write_context, (
"total_read_roi and total_write_roi have context "
f"{total_context}, which is unequal to the block context: "
f"{self.read_write_context}"
)
if total_read_roi is not None:
self.total_read_roi = total_read_roi
self.total_write_roi = self.total_read_roi.grow(
-self.read_write_context[0], -self.read_write_context[1]
)
elif total_write_roi is not None:
self.total_write_roi = total_write_roi
self.total_read_roi = self.total_write_roi.grow(
self.read_write_context[0], self.read_write_context[1]
)
else:
raise ValueError(
"Either total_read_roi or total_write_roi must be provided!"
)
self.task_id = task_id
self.read_write_conflict = read_write_conflict
self.fit = fit
# when computing block offsets, make sure to include blocks
# on the upper boundary with a rounding term
if self.fit == "overhang" or self.fit == "shrink":
# want to round up if there is any write roi left
self.rounding_term = (1,) * self.block_write_roi.dims
else:
# want to round up only if there is a full write block left.
self.rounding_term = self.block_write_roi.shape
# computed values
self._level_stride = self.compute_level_stride()
self._level_offsets = self.compute_level_offsets()
self._level_conflicts = self.compute_level_conflicts()
@property
def num_levels(self):
return len(self._level_offsets)
@property
def num_blocks(self):
num_blocks = 0
for level in range(self.num_levels):
num_blocks += self._num_level_blocks(level)
return num_blocks
@property
def inclusion_criteria(self):
# TODO: Can't we remove this entirely by pre computing the write_roi
inclusion_criteria = {
"valid": lambda b: self.total_write_roi.contains(b.write_roi),
"overhang": lambda b: self.total_write_roi.contains(
b.write_roi.begin
),
"shrink": lambda b: self.shrink_possible(b),
}[self.fit]
return inclusion_criteria
@property
def fit_block(self):
# TODO: Can't we remove this by pre computing the write_roi and
# intersecting edge blocks with the write roi while making them?
fit_block = {
"valid": lambda b: b, # noop
"overhang": lambda b: b, # noop
"shrink": lambda b: self.shrink(b),
}[self.fit]
return fit_block
def num_roots(self):
return self._num_level_blocks(0)
def _num_level_blocks(self, level):
level_offset = self._level_offsets[level]
axis_blocks = [
(e - lo + s - r) // s
for lo, e, s, r in zip(
level_offset,
self.total_write_roi.shape,
self._level_stride,
self.rounding_term,
)
]
num_blocks = np.prod(axis_blocks)
logger.debug(
"number of blocks for write_roi: %s, level (%d), "
"offset (%s), and stride (%s): %d (per dim: %s)",
self.total_write_roi,
level,
level_offset,
self._level_stride,
num_blocks,
axis_blocks)
return num_blocks
def level_blocks(self, level):
for block_offset in self._compute_level_block_offsets(level):
block = Block(
self.total_read_roi,
self.block_read_roi + block_offset,
self.block_write_roi + block_offset,
task_id=self.task_id,
)
# TODO: We probably don't need to check every block for inclusion
# and fit, but rather just the blocks on the total roi boundary.
# can probably be handled when we calculate the block_dim_offsets
if self.inclusion_criteria(block):
yield self.fit_block(block)
else:
raise RuntimeError("Unreachable!")
def root_gen(self):
blocks = self.level_blocks(level=0)
for block in blocks:
yield self.fit_block(block)
def _block_offset(self, block):
# The block offset is the offset of the read roi relative to total roi
block_offset = (
block.read_roi.offset -
self.total_read_roi.offset)
return block_offset
def _level(self, block):
block_offset = self._block_offset(block)
level_offset = block_offset % self._level_stride
for i, offset in enumerate(self._level_offsets):
if level_offset == offset:
return i
raise NotImplementedError(
f"Should not be reachable! {level_offset} not in "
f"{self._level_offsets}, stride: {self._level_stride}"
)
def downstream(self, block):
"""
get all block_id's that are directly dependent on this block_id
i.e. this block offset by all conflict offsets in the next level
"""
level = self._level(block)
next_level = level + 1
if next_level >= self.num_levels:
return []
conflicts = []
for conflict in self._level_conflicts[next_level]:
conflict_block = Block(
total_roi=self.total_read_roi,
read_roi=Roi(
block.read_roi.offset - conflict,
self.block_read_roi.shape,
),
write_roi=Roi(
block.write_roi.offset - conflict,
self.block_write_roi.shape,
),
task_id=self.task_id,
)
# TODO: We probably don't need to check every block for inclusion
# and fit, but rather just the blocks on the total roi boundary
if self.inclusion_criteria(conflict_block):
conflicts.append(self.fit_block(conflict_block))
conflicts = list(set(conflicts))
return conflicts
def upstream(self, block):
"""
get all upstream block id's for a given block_id
i.e. this block offset by all conflict offsets in this level
"""
level = self._level(block)
conflicts = []
for conflict in self._level_conflicts[level]:
conflict_block = Block(
total_roi=self.total_read_roi,
read_roi=Roi(
block.read_roi.offset + conflict,
self.block_read_roi.shape,
),
write_roi=Roi(
block.write_roi.offset + conflict,
self.block_write_roi.shape,
),
task_id=self.task_id,
)
# TODO: We probably don't need to check every block for inclusion
# and fit, but rather just the blocks on the total roi boundary
if self.inclusion_criteria(conflict_block):
conflicts.append(self.fit_block(conflict_block))
conflicts = list(set(conflicts))
return conflicts
def enumerate_all_dependencies(self):
self._level_block_offsets = self.compute_level_block_offsets()
for level in range(self.num_levels):
level_blocks = self.level_blocks(level)
for block in level_blocks:
yield (block, self.upstream(block))
def compute_level_stride(self) -> Coordinate:
"""
Get the stride that separates independent blocks in one level.
"""
if not self.read_write_conflict:
return self.block_write_roi.shape
logger.debug(
"Compute level stride for read ROI %s and write ROI %s.",
self.block_read_roi,
self.block_write_roi,
)
assert self.block_read_roi.contains(
self.block_write_roi
), "Read ROI must contain write ROI."
context_ul = (
self.block_write_roi.begin -
self.block_read_roi.begin)
context_lr = (
self.block_read_roi.end -
self.block_write_roi.end)
max_context = Coordinate(
(max(ul, lr) for ul, lr in zip(context_ul, context_lr))
)
logger.debug("max context per dimension is %s", max_context)
# this stride guarantees that blocks are independent, but not a
# multiple of the write_roi shape. It would be impossible to tile
# the output roi with blocks shifted by this min_level_stride
min_level_stride = max_context + self.block_write_roi.shape
logger.debug("min level stride is %s", min_level_stride)
# to avoid overlapping write ROIs, increase the stride to the next
# multiple of write shape
write_shape = self.block_write_roi.shape
level_stride = Coordinate((
((level - 1) // w + 1) * w
for level, w in zip(min_level_stride, write_shape)
))
# Handle case where min_level_stride > total_write_roi.
# This case leads to levels with no blocks in them. This makes
# calculating dependencies on the fly significantly more difficult
write_roi_shape = self.total_write_roi.shape
if self.fit == "valid":
# round down to nearest block size
write_roi_shape -= write_roi_shape % self.block_write_roi.shape
else:
# round up to nearest block size
write_roi_shape += (
-write_roi_shape % self.block_write_roi.shape
) % self.block_write_roi.shape
level_stride = Coordinate(
(min(a, b) for a, b in zip(level_stride, write_roi_shape)))
logger.debug(
"final level stride (multiples of write size) is %s",
level_stride)
return level_stride
def compute_level_offsets(self) -> List[Coordinate]:
"""
compute an offset for each level.
"""
write_stride = self.block_write_roi.shape
logger.debug(
"Compute level offsets for level stride %s and write stride %s.",
self._level_stride,
write_stride,
)
dim_offsets = [
range(0, e, step)
for e, step in zip(self._level_stride, write_stride)
]
level_offsets = list(
reversed([
Coordinate(o)
for o in product(*dim_offsets)])
)
logger.debug("level offsets: %s", level_offsets)
return level_offsets
def compute_level_conflicts(self) -> List[List[Coordinate]]:
"""
For each level, compute the set of conflicts from previous levels.
"""
level_conflict_offsets = []
prev_level_offset = None
for level, level_offset in enumerate(self._level_offsets):
# get conflicts to previous level
if prev_level_offset is not None and self.read_write_conflict:
conflict_offsets = self.get_conflict_offsets(
level_offset, prev_level_offset, self._level_stride
)
else:
conflict_offsets = []
prev_level_offset = level_offset
level_conflict_offsets.append(conflict_offsets)
return level_conflict_offsets
def _compute_level_block_offsets(self, level):
level_offset = self._level_offsets[level]
# all block offsets of the current level (relative to total ROI start)
block_dim_offsets = [
range(lo, e + 1 - r, s)
for lo, e, s, r in zip(
level_offset,
self.total_write_roi.shape,
self._level_stride,
self.rounding_term,
)
]
for offset in product(*block_dim_offsets):
# TODO: can we do this part lazily? This might be a lot of
# Coordinates
block_offset = Coordinate(offset)
# convert to global coordinates
block_offset += (
self.total_read_roi.begin -
self.block_read_roi.begin
)
yield block_offset
def compute_level_block_offsets(self) -> List[List[Coordinate]]:
"""
For each level, get the set of all offsets corresponding to blocks in
this level.
"""
level_block_offsets = []
for level in range(self.num_levels):
level_block_offsets.append(
list(
self._compute_level_block_offsets(level)))
return level_block_offsets
def get_conflict_offsets(
self,
level_offset,
prev_level_offset,
level_stride):
"""Get the offsets to all previous level blocks that are in conflict
with the current level blocks."""
offset_to_prev = prev_level_offset - level_offset
logger.debug("offset to previous level: %s", offset_to_prev)
def get_offsets(op, ls):
if op < 0:
return [op, op + ls]
elif op == 0:
return [op]
else:
return [op - ls, op]
conflict_dim_offsets = [
get_offsets(op, ls) for op, ls in zip(offset_to_prev, level_stride)
]
conflict_offsets = [
Coordinate(o)
for o in product(*conflict_dim_offsets)
]
logger.debug(
"conflict offsets to previous level: %s",
conflict_offsets)
return conflict_offsets
def shrink_possible(self, block):
return self.total_write_roi.contains(block.write_roi.begin)
def shrink(self, block):
"""Ensure that read and write ROI are within total ROI by shrinking
both. Size of context will be preserved."""
w = self.total_write_roi.intersect(block.write_roi)
r = self.total_read_roi.intersect(block.read_roi)
shrunk_block = block.copy()
shrunk_block.read_roi = r
shrunk_block.write_roi = w
return shrunk_block
def get_subgraph_blocks(self, sub_roi, read_roi=False):
"""Return ids of blocks, as instantiated in the full graph, such that
their total write rois fully cover `sub_roi`.
The function API assumes that `sub_roi` and `total_roi` use world
coordinates and `self.block_read_roi` and `self.block_write_roi` use
relative coordinates.
"""
if read_roi:
# if we want to get blocks whose read_roi overlaps with sub_roi
# simply grow the sub_roi by the block context. That way we
# only need to check if a blocks read_roi overlaps with sub_roi.
# This is the same behavior as when we want write_roi overlap
sub_roi = sub_roi.grow(
self.read_write_context[0], self.read_write_context[1])
# TODO: handle unsatisfiable sub_rois
# i.e. sub_roi is outside of *total_write_roi
# after accounting for padding
sub_roi = sub_roi.intersect(self.total_write_roi)
# get sub_roi relative to the write roi
begin = sub_roi.begin - self.total_write_roi.offset
end = sub_roi.end - self.total_write_roi.offset
# convert to block coordinates. Handle upper block based on fit
aligned_subroi = (
begin // self.block_write_roi.shape, # `floordiv`
-(-end // self.block_write_roi.shape), # `ceildiv`
)
# generate relative offsets of relevant write blocks
block_dim_offsets = [
range(lo, e, s)
for lo, e, s in zip(
aligned_subroi[0] * self.block_write_roi.shape,
aligned_subroi[1] * self.block_write_roi.shape,
self.block_write_roi.shape,
)
]
# generate absolute offsets
block_offsets = [
Coordinate(o) + self.total_read_roi.offset
for o in product(*block_dim_offsets)
]
blocks = [
self.fit_block(
Block(
self.total_read_roi,
self.block_read_roi + offset - self.block_read_roi.offset,
self.block_write_roi + offset - self.block_read_roi.offset,
task_id=self.task_id,
)
)
for offset in block_offsets
]
return [block for block in blocks if self.inclusion_criteria(block)]
[docs]class DependencyGraph:
def __init__(self, tasks):
self.upstream_tasks = collections.defaultdict(set)
self.downstream_tasks = collections.defaultdict(set)
self.task_map = {}
for task in tasks:
self.__add_task(task)
self.task_dependency_graphs = {}
for task in self.task_map.values():
self.__add_task_dependency_graph(task)
@property
def task_ids(self):
return self.task_map.keys()
def num_blocks(self, task_id):
return self.task_dependency_graphs[task_id].num_blocks
def upstream(self, block):
upstream = self.task_dependency_graphs[block.task_id].upstream(block)
for upstream_task in self.upstream_tasks[block.task_id]:
upstream.extend(
self.task_dependency_graphs[upstream_task].get_subgraph_blocks(
block.read_roi, read_roi=False
)
)
return sorted(
upstream,
key=lambda b: b.block_id[1],
)
def downstream(self, block):
dep_graphs = self.task_dependency_graphs
downstream = dep_graphs[block.task_id].downstream(block)
for downstream_task in self.downstream_tasks[block.task_id]:
downstream.extend(
dep_graphs[downstream_task].get_subgraph_blocks(
block.write_roi, read_roi=True
)
)
return sorted(
downstream,
key=lambda b: b.block_id[1],
)
def root_tasks(self):
return [
task_id
for task_id, upstream_tasks in self.upstream_tasks.items()
if len(upstream_tasks) == 0
]
def num_roots(self, task_id):
return self.task_dependency_graphs[task_id].num_roots()
def root_gen(self, task_id):
return self.task_dependency_graphs[task_id].root_gen()
def roots(self):
root_tasks = self.root_tasks()
return {
task_id: (self.num_roots(task_id), self.root_gen(task_id))
for task_id in root_tasks
}
def __add_task(self, task):
if task.task_id not in self.task_map:
self.task_map[task.task_id] = task
self.upstream_tasks[task.task_id] = set()
self.downstream_tasks[task.task_id] = set()
for upstream_task in task.requires():
self.__add_task(upstream_task)
self.upstream_tasks[task.task_id].add(upstream_task.task_id)
self.downstream_tasks[upstream_task.task_id].add(task.task_id)
def __add_task_dependency_graph(self, task):
"""Create dependency graph a specific task"""
# create intra task dependency graph
self.task_dependency_graphs[task.task_id] = BlockwiseDependencyGraph(
task.task_id,
task.read_roi,
task.write_roi,
task.read_write_conflict,
task.fit,
total_read_roi=task.total_roi,
)
def __enumerate_all_dependencies(self):
# enumerate all the blocks
for task_id in self.task_ids:
block_dependencies = self.task_dependency_graphs[
task_id
].enumerate_all_dependencies()
for block, upstream_blocks in block_dependencies:
if block.block_id in self.blocks:
continue
self.blocks[block.block_id] = block
for upstream_block in upstream_blocks:
if upstream_block.block_id not in self.blocks:
raise RuntimeError(
"Block dependency %s is not found for task %s."
% (upstream_block.block_id, task_id)
)
self._downstream[upstream_block.block_id].add(
block.block_id)
self._upstream[block.block_id].add(upstream_block.block_id)
# enumerate all of the upstream / downstream dependencies
for task_id in self.task_ids:
# add inter-task read-write dependency
if len(self.upstream_tasks[task_id]):
for block in self.task_dependency_graphs[task_id].blocks:
roi = block.read_roi
upstream_blocks = []
for upstream_task_id in self.upstream_tasks[task_id]:
upstream_task_blocks = self.task_dependency_graphs[
upstream_task_id
].get_subgraph_blocks(roi)
upstream_blocks.extend([upstream_task_blocks])
for upstream_block in upstream_blocks:
if upstream_block.block_id not in self.blocks:
raise RuntimeError(
"Block dependency %s is not found for task %s."
% (upstream_block.block_id, task_id)
)
self._downstream[upstream_block.block_id].add(
block.block_id)
self._upstream[block.block_id].add(
upstream_block.block_id)