#!/usr/bin/python
# encoding: utf-8
#
# Copyright 2011-2013 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tool to supervise launch other binaries."""


import errno
import getopt
import logging
import logging.handlers
import os
import random
import signal
import subprocess
import sys
import tempfile
import time


class Error(Exception):
    """Base error."""


class ExecuteError(Error):
    """Error executing."""


class OptionError(Error):
    """Option error."""


class TimeoutError(Error):
    """Timeout while execute() running."""


DEFAULT_ERROR_EXEC_EXIT_CODES = [1]
EXIT_STATUS_TIMEOUT = -99
KILL_WAIT_SECS = 1


class Supervisor(object):
    def __init__(self, delayrandom_abort=False):
        """Init.

        Args:
            delayrandom_abort: bool, default False. If True, sending
                a SIGUSR1 to the process will stop any initial delayrandom
                from continuing to countdown, and will immediately end the
                delay. Note that setting this on multiple Supervisor instances
                in one process might not work too well depending on the
                timing of the execute() calls, see below.
        """
        self.options = {
            'error-exec': None,
            'error-exec-exit-codes': None,
            'timeout': None,
            'delayrandom': None,
            'stdout': None,
            'stderr': None,
            'debug': None,
        }
        self.exit_status = None
        self.delayrandom_abort = delayrandom_abort

    def setOptions(self, **kwargs):
        for k in kwargs:
            self.options[k] = kwargs[k]

    def signalHandler(self, signum, dummy_frame):
        if signum == signal.SIGUSR1:
            self.continue_sleeping = False

    def execute(self, args):
        """Exec.

        Args:
            args: list, arguments to execute, args[0] is binary name
        """
        logging.debug('execute(%s)' % str(args))

        if self.delayrandom_abort:
            # A second Supervisor process will not take over the previous
            # Supervisor process who is holding this signal now.
            if signal.getsignal(signal.SIGUSR1) == signal.SIG_DFL:
                signal.signal(signal.SIGUSR1, self.signalHandler)
            self.continue_sleeping = True

        if 'delayrandom' in self.options and self.options['delayrandom']:
          max_secs = self.options['delayrandom']
          random_secs = random.randrange(0, max_secs)
          logging.debug(
              'Applying random delay up to %s seconds: %s',
              max_secs, random_secs)
          time.sleep(random_secs)

        if self.delayrandom_abort:
          if not self.continue_sleeping:
            logging.debug('Awoken from random delay by signal')
          signal.signal(signal.SIGUSR1, signal.SIG_DFL)

        if self.options['error-exec']:
            self.stdout = tempfile.NamedTemporaryFile()
            stdout_pipe = self.stdout
            self.stderr = tempfile.NamedTemporaryFile()
            stderr_pipe = self.stderr
            # Parse error-exec-exit-codes, or set default if not provided.
            exit_codes = self.options['error-exec-exit-codes']
            if exit_codes:
                self.error_exec_codes = [int(i) for i in exit_codes.split(',')]
            else:
                self.error_exec_codes = DEFAULT_ERROR_EXEC_EXIT_CODES
        else:
            stdout_pipe = None
            stderr_pipe = None

        try:
            proc = subprocess.Popen(
                args,
                preexec_fn=lambda: os.setpgid(os.getpid(), os.getpid()),
                stdout=stdout_pipe,
                stderr=stderr_pipe,
            )
        except OSError, e:
            self.exit_status = 127
            raise ExecuteError(str(e))

        self.exit_status = None
        self.continue_sleeping = True

        start_time = time.time()

        try:
            while 1:
                slept = 0

                exit_status = proc.poll()
                if exit_status is not None:
                    self.exit_status = exit_status
                    break

                if 'timeout' in self.options and self.options['timeout']:
                    if (time.time() - start_time) > self.options['timeout']:
                        raise TimeoutError

                # this loop is constructed this way, rather than using alarm or
                # something, to facilitate future features, e.g. pipe
                # stderr/stdout to syslog.
                if slept < 1:
                    time.sleep(1)
                    slept += 1

        except TimeoutError:
            logging.critical('Timeout error executing %s', ' '.join(args))
            self.killPid(proc.pid)
            self.exit_status = EXIT_STATUS_TIMEOUT
            raise


    def killPid(self, pid):
        """Kill a pid, aggressively if necessary."""

        exited = {}

        class __ChildExit(Exception):
            """Child exited."""

        def __sigchld_handler(signum, frame):
            if signum == signal.SIGCHLD:
                os.waitpid(pid, os.WNOHANG)
                exited[pid] = True

        try:
            signal.signal(signal.SIGCHLD, __sigchld_handler)
            logging.warning('Sending SIGTERM to %d', pid)
            os.kill(-1 * pid, signal.SIGTERM)  # *-1 = entire process group
            time.sleep(KILL_WAIT_SECS)
            if pid in exited:
                return
            logging.warning('Sending SIGKILL to %d', pid)
            os.kill(-1 * pid, signal.SIGKILL)
            time.sleep(KILL_WAIT_SECS)
        except OSError, e:
            if e.args[0] == errno.ESRCH:
                logging.warning('pid %d died on its own')
            else:
                logging.critical('killPid: %s', str(e))

        if pid in exited:
            return

        logging.debug('pid %d will not die', pid)

    def getExitStatus(self):
        return self.exit_status

    def cleanup(self):
        """Handle errors and call error-exec specified bin."""
        if not self.options['error-exec']:
            return

        if self.exit_status in self.error_exec_codes:
            did_timeout = int(bool(self.exit_status is EXIT_STATUS_TIMEOUT))
            arg_str = self.options['error-exec']
            arg_str = arg_str.replace('{EXIT}', str(self.exit_status))
            arg_str = arg_str.replace('{TIMEOUT}', str(did_timeout))
            arg_str = arg_str.replace('{STDOUT}', self.stdout.name)
            arg_str = arg_str.replace('{STDERR}', self.stderr.name)
            args = ('/bin/sh', '-c', arg_str)
            error_supv = Supervisor()
            error_supv.setOptions(timeout=5 * 3600)
            error_supv.execute(args)

        self.stdout.close()
        self.stdout = None
        self.stderr.close()
        self.stderr = None


