#!/opt/alt/python37/bin/python3

"""
Usage: logs-at <time-query> [<log-files>]
Print log lines sorted by time corresponding to <time-query> from
<log-files> If omitted, then from all text files in the current
working directory recursively.
## Use case
See close-in-time events from different log files in chronological
order, to see more easily what events are concurrent, to establish the
timeline.  For example, the command may be run from the rpm tests'
artifacts directory.
## Time query format
<time-query> is a profile of ISO 8601:
    YY[Y[Y[-MM[-DD[ HH[:MM[:SS[.ffffff]]]]]]]]
that defines the time range e.g.:
    $ logs-at '2020-07-14 03:14'
will print logs from the current working directory (including subdirs)
starting from `2020-07-14T03:14:00.000000Z` until
`2020-07-14T03:15:00.000000Z`.
A timestamp `2020-07-14 03:14:15` would correspond to the
`2020-07-14T03:14:15.000000Z--2020-07-14T03:14:16Z` time
range, etc.
The time range may be specified explicitly too:
    <start-time>--<end-time>
where start/end time are expressed as RFC 3339.
The time is specified in UTC.
"""

import datetime as DT
import logging
import re
import sys
from contextlib import ExitStack, suppress
from collections import namedtuple
from heapq import merge
from itertools import islice
from operator import attrgetter
from pathlib import Path


__all__ = ["main"]
JUNK_LINE_COUNT = 1000
"""How many lines at most might not contain a valid time."""
logger = logging.getLogger("logs-at")
Event = namedtuple("Event", "time path line")


def main():
    # parse command-line args
    try:
        del sys.argv[sys.argv.index("-q")]
    except ValueError:
        quiet = False
    else:
        quiet = True
    if len(sys.argv) < 2:
        sys.exit(__doc__)
    elif len(sys.argv) == 2:
        [time_query] = sys.argv[1:]
        # all text files in the current directory & its subdirs recursively
        log_paths = [
            path
            for path in Path().rglob("*")
            if contains_text(path)  # skip binary files
        ]
    else:
        time_query = sys.argv[1]
        log_paths = list(map(Path, sys.argv[2:]))
    logging.basicConfig(level=logging.INFO if not quiet else logging.ERROR)
    # find time range
    start_time_str, sep, end_time_str = time_query.partition("--")
    tokens = re.findall(r"\d+", start_time_str)
    if not (1 <= len(tokens) <= 7):
        sys.exit(f"Wrong time query: {tokens}\n{__doc__}")
    if sep:  # parse given explicit start/end times
        start_time, end_time = map(parse_time, [start_time_str, end_time_str])
    else:
        # find start/end time given time query (see ## Time query format)
        start_time = DT.datetime(*map(int, tokens))
        last_field = [
            None,
            "year",
            "month",
            "day",
            "hour",
            "minute",
            "second",
            "microsecond",
        ][len(tokens)]
        end_time = start_time.replace(
            **{last_field: getattr(start_time, last_field) + 1}
        )
    logger.info("time range: [%s, %s), log paths: %s", start_time, end_time, log_paths)
    # print sorted log lines from given time range
    # note: assume logs are sorted within individual file
    # https://stackoverflow.com/a/16954837/4279
    with ExitStack() as stack:
        # drop lines outside the time range, print the rest in chronological order
        sorted_files = [
            (  # Drop events that are not in the [start, end) time interval
                event
                for event in parse_events(
                    stack.enter_context(path.open("rb")), path
                )
                if start_time <= event.time < end_time
            )
            for path in log_paths
        ]
        write = sys.stdout.buffer.write
        for _, path, line in merge(*sorted_files, key=attrgetter("time")):
            write(bytes(path))
            write(b": ")
            write(line)


def contains_text(path: Path, size=1024) -> bool:
    """Whether *path* contains text.
    The behavior is similar to file(1) utility.
    Only the first *size* bytes are inspected.
    """
    # https://stackoverflow.com/a/7392391/4279
    textchars = bytearray(
        {7, 8, 9, 10, 12, 13, 27, 29} | set(range(0x20, 0x100)) - {0x7F}
    )
    with suppress(OSError):
        with path.open("rb") as file:
            return not bool(file.read(size).translate(None, textchars))
    return False


def parse_time(time_string: str):
    """Parse RFC 3339 time string"""
    return DT.datetime.fromisoformat(time_string.replace("Z", "+00:00"))


def extract_logtime_audit_log(line: bytes, prev_time):
    """audit(time_stamp:ID)"""
    m = re.search(br"audit\(([^:]+)", line)
    if not m:
        raise ValueError("Excepted line with audit(time_stamp:ID). Got %s" % (line,))
    return DT.datetime.utcfromtimestamp(float(m.group(1)))


def extract_logtime_console_log(line: bytes, prev_time):
    """Extract time from *line* in a console.log's format.
    e.g.: INFO    [2020-07-08 22:40:36,798]
    """
    return _tokens2datetime(line, ntokens=7)


