#!/usr/bin/env python3
# coding: utf-8

# Copyright (C) 1994-2021 Altair Engineering, Inc.
# For more information, contact Altair at www.altair.com.
#
# This file is part of both the OpenPBS software ("OpenPBS")
# and the PBS Professional ("PBS Pro") software.
#
# Open Source License Information:
#
# OpenPBS is free software. You can redistribute it and/or modify it under
# the terms of the GNU Affero General Public License as published by the
# Free Software Foundation, either version 3 of the License, or (at your
# option) any later version.
#
# OpenPBS is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public
# License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Commercial License Information:
#
# PBS Pro is commercially licensed software that shares a common core with
# the OpenPBS software.  For a copy of the commercial license terms and
# conditions, go to: (http://www.pbspro.com/agreement.html) or contact the
# Altair Legal Department.
#
# Altair's dual-license business model allows companies, individuals, and
# organizations to create proprietary derivative works of OpenPBS and
# distribute them - whether embedded or bundled with other software -
# under a commercial license agreement.
#
# Use of Altair's trademarks, including but not limited to "PBS™",
# "OpenPBS®", "PBS Professional®", and "PBS Pro™" and Altair's logos is
# subject to Altair's trademark licensing policies.


import getopt
import logging
import errno

import ptl
from ptl.lib.pbs_testlib import *
from ptl.utils.pbs_testsuite import PBS_GROUPS

# trap SIGINT and SIGPIPE


def trap_exceptions(etype, value, tb):
    sys.excepthook = sys.__excepthook__
    if issubclass(etype, KeyboardInterrupt):
        pass
    elif issubclass(etype, IOError) and value.errno == errno.EPIPE:
        pass
    else:
        sys.__excepthook__(etype, value, tb)


sys.excepthook = trap_exceptions


def usage():
    msg = []
    msg += ['Usage: ' + os.path.basename(sys.argv[0]) + ' [OPTION]\n\n']
    msg += ['-t <hostnames>: comma-separated hosts to operate on. Defaults '
            'to localhost\n']
    msg += ['-l <log level>: one of DEBUG, INFO, ERROR, FATAL, WARNING\n']
    msg += ['\n']
    msg += ['--log-conf=<file>: logging config file\n']
    msg += ['--snap=<pbs_snapshot>: Mimic pbs_snapshot\n']
    msg += ['\t --acct-logs=<path to acct logs>: path to accounting logs'
            ', used to create users & groups\n']
    msg += ['--revert-config: revert services to their default ' +
            'configuration\n']
    msg += ['\t --scheduler: operate on scheduler\n']
    msg += ['\t --server: operate on server\n']
    msg += ['\t --mom: operate on MoM\n']
    msg += ['\t --del-hooks=<True|False>: If True delete hooks.']
    msg += [' Defaults to True\n']
    msg += ['\t --del-queues=<True|False>: Delete non-default queues.']
    msg += [' Defaults to True\n']
    msg += ['--save-config=<config>: save configuration to file\n']
    msg += ['\t revert-config and save-config can operate on the following\n']
    msg += ['\t --scheduler: operate on scheduler\n']
    msg += ['\t --server: operate on server\n']
    msg += ['\t --mom: operate on MoM\n']
    msg += ['--load-config=<config>: load configuration from saved file.\n']
    msg += ['\n']

    msg += ['--vnodify: define vnodes using the following suboptions:\n']
    msg += ['\t-a <attrs>: comma separated list of attributes to set ' +
            'on vnodes.\n']
    msg += ['\t            format: <name>=<value>. Defaults to 8 cpus ' +
            '8gb of mem.\n']
    msg += ['\t-A: set additive mode, leave vnode definitions in ' +
            'place. \n']
    msg += ['\t    Default is to clear all existing vnode definition ' +
            'files.\n']
    msg += ['\t-d <y|n>: if y, delete all server nodes. Defaults to y.\n']
    msg += ['\t-f <filename>: use output of pbsnodes -av from file as ' +
            'definition\n']
    msg += ['\t-P <num>: number of vnodes per host\n']
    msg += ['\t-o <filename>: output vnode definition to filename\n']
    msg += ['\t-M <mom>: MoM to operate on, format <host>@<path/to/conf>.\n'
            '\t          Defaults to localhost.\n']
    msg += ['\t-N <num vnodes>: number of vnodes to create. No default.\n']
    msg += ['\t-n <name>: name of the natural vnode to create. ' +
            'Defaults to MoM FQDN\n']
    msg += ['\t-p <name>: prefix of name of node to create. ' +
            'Output format: \n']
    msg += ['\t           prefix followed by [<num]. Defaults to vnode\n']
    msg += ['\t-r <y|n>: restart MoM or not, defaults to y\n']
    msg += ['\t-s: if set, share vnodes on the host. ' +
            'Default is "standalone" hosts\n']
    msg += ['\t-u: if set, allocate the natural vnode\n']
    msg += ['\n']

    msg += ['--multi-mom: Define and create multiple MoMs on a host\n']
    msg += ['\t--create=<num>: number of MoMs to create. No default.\n']
    msg += ['\t--restart=<[seq]>: restart MoMs in sequence\n']
    msg += ['\t--stop=<[seq]>: stop MoMs in sequence\n']
    msg += ['\t--serverhost=<host>: hostname of server, defaults to '
            'localhost\n']
    msg += ['\t--home-prefix=<path>: prefix to PBS_HOME directory, defaults\n'
            '\t                      to /var/spool/PBS_m\n']
    msg += ['\t--conf-prefix=<path>: prefix to pbs.conf file. Defaults to\n'
            '\t                      /etc/pbs.conf.m\n']
    msg += ['\t--init-port=<number>: initial port to allocate. Defaults to\n'
            '\t                      15011\n']
    msg += ['\t--step-port=<number>: step for port sequence. Defaults to 2\n']
    msg += ['\n']
    msg += ['--switch-version=<version>: switch to a given installed ' +
            'version of PBS\n']
    msg += ['\tcurrently only works for "vanilla" installs, i.e, not '
            'developer installs\n']
    msg += ['\tbased on /etc/pbs.conf and "default" PBS_EXEC\n']
    msg += ['\n']
    msg += ['--check-ug: verifies whether test users and groups are ']
    msg += [' defined as expected.\n               Note that -t option '
            'will be ignored.\n']
    msg += ['--make-ug: create users and groups to match what is expected\n']
    msg += ['           Note that -t option will be ignored\n']
    msg += ['--del-ug: delete users and groups which is expected for PTL\n']
    msg += ['           Note that -t option will be ignored\n']
    msg += ['--version: print version number and exit\n']

    print("".join(msg))


