# -*- coding: utf-8 -*-

# Copyright (c) 2020 - 2024 Detlev Offenbach <detlev@die-offenbachs.de>
#

"""
Module implementing checks for potential XSS vulnerability.
"""

#
# This is a modified version of the one found in the bandit package.
#
# Original Copyright 2018 Victor Torre
#
# SPDX-License-Identifier: Apache-2.0
#

import ast

import AstUtilities


def getChecks():
    """
    Public method to get a dictionary with checks handled by this module.

    @return dictionary containing checker lists containing checker function and
        list of codes
    @rtype dict
    """
    return {
        "Call": [
            (checkDjangoXssVulnerability, ("S703",)),
        ],
    }


def checkDjangoXssVulnerability(reportError, context, config):  # noqa: U100
    """
    Function to check for potential XSS vulnerability.

    @param reportError function to be used to report errors
    @type func
    @param context security context object
    @type SecurityContext
    @param config dictionary with configuration data
    @type dict
    """
    if context.isModuleImportedLike("django.utils.safestring"):
        affectedFunctions = [
            "mark_safe",
            "SafeText",
            "SafeUnicode",
            "SafeString",
            "SafeBytes",
        ]
        if context.callFunctionName in affectedFunctions:
            xss = context.node.args[0]
            if not AstUtilities.isString(xss):
                checkPotentialRisk(reportError, context.node)


def checkPotentialRisk(reportError, node):
    """
    Function to check a given node for a potential XSS vulnerability.

    @param reportError function to be used to report errors
    @type func
    @param node node to be checked
    @type ast.Call
    """
    xssVar = node.args[0]

    secure = False

    if isinstance(xssVar, ast.Name):
        # Check if the var are secure
        parent = node._securityParent
        while not isinstance(parent, (ast.Module, ast.FunctionDef)):
            parent = parent._securityParent

        isParam = False
        if isinstance(parent, ast.FunctionDef):
            for name in parent.args.args:
                if name.arg == xssVar.id:
                    isParam = True
                    break

        if not isParam:
            secure = evaluateVar(xssVar, parent, node.lineno)
    elif isinstance(xssVar, ast.Call):
        parent = node._securityParent
        while not isinstance(parent, (ast.Module, ast.FunctionDef)):
            parent = parent._securityParent
        secure = evaluateCall(xssVar, parent)
    elif isinstance(xssVar, ast.BinOp):
        isMod = isinstance(xssVar.op, ast.Mod)
        isLeftStr = AstUtilities.isString(xssVar.left)
        if isMod and isLeftStr:
            parent = node._securityParent
            while not isinstance(parent, (ast.Module, ast.FunctionDef)):
                parent = parent._securityParent
            newCall = transform2call(xssVar)
            secure = evaluateCall(newCall, parent)

    if not secure:
        reportError(node.lineno - 1, node.col_offset, "S703", "M", "H")


class DeepAssignation:
    """
    Class to perform a deep analysis of an assign.
    """

    def __init__(self, varName, ignoreNodes=None):
        """
        Constructor

        @param varName name of the variable
        @type str
        @param ignoreNodes list of nodes to ignore
        @type list of ast.AST
        """
        self.__varName = varName
        self.__ignoreNodes = ignoreNodes

    def isAssignedIn(self, items):
        """
        Public method to check, if the variable is assigned to.

        @param items list of nodes to check against
        @type list of ast.AST
        @return list of nodes assigned
        @rtype list of ast.AST
        """
        assigned = []
        for astInst in items:
            newAssigned = self.isAssigned(astInst)
            if newAssigned:
                if isinstance(newAssigned, (list, tuple)):
                    assigned.extend(newAssigned)
                else:
                    assigned.append(newAssigned)

        return assigned

    def isAssigned(self, node):
        """
        Public method to check assignment against a given node.

        @param node node to check against
        @type ast.AST
        @return flag indicating an assignement
        @rtype bool
        """
        assigned = False
        if (
            self.__ignoreNodes
            and isinstance(self.__ignoreNodes, (list, tuple, object))
            and isinstance(node, self.__ignoreNodes)
        ):
            return assigned

        if isinstance(node, ast.Expr):
            assigned = self.isAssigned(node.value)
        elif isinstance(node, ast.FunctionDef):
            for name in node.args.args:
                if isinstance(name, ast.Name) and name.id == self.var_name.id:
                    # If is param the assignations are not affected
                    return assigned

            assigned = self.isAssignedIn(node.body)
        elif isinstance(node, ast.With):
            for withitem in node.items:
                varId = getattr(withitem.optional_vars, "id", None)
                assigned = (
                    node if varId == self.__varName.id else self.isAssignedIn(node.body)
                )
        elif isinstance(node, ast.Try):
            assigned = []
            assigned.extend(self.isAssignedIn(node.body))
            assigned.extend(self.isAssignedIn(node.handlers))
            assigned.extend(self.isAssignedIn(node.orelse))
            assigned.extend(self.isAssignedIn(node.finalbody))
        elif isinstance(node, ast.ExceptHandler):
            assigned = []
            assigned.extend(self.isAssignedIn(node.body))
        elif isinstance(node, (ast.If, ast.For, ast.While)):
            assigned = []
            assigned.extend(self.isAssignedIn(node.body))
            assigned.extend(self.isAssignedIn(node.orelse))
        elif (
            isinstance(node, ast.AugAssign)
            and isinstance(node.target, ast.Name)
            and node.target.id == self.__varName.id
        ):
            assigned = node.value
        elif isinstance(node, ast.Assign) and node.targets:
            target = node.targets[0]
            if isinstance(target, ast.Name):
                if target.id == self.__varName.id:
                    assigned = node.value
            elif isinstance(target, ast.Tuple) and isinstance(node.value, ast.Tuple):
                for pos, name in enumerate(target.elts):
                    if name.id == self.__varName.id:
                        assigned = node.value.elts[pos]
                        break

        return assigned


