#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Module IntegrityRoutine Contains IntegrityRoutine class helps with FIPS 140-2 build time integrity routine.
This module is needed to calculate HMAC and embed other needed stuff.
"""

import hmac
import hashlib
import bisect
import itertools
import binascii
from ELF import *

__author__ = "Vadym Stupakov"
__copyright__ = "Copyright (c) 2017 Samsung Electronics"
__credits__ = ["Vadym Stupakov"]
__version__ = "1.0"
__maintainer__ = "Vadym Stupakov"
__email__ = "v.stupakov@samsung.com"
__status__ = "Production"


class IntegrityRoutine(ELF):
    """
    Utils for fips-integrity process
    """
    def __init__(self, elf_file, readelf_path="readelf"):
        ELF.__init__(self, elf_file, readelf_path)

    @staticmethod
    def __remove_all_dublicates(lst):
        """
        Removes all occurrences of tha same value. For instance: transforms [1, 2, 3, 1] -> [2, 3]
        :param lst: input list
        :return: lst w/o duplicates
        """
        to_remove = list()
        for i in range(len(lst)):
            it = itertools.islice(lst, i + 1, len(lst) - 1, None)
            for j, val in enumerate(it, start=i+1):
                if val == lst[i]:
                    to_remove.extend([lst[i], lst[j]])

        for el in to_remove:
            lst.remove(el)

    def get_reloc_gaps(self, start_addr, end_addr):
        """
        :param start_addr: start address :int
        :param end_addr: end address: int
        :returns list of relocation gaps like [[gap_start, gap_end], [gap_start, gap_end], ...]
        """
        all_relocs = self.get_relocs(start_addr, end_addr)
        relocs_gaps = list()
        for addr in all_relocs:
            relocs_gaps.append(addr)
            relocs_gaps.append(addr + 8)
        self.__remove_all_dublicates(relocs_gaps)
        relocs_gaps.sort()
        relocs_gaps = [[addr1, addr2] for addr1, addr2 in self.utils.pairwise(relocs_gaps)]
        return relocs_gaps

    def get_addrs_for_hmac(self, sec_sym_sequence, relocs_gaps=None):
        """
        Generate addresses for calculating HMAC
        :param sec_sym_sequence: [addr_start1, addr_end1, ..., addr_startN, addr_endN],
        :param relocs_gaps: [[start_gap_addr, end_gap_addr], [start_gap_addr, end_gap_addr]]
        :return: addresses for calculating HMAC: [[addr_start, addr_end], [addr_start, addr_end], ...]
        """
        addrs_for_hmac = list()
        for section_name, sym_names in sec_sym_sequence.items():
            if relocs_gaps is not None and section_name == ".rodata":
                for symbol in self.get_symbol_by_name(sym_names):
                    addrs_for_hmac.append(symbol.addr)
            else:
                for symbol in self.get_symbol_by_name(sym_names):
                    addrs_for_hmac.append(symbol.addr)
        addrs_for_hmac.extend(self.utils.flatten(relocs_gaps))
        addrs_for_hmac.sort()
        return [[item1, item2] for item1, item2 in self.utils.pairwise(addrs_for_hmac)]

    def embed_bytes(self, vaddr, in_bytes):
        """
        Write bytes to ELF file
        :param vaddr: virtual address in ELF
        :param in_bytes: byte array to write
        """
        offset = self.vaddr_to_file_offset(vaddr)
        with open(self.get_elf_file(), "rb+") as elf_file:
            elf_file.seek(offset)
            elf_file.write(in_bytes)

    def __update_hmac(self, hmac_obj, file_obj, file_offset_start, file_offset_end):
        """
        Update hmac from addrstart tp addr_end
        FIXMI: it needs to implement this function via fixed block size
        :param file_offset_start: could be string or int
        :param file_offset_end:   could be string or int
        """
        file_offset_start = self.utils.to_int(file_offset_start)
        file_offset_end = self.utils.to_int(file_offset_end)
        file_obj.seek(self.vaddr_to_file_offset(file_offset_start))
        block_size = file_offset_end - file_offset_start
        msg = file_obj.read(block_size)
        hmac_obj.update(msg)

    def get_hmac(self, offset_sequence, key, output_type="byte"):
        """
        Calculate HMAC
        :param offset_sequence: start and end addresses sequence [addr_start, addr_end], [addr_start, addr_end], ...]
        :param key HMAC key: string value
        :param output_type string value. Could be "hex" or "byte"
        :return: bytearray or hex string
        """
        digest = hmac.new(bytearray(key.encode("utf-8")), digestmod=hashlib.sha256)
        with open(self.get_elf_file(), "rb") as file:
            for addr_start, addr_end in offset_sequence:
                self.__update_hmac(digest, file, addr_start, addr_end)
        if output_type == "byte":
            return digest.digest()
        if output_type == "hex":
            return digest.hexdigest()

    def __find_nearest_symbol_by_vaddr(self, vaddr, method):
        """
        Find nearest symbol near vaddr
        :param vaddr:
        :return: idx of symbol from self.get_symbols()
        """
        symbol = self.get_symbol_by_vaddr(vaddr)
        if symbol is None:
            raise ValueError("Can't find symbol by vaddr")
        idx = method(list(self.get_symbols()), vaddr)
        return idx

    def find_rnearest_symbol_by_vaddr(self, vaddr):
        """
        Find right nearest symbol near vaddr
        :param vaddr:
        :return: idx of symbol from self.get_symbols()
        """
        return self.__find_nearest_symbol_by_vaddr(vaddr, bisect.bisect_right)

    def find_lnearest_symbol_by_vaddr(self, vaddr):
        """
        Find left nearest symbol near vaddr
        :param vaddr:
        :return: idx of symbol from self.get_symbols()
        """
        return self.__find_nearest_symbol_by_vaddr(vaddr, bisect.bisect_left)

    def find_symbols_between_vaddrs(self, vaddr_start, vaddr_end):
        """
        Returns list of symbols between two virtual addresses
        :param vaddr_start:
        :param vaddr_end:
        :return: [(Symbol(), Section)]
        """
        symbol_start = self.get_symbol_by_vaddr(vaddr_start)
        symbol_end = self.get_symbol_by_vaddr(vaddr_end)
        if symbol_start is None or symbol_end is None:
            raise ValueError("Error: Cannot find symbol by vaddr. vaddr should coincide with symbol address!")

        idx_start = self.find_lnearest_symbol_by_vaddr(vaddr_start)
        idx_end = self.find_lnearest_symbol_by_vaddr(vaddr_end)

        sym_sec = list()
        for idx in range(idx_start, idx_end):
            symbol_addr = list(self.get_symbols())[idx]
            symbol = self.get_symbol_by_vaddr(symbol_addr)
            section = self.get_section_by_vaddr(symbol_addr)
            sym_sec.append((symbol, section))

        sym_sec.sort(key=lambda x: x[0])
        return sym_sec

    @staticmethod
    def __get_skipped_bytes(symbol, relocs):
        """
        :param symbol: Symbol()
        :param relocs: [[start1, end1], [start2, end2]]
        :return: Returns skipped bytes and [[start, end]] addresses that show which bytes were skipped
        """
        symbol_start_addr = symbol.addr
        symbol_end_addr = symbol.addr + symbol.size
        skipped_bytes = 0
        reloc_addrs = list()
        for reloc_start, reloc_end in relocs:
            if reloc_start >= symbol_start_addr and reloc_end <= symbol_end_addr:
                skipped_bytes += reloc_end - reloc_start
                reloc_addrs.append([reloc_start, reloc_end])
            if reloc_start > symbol_end_addr:
                break

        return skipped_bytes, reloc_addrs

    def print_covered_info(self, sec_sym, relocs, print_reloc_addrs=False, sort_by="address", reverse=False):
        """
        Prints information about covered symbols in detailed table:
        |N| symbol name | symbol address     | symbol section | bytes skipped | skipped bytes address range      |
        |1| symbol      | 0xXXXXXXXXXXXXXXXX | .rodata        | 8             | [[addr1, addr2], [addr1, addr2]] |
        :param sec_sym: {section_name : [sym_name1, sym_name2]}
        :param relocs: [[start1, end1], [start2, end2]]
        :param print_reloc_addrs: print or not skipped bytes address range
        :param sort_by: method for sorting table. Could be: "address", "name", "section"
        :param reverse: sort order
        """
        if sort_by.lower() == "address":
            def sort_method(x): return x[0].addr
        elif sort_by.lower() == "name":
            def sort_method(x): return x[0].name
        elif sort_by.lower() == "section":
            def sort_method(x): return x[1].name
        else:
            raise ValueError("Invalid sort type!")
        table_format = "|{:4}| {:50} | {:18} | {:20} | {:15} |"
        if print_reloc_addrs is True:
            table_format += "{:32} |"

        print(table_format.format("N", "symbol name", "symbol address", "symbol section", "bytes skipped",
                                  "skipped bytes address range"))
        data_to_print = list()
        for sec_name, sym_names in sec_sym.items():
            for symbol_start, symbol_end in self.utils.pairwise(self.get_symbol_by_name(sym_names)):
                symbol_sec_in_range = self.find_symbols_between_vaddrs(symbol_start.addr, symbol_end.addr)
                for symbol, section in symbol_sec_in_range:
                    skipped_bytes, reloc_addrs = self.__get_skipped_bytes(symbol, relocs)
                    reloc_addrs_str = "["
                    for start_addr, end_addr in reloc_addrs:
                        reloc_addrs_str += "[{}, {}], ".format(hex(start_addr), hex(end_addr))
                    reloc_addrs_str += "]"
                    if symbol.size > 0:
                        data_to_print.append((symbol, section, skipped_bytes, reloc_addrs_str))

        skipped_bytes_size = 0
        symbol_covered_size = 0
        cnt = 0
        data_to_print.sort(key=sort_method, reverse=reverse)
        for symbol, section, skipped_bytes, reloc_addrs_str in data_to_print:
            cnt += 1
            symbol_covered_size += symbol.size
            skipped_bytes_size += skipped_bytes
            if print_reloc_addrs is True:
                print(table_format.format(cnt, symbol.name, hex(symbol.addr), section.name,
                                          self.utils.human_size(skipped_bytes), reloc_addrs_str))
            else:
                print(table_format.format(cnt, symbol.name, hex(symbol.addr), section.name,
                                          self.utils.human_size(skipped_bytes)))
        addrs_for_hmac = self.get_addrs_for_hmac(sec_sym, relocs)
        all_covered_size = 0
        for addr_start, addr_end in addrs_for_hmac:
            all_covered_size += addr_end - addr_start
        print("Symbol covered bytes len: {} ".format(self.utils.human_size(symbol_covered_size - skipped_bytes_size)))
        print("All covered bytes len   : {} ".format(self.utils.human_size(all_covered_size)))
        print("Skipped bytes len       : {} ".format(self.utils.human_size(skipped_bytes_size)))

    def dump_covered_bytes(self, vaddr_seq, out_file):
        """
        Dumps covered bytes
        :param vaddr_seq: [[start1, end1], [start2, end2]] start - end sequence of covered bytes
        :param out_file: file where will be stored dumped bytes
        """
        with open(self.get_elf_file(), "rb") as elf_fp:
            with open(out_file, "wb") as out_fp:
                for vaddr_start, vaddr_end, in vaddr_seq:
                    elf_fp.seek(self.vaddr_to_file_offset(vaddr_start))
                    out_fp.write(elf_fp.read(vaddr_end - vaddr_start))

    def make_integrity(self, sec_sym, module_name, debug=False, print_reloc_addrs=False, sort_by="address",
                       reverse=False):
        """
        Calculate HMAC and embed needed info
        :param sec_sym: {sec_name: [addr1, addr2, ..., addrN]}
        :param module_name: module name that you want to make integrity. See Makefile targets
        :param debug: If True prints debug information
        :param print_reloc_addrs: If True, print relocation addresses that are skipped
        :param sort_by: sort method
        :param reverse: sort order
        """
        rel_addr_start = self.get_symbol_by_name("first_" + module_name + "_rodata")
        rel_addr_end = self.get_symbol_by_name("last_" + module_name + "_rodata")

        reloc_gaps = self.get_reloc_gaps(rel_addr_start.addr, rel_addr_end.addr)
        addrs_for_hmac = self.get_addrs_for_hmac(sec_sym, reloc_gaps)

        digest = self.get_hmac(addrs_for_hmac, "The quick brown fox jumps over the lazy dog")

        self.embed_bytes(self.get_symbol_by_name("builtime_" + module_name + "_hmac").addr,
                         self.utils.to_bytearray(digest))

        self.embed_bytes(self.get_symbol_by_name("integrity_" + module_name + "_addrs").addr,
                         self.utils.to_bytearray(addrs_for_hmac))

        self.embed_bytes(self.get_symbol_by_name(module_name + "_buildtime_address").addr,
                        self.utils.to_bytearray(self.get_symbol_by_name(module_name + "_buildtime_address").addr))

        print("HMAC for \"{}\" module is: {}".format(module_name, binascii.hexlify(digest)))
        if debug:
            self.print_covered_info(sec_sym, reloc_gaps, print_reloc_addrs=print_reloc_addrs, sort_by=sort_by,
                                    reverse=reverse)
            self.dump_covered_bytes(addrs_for_hmac, "covered_dump_for_" + module_name + ".bin")

        print("FIPS integrity procedure has been finished for {}".format(module_name))
