#!/usr/bin/env python3
# coding: utf-8

# Copyright (C) 1994-2021 Altair Engineering, Inc.
# For more information, contact Altair at www.altair.com.
#
# This file is part of both the OpenPBS software ("OpenPBS")
# and the PBS Professional ("PBS Pro") software.
#
# Open Source License Information:
#
# OpenPBS is free software. You can redistribute it and/or modify it under
# the terms of the GNU Affero General Public License as published by the
# Free Software Foundation, either version 3 of the License, or (at your
# option) any later version.
#
# OpenPBS is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public
# License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Commercial License Information:
#
# PBS Pro is commercially licensed software that shares a common core with
# the OpenPBS software.  For a copy of the commercial license terms and
# conditions, go to: (http://www.pbspro.com/agreement.html) or contact the
# Altair Legal Department.
#
# Altair's dual-license business model allows companies, individuals, and
# organizations to create proprietary derivative works of OpenPBS and
# distribute them - whether embedded or bundled with other software -
# under a commercial license agreement.
#
# Use of Altair's trademarks, including but not limited to "PBS™",
# "OpenPBS®", "PBS Professional®", and "PBS Pro™" and Altair's logos is
# subject to Altair's trademark licensing policies.


import os
import sys
import getopt
import errno
import logging
import ptl
import time
import tarfile

from getopt import GetoptError
from threading import Thread
from pathlib import Path

from ptl.lib.pbs_testlib import PtlConfig
from ptl.utils.pbs_snaputils import PBSSnapUtils, ObfuscateSnapshot
from ptl.utils.pbs_cliutils import CliUtils
from ptl.utils.pbs_dshutils import DshUtils


def trap_exceptions(etype, value, tb):
    """
    Trap SIGINT and SIGPIPE
    """
    # This is done so that any exceptions created by this method itself
    # are caught by the default excepthook to prevent endless recursion
    sys.excepthook = sys.__excepthook__

    if issubclass(etype, KeyboardInterrupt):
        pass
    elif issubclass(etype, IOError) and value.errno == errno.EPIPE:
        pass
    else:
        sys.__excepthook__(etype, value, tb)

    # Set sys.excepthook back to trap_exceptions to catch future exceptions
    sys.excepthook = trap_exceptions


sys.excepthook = trap_exceptions


def usage():
    msg = """
Usage: pbs_snapshot -o <path to existing output directory> [OPTION]

    Take snapshot of a PBS system and optionally capture logs for diagnostics

    -H <hostname>                     primary hostname to operate on
                                      Defaults to local host
    -l <loglevel>                     set log level to one of INFO, INFOCLI,
                                      INFOCLI2, DEBUG, DEBUG2, WARNING, ERROR
                                      or FATAL
    -h, --help                        display this usage message
    --basic                           Capture only basic config & state data
    --daemon-logs=<num days>          number of daemon logs to collect
    --accounting-logs=<num days>      number of accounting logs to collect
    --additional-hosts=<hostname>     collect data from additional hosts
                                      'hostname' is a comma separated list
    --map=<file>                      file to store the map of obfuscated data
    --obfuscate                       obfuscates sensitive data
    --with-sudo                       Uses sudo to capture privileged data
    --version                         print version number and exit
    --obf-snap                        path to existing snapshot to obfuscate
"""
    print(msg)


def create_snapshot_tar(snappath):
    """
    Create a compressed tar for the snapshot given
    Warning: Deletes the snapshot directory after creating its tar

    snappath: path to snapshot to compress and tar
    type: str

    return: str - path to the compressed tar file
    """
    outtar = snappath + ".tgz"
    with tarfile.open(outtar, "w:gz") as tar:
        tar.add(snappath, arcname=os.path.basename(snappath))

    # Delete the snapshot directory itself
    du.rm(path=snappath, recursive=True, force=True)

    return outtar