def evaluateVar(xssVar, parent, until, ignoreNodes=None):
    """
    Function to evaluate a variable node for potential XSS vulnerability.

    @param xssVar variable node to be checked
    @type ast.Name
    @param parent parent node
    @type ast.AST
    @param until end line number to evaluate variable against
    @type int
    @param ignoreNodes list of nodes to ignore
    @type list of ast.AST
    @return flag indicating a secure evaluation
    @rtype bool
    """
    secure = False
    if isinstance(xssVar, ast.Name):
        if isinstance(parent, ast.FunctionDef) and any(
            name.arg == xssVar.id for name in parent.args.args
        ):
            return False  # Params are not secure

        analyser = DeepAssignation(xssVar, ignoreNodes)
        for node in parent.body:
            if node.lineno >= until:
                break
            to = analyser.isAssigned(node)
            if to:
                if AstUtilities.isString(to):
                    secure = True
                elif isinstance(to, ast.Name):
                    secure = evaluateVar(to, parent, to.lineno, ignoreNodes)
                elif isinstance(to, ast.Call):
                    secure = evaluateCall(to, parent, ignoreNodes)
                elif isinstance(to, (list, tuple)):
                    numSecure = 0
                    for someTo in to:
                        if AstUtilities.isString(someTo):
                            numSecure += 1
                        elif isinstance(someTo, ast.Name):
                            if evaluateVar(someTo, parent, node.lineno, ignoreNodes):
                                numSecure += 1
                            else:
                                break
                        else:
                            break
                    if numSecure == len(to):
                        secure = True
                    else:
                        secure = False
                        break
                else:
                    secure = False
                    break

    return secure


def evaluateCall(call, parent, ignoreNodes=None):
    """
    Function to evaluate a call node for potential XSS vulnerability.

    @param call call node to be checked
    @type ast.Call
    @param parent parent node
    @type ast.AST
    @param ignoreNodes list of nodes to ignore
    @type list of ast.AST
    @return flag indicating a secure evaluation
    @rtype bool
    """
    secure = False
    evaluate = False

    if (
        isinstance(call, ast.Call)
        and isinstance(call.func, ast.Attribute)
        and AstUtilities.isString(call.func.value)
        and call.func.attr == "format"
    ):
        evaluate = True
        if call.keywords:
            evaluate = False

    if evaluate:
        args = list(call.args)

        numSecure = 0
        for arg in args:
            if AstUtilities.isString(arg):
                numSecure += 1
            elif isinstance(arg, ast.Name):
                if evaluateVar(arg, parent, call.lineno, ignoreNodes):
                    numSecure += 1
                else:
                    break
            elif isinstance(arg, ast.Call):
                if evaluateCall(arg, parent, ignoreNodes):
                    numSecure += 1
                else:
                    break
            elif isinstance(arg, ast.Starred) and isinstance(
                arg.value, (ast.List, ast.Tuple)
            ):
                args.extend(arg.value.elts)  # noqa: M538
                numSecure += 1
            else:
                break
        secure = numSecure == len(args)

    return secure


def transform2call(var):
    """
    Function to transform a variable node to a call node.

    @param var variable node
    @type ast.BinOp
    @return call node
    @rtype ast.Call
    """
    if isinstance(var, ast.BinOp):
        isMod = isinstance(var.op, ast.Mod)
        isLeftStr = AstUtilities.isString(var.left)
        if isMod and isLeftStr:
            newCall = ast.Call()
            newCall.args = []
            newCall.args = []
            newCall.keywords = None
            newCall.lineno = var.lineno
            newCall.func = ast.Attribute()
            newCall.func.value = var.left
            newCall.func.attr = "format"
            if isinstance(var.right, ast.Tuple):
                newCall.args = var.right.elts
            else:
                newCall.args = [var.right]

            return newCall

    return None
