Source code for fermipy.jobs.scatter_gather

# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Abstract interface for parallel execution of multiple jobs.

The main class is `ScatterGather`, which can submit many instances
of a job with different configurations.
"""
from __future__ import absolute_import, division, print_function

import sys
import time

import numpy as np

from fermipy.jobs.batch import get_batch_job_interface
from fermipy.jobs.job_archive import JobStatus,\
    JobStatusVector, JobDetails, JobArchive, JOB_STATUS_STRINGS
from fermipy.jobs.link import extract_arguments, Link

from fermipy.jobs import defaults

ACTIONS = ['run', 'resubmit', 'check_status', 'config', 'skip', 'clean']


[docs]class ScatterGather(Link): """ Class to dispatch several jobs in parallel and collect and merge the results. Sub-classes will need to generatare configuration for the jobs that they launch. Parameters ---------- clientclass : type Type of `Link` object managed by this class. job_time : int Estimated maximum time it takes to run a job This is used to manage batch farm scheduling and checking for completion. """ appname = 'dummy-sg' usage = "%s [options]" % (appname) description = "Run multiple analyses" clientclass = None job_time = 1500 default_prefix_logfile = 'scatter' default_options = dict() default_options_base = dict(action=defaults.jobs['action'], dry_run=defaults.jobs['dry_run'], job_check_sleep=defaults.jobs['job_check_sleep'], print_update=defaults.jobs['print_update'], check_status_once=defaults.jobs['check_status_once']) def __init__(self, link, **kwargs): """C'tor Keyword arguments ----------------- interface : `SysInterface` subclass Object used to interface with batch system usage : str Usage string for argument parser description : str Description string for argument parser job_archive : `fermipy.job_archive.JobArchive` [optional] Archive used to track jobs and associated data products Defaults to None scatter : `fermipy.chain.Link` Link run for the scatter stage Defaults to None """ kwargs.setdefault('interface', get_batch_job_interface(self.job_time)) self._scatter_link = link self.default_options.update(self.default_options_base.copy()) Link.__init__(self, **kwargs) # Override the value of job_check_sleep to avoid excess waiting job_check_sleep = np.clip(int(self.job_time / 5), 60, 300) self.args['job_check_sleep'] = job_check_sleep self._job_configs = {} @property def scatter_link(self): """Return the `Link` object used the scatter phase of processing""" return self._scatter_link @classmethod def _make_scatter_logfile_name(cls, key, linkname, job_config): """Hook to inster the name of a logfile into the input config """ logfile = job_config.get('logfile', "%s_%s_%s.log" % (cls.default_prefix_logfile, linkname, key)) job_config['logfile'] = logfile
[docs] @classmethod def create(cls, **kwargs): """Build and return a `ScatterGather` object """ linkname = kwargs.setdefault('linkname', cls.clientclass.linkname_default) # Don't use setdefault b/c we don't want to build a JobArchive # Unless it is needed job_archive = kwargs.get('job_archive', None) if job_archive is None: job_archive = JobArchive.build_temp_job_archive() kwargs.setdefault('job_archive', job_archive) kwargs_client = dict(linkname=linkname, link_prefix=kwargs.get('link_prefix', ''), file_stage=kwargs.get('file_stage', None), job_archive=job_archive) link = cls.clientclass.create(**kwargs_client) sg = cls(link, **kwargs) return sg
[docs] @classmethod def main(cls): """Hook for command line interface to sub-classes """ link = cls.create() link._invoke(sys.argv[1:])
[docs] def build_job_configs(self, args): """Hook to build job configurations Sub-class implementation should return: job_configs : dict Dictionary of dictionaries passed to parallel jobs """ raise NotImplementedError("ScatterGather.build_job_configs")
def _latch_file_info(self): """Internal function to update the dictionaries keeping track of input and output files """ self.files.file_dict.clear() self.sub_files.file_dict.clear() self.files.latch_file_info(self.args) self._scatter_link._update_sub_file_dict(self.sub_files) def _check_link_completion(self, link, fail_pending=False, fail_running=False): """Internal function to check the completion of all the dispatched jobs Returns ------- status_vect : `JobStatusVector` Vector that summarize the number of jobs in various states. """ status_vect = JobStatusVector() for job_key, job_details in link.jobs.items(): # if job_details.status == JobStatus.failed: # failed = True # continue # elif job_details.status == JobStatus.done: # continue if job_key.find(JobDetails.topkey) >= 0: continue job_details.status = self._interface.check_job(job_details) if job_details.status == JobStatus.pending: if fail_pending: job_details.status = JobStatus.failed elif job_details.status == JobStatus.running: if fail_running: job_details.status = JobStatus.failed status_vect[job_details.status] += 1 link.jobs[job_key] = job_details link._set_status_self(job_details.jobkey, job_details.status) return status_vect def _build_job_dict(self): """Build a dictionary of `JobDetails` objects for the internal `Link`""" if self.args['dry_run']: status = JobStatus.unknown else: status = JobStatus.not_ready base_config = self.scatter_link.args for jobkey, job_config in sorted(self._job_configs.items()): full_job_config = base_config.copy() full_job_config.update(job_config) ScatterGather._make_scatter_logfile_name(jobkey, self.linkname, full_job_config) logfile = job_config.get('logfile') self._scatter_link._register_job(key=jobkey, job_config=full_job_config, logfile=logfile, status=status) def _run_link(self, stream=sys.stdout, dry_run=False, stage_files=True, resubmit_failed=False): """Internal function that actually runs this link. This checks if input and output files are present. If input files are missing this will raise `OSError` if dry_run is False If all output files are present this will skip execution. Parameters ----------- stream : `file` Stream that this `Link` will print to, must have 'write' function. dry_run : bool Print command but do not run it. stage_files : bool Stage files to and from the scratch area. resubmit_failed : bool Resubmit failed jobs. """ if resubmit_failed: self.args['action'] = 'resubmit' argv = self._make_argv() if dry_run: argv.append('--dry_run') self._invoke(argv, stream) def _invoke(self, argv, stream=sys.stdout): """Invoke this object to preform a particular action Parameters ---------- argv : list List of command line arguments, passed to helper classes stream : `file` Stream that this function will print to, must have 'write' function. Returns ------- status_vect : `JobStatusVector` Vector that summarize the number of jobs in various states. """ args = self._run_argparser(argv) if args.action not in ACTIONS: sys.stderr.write( "Unrecognized action %s, options are %s\n" % (args.action, ACTIONS)) if args.action == 'skip': return JobStatus.no_job elif args.action in ['run', 'resubmit', 'check_status', 'config']: self._job_configs = self.build_job_configs(args.__dict__) self._interface._dry_run = args.dry_run if args.action == 'run': status_vect = self.run_jobs(stream) elif args.action == 'resubmit': status_vect = self.resubmit(stream) elif args.action == 'check_status': self._build_job_dict() status_vect = self.check_status(stream) elif args.action == 'config': self._build_job_dict() status_vect = JobStatusVector() status_vect[JobStatus.done] += 1 return status_vect
[docs] def update_args(self, override_args): """Update the arguments used to invoke the application Note that this will also update the dictionary of input and output files Parameters ---------- override_args : dict dictionary of arguments to override the current values """ self.args = extract_arguments(override_args, self.args) self._job_configs = self.build_job_configs(self.args) if not self._scatter_link.jobs: self._build_job_dict() self._latch_file_info()
[docs] def clear_jobs(self, recursive=True): """Clear the self.jobs dictionary that contains information about jobs associated with this `ScatterGather` If recursive is True this will include jobs from all internal `Link` """ if recursive: self._scatter_link.clear_jobs(recursive) self.jobs.clear()
[docs] def get_jobs(self, recursive=True): """Return a dictionary with all the jobs If recursive is True this will include jobs from all internal `Link` """ if recursive: ret_dict = self.jobs.copy() ret_dict.update(self._scatter_link.get_jobs(recursive)) return ret_dict return self.jobs
[docs] def check_status(self, stream=sys.stdout, check_once=False, fail_pending=False, fail_running=False, no_wait=False, do_print=True, write_status=False): """Loop to check on the status of all the jobs in job dict. Parameters ----------- stream : `file` Stream that this function will print to, Must have 'write' function. check_once : bool Check status once and exit loop. fail_pending : `bool` If True, consider pending jobs as failed fail_running : `bool` If True, consider running jobs as failed no_wait : bool Do not sleep before checking jobs. do_print : bool Print summary stats. write_status : bool Write the status the to log file. Returns ------- status_vect : `JobStatusVector` Vector that summarize the number of jobs in various states. """ running = True first = True if not check_once: if stream != sys.stdout: sys.stdout.write('Checking status (%is): ' % self.args['job_check_sleep']) sys.stdout.flush() status_vect = JobStatusVector() while running: if first: first = False elif self.args['dry_run']: break elif no_wait: pass else: stream.write("Sleeping %.0f seconds between status checks\n" % self.args['job_check_sleep']) if stream != sys.stdout: sys.stdout.write('.') sys.stdout.flush() time.sleep(self.args['job_check_sleep']) status_vect = self._check_link_completion(self._scatter_link, fail_pending, fail_running) if self.args['check_status_once'] or check_once or no_wait: if do_print: self.print_update(stream, status_vect) break if self.args['print_update']: if do_print: self.print_update(stream, status_vect) if self._job_archive is not None: self._job_archive.write_table_file() n_total = status_vect.n_total n_done = status_vect.n_done n_failed = status_vect.n_failed if n_done + n_failed == n_total: running = False status = status_vect.get_status() if status in [JobStatus.failed, JobStatus.partial_failed]: if do_print: self.print_update(stream, status_vect) self.print_failed(stream) if write_status: self._write_status_to_log(status, stream) else: if write_status: self._write_status_to_log(0, stream) self._set_status_self(status=status) if not check_once: if stream != sys.stdout: sys.stdout.write("! %s\n" % (JOB_STATUS_STRINGS[status])) if self._job_archive is not None: self._job_archive.write_table_file() return status_vect
[docs] def run_jobs(self, stream=sys.stdout): """Function to dipatch jobs and collect results Parameters ----------- stream : `file` Stream that this function will print to, Must have 'write' function. Returns ------- status_vect : `JobStatusVector` Vector that summarize the number of jobs in various states. """ self._build_job_dict() self._interface._dry_run = self.args['dry_run'] scatter_status = self._interface.submit_jobs(self.scatter_link, job_archive=self._job_archive, stream=stream) if scatter_status == JobStatus.failed: return JobStatus.failed status_vect = self.check_status(stream, write_status=True) return status_vect
[docs] def resubmit(self, stream=sys.stdout, fail_running=False): """Function to resubmit failed jobs and collect results Parameters ----------- stream : `file` Stream that this function will print to, Must have 'write' function. fail_running : `bool` If True, consider running jobs as failed Returns ------- status_vect : `JobStatusVector` Vector that summarize the number of jobs in various states. """ self._build_job_dict() status_vect = self.check_status(stream, check_once=True, fail_pending=True, fail_running=fail_running) status = status_vect.get_status() if status == JobStatus.done: return status failed_jobs = self._scatter_link.get_failed_jobs(True, True) if failed_jobs: scatter_status = self._interface.submit_jobs(self._scatter_link, failed_jobs, job_archive=self._job_archive, stream=stream) if scatter_status == JobStatus.failed: return JobStatus.failed status_vect = self.check_status(stream, write_status=True) if self.args['dry_run']: return JobStatus.unknown return status_vect
[docs] def clean_jobs(self, recursive=False): """Clean up all the jobs associated with this object. If recursive is True this also clean jobs dispatch by this object.""" self._interface.clean_jobs(self.scatter_link, clean_all=recursive)
[docs] def run(self, stream=sys.stdout, dry_run=False, stage_files=True, resubmit_failed=True): """Runs this `Link`. This version is intended to be overwritten by sub-classes so as to provide a single function that behaves the same for all version of `Link` Parameters ----------- stream : `file` Stream that this `Link` will print to, Must have 'write' function dry_run : bool Print command but do not run it. stage_files : bool Copy files to and from scratch staging area. resubmit_failed : bool Flag for sub-classes to resubmit failed jobs. """ self._run_link(stream, dry_run, stage_files, resubmit_failed)
[docs] def print_summary(self, stream=sys.stdout, indent="", recurse_level=2): """Print a summary of the activity done by this `Link`. Parameters ---------- stream : `file` Stream to print to indent : str Indentation at start of line recurse_level : int Number of recursion levels to print """ Link.print_summary(self, stream, indent, recurse_level) if recurse_level > 0: recurse_level -= 1 indent += " " stream.write("\n") self._scatter_link.print_summary(stream, indent, recurse_level)
[docs] def print_update(self, stream=sys.stdout, job_stats=None): """Print an update about the current number of jobs running """ if job_stats is None: job_stats = JobStatusVector() job_det_list = [] job_det_list += self._scatter_link.jobs.values() for job_dets in job_det_list: if job_dets.status == JobStatus.no_job: continue job_stats[job_dets.status] += 1 stream.write("Status :\n Total : %i\n Unknown: %i\n" % (job_stats.n_total, job_stats[JobStatus.unknown])) stream.write(" Not Ready: %i\n Ready: %i\n" % (job_stats[JobStatus.not_ready], job_stats[JobStatus.ready])) stream.write(" Pending: %i\n Running: %i\n" % (job_stats[JobStatus.pending], job_stats[JobStatus.running])) stream.write(" Done: %i\n Failed: %i\n" % (job_stats[JobStatus.done], job_stats[JobStatus.failed]))
[docs] def print_failed(self, stream=sys.stderr): """Print list of the failed jobs """ for job_key, job_details in sorted(self.scatter_link.jobs.items()): if job_details.status == JobStatus.failed: stream.write("Failed job %s\n log = %s\n" % (job_key, job_details.logfile))
[docs] def run_analysis(self, argv): """Implemented by sub-classes to run a particular analysis""" raise RuntimeError("run_analysis called for ScatterGather type object")