def get_snapshot_from_tar(logger, tarpath):
    """
    Extract the snapshot from tarfile given

    :tarpath - path to the tar
    :type - str

    :return tuple of (string path to extracted snapshot, True/False),
        True: if the extraction was a success
        False: if there was already a snapshot by the name of the one
            being extracted, so the tar was not untar'd
    """
    parentdir = Path(tarpath).parent
    with tarfile.open(tarpath) as tar:
        main_snap = tar.getnames()[0].split(os.sep, 1)[0]
        main_snap = os.path.join(os.path.abspath(parentdir), main_snap)
        if os.path.isdir(main_snap):
            logger.error("Existing snapshot %s found, cannot extract tar" %
                         main_snap)
            return (main_snap, False)
        tar.extractall(path=parentdir)

    return (main_snap, True)


def remotesnap_thread(logger, host):
    """
    Routine to capture snapshot from a remote host

    :param logger - Logging object
    :type logger - logging.Logger
    :param host - the hostname for remote host
    :type host - str
    """
    logger.info("Capturing snapshot from host %s" % (host))

    du = DshUtils()

    # Get path to pbs_snapshot on remote host
    host_pbsconf = du.parse_pbs_config(hostname=host)
    try:
        pbs_exec_path = host_pbsconf["PBS_EXEC"]
    except KeyError:
        logger.error("Couldn't find PBS_EXEC on host %s"
                     ", won't capture snapshot on this host" % (
                         host))
        return
    host_pbssnappath = os.path.join(pbs_exec_path, "sbin",
                                    "pbs_snapshot")

    # Create a directory on the remote host with a unique name
    # We will create the snapshot here
    timestamp = str(int(time.time()))
    snap_home = "host_" + timestamp
    du.mkdir(hostname=host, path=snap_home)

    # Run pbs_snapshot on the remote host
    cmd = [host_pbssnappath, "-o", snap_home,
           "--daemon-logs=" + str(daemon_logs),
           "--accounting-logs=" + str(acct_logs)]
    if obfuscate:
        cmd.extend(["--obfuscate", "--map=" + map_file])
    if with_sudo:
        cmd.append("--with-sudo")

    ret = du.run_cmd(hosts=host, cmd=cmd, logerr=False)
    if ret['rc'] != 0:
        logger.error("Error capturing snapshot from host %s" % (host))
        print(ret['err'])
        return

    # Get the snapshot tar filename from stdout
    child_stdout = ret['out'][-1]
    snaptarname = child_stdout.split("Snapshot available at: ")[1]

    # Copy over the snapshot tar file as <hostname>_snapshot.tgz
    dest_path = os.path.join(out_dir, host + "_snapshot.tgz")
    ret = du.run_copy(srchost=host, src=snaptarname, dest=dest_path)
    if ret['rc'] != 0:
        logger.error("Error copying child snapshot from host %s" % (host))

    # Copy over map file if any as 'host_<map filename>'
    if map_file is not None:
        mfilename = os.path.basename(map_file)
        dest_path = os.path.join(out_dir, host + "_" + mfilename)
        src_path = os.path.join(snap_home, map_file)
        ret = du.run_copy(srchost=host, src=src_path, dest=dest_path)
        if ret['rc'] != 0:
            logger.error("Error copying map file from host %s" % (host))

    # Delete the snapshot home from remote host
    du.rm(hostname=host, path=snap_home, recursive=True, force=True)


def obfuscate_snapshot_wrapper(snap_name, map_file=None, sudo_val=False):
    if map_file is None:
        snappath = Path(snap_name)
        if snappath is None:
            sys.stderr.write("snapshot path not found")
            usage()
            sys.exit(1)
        map_file = os.path.join(snappath.parent, "obfuscate.map")
    obj = ObfuscateSnapshot()
    obj.obfuscate_snapshot(snap_name, map_file, sudo_val)