def process_config(hosts, process_obj, conf_file=None, type='default',
                   delqueues=False, delhooks=False):
    for host in hosts:
        svr_obj = Server(host)
        if MGR_OBJ_SCHED in process_obj:
            if type == 'default':
                Scheduler(svr_obj, host).revert_to_defaults()
            elif type == 'load':
                Scheduler(svr_obj, host).load_configuration(conf_file)
            elif type == 'save':
                Scheduler(svr_obj, host).save_configuration(conf_file)
        if MGR_OBJ_SERVER in process_obj:
            if type == 'default':
                Server(host).revert_to_defaults(delhooks=delhooks,
                                                delqueues=delqueues)
            elif type == 'load':
                Server(host).load_configuration(conf_file)
            elif type == 'save':
                Server(host).save_configuration(conf_file)
        if MGR_OBJ_NODE in process_obj:
            if type == 'default':
                MoM(svr_obj, host).revert_to_defaults()
            elif type == 'load':
                MoM(svr_obj, host).load_configuration(conf_file)
            elif type == 'save':
                MoM(svr_obj, host).save_configuration(conf_file)


def process_attributes(attrs):
    nattrs = {}
    for a in attrs.split(','):
        if '=' not in a:
            logging.error('attributes must be of the form' +
                          ' <name>=<value>')
            sys.exit(1)
        k, v = a.split('=')
        nattrs[k] = v
    return nattrs


def common_users_groups_ops():
    du = DshUtils()
    g_create = []
    u_create = []
    gm_expected = {}
    gm_actual = du.group_memberships([str(g) for g in PBS_GROUPS])
    for g in PBS_GROUPS:
        gm_expected[g] = g.users
    for k, v in gm_expected.items():
        if str(k) not in gm_actual:
            g_create.append(k)
            for _u in v:
                if _u not in u_create:
                    u_create.append(_u)
        else:
            for _u in v:
                if ((str(_u) not in gm_actual[str(k)]) and
                        (_u not in u_create)):
                    u_create.append(_u)
    return (gm_expected, gm_actual, g_create, u_create)


def check_users_groups():
    gm_expected, gm_actual, g_create, u_create = common_users_groups_ops()
    if ((len(g_create) > 0) or (len(u_create) > 0)):
        out = ['Expected (format is <group name>: <user> [, <user2>...) ']
        for k, v in gm_expected.items():
            out += [str(k) + ': ' + ', '.join([str(u) for u in v])]
        out += ['\n', 'Actual: ']
        for k, v in gm_actual.items():
            out += [k + ': ' + ', '.join(v)]
        print('\n'.join(out))
        return False
    else:
        return True


