#!/usr/bin/python3 -u
#
# The Spacewalk Acquire Method
#
# Author:  Simon Lukasik <xlukas08 [at] stud.fit.vutbr.cz>
# Date:    2011-01-01
# License: GPLv2
#
# Copyright (c) 1999--2012 Red Hat, Inc.
#
# This software is licensed to you under the GNU General Public License,
# version 2 (GPLv2). There is NO WARRANTY for this software, express or
# implied, including the implied warranties of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. You should have received a copy of GPLv2
# along with this software; if not, see
# http://www.gnu.org/licenses/old-licenses/gpl-2.0.txt.


from __future__ import print_function

import sys
import logging
from logging.handlers import RotatingFileHandler
import hashlib

import warnings
import requests
from urllib3.exceptions import SubjectAltNameWarning

warnings.filterwarnings("ignore", message="the md5 module is deprecated; use hashlib instead")
warnings.filterwarnings("ignore", category=SubjectAltNameWarning)
sys.path.append("/usr/share/rhn/")

from urllib.parse import urlparse
from up2date_client import config
from up2date_client import rhnChannel
from up2date_client import up2dateAuth
from up2date_client import up2dateErrors


LOG_FILE_PATH = '/var/log/apt-spacewalk.log'


class pkg_acquire_method:
    """
    This is slightly modified python variant of apt-pkg/acquire-method.
    It is a skeleton class that implements only very basic of apt methods
    functionality.
    """
    __eof = False

    def __init__(self):
        print("100 Capabilities\nVersion: 1.0\nSingle-Instance: true\n\n", end='')

    def __get_next_msg(self):
        """
        Apt uses for communication with its methods the text protocol similar
        to http. This function parses the protocol messages from stdin.
        """
        if self.__eof:
            return None
        result = {};
        line = sys.stdin.readline()
        while line == '\n':
            line = sys.stdin.readline()
        if not line:
            self.__eof = True
            return None
        s = line.split(" ", 1)
        result['_number'] = int(s[0])
        result['_text'] = s[1].strip()

        while not self.__eof:
            line = sys.stdin.readline()
            if not line:
                self.__eof = True
                return result
            if line == '\n':
                return result
            s = line.split(":", 1)
            result[s[0]] = s[1].strip()

    def __dict2msg(self, msg):
        """Convert dictionary to http like message"""
        result = ""
        for item in list(msg.keys()):
            if msg[item] != None:
                result += item + ": " + msg[item] + "\n"
        return result

    def status(self, **kwargs):
        print("102 Status\n%s\n" % self.__dict2msg(kwargs), end='')

    def uri_start(self, msg):
        print("200 URI Start\n%s\n" % self.__dict2msg(msg), end='')

    def uri_done(self, msg):
        print("201 URI Done\n%s\n" % self.__dict2msg(msg), end='')

    def uri_failure(self, msg):
        print("400 URI Failure\n%s\n" % self.__dict2msg(msg), end='')

    def run(self):
        """Loop through requests on stdin"""
        while True:
            msg = self.__get_next_msg()
            if msg == None:
                return 0
            if msg['_number'] == 600:
                try:
                    self.fetch(msg)
                except Exception as e:
                    logging.exception('Failed to fetch due to unknown reason, msg=%s', msg)
                    self.fail(e.__class__.__name__ + ": " + str(e))
                except up2dateErrors.Error as e:
                    logging.exception('Failed to fetch due to up2dateErrors, msg=%s', msg)
                    self.fail(e.__class__.__name__ + ": " + str(e))
            else:
                return 100



def get_ssl_ca_cert(up2date_cfg):
    if not ('sslCACert' in up2date_cfg and up2date_cfg['sslCACert']):
       raise BadSslCaCertConfig

    ca_certs = up2date_cfg['sslCACert']
    if type(ca_certs) == list:
        logging.warning('Multiple sslCACert set, using first one: %s', ca_certs[0])
        return ca_certs[0]
    return ca_certs



