#!/usr/bin/python3

""" Create a JSON allowlist/policy from given files as input """

import argparse
import binascii
import json
import sys

from cryptography import x509
from cryptography.hazmat import backends

from keylime.ima import file_signatures, ima

IMA_MEASUREMENT_LIST = "/sys/kernel/security/ima/ascii_runtime_measurements"
IGNORED_KEYRINGS = []


def is_x509_cert(bindata):
    """Determine whether the given bindata are a x509 cert"""
    try:
        x509.load_der_x509_certificate(bindata, backend=backends.default_backend())
        return True
    except ValueError:
        return False


def process_flat_allowlist(allowlist_file, hashes_map):
    """Process a flat allowlist file"""
    ret = 0
    try:
        with open(allowlist_file, "r", encoding="utf-8") as fobj:
            while True:
                line = fobj.readline()
                if not line:
                    break
                line = line.strip()
                if len(line) == 0:
                    continue
                pieces = line.split(None, 1)
                if not len(pieces) == 2:
                    print(f"Skipping line that was split into {len(pieces)} parts, " "expected 2: {line}")
                (checksum_hash, path) = pieces
                if path in hashes_map:
                    hashes_map[path].append(checksum_hash)
                else:
                    hashes_map[path] = [checksum_hash]
    except (PermissionError, FileNotFoundError) as ex:
        print(f"An error occurred while accessing the allowlist: {ex}")
        ret = 1
    return hashes_map, ret


def get_hashes_from_measurement_list(ima_measurement_list_file, hashes_map):
    """Get the hashes from the IMA measurement list file"""
    ret = 0
    try:
        with open(ima_measurement_list_file, "r", encoding="utf-8") as fobj:
            while True:
                line = fobj.readline()
                if not line:
                    break
                pieces = line.split(" ")
                if len(pieces) < 5:
                    print(f"Skipping line that was split into {len(pieces)} pieces, " "expected at least 5: {line}")
                    continue
                if pieces[2] not in ["ima-sig", "ima-ng"]:
                    continue
                # FIXME: filenames with spaces may be problematic
                checksum_hash = pieces[3].split(":")[1]
                path = pieces[4]
                if path in hashes_map:
                    hashes_map[path].append(checksum_hash)
                else:
                    hashes_map[path] = [checksum_hash]
    except (PermissionError, FileNotFoundError) as ex:
        print(f"An error occurred: {ex}")
        ret = 1
    return hashes_map, ret


def process_ima_buf_in_measurement_list(
    ima_measurement_list_file, ignored_keyrings, get_keyrings, keyrings_map, get_ima_buf, ima_buf_map
):
    """Process ima-buf entries and get the keyrings map from key-related entries
    and ima_buf map from the rest
    """
    ret = 0
    try:
        with open(ima_measurement_list_file, "r", encoding="utf-8") as fobj:
            while True:
                line = fobj.readline()
                if not line:
                    break
                # FIXME: filenames with spaces may be problematic
                pieces = line.split(" ")
                if len(pieces) != 6:
                    print(f"Skipping line that was split into {len(pieces)} pieces, " "expected 6: {line}")
                    continue
                if pieces[2] not in ["ima-buf"]:
                    continue
                checksum_hash = pieces[3].split(":")[1]
                path = pieces[4]

                bindata = None
                try:
                    bindata = binascii.unhexlify(pieces[5].strip())
                except binascii.Error:
                    pass

                # check whether buf's bindata contains a key; if so, we will only
                # append it to 'keyrings', never to 'ima-buf'
                if bindata and is_x509_cert(bindata):
                    if path in ignored_keyrings or not get_keyrings:
                        continue

                    if path in keyrings_map:
                        keyrings_map[path].append(checksum_hash)
                    else:
                        keyrings_map[path] = [checksum_hash]
                    continue

                if get_ima_buf:
                    if path in ima_buf_map:
                        ima_buf_map[path].append(checksum_hash)
                    else:
                        ima_buf_map[path] = [checksum_hash]
    except (PermissionError, FileNotFoundError) as ex:
        print(f"An error occurred: {ex}")
        ret = 1
    return keyrings_map, ima_buf_map, ret


def process_signature_verification_keys(verification_keys, policy):
    """Add the given keys (x509 certificates) to keyring"""

    verification_key_list = None
    if verification_keys:
        if policy.get("verification-keys"):
            keyring = file_signatures.ImaKeyring().from_string(policy["verification-keys"])
            if not keyring:
                print("Could not create IMAKeyring from JSON")
        else:
            keyring = file_signatures.ImaKeyring()
        if keyring:
            for key in verification_keys:
                try:
                    pubkey, keyidv2 = file_signatures.get_pubkey_from_file(key)
                    if not pubkey:
                        print(f"File '{key}' is not a file with a key")
                    else:
                        keyring.add_pubkey(pubkey, keyidv2)
                except ValueError as e:
                    print(f"File '{key}' does not have a supported key: {e}")
            verification_key_list = keyring.to_string()

    if verification_key_list:
        policy["verification-keys"] = verification_key_list

    return policy