def parseOpts(argv):
    """Parse argv and return options and arguments.

    Args:
        argv: list, all argv parameters
    Returns:
        (dict of options, list extra args besides options)
    """
    try:
        argopts, args = getopt.gnu_getopt(
            argv, '',
            [
                'timeout=', 'delayrandom=', 'debug', 'help',
                'error-exec=', 'error-exec-exit-codes=',
            ])
    except getopt.GetoptError, e:
        raise OptionError(str(e))

    options = {}
    for k, v in argopts:
        if k in ['--timeout', '--delayrandom']:
            options[k[2:]] = int(v)
        else:
            options[k[2:]] = v
    return options, args


def Usage():
    """Print usage."""
    print """supervisor [options] [--] [path to executable] [arguments]

    options:
        --timeout n
            after n seconds, terminate the executable
        --delayrandom n
            delay the execution of executable by random seconds up to n
        --error-exec "path and options string"
            exec path when executable returns non zero exit status.

            in this mode the stdout and stderr from the supervised
            executable are recorded to temp files.

            the path and options string can include tokens which will be
            replaced with values.  note the braces {} should be included.

                {EXIT} = exit status
                {TIMEOUT} = 1 or 0, timeout did or did not occur
                {STDOUT} = path to stdout file
                {STDERR} = path to stderr file

            the error-exec bin may use the stdin, stderr files while it is
            executing, but it should assume they will disappear when
            the error-exec bin returns with any exit status.

            the bin should not run more than 5 minutes or it will be
            terminated.
        --error-exec-exit-codes "1,100,203"
            comma-delimited list of integer exit status codes.  If the
            supervised script exits with one of these codes, the error-exec
            executable will be run.  Default: "1"
        --debug
            enable debugging output, all logs to stderr and not syslog.
        --help
            this text
        --
            use the -- to separate supervisor options from arguments to the
            executable which will appear as options.
    """


def processOpts(options, args):
    """Process options for validity etc.

    Args:
        options: dict, options
        args: list, extra args
    Returns:
        True if supervisor startup should occur, False if not.
    Raises:
        OptionError: if there is an error in options
    """
    if not args or options.get('help', None) is not None:
        Usage()
        return False
    if options.get('debug', None) is not None:
        logging.getLogger().setLevel(logging.DEBUG)
    return True


def setupSyslog():
    """Setup syslog as a logger."""
    logger = logging.getLogger()
    syslog = logging.handlers.SysLogHandler('/var/run/syslog')
    formatter = logging.Formatter(
        '%(filename)s[%(process)d]: %(levelname)s %(message)s')
    syslog.setFormatter(formatter)
    syslog.setLevel(logging.DEBUG)
    logger.addHandler(syslog)


def main(argv):
    try:
        options, args = parseOpts(argv[1:])
        if not processOpts(options, args):
            return 0
    except OptionError, e:
        logging.error(str(e))
        return 1

    if options.get('debug', None) is None:
        setupSyslog()

    try:
        sp = Supervisor(delayrandom_abort=True)
        sp.setOptions(**options)
    except Error, e:
        logging.exception('%s %s', e.__class__.__name__, str(e))
        return 1

    ex = 0

    try:
        sp.execute(args)
        ex = sp.getExitStatus()
    except TimeoutError, e:
        ex = 1
    except Error, e:
        logging.exception('%s %s', e.__class__.__name__, str(e))
        ex = 1

    sp.cleanup()
    return ex


if __name__ == '__main__':
    sys.exit(main(sys.argv))
