Source code for daisy.scheduler

from __future__ import absolute_import

from .dependency_graph import DependencyGraph
from .ready_surface import ReadySurface
from .task_state import TaskState
from .processing_queue import ProcessingQueue
from .task import Task
from .block import BlockStatus

from typing import List
import collections
import logging

logger = logging.getLogger(__name__)


[docs]class Scheduler: """This is the main scheduler that tracks states of tasks. The Scheduler takes a list of tasks, and upon request will provide the next block available for processing. args: tasks: the list of tasks to schedule. If any of the tasks have upstream dependencies these will be recursively enumerated and added to the scheduler. count_all_orphans: bool: Whether to guarantee accurate counting of all orphans. This can be inefficient if your dependency tree is particularly deep rather than just wide, so consider flipping this to False if you are having performance issues. If False, orphaned blocks will be counted as "pending" in the task state since there is no way to tell the difference between the two types without enumerating all orphans. """ def __init__(self, tasks: List[Task], count_all_orphans=True): self.dependency_graph = DependencyGraph(tasks) self.ready_surface = ReadySurface( self.dependency_graph.downstream, self.dependency_graph.upstream ) self.task_map = {} self.task_states = collections.defaultdict(TaskState) self.task_queues = collections.defaultdict(ProcessingQueue) # root tasks is a mapping from task_id -> (num_roots, root_generator) roots = self.dependency_graph.roots() for task_id, (num_roots, root_gen) in roots.items(): self.task_states[task_id].ready_count += num_roots self.task_queues[task_id] = ProcessingQueue(num_roots, root_gen) for task in tasks: self.__init_task(task) self.count_all_orphans = count_all_orphans def get_ready_tasks(self) -> List[Task]: """ Get a list of tasks that currently have blocks available for scheduling """ ready_tasks = [] for task_id, task_state in self.task_states.items(): if task_state.ready_count > 0: ready_tasks.append(self.task_map[task_id]) return ready_tasks def acquire_block(self, task_id): """ Get a block that is ready to process for task with given task_id. Args: task_id(``int``): The task for which you want a block Return: ``Block`` or None: A block that can be run without worry of conflicts. """ while True: block = self.task_queues[task_id].get_next() if block is not None: # update states self.task_states[task_id].ready_count -= 1 self.task_states[task_id].processing_count += 1 pre_check_ret = self.__precheck(block) if pre_check_ret: logger.debug( "Skipping block (%s); already processed.", block.block_id) block.status = BlockStatus.SUCCESS self.task_states[task_id].skipped_count += 1 # adding block so release_block() can remove it self.task_queues[task_id].processing_blocks.add( block.block_id) self.release_block(block) continue else: self.task_states[task_id].started = True self.task_queues[task_id].processing_blocks.add( block.block_id) return block else: return None def release_block(self, block): """ Update the dependency graph with the status of a given block ``block``. Args: task(``Task``): Task of interest. block(``Block``): Block of interest. Return: ``dict``(``task_id`` -> ``TaskState``): Each task returned had its state changed by updating the status of the given block on the given task. i.e. if a task B was dependent on task A, and marking a block in A as solved made some blocks in B available for processing, task B would be returned with its state. """ task_id = block.task_id self.__remove_from_processing_blocks(block) if block.status == BlockStatus.SUCCESS: new_blocks = self.ready_surface.mark_success(block) self.task_states[block.task_id].completed_count += 1 updated_tasks = self.__update_ready_queue(new_blocks) return updated_tasks if block.status == BlockStatus.FAILED: if ( self.task_queues[task_id].block_retries[block.block_id] >= self.task_map[task_id].max_retries ): logger.debug("Marking %s as permanently failed", block) orphans = self.ready_surface.mark_failure( block, count_all_orphans=self.count_all_orphans ) logger.debug("Number of orphans is %d", len(orphans)) self.task_states[block.task_id].failed_count += 1 for orphan in orphans: self.task_states[orphan.task_id].orphaned_count += 1 return {} else: logger.debug("Marking %s as temporarily failed", block) self.__queue_ready_block(block) self.task_queues[task_id].block_retries[block.block_id] += 1 return {task_id: self.task_states[task_id]} else: raise RuntimeError( f"Unexpected status for released block: {block.status}") def __init_task(self, task): if task.task_id not in self.task_map: self.task_map[task.task_id] = task num_blocks = self.dependency_graph.num_blocks(task.task_id) self.task_states[ task.task_id ].total_block_count = num_blocks for upstream_task in task.requires(): self.__init_task(upstream_task) def __queue_ready_block(self, block, index=None): if index is None: self.task_queues[block.task_id].ready_queue.append(block) else: self.task_queues[block.task_id].ready_queue.insert(index, block) self.task_states[block.task_id].ready_count += 1 def __remove_from_processing_blocks(self, block): self.task_queues[block.task_id].processing_blocks.remove( block.block_id) self.task_states[block.task_id].processing_count -= 1 def __update_ready_queue(self, ready_blocks): updated_tasks = {} for ready_block in ready_blocks: self.__queue_ready_block(ready_block) task_state = self.task_states[ready_block.task_id] updated_tasks[ready_block.task_id] = task_state return updated_tasks def __precheck(self, block): try: # pre_check can intermittently fail # so we wrap it in a try block if self.task_map[block.task_id].check_function is not None: return self.task_map[block.task_id].check_function(block) else: return False except Exception: logger.exception( f"pre_check() exception for block {block.block_id}") return False