def main():
    """main"""
    parser = argparse.ArgumentParser(
        description="This is an experimental tool for adding items to a Keylime's IMA runtime policy"
    )
    parser.add_argument(
        "-B",
        "--base-policy",
        action="store",
        dest="base_policy",
        help="Merge new data into the given JSON runtime policy",
    )
    parser.add_argument(
        "-k", "--keyrings", action="store_true", dest="get_keyrings", help="Create keyrings policy entries"
    )
    parser.add_argument(
        "-b",
        "--ima-buf",
        action="store_true",
        dest="get_ima_buf",
        help="Process ima-buf entries other than those related to keyrings",
    )
    parser.add_argument("-a", "--allowlist", action="store", dest="allowlist", help="Use given plain-text allowlist")
    parser.add_argument(
        "-m",
        "--ima-measurement-list",
        action="store",
        dest="ima_measurement_list",
        default=IMA_MEASUREMENT_LIST,
        help="Use given IMA measurement list for keyrings and critical "
        f"data extraction rather than {IMA_MEASUREMENT_LIST}",
    )
    parser.add_argument(
        "-i",
        "--ignored-keyrings",
        action="append",
        dest="ignored_keyrings",
        default=IGNORED_KEYRINGS,
        help="Ignored the given keyring; this option may be passed multiple times",
    )
    parser.add_argument(
        "-o",
        "--output",
        action="store",
        dest="output",
        help="File to write JSON policy into; default is to print to stdout",
    )
    parser.add_argument(
        "--no-hashes", action="store_true", dest="no_hashes", help="Do not add any hashes to the policy"
    )
    parser.add_argument(
        "-A",
        "--add-ima-signature-verification-key",
        action="append",
        dest="ima_signature_keys",
        default=[],
        help="Add the given IMA signature verification key to the Keylime-internal 'tenant_keyring'; "
        "the key should be an x509 certificate in DER or PEM format but may also be a public or private key file; "
        "this option may be passed multiple times",
    )
    args = parser.parse_args()

    policy = ima.EMPTY_RUNTIME_POLICY
    policy["ima"]["ignored_keyrings"] = args.ignored_keyrings

    ret = 0

    if args.base_policy:
        try:
            with open(args.base_policy, "r", encoding="utf-8") as fobj:
                basepol = fobj.read()
            base_policy = json.loads(basepol)

            # Cherry-pick from base policy what is supported and merge into policy
            policy["digests"] = base_policy.get("digests", {})
            policy["excludes"] = base_policy.get("excludes", [])
            policy["keyrings"] = base_policy.get("keyrings", {})
            policy["ima-buf"] = base_policy.get("ima-buf", {})
            ignored_keyrings = base_policy.get("ima", {}).get("ignored_keyrings", [])
            policy["ima"]["ignored_keyrings"] = ignored_keyrings
            policy["verification-keys"] = base_policy.get("verification-keys", "")
        except (PermissionError, FileNotFoundError) as ex:
            print(f"An error occurred while loading the policy: {ex}")
            ret = 1
        except json.decoder.JSONDecodeError as ex:
            print(f"An error occurred while converting the policy to a JSON object: {ex}")
            ret = 1
    if ret:
        sys.exit(ret)

    # Add the digests map either from the allowlist, if given, or the IMA measurement list.
    if args.allowlist:
        policy["digests"], ret = process_flat_allowlist(args.allowlist, policy["digests"])
    elif not args.no_hashes:
        policy["digests"], ret = get_hashes_from_measurement_list(args.ima_measurement_list, policy["digests"])
    if ret:
        sys.exit(ret)

    if args.get_keyrings or args.get_ima_buf:
        policy["keyrings"], policy["ima-buf"], ret = process_ima_buf_in_measurement_list(
            args.ima_measurement_list,
            policy["ima"]["ignored_keyrings"],
            args.get_keyrings,
            policy["keyrings"],
            args.get_ima_buf,
            policy["ima-buf"],
        )

    if ret:
        sys.exit(ret)

    policy = process_signature_verification_keys(args.ima_signature_keys, policy)

    jsonpolicy = json.dumps(policy)
    if args.output:
        try:
            with open(args.output, "w", encoding="utf-8") as fobj:
                fobj.write(jsonpolicy)
        except (PermissionError, FileNotFoundError) as ex:
            print(f"An error occurred while writing the policy: %{ex}")
            ret = 1
    else:
        print(jsonpolicy)

    sys.exit(ret)


if __name__ == "__main__":
    main()