def _tokens2datetime(line: bytes, ntokens):
    tokens = re.findall(br"\d+", line)[:ntokens]
    if len(tokens) != ntokens:
        raise ValueError("Expected %d tokens. Got %d in %s" % (ntokens, len(tokens), line))
    if len(tokens) >= 7:
        tokens[6] = tokens[6].ljust(6, b"0")  # zero-pad to microsecond
    try:
        return DT.datetime(*map(int, tokens))
    except OverflowError as err:
        raise ValueError from err


def extract_logtime_syslog(line: bytes, prev_time):
    """e.g.: Jun 22 02:11:43"""
    # length to cut: 3 (Short month name) + 4 + 4 = 11
    try:
        date_from_log = line[:11].decode().rsplit(' ', 1)[0]
    except:
        raise ValueError("Invalid date format in log")
    # %b - Month as locale’s abbreviated name. (e.g. Sep)
    # %B - Month as locale’s full name. (e.g. September)
    # %d - Day of the month as a zero-padded decimal number. (e.g. 30)
    # %m - Month as a zero-padded decimal number. (e.g. 09)
    # %Y - Year with century as a decimal number. (e.g. 2013)
    dt_formats_list = ["%b %d",        # Apr 6 (06)
                       "%d %b",        # 06 Apr
                       "%b-%d",        # Apr-06
                       "%d-%b",        # 06-Apr
                       "%b-%d-%Y",     # Apr-06-2021
                       "%d-%b-%Y",     # 06-Apr-2021
                       "%b-%d-%Y",     # 2021-Apr-06
                       "%Y-%d-%b",     # 2021-06-Apr
                       "%Y%d%b",       # 20210604
                       "%Y%b%d",       # 20210406
                       "%d%b%Y",       # 04062021
                       ]
    for dt_format in dt_formats_list:
        try:
            return DT.datetime.strptime(date_from_log, dt_format).replace(year=prev_time.year)
        except ValueError:
            # Invalid datetime fomat
            pass
    raise ValueError("Invalid date format in log")


def extract_logtime_modsec_audit_log(line: bytes, prev_time):
    """e.g.: 05/Aug/2020:20:19:42 +0000"""
    return DT.datetime.strptime(line[:28].decode(), "[%d/%b/%Y:%H:%M:%S +0000]")


def extract_logtime_realtime_av_log(line: bytes, prev_time):
    """e.g.: 2020/08/05 14:36:57"""
    return _tokens2datetime(line[:19], ntokens=6)


def extract_logtime_apache_error_log(line: bytes, prev_time):
    """e.g.: Mon Jul 27 01:58:55.334883 2020"""
    return DT.datetime.strptime(line[:33].decode(), "[%a %b %d %H:%M:%S.%f %Y]")


def extract_logtime_http_access_log(line: bytes, prev_time):
    """e.g.: 10.51.48.21 - - [11/Aug/2020:00:34:13 +0000]"""
    m = re.match(br"\d+\.\d+\.\d+\.\d+[^[]+\[([^]]+)", line)
    if not m:
        raise ValueError("Can't find time in {!r}".format(line))
    # 11/Aug/2020:00:34:13 +0000
    return DT.datetime.strptime(m.group(1).decode(), "%d/%b/%Y:%H:%M:%S +0000")


def parse_events(loglines, path):
    """Attach parsed time to each log line.
    If a time can't be extracted from the log line, the time from the
    previous one is used instead.
    """
    def get_extract_time(_line, prev_time):
        """
        Find function that can parse time in the current file.
        :return: tuple(datetime, function)
        """
        _dt = None
        for extract_logtime in [
            extract_logtime_audit_log,
            extract_logtime_console_log,
            extract_logtime_syslog,
            extract_logtime_modsec_audit_log,
            extract_logtime_realtime_av_log,
            extract_logtime_apache_error_log,
            extract_logtime_http_access_log,
        ]:
            try:
                _dt = extract_logtime(_line, prev_time)
            except ValueError:
                continue  # try the next time format
            else:
                break
        return _dt, extract_logtime
    it = iter(loglines)
    before_dt = DT.datetime.utcnow()
    line_no = 0
    # Determine function which can to decode dates in logs
    for line_no, line in enumerate(islice(it, JUNK_LINE_COUNT), start=1):
        dt, extract_time_func = get_extract_time(line, before_dt)
        if dt:
            yield Event(dt, path, line)
            break
    else:
        if line_no > 1:
            logger.warning("%s: supported time format not found", path)
        return
    # Use found extract_time()
    # how many lines to skip before printing the next warning
    skip_count = 1000
    warn_line_no = line_no
    # Scan remaining lines using selected function, using line_no from previous cycle
    for line_no, line in enumerate(it, start=line_no + 1):
        try:
            dt = extract_time_func(line, dt)
        except ValueError as err:
            if warn_line_no < line_no:
                warn_line_no = line_no + skip_count
                logger.warning("%s:%s can't get time, reason: %s", path, line_no, err)
            # reuse previous dt value
        yield Event(dt, path, line)


if __name__ == "__main__":
    main()