class spacewalk_method(pkg_acquire_method):
    """
    Spacewalk acquire method
    """
    up2date_cfg = None
    login_info = None
    current_url = None
    svr_channels = None
    http_headers = None
    base_channel = None
    conn = None
    not_registered_msg = 'This system is not registered with the spacewalk server'

    def fail(self, message = not_registered_msg):
        self.uri_failure({'URI': self.uri,
                          'Message': message})

    def __init_logging(self):
        logging.basicConfig(
            format='[%(asctime)s; %(name)s; %(levelname)s]: %(message)s',
            handlers=[
                RotatingFileHandler(
                    filename=LOG_FILE_PATH,
                    maxBytes=1 * 1024 * 1024,
                    backupCount=1
                )
            ],
            level=logging.DEBUG if self.up2date_cfg['debug'] else logging.WARNING
        )

        logging.debug('Debug mode enabled')

    def __load_config(self):
        if self.up2date_cfg == None:
            self.up2date_cfg = config.initUp2dateConfig()
            self.up2date_server = config.getServerlURL()[0]
        # TODO: proxy settings


    def __login(self):
        if self.login_info == None:
            self.status(URI = self.uri, Message = 'Logging into the spacewalk server')
            self.login_info = up2dateAuth.getLoginInfo()
            if not self.login_info:
                raise up2date_client.AuthenticationError(self.not_registered_msg)
            self.status(URI = self.uri, Message = 'Logged in')


    def __init_channels(self):
        if self.svr_channels == None:
            self.svr_channels = rhnChannel.getChannelDetails()
            for channel in self.svr_channels:
                if channel['parent_channel'] == '':
                    self.base_channel = channel['label']


    def __init_headers(self):
        if self.http_headers == None:
            rhn_needed_headers = ['X-RHN-Server-Id',
                                  'X-RHN-Auth-User-Id',
                                  'X-RHN-Auth',
                                  'X-RHN-Auth-Server-Time',
                                  'X-RHN-Auth-Expire-Offset']
            self.http_headers = {};
            for header in rhn_needed_headers:
                if header not in self.login_info:
                    raise up2date_client.AuthenticationError(
                        "Missing required login information %s" % (header))
                self.http_headers[header] = str(self.login_info[header])
            self.http_headers['X-RHN-Transport-Capability'] = 'follow-redirects=3'

    def __transform_document(self, document):
        """Transform url given by apt to real spacewalk url"""
        document = '/GET-REQ' + document
        return document

    def fetch(self, msg):
        """
        Fetch the content from spacewalk server to the file.

        Acording to the apt protocol msg must contain: 'URI' and 'Filename'.
        Other possible keys are: 'Last-Modified', 'Index-File', 'Fail-Ignore'
        """
        self.uri = msg['URI']
        self.uri_parsed = urlparse(msg['URI'])
        self.filename = msg['Filename']

        self.__load_config()
        self.__init_logging()
        if self.uri_parsed.netloc != urlparse(self.up2date_server).netloc:
            return self.fail()
        self.__login()
        self.__init_channels()

        document = self.__transform_document(self.uri_parsed.path)

        self.__init_headers()

        ca_path = get_ssl_ca_cert(self.up2date_cfg)
        logging.info('Using %s as trusted CA path', ca_path)
        with requests.request("GET", self.up2date_server + document,
                              headers=self.http_headers,
                              verify=ca_path,
                              stream=True) as res:
            self.status(URI = self.uri, Message = 'Waiting for headers')

            if res.status_code != 200:
                self.uri_failure({'URI': self.uri,
                                  'Message': str(res.status_code) + '  ' + res.reason,
                                  'FailReason': 'HttpError' + str(res.status_code)})
            else:
                self.uri_start({'URI': self.uri,
                                'Size': res.headers['content-length'],
                                'Last-Modified': res.headers['last-modified']})

                f = open(self.filename, "wb")
                hash_sha256 = hashlib.sha256()
                hash_md5 = hashlib.md5()
                for data in res.iter_content(4096):
                    hash_sha256.update(data)
                    hash_md5.update(data)
                    f.write(data)
                res.close()
                f.close()

        self.uri_done({'URI': self.uri,
                       'Filename': self.filename,
                       'Size': res.headers['content-length'],
                       'Last-Modified': res.headers['last-modified'],
                       'MD5-Hash': hash_md5.hexdigest(),
                       'MD5Sum-Hash': hash_md5.hexdigest(),
                       'SHA256-Hash': hash_sha256.hexdigest()})


if __name__ == '__main__':
    try:
        method = spacewalk_method()
        ret = method.run()
        sys.exit(ret)
    except KeyboardInterrupt:
        pass