def make_users_groups():
    du = DshUtils()
    _, _, g_create, u_create = common_users_groups_ops()
    for g in g_create:
        du.groupadd(g, g.gid, logerr=False)
    for u in u_create:
        du.useradd(name=u, uid=u.uid, gid=u.groups[0], groups=u.groups,
                   logerr=False)
    return True


def delete_users_groups():
    du = DshUtils()
    _, gm_actual, _, _ = common_users_groups_ops()
    for v in gm_actual.values():
        for u in v:
            du.userdel(u, logerr=False)
    for k in gm_actual.keys():
        du.groupdel(k, logerr=False)
    return True


if __name__ == '__main__':

    if len(sys.argv) < 2:
        usage()
        sys.exit(0)

    # vnodify options
    vnodify = False
    vnodeprefix = 'vnode'
    num_vnodes = None
    additive = False
    sharedhost = False
    filename = None
    attrs = "resources_available.ncpus=8,resources_available.mem=8gb"
    hostname = None
    conf_file = None
    restart = True
    delall = True
    natvnode = None
    usenatvnode = False
    vdefname = None
    # end of vnodify options

    hosts = None
    revert = False
    op = None
    loadconf = None
    saveconf = None
    logconf = None
    vnodes_per_host = 1
    delqueues = True
    delhooks = True
    lvl = logging.INFO

    switchversion = None

    check_ug = False
    make_ug = False
    del_ug = False

    multimom = False
    num_moms = None
    restart_moms = None
    stop_moms = None
    clienthost = None
    serverhost = None
    init_port = 15011
    step_port = 2

    import_jobs = False
    home_prefix = 'PBS_m'
    conf_prefix = 'pbs.conf_m'
    acct_logs = None

    as_snap = None

    process_obj = []
    vnodify_args = "a:d:f:N:n:o:P:p:M:l:r:v:Asu"
    generic_args = "l:t:h"
    largs = ["scheduler", "server", "mom", "revert-config", "load-config=",
             "save-config=", "vnodify", "import-jobs", "del-ug",
             "switch-version=", "log-conf=", "check-ug", "del-hooks=",
             "del-queues=", "version", "make-ug", "multi-mom", "clienthost=",
             "home-prefix=", "conf-prefix=", "serverhost=", "init-port=",
             "step-port=", "snap=", "acct-logs=", "create=", "restart=",
             "stop="]

    try:
        opts, args = getopt.getopt(sys.argv[1:], vnodify_args + generic_args,
                                   largs)
    except getopt.GetoptError:
        usage()
        sys.exit(1)

    for o, val in opts:
        if o == '-l':
            lvl = CliUtils().get_logging_level(val)
        elif o == '-t':
            hosts = val
        elif o == '-a':
            attrs = val
        elif o == '-A':
            additive = True
        elif o == '-d':
            if val.startswith('y'):
                delall = True
            else:
                delall = False
        elif o == '-f':
            filename = CliUtils.expand_abs_path(val)
        elif o == '-P':
            vnodes_per_host = int(val)
        elif o == '-p':
            vnodeprefix = val
        elif o == '-M':
            if '@' in val:
                (hostname, conf_file) = val.split('@')
            else:
                hostname = val
        elif o == '-N':
            num_vnodes = int(val)
        elif o == '-o':
            vdefname = val
        elif o == '-s':
            sharedhost = True
        elif o == '-r':
            if val.startswith('y'):
                restart = True
        elif o == '-n':
            natvnode = val
        elif o == '-u':
            usenatvnode = True
        elif o == '--check-ug':
            check_ug = True
        elif o == '--make-ug':
            make_ug = True
        elif o == '--del-ug':
            del_ug = True
        elif o == '--del-hooks':
            delhooks = eval(val)
        elif o == '--del-queues':
            delqueues = eval(val)
        elif o == '--snap':
            as_snap = CliUtils.expand_abs_path(val)
        elif o == '--acct-logs':
            confirm = raw_input("--acct-logs will create users & groups "
                                "from accounting log trace\n"
                                "Ok to do so? (Y/N)")
            if confirm in ("Y", "y"):
                acct_logs = CliUtils.expand_abs_path(val)
            else:
                acct_logs = None
        elif o == '--import-jobs':
            import_jobs = True
        elif o == '--log-conf':
            logconf = val
        elif o == '--multi-mom':
            multimom = True
        elif o == '--create':
            num_moms = int(val)
        elif o == '--home-prefix':
            home_prefix = val
        elif o == '--conf-prefix':
            conf_prefix = val
        elif o == '--scheduler':
            process_obj.append(MGR_OBJ_SCHED)
        elif o == '--server':
            process_obj.append(MGR_OBJ_SERVER)
        elif o == '--mom':
            process_obj.append(MGR_OBJ_NODE)
        elif o == '--restart':
            restart_moms = eval(val, {}, {})
        elif o == '--stop':
            stop_moms = eval(val, {}, {})
        elif o == '--vnodify':
            vnodify = True
        elif o == '--revert-config':
            revert = True
        elif o == '--load-config':
            loadconf = CliUtils.expand_abs_path(val)
        elif o == '--save-config':
            saveconf = CliUtils.expand_abs_path(val)
        elif o == '--serverhost':
            serverhost = val
        elif o == '--init-port':
            init_port = int(val)
        elif o == '--step-port':
            step_port = int(val)
        elif o == '--switch-version':
            switchversion = val
        elif o == '--version':
            print(ptl.__version__)
            sys.exit(0)
        else:
            sys.stderr.write("Unrecognized option " + o + "\n")
            usage()
            sys.exit(1)

    PtlConfig()

    if logconf:
        logging.config.fileConfig(logconf)
    else:
        logging.basicConfig(level=lvl)

    if hosts is None:
        hosts = [socket.gethostname()]
    else:
        hosts = hosts.split(',')

    if check_ug:
        rv = check_users_groups()
        if rv:
            sys.exit(0)
        sys.exit(1)

    if del_ug:
        rv = delete_users_groups()
        if rv:
            sys.exit(0)
        sys.exit(1)

    if make_ug:
        rv = make_users_groups()
        if rv:
            sys.exit(0)
        sys.exit(1)

    if revert:
        process_config(hosts, process_obj, type='default', delqueues=delqueues,
                       delhooks=delhooks)
    elif loadconf:
        # when loading configuration apply the saved configuration based
        # on what was saved irregardless of what object types were passed in
        allobjs = [MGR_OBJ_SCHED, MGR_OBJ_SERVER, MGR_OBJ_NODE]
        process_config(hosts, allobjs, loadconf, type='load',
                       delqueues=delqueues, delhooks=delhooks)
    elif saveconf:
        if os.path.isfile(saveconf):
            answer = input('file ' + saveconf + ' exists, overwrite? '
                           '[y]/n: ')
            if answer == 'n':
                sys.exit(1)
        if not process_obj:
            process_obj = [MGR_OBJ_SERVER, MGR_OBJ_SCHED, MGR_OBJ_NODE]
        process_config(hosts, process_obj, saveconf, type='save',
                       delqueues=delqueues, delhooks=delhooks)
    elif vnodify:
        if filename:
            vdef = BatchUtils().file_to_vnodedef(filename)
            if vdef:
                svr_obj = Server(hostname, pbsconf_file=conf_file)
                MoM(svr_obj, hostname,
                    pbsconf_file=conf_file).insert_vnode_def(vdef)
        elif num_vnodes is None:
            logging.error('A number of vnodes to create is required\n')
            sys.exit(1)
        else:
            nattrs = process_attributes(attrs)
            for hostname in hosts:
                svr_obj = Server(hostname, pbsconf_file=conf_file)
                m = MoM(svr_obj, hostname, pbsconf_file=conf_file)
                m.create_vnodes(nattrs, num_vnodes,
                                additive, sharedhost, restart, delall,
                                natvnode, usenatvnode, fname=vdefname,
                                vnodes_per_host=vnodes_per_host)
    elif switchversion:
        pi = PBSInitServices()
        for host in hosts:
            pi.switch_version(host, switchversion)

    elif multimom:
        if num_moms is not None:
            if os.getuid() != 0:
                logging.error('Must be run as root')
                sys.exit(1)
            du = DshUtils()
            conf = du.parse_pbs_config(serverhost)
            serverhost = DshUtils().get_pbs_server_name(conf)
            s = Server(serverhost)
            nattrs = process_attributes(attrs)
            s.create_moms(num=num_moms, attrib=nattrs, conf_prefix=conf_prefix,
                          home_prefix=home_prefix, momhosts=hosts,
                          init_port=init_port, step_port=step_port)
        if (restart_moms or stop_moms):
            mom_op = []
            if restart_moms:
                mom_op = restart_moms
            if stop_moms:
                mom_op += stop_moms
            for i in mom_op:
                c = os.path.join('/etc', conf_prefix + str(i))
                pi = PBSInitServices(serverhost, conf=c)
                if restart_moms:
                    ret = pi.restart()
                if stop_moms:
                    ret = pi.stop()
                if ret['rc'] != 0:
                    logging.error(ret['err'])
                del pi
    elif as_snap is not None:
        if os.getuid() != 0:
            logging.error('Must be run as root')
            sys.exit(1)
        Server(snap=as_snap).clusterize(conf_file, hosts, acct_logs=acct_logs,
                                        import_jobs=import_jobs)