def capture_local_snap(sudo_val):
    """
    Helper method to capture snapshot of the local host

    :param sudo_val - Value of the --with-sudo option
    :type sudo_val - bool

    :returns Name of the snapshot directory/tar file captured
    """
    if obfuscate or additional_hosts is not None:
        # We will need to the captured snapshot directory in these 2 cases
        create_tar = False
    else:
        create_tar = True

    with PBSSnapUtils(out_dir, basic=basic, acct_logs=acct_logs,
                      daemon_logs=daemon_logs, create_tar=create_tar,
                      log_path=log_path, with_sudo=sudo_val) as snap_utils:
        snap_name = snap_utils.capture_all()

    if obfuscate:
        obfuscate_snapshot_wrapper(snap_name, map_file, sudo_val)
        if additional_hosts is None:
            # Now create the tar
            snap_name = create_snapshot_tar(snap_name)

    return snap_name


if __name__ == '__main__':

    # Arguments to PBSSnapUtils
    out_dir = None
    primary_host = None
    log_level = "INFOCLI2"
    acct_logs = None
    daemon_logs = None
    additional_hosts = None
    map_file = None
    obfuscate = False
    log_file = "pbs_snapshot.log"
    with_sudo = False
    du = DshUtils()
    basic = False
    obf_snap = None

    PtlConfig()

    # Parse the options provided to pbs_snapshot
    try:
        sopt = "d:H:l:o:h"
        lopt = ["basic", "accounting-logs=", "daemon-logs=", "help",
                "additional-hosts=", "map=", "obfuscate", "with-sudo",
                "version", "obf-snap="]
        opts, args = getopt.getopt(sys.argv[1:], sopt, lopt)
    except GetoptError:
        usage()
        sys.exit(1)

    for o, val in opts:
        if o == "-o":
            out_dir = val
        elif o == "-H":
            primary_host = val
        elif o == "-l":
            log_level = val
        elif o == "-h" or o == "--help":
            usage()
            sys.exit(0)
        elif o == "--basic":
            basic = True
        elif o == "--accounting-logs":
            try:
                acct_logs = int(val)
            except ValueError:
                raise ValueError("Invalid value for --accounting-logs" +
                                 "option, should be an integer")
        elif o == "--daemon-logs":
            try:
                daemon_logs = int(val)
            except ValueError:
                raise ValueError("Invalid value for --daemon-logs" +
                                 "option, should be an integer")
        elif o == "--additional-hosts":
            additional_hosts = val
        elif o == "--map":
            map_file = val
        elif o == "--obfuscate":
            obfuscate = True
        elif o == "--with-sudo":
            with_sudo = True
        elif o == "--version":
            print(ptl.__version__)
            sys.exit(0)
        elif o == "--obf-snap":
            obf_snap = val
        else:
            sys.stderr.write("Unrecognized option")
            usage()
            sys.exit(1)

    if obf_snap:
        obfuscate = True
        path = Path(obf_snap)
        if path:
            out_dir = str(path.parent)

    # Check that parent of snapshot directory exists
    if out_dir is None:
        sys.stderr.write("-o option not provided")
        usage()
        sys.exit(1)
    elif not os.path.isdir(out_dir):
        sys.stderr.write("-o path should exist,"
                         " this is where the snapshot is captured")
        usage()
        sys.exit(1)

    fmt = '%(asctime)-15s %(levelname)-8s %(message)s'
    level_int = CliUtils.get_logging_level(log_level)
    log_path = os.path.join(out_dir, log_file)
    logging.basicConfig(filename=log_path, filemode='w+',
                        level=level_int, format=fmt)
    stream_hdlr = logging.StreamHandler()
    stream_hdlr.setLevel(level_int)
    stream_hdlr.setFormatter(logging.Formatter(fmt))
    ptl_logger = logging.getLogger('ptl')
    ptl_logger.addHandler(stream_hdlr)
    ptl_logger.setLevel(level_int)

    if obfuscate is True:
        # find the parent directory of the snapshot
        # This will be used to store the map file
        out_abspath = os.path.abspath(out_dir)
        if map_file is None:
            map_file = os.path.join(out_abspath, "obfuscate.map")

    # Obfuscate an existing snapshot
    if obf_snap:
        istar = False
        if os.path.isfile(obf_snap):
            if tarfile.is_tarfile(obf_snap):
                istar = True
                obf_snap, ret = get_snapshot_from_tar(ptl_logger, obf_snap)
                if ret is False:
                    sys.exit(1)
            else:
                ptl_logger.error("Path is not a valid snapshot")
                sys.exit(1)

        # The obfuscated snapshot will be at path: <obf_snap>_obf
        obfout = obf_snap + "_obf"
        if os.path.isfile(obfout + ".tgz"):
            ptl_logger.error("%s.tgz already exists, "
                             "delete it to create a new obfuscated snapshot" %
                             obfout)
            if istar:
                # If input was a tar file, then we'd have extracted it
                # So, delete the dir that was extracted from the input tar
                du.rm(path=obf_snap, force=True, recursive=True)
            sys.exit(1)
        du.run_copy(src=obf_snap, dest=obfout, recursive=True)
        obfuscate_snapshot_wrapper(obfout, map_file, with_sudo)
        # Create the tar
        obfout = create_snapshot_tar(obfout)
        if istar:
            du.rm(path=obf_snap, force=True, recursive=True)

        print("Obfuscated snapshot at: " + str(obfout))
        sys.exit(0)

    if not basic:
        # Capture 5 days of daemon logs and 30 days of accounting logs
        # by default
        if daemon_logs is None:
            daemon_logs = 5
        if acct_logs is None:
            acct_logs = 30

    if additional_hosts is not None:
        # Capture snapshot of remote hosts in addition to the main host

        hostnames = additional_hosts.split(",")
        remotesnap_threads = {}
        for host in hostnames:
            thread = Thread(target=remotesnap_thread, args=(
                ptl_logger, host))
            thread.start()
            remotesnap_threads[host] = thread

        # Capture snapshot of the main host in the meantime
        thread_p = None
        if not du.is_localhost(primary_host):
            # The main host is remote
            thread_p = Thread(target=remotesnap_thread,
                              args=(ptl_logger, primary_host))
            thread_p.start()
        else:
            # Capture a local snapshot
            main_snap = capture_local_snap(with_sudo)

        if thread_p is not None:
            # Let's get the main host's snapshot first
            thread_p.join()

            # We need to copy additional hosts' snapshosts in
            # the main snapshot, so un-tar the snapshot
            p_snappath = os.path.join(out_dir, primary_host + "_snapshot.tgz")
            main_snap, _ = get_snapshot_from_tar(ptl_logger, p_snappath)

            # Delete the tar file
            du.rm(path=p_snappath, force=True)

        # Let's reconcile the child snapshots
        for host, thread in remotesnap_threads.items():
            thread.join()
            host_snappath = os.path.join(out_dir, host + "_snapshot.tgz")
            if os.path.isfile(host_snappath):
                # Move the tar file to the main snapshot
                du.run_copy(src=host_snappath, dest=main_snap)
                du.rm(path=host_snappath, force=True)

        # Finally, create a tar of the whole snapshot
        outtar = create_snapshot_tar(main_snap)
    elif not du.is_localhost(primary_host):
        # Capture snapshot of a remote host
        remotesnap_thread(ptl_logger, primary_host)
        p_snappath = os.path.join(out_dir, primary_host + "_snapshot.tgz")
        with tarfile.open(p_snappath) as tar:
            main_snap = tar.getnames()[0].split(os.sep, 1)[0]
            main_snap = os.path.join(os.path.abspath(out_dir), main_snap)

        outtar = main_snap + ".tgz"

        # remotesnap_thread named the tar as <hostname>_snapshot.tgz
        # rename it to the timestamp name of the original snapshot
        du.run_copy(src=p_snappath, dest=outtar)
        du.rm(path=p_snappath, force=True)
    else:
        # Capture snapshot of the local host
        outtar = capture_local_snap(with_sudo)

    if outtar is not None:
        print("Snapshot available at: " + outtar)
