import sympy
import re
from sympy import matrix_symbols, simplify, factor, expand, apart, expand_trig
from antlr4 import InputStream, CommonTokenStream
from antlr4.error.ErrorListener import ErrorListener

try:
    from gen.PSParser import PSParser
    from gen.PSLexer import PSLexer
    from gen.PSListener import PSListener
except Exception:
    from .gen.PSParser import PSParser
    from .gen.PSLexer import PSLexer
    from .gen.PSListener import PSListener

from sympy.printing.str import StrPrinter

from sympy.parsing.sympy_parser import parse_expr

import hashlib

is_real = None

frac_type = r'\frac'

variances = {}
var = {}

VARIABLE_VALUES = {}


def set_real(value):
    global is_real
    is_real = value


def set_variances(vars):
    global variances
    variances = vars
    global var
    var = {}
    for variance in vars:
        var[str(variance)] = vars[variance]


def latex2sympy(sympy: str, variable_values={}):
    # record frac
    global frac_type
    if sympy.find(r'\frac') != -1:
        frac_type = r'\frac'
    if sympy.find(r'\dfrac') != -1:
        frac_type = r'\dfrac'
    if sympy.find(r'\tfrac') != -1:
        frac_type = r'\tfrac'
    sympy = sympy.replace(r'\dfrac', r'\frac')
    sympy = sympy.replace(r'\tfrac', r'\frac')
    # Translate Transpose
    sympy = sympy.replace(r'\mathrm{T}', 'T', -1)
    # Translate Derivative
    sympy = sympy.replace(r'\mathrm{d}', 'd', -1).replace(r'{\rm d}', 'd', -1)
    # Translate Matrix
    sympy = sympy.replace(r'\left[\begin{matrix}', r'\begin{bmatrix}', -1).replace(r'\end{matrix}\right]', r'\end{bmatrix}', -1)
    # Translate Permutation
    sympy = re.sub(r"\(([a-zA-Z0-9+\-*/\\ ]+?)\)_{([a-zA-Z0-9+\-*/\\ ]+?)}", r"\\frac{(\1)!}{((\1)-(\2))!}", sympy)
    # Remove \displaystyle
    sympy = sympy.replace(r'\displaystyle', ' ', -1)
    # Remove \quad
    sympy = sympy.replace(r'\quad', ' ', -1).replace(r'\qquad', ' ', -1).replace(r'~', ' ', -1).replace(r'\,', ' ', -1)
    # Remove $
    sympy = sympy.replace(r'$', ' ', -1)

    # variable values
    global VARIABLE_VALUES
    if len(variable_values) > 0:
        VARIABLE_VALUES = variable_values
    else:
        VARIABLE_VALUES = {}

    # setup listener
    matherror = MathErrorListener(sympy)

    # stream input
    stream = InputStream(sympy)
    lex = PSLexer(stream)
    lex.removeErrorListeners()
    lex.addErrorListener(matherror)

    tokens = CommonTokenStream(lex)
    parser = PSParser(tokens)

    # remove default console error listener
    parser.removeErrorListeners()
    parser.addErrorListener(matherror)

    # process the input
    return_data = None
    math = parser.math()

    # if a list
    if math.relation_list():
        return_data = []

        # go over list items
        relation_list = math.relation_list().relation_list_content()
        for list_item in relation_list.relation():
            expr = convert_relation(list_item)
            return_data.append(expr)

    # if not, do default
    else:
        relation = math.relation()
        return_data = convert_relation(relation)

    return return_data


class MathErrorListener(ErrorListener):
    def __init__(self, src):
        super(ErrorListener, self).__init__()
        self.src = src

    def syntaxError(self, recog, symbol, line, col, msg, e):
        fmt = "%s\n%s\n%s"
        marker = "~" * col + "^"

        if msg.startswith("missing"):
            err = fmt % (msg, self.src, marker)
        elif msg.startswith("no viable"):
            err = fmt % ("I expected something else here", self.src, marker)
        elif msg.startswith("mismatched"):
            names = PSParser.literalNames
            expected = [names[i] for i in e.getExpectedTokens() if i < len(names)]
            if len(expected) < 10:
                expected = " ".join(expected)
                err = (fmt % ("I expected one of these: " + expected,
                              self.src, marker))
            else:
                err = (fmt % ("I expected something else here", self.src, marker))
        else:
            err = fmt % ("I don't understand this", self.src, marker)
        raise Exception(err)


def convert_relation(rel):
    if rel.expr():
        return convert_expr(rel.expr())

    lh = convert_relation(rel.relation(0))
    rh = convert_relation(rel.relation(1))
    if rel.LT():
        return sympy.StrictLessThan(lh, rh, evaluate=False)
    elif rel.LTE():
        return sympy.LessThan(lh, rh, evaluate=False)
    elif rel.GT():
        return sympy.StrictGreaterThan(lh, rh, evaluate=False)
    elif rel.GTE():
        return sympy.GreaterThan(lh, rh, evaluate=False)
    elif rel.EQUAL():
        return sympy.Eq(lh, rh, evaluate=False)
    elif rel.ASSIGNMENT():
        # !Use Global variances
        if lh.is_Symbol:
            # set value
            variances[lh] = rh
            var[str(lh)] = rh
            return rh
        else:
            # find the symbols in lh - rh
            equation = lh - rh
            syms = equation.atoms(sympy.Symbol)
            if len(syms) > 0:
                # Solve equation
                result = []
                for sym in syms:
                    values = sympy.solve(equation, sym)
                    for value in values:
                        result.append(sympy.Eq(sym, value, evaluate=False))
                return result
            else:
                return sympy.Eq(lh, rh, evaluate=False)
    elif rel.IN():
        # !Use Global variances
        if hasattr(rh, 'is_Pow') and rh.is_Pow and hasattr(rh.exp, 'is_Mul'):
            n = rh.exp.args[0]
            m = rh.exp.args[1]
            if n in variances:
                n = variances[n]
            if m in variances:
                m = variances[m]
            rh = sympy.MatrixSymbol(lh, n, m)
            variances[lh] = rh
            var[str(lh)] = rh
        else:
            raise Exception("Don't support this form of definition of matrix symbol.")
        return lh
    elif rel.UNEQUAL():
        return sympy.Ne(lh, rh, evaluate=False)


def convert_expr(expr):
    if expr.additive():
        return convert_add(expr.additive())


def convert_elementary_transform(matrix, transform):
    if transform.transform_scale():
        transform_scale = transform.transform_scale()
        transform_atom = transform_scale.transform_atom()
        k = None
        num = int(transform_atom.NUMBER().getText()) - 1
        if transform_scale.expr():
            k = convert_expr(transform_scale.expr())
        elif transform_scale.group():
            k = convert_expr(transform_scale.group().expr())
        elif transform_scale.SUB():
            k = -1
        else:
            k = 1
        if transform_atom.LETTER_NO_E().getText() == 'r':
            matrix = matrix.elementary_row_op(op='n->kn', row=num, k=k)
        elif transform_atom.LETTER_NO_E().getText() == 'c':
            matrix = matrix.elementary_col_op(op='n->kn', col=num, k=k)
        else:
            raise Exception('Row and col don\'s match')

    elif transform.transform_swap():
        first_atom = transform.transform_swap().transform_atom()[0]
        second_atom = transform.transform_swap().transform_atom()[1]
        first_num = int(first_atom.NUMBER().getText()) - 1
        second_num = int(second_atom.NUMBER().getText()) - 1
        if first_atom.LETTER_NO_E().getText() != second_atom.LETTER_NO_E().getText():
            raise Exception('Row and col don\'s match')
        elif first_atom.LETTER_NO_E().getText() == 'r':
            matrix = matrix.elementary_row_op(op='n<->m', row1=first_num, row2=second_num)
        elif first_atom.LETTER_NO_E().getText() == 'c':
            matrix = matrix.elementary_col_op(op='n<->m', col1=first_num, col2=second_num)
        else:
            raise Exception('Row and col don\'s match')

    elif transform.transform_assignment():
        first_atom = transform.transform_assignment().transform_atom()
        second_atom = transform.transform_assignment().transform_scale().transform_atom()
        transform_scale = transform.transform_assignment().transform_scale()
        k = None
        if transform_scale.expr():
            k = convert_expr(transform_scale.expr())
        elif transform_scale.group():
            k = convert_expr(transform_scale.group().expr())
        elif transform_scale.SUB():
            k = -1
        else:
            k = 1
        first_num = int(first_atom.NUMBER().getText()) - 1
        second_num = int(second_atom.NUMBER().getText()) - 1
        if first_atom.LETTER_NO_E().getText() != second_atom.LETTER_NO_E().getText():
            raise Exception('Row and col don\'s match')
        elif first_atom.LETTER_NO_E().getText() == 'r':
            matrix = matrix.elementary_row_op(op='n->n+km', k=k, row1=first_num, row2=second_num)
        elif first_atom.LETTER_NO_E().getText() == 'c':
            matrix = matrix.elementary_col_op(op='n->n+km', k=k, col1=first_num, col2=second_num)
        else:
            raise Exception('Row and col don\'s match')

    return matrix


def convert_matrix(matrix):
    # build matrix
    row = matrix.matrix_row()
    tmp = []
    rows = 0
    mat = None

    for r in row:
        tmp.append([])
        for expr in r.expr():
            tmp[rows].append(convert_expr(expr))
        rows = rows + 1

    mat = sympy.Matrix(tmp)

    if hasattr(matrix, 'MATRIX_XRIGHTARROW') and matrix.MATRIX_XRIGHTARROW():
        transforms_list = matrix.elementary_transforms()
        if len(transforms_list) == 1:
            for transform in transforms_list[0].elementary_transform():
                mat = convert_elementary_transform(mat, transform)
        elif len(transforms_list) == 2:
            # firstly transform top of xrightarrow
            for transform in transforms_list[1].elementary_transform():
                mat = convert_elementary_transform(mat, transform)
            # firstly transform bottom of xrightarrow
            for transform in transforms_list[0].elementary_transform():
                mat = convert_elementary_transform(mat, transform)

    return mat


def add_flat(lh, rh):
    if hasattr(lh, 'is_Add') and lh.is_Add or hasattr(rh, 'is_Add') and rh.is_Add:
        args = []
        if hasattr(lh, 'is_Add') and lh.is_Add:
            args += list(lh.args)
        else:
            args += [lh]
        if hasattr(rh, 'is_Add') and rh.is_Add:
            args = args + list(rh.args)
        else:
            args += [rh]
        return sympy.Add(*args, evaluate=False)
    else:
        return sympy.Add(lh, rh, evaluate=False)


def mat_add_flat(lh, rh):
    if hasattr(lh, 'is_MatAdd') and lh.is_MatAdd or hasattr(rh, 'is_MatAdd') and rh.is_MatAdd:
        args = []
        if hasattr(lh, 'is_MatAdd') and lh.is_MatAdd:
            args += list(lh.args)
        else:
            args += [lh]
        if hasattr(rh, 'is_MatAdd') and rh.is_MatAdd:
            args = args + list(rh.args)
        else:
            args += [rh]
        return sympy.MatAdd(*[arg.doit() for arg in args], evaluate=False)
    else:
        return sympy.MatAdd(lh.doit(), rh.doit(), evaluate=False)


def mul_flat(lh, rh):
    if hasattr(lh, 'is_Mul') and lh.is_Mul or hasattr(rh, 'is_Mul') and rh.is_Mul:
        args = []
        if hasattr(lh, 'is_Mul') and lh.is_Mul:
            args += list(lh.args)
        else:
            args += [lh]
        if hasattr(rh, 'is_Mul') and rh.is_Mul:
            args = args + list(rh.args)
        else:
            args += [rh]
        return sympy.Mul(*args, evaluate=False)
    else:
        return sympy.Mul(lh, rh, evaluate=False)


def mat_mul_flat(lh, rh):
    if hasattr(lh, 'is_MatMul') and lh.is_MatMul or hasattr(rh, 'is_MatMul') and rh.is_MatMul:
        args = []
        if hasattr(lh, 'is_MatMul') and lh.is_MatMul:
            args += list(lh.args)
        else:
            args += [lh]
        if hasattr(rh, 'is_MatMul') and rh.is_MatMul:
            args = args + list(rh.args)
        else:
            args += [rh]
        return sympy.MatMul(*[arg.doit() for arg in args], evaluate=False)
    else:
        if hasattr(lh, 'doit') and hasattr(rh, 'doit'):
            return sympy.MatMul(lh.doit(), rh.doit(), evaluate=False)
        elif hasattr(lh, 'doit') and not hasattr(rh, 'doit'):
            return sympy.MatMul(lh.doit(), rh, evaluate=False)
        elif not hasattr(lh, 'doit') and hasattr(rh, 'doit'):
            return sympy.MatMul(lh, rh.doit(), evaluate=False)
        else:
            return sympy.MatMul(lh, rh, evaluate=False)


def convert_add(add):
    if add.ADD():
        lh = convert_add(add.additive(0))
        rh = convert_add(add.additive(1))

        if lh.is_Matrix or rh.is_Matrix:
            return mat_add_flat(lh, rh)
        else:
            return add_flat(lh, rh)
    elif add.SUB():
        lh = convert_add(add.additive(0))
        rh = convert_add(add.additive(1))

        if lh.is_Matrix or rh.is_Matrix:
            return mat_add_flat(lh, mat_mul_flat(-1, rh))
        else:
            # If we want to force ordering for variables this should be:
            # return Sub(lh, rh, evaluate=False)
            if not rh.is_Matrix and rh.func.is_Number:
                rh = -rh
            else:
                rh = mul_flat(-1, rh)
            return add_flat(lh, rh)
    else:
        return convert_mp(add.mp())


def convert_mp(mp):
    if hasattr(mp, 'mp'):
        mp_left = mp.mp(0)
        mp_right = mp.mp(1)
    else:
        mp_left = mp.mp_nofunc(0)
        mp_right = mp.mp_nofunc(1)

    if mp.MUL() or mp.CMD_TIMES() or mp.CMD_CDOT():
        lh = convert_mp(mp_left)
        rh = convert_mp(mp_right)

        if lh.is_Matrix or rh.is_Matrix:
            return mat_mul_flat(lh, rh)
        else:
            return mul_flat(lh, rh)
    elif mp.DIV() or mp.CMD_DIV() or mp.COLON():
        lh = convert_mp(mp_left)
        rh = convert_mp(mp_right)
        if lh.is_Matrix or rh.is_Matrix:
            return sympy.MatMul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False)
        else:
            return sympy.Mul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False)
    elif mp.CMD_MOD():
        lh = convert_mp(mp_left)
        rh = convert_mp(mp_right)
        if rh.is_Matrix:
            raise Exception("Cannot perform modulo operation with a matrix as an operand")
        else:
            return sympy.Mod(lh, rh, evaluate=False)
    else:
        if hasattr(mp, 'unary'):
            return convert_unary(mp.unary())
        else:
            return convert_unary(mp.unary_nofunc())


def convert_unary(unary):
    if hasattr(unary, 'unary'):
        nested_unary = unary.unary()
    else:
        nested_unary = unary.unary_nofunc()
    if hasattr(unary, 'postfix_nofunc'):
        first = unary.postfix()
        tail = unary.postfix_nofunc()
        postfix = [first] + tail
    else:
        postfix = unary.postfix()

    if unary.ADD():
        return convert_unary(nested_unary)
    elif unary.SUB():
        tmp_convert_nested_unary = convert_unary(nested_unary)
        if tmp_convert_nested_unary.is_Matrix:
            return mat_mul_flat(-1, tmp_convert_nested_unary, evaluate=False)
        else:
            if tmp_convert_nested_unary.func.is_Number:
                return -tmp_convert_nested_unary
            else:
                return mul_flat(-1, tmp_convert_nested_unary)
    elif postfix:
        return convert_postfix_list(postfix)


def convert_postfix_list(arr, i=0):
    if i >= len(arr):
        raise Exception("Index out of bounds")

    res = convert_postfix(arr[i])

    if isinstance(res, sympy.Expr) or isinstance(res, sympy.Matrix) or res is sympy.S.EmptySet:
        if i == len(arr) - 1:
            return res  # nothing to multiply by
        else:
            # multiply by next
            rh = convert_postfix_list(arr, i + 1)

            if res.is_Matrix or rh.is_Matrix:
                return mat_mul_flat(res, rh)
            else:
                return mul_flat(res, rh)
    elif ((isinstance(res, tuple) or isinstance(res, list)) and res[0] != 'derivative') or isinstance(res, dict):
        return res
    else:  # must be derivative
        wrt = res[1]
        if i == len(arr) - 1:
            raise Exception("Expected expression for derivative")
        else:
            expr = convert_postfix_list(arr, i + 1)
            return sympy.Derivative(expr, wrt)


def do_subs(expr, at):
    if at.expr():
        at_expr = convert_expr(at.expr())
        syms = at_expr.atoms(sympy.Symbol)
        if len(syms) == 0:
            return expr
        elif len(syms) > 0:
            sym = next(iter(syms))
            return expr.subs(sym, at_expr)
    elif at.equality():
        lh = convert_expr(at.equality().expr(0))
        rh = convert_expr(at.equality().expr(1))
        return expr.subs(lh, rh)


def convert_postfix(postfix):
    if hasattr(postfix, 'exp'):
        exp_nested = postfix.exp()
    else:
        exp_nested = postfix.exp_nofunc()

    exp = convert_exp(exp_nested)
    for op in postfix.postfix_op():
        if op.BANG():
            if isinstance(exp, list):
                raise Exception("Cannot apply postfix to derivative")
            exp = sympy.factorial(exp, evaluate=False)
        elif op.eval_at():
            ev = op.eval_at()
            at_b = None
            at_a = None
            if ev.eval_at_sup():
                at_b = do_subs(exp, ev.eval_at_sup())
            if ev.eval_at_sub():
                at_a = do_subs(exp, ev.eval_at_sub())
            if at_b is not None and at_a is not None:
                exp = add_flat(at_b, mul_flat(at_a, -1))
            elif at_b is not None:
                exp = at_b
            elif at_a is not None:
                exp = at_a
        elif op.transpose():
            try:
                exp = exp.T
            except:
                try:
                    exp = sympy.transpose(exp)
                except:
                    pass
                pass

    return exp


def convert_exp(exp):
    if hasattr(exp, 'exp'):
        exp_nested = exp.exp()
    else:
        exp_nested = exp.exp_nofunc()

    if exp_nested:
        base = convert_exp(exp_nested)
        if isinstance(base, list):
            raise Exception("Cannot raise derivative to power")
        if exp.atom():
            exponent = convert_atom(exp.atom())
        elif exp.expr():
            exponent = convert_expr(exp.expr())
        return sympy.Pow(base, exponent, evaluate=False)
    else:
        if hasattr(exp, 'comp'):
            return convert_comp(exp.comp())
        else:
            return convert_comp(exp.comp_nofunc())


def convert_comp(comp):
    if comp.group():
        return convert_expr(comp.group().expr())
    elif comp.norm_group():
        return convert_expr(comp.norm_group().expr()).norm()
    elif comp.abs_group():
        return sympy.Abs(convert_expr(comp.abs_group().expr()), evaluate=False)
    elif comp.floor_group():
        return handle_floor(convert_expr(comp.floor_group().expr()))
    elif comp.ceil_group():
        return handle_ceil(convert_expr(comp.ceil_group().expr()))
    elif comp.atom():
        return convert_atom(comp.atom())
    elif comp.frac():
        return convert_frac(comp.frac())
    elif comp.binom():
        return convert_binom(comp.binom())
    elif comp.matrix():
        return convert_matrix(comp.matrix())
    elif comp.det():
        # !Use Global variances
        return convert_matrix(comp.det()).subs(variances).det()
    elif comp.func():
        return convert_func(comp.func())


def convert_atom(atom):
    if atom.atom_expr():
        atom_expr = atom.atom_expr()

        # find the atom's text
        atom_text = ''
        if atom_expr.LETTER_NO_E():
            atom_text = atom_expr.LETTER_NO_E().getText()
            if atom_text == "I":
                return sympy.I
        elif atom_expr.GREEK_CMD():
            atom_text = atom_expr.GREEK_CMD().getText()[1:].strip()
        elif atom_expr.OTHER_SYMBOL_CMD():
            atom_text = atom_expr.OTHER_SYMBOL_CMD().getText().strip()
        elif atom_expr.accent():
            atom_accent = atom_expr.accent()
            # get name for accent
            name = atom_accent.start.text
            # name = atom_accent.start.text[1:]
            # exception: check if bar or overline which are treated both as bar
            # if name in ["bar", "overline"]:
            #     name = "bar"
            # if name in ["vec", "overrightarrow"]:
            #     name = "vec"
            # if name in ["tilde", "widetilde"]:
            #     name = "tilde"
            # get the base (variable)
            base = atom_accent.base.getText()
            # set string to base+name
            atom_text = name + '{' + base + '}'

        # find atom's subscript, if any
        subscript_text = ''
        if atom_expr.subexpr():
            subexpr = atom_expr.subexpr()
            subscript = None
            if subexpr.expr():  # subscript is expr
                subscript = subexpr.expr().getText().strip()
            elif subexpr.atom():  # subscript is atom
                subscript = subexpr.atom().getText().strip()
            elif subexpr.args():  # subscript is args
                subscript = subexpr.args().getText().strip()
            subscript_inner_text = StrPrinter().doprint(subscript)
            if len(subscript_inner_text) > 1:
                subscript_text = '_{' + subscript_inner_text + '}'
            else:
                subscript_text = '_' + subscript_inner_text

        # construct the symbol using the text and optional subscript
        atom_symbol = sympy.Symbol(atom_text + subscript_text, real=is_real)
        # for matrix symbol
        matrix_symbol = None
        global var
        if atom_text + subscript_text in var:
            try:
                rh = var[atom_text + subscript_text]
                shape = sympy.shape(rh)
                matrix_symbol = sympy.MatrixSymbol(atom_text + subscript_text, shape[0], shape[1])
                variances[matrix_symbol] = variances[atom_symbol]
            except:
                pass

        # find the atom's superscript, and return as a Pow if found
        if atom_expr.supexpr():
            supexpr = atom_expr.supexpr()
            func_pow = None
            if supexpr.expr():
                func_pow = convert_expr(supexpr.expr())
            else:
                func_pow = convert_atom(supexpr.atom())
            return sympy.Pow(atom_symbol, func_pow, evaluate=False)

        return atom_symbol if not matrix_symbol else matrix_symbol
    elif atom.SYMBOL():
        s = atom.SYMBOL().getText().replace("\\$", "").replace("\\%", "")
        if s == "\\infty":
            return sympy.oo
        elif s == '\\pi':
            return sympy.pi
        elif s == '\\emptyset':
            return sympy.S.EmptySet
        else:
            raise Exception("Unrecognized symbol")
    elif atom.NUMBER():
        s = atom.NUMBER().getText().replace(",", "")
        try:
            sr = sympy.Rational(s)
            return sr
        except (TypeError, ValueError):
            return sympy.Number(s)
    elif atom.E_NOTATION():
        s = atom.E_NOTATION().getText().replace(",", "")
        try:
            sr = sympy.Rational(s)
            return sr
        except (TypeError, ValueError):
            return sympy.Number(s)
    elif atom.DIFFERENTIAL():
        var = get_differential_var(atom.DIFFERENTIAL())
        return sympy.Symbol('d' + var.name, real=is_real)
    elif atom.mathit():
        text = rule2text(atom.mathit().mathit_text())
        return sympy.Symbol(text, real=is_real)
    elif atom.VARIABLE():
        text = atom.VARIABLE().getText()
        is_percent = text.endswith("\\%")
        trim_amount = 3 if is_percent else 1
        name = text[10:]
        name = name[0:len(name) - trim_amount]

        # add hash to distinguish from regular symbols
        hash = hashlib.md5(name.encode()).hexdigest()
        symbol_name = name + hash

        # replace the variable for already known variable values
        if name in VARIABLE_VALUES:
            # if a sympy class
            if isinstance(VARIABLE_VALUES[name], tuple(sympy.core.all_classes)):
                symbol = VARIABLE_VALUES[name]

            # if NOT a sympy class
            else:
                symbol = parse_expr(str(VARIABLE_VALUES[name]))
        else:
            symbol = sympy.Symbol(symbol_name, real=is_real)

        if is_percent:
            return sympy.Mul(symbol, sympy.Pow(100, -1, evaluate=False), evaluate=False)

        # return the symbol
        return symbol

    elif atom.PERCENT_NUMBER():
        text = atom.PERCENT_NUMBER().getText().replace("\\%", "").replace(",", "")
        try:
            number = sympy.Rational(text)
        except (TypeError, ValueError):
            number = sympy.Number(text)
        percent = sympy.Rational(number, 100)
        return percent


def rule2text(ctx):
    stream = ctx.start.getInputStream()
    # starting index of starting token
    startIdx = ctx.start.start
    # stopping index of stopping token
    stopIdx = ctx.stop.stop

    return stream.getText(startIdx, stopIdx)


def convert_frac(frac):
    diff_op = False
    partial_op = False
    lower_itv = frac.lower.getSourceInterval()
    lower_itv_len = lower_itv[1] - lower_itv[0] + 1
    if (frac.lower.start == frac.lower.stop and
            frac.lower.start.type == PSLexer.DIFFERENTIAL):
        wrt = get_differential_var_str(frac.lower.start.text)
        diff_op = True
    elif (lower_itv_len == 2 and
          frac.lower.start.type == PSLexer.SYMBOL and
          frac.lower.start.text == '\\partial' and
          (frac.lower.stop.type == PSLexer.LETTER_NO_E or frac.lower.stop.type == PSLexer.SYMBOL)):
        partial_op = True
        wrt = frac.lower.stop.text
        if frac.lower.stop.type == PSLexer.SYMBOL:
            wrt = wrt[1:]

    if diff_op or partial_op:
        wrt = sympy.Symbol(wrt, real=is_real)
        if (diff_op and frac.upper.start == frac.upper.stop and
            frac.upper.start.type == PSLexer.LETTER_NO_E and
                frac.upper.start.text == 'd'):
            return ('derivative', wrt)
        elif (partial_op and frac.upper.start == frac.upper.stop and
              frac.upper.start.type == PSLexer.SYMBOL and
              frac.upper.start.text == '\\partial'):
            return ('derivative', wrt)
        upper_text = rule2text(frac.upper)

        expr_top = None
        if diff_op and upper_text.startswith('d'):
            expr_top = latex2sympy(upper_text[1:])
        elif partial_op and frac.upper.start.text == '\\partial':
            expr_top = latex2sympy(upper_text[len('\\partial'):])
        if expr_top:
            return sympy.Derivative(expr_top, wrt)

    expr_top = convert_expr(frac.upper)
    expr_bot = convert_expr(frac.lower)
    if expr_top.is_Matrix or expr_bot.is_Matrix:
        return sympy.MatMul(expr_top, sympy.Pow(expr_bot, -1, evaluate=False), evaluate=False)
    else:
        return sympy.Mul(expr_top, sympy.Pow(expr_bot, -1, evaluate=False), evaluate=False)


def convert_binom(binom):
    expr_top = convert_expr(binom.upper)
    expr_bot = convert_expr(binom.lower)
    return sympy.binomial(expr_top, expr_bot)


def convert_func(func):
    if func.func_normal_single_arg():
        if func.L_PAREN():  # function called with parenthesis
            arg = convert_func_arg(func.func_single_arg())
        else:
            arg = convert_func_arg(func.func_single_arg_noparens())

        name = func.func_normal_single_arg().start.text[1:]

        # change arc<trig> -> a<trig>
        if name in ["arcsin", "arccos", "arctan", "arccsc", "arcsec",
                    "arccot"]:
            name = "a" + name[3:]
            expr = getattr(sympy.functions, name)(arg, evaluate=False)
        elif name in ["arsinh", "arcosh", "artanh"]:
            name = "a" + name[2:]
            expr = getattr(sympy.functions, name)(arg, evaluate=False)
        elif name in ["arcsinh", "arccosh", "arctanh"]:
            name = "a" + name[3:]
            expr = getattr(sympy.functions, name)(arg, evaluate=False)
        elif name == "operatorname":
            operatorname = func.func_normal_single_arg().func_operator_name.getText()

            if operatorname in ["arsinh", "arcosh", "artanh"]:
                operatorname = "a" + operatorname[2:]
                expr = getattr(sympy.functions, operatorname)(arg, evaluate=False)
            elif operatorname in ["arcsinh", "arccosh", "arctanh"]:
                operatorname = "a" + operatorname[3:]
                expr = getattr(sympy.functions, operatorname)(arg, evaluate=False)
            elif operatorname == "floor":
                expr = handle_floor(arg)
            elif operatorname == "ceil":
                expr = handle_ceil(arg)
            elif operatorname == 'eye':
                expr = sympy.eye(arg)
            elif operatorname == 'rank':
                expr = sympy.Integer(arg.rank())
            elif operatorname in ['trace', 'tr']:
                expr = arg.trace()
            elif operatorname == 'rref':
                expr = arg.rref()[0]
            elif operatorname == 'nullspace':
                expr = arg.nullspace()
            elif operatorname == 'norm':
                expr = arg.norm()
            elif operatorname == 'cols':
                expr = [arg.col(i) for i in range(arg.cols)]
            elif operatorname == 'rows':
                expr = [arg.row(i) for i in range(arg.rows)]
            elif operatorname in ['eig', 'eigen', 'diagonalize']:
                expr = arg.diagonalize()
            elif operatorname in ['eigenvals', 'eigenvalues']:
                expr = arg.eigenvals()
            elif operatorname in ['eigenvects', 'eigenvectors']:
                expr = arg.eigenvects()
            elif operatorname in ['svd', 'SVD']:
                expr = arg.singular_value_decomposition()
        elif name in ["log", "ln"]:
            if func.subexpr():
                if func.subexpr().atom():
                    base = convert_atom(func.subexpr().atom())
                else:
                    base = convert_expr(func.subexpr().expr())
            elif name == "log":
                base = 10
            elif name == "ln":
                base = sympy.E
            expr = sympy.log(arg, base, evaluate=False)
        elif name in ["exp", "exponentialE"]:
            expr = sympy.exp(arg)
        elif name == "floor":
            expr = handle_floor(arg)
        elif name == "ceil":
            expr = handle_ceil(arg)
        elif name == 'det':
            expr = arg.det()

        func_pow = None
        should_pow = True
        if func.supexpr():
            if func.supexpr().expr():
                func_pow = convert_expr(func.supexpr().expr())
            else:
                func_pow = convert_atom(func.supexpr().atom())

        if name in ["sin", "cos", "tan", "csc", "sec", "cot", "sinh", "cosh", "tanh"]:
            if func_pow == -1:
                name = "a" + name
                should_pow = False
            expr = getattr(sympy.functions, name)(arg, evaluate=False)

        if func_pow and should_pow:
            expr = sympy.Pow(expr, func_pow, evaluate=False)

        return expr

    elif func.func_normal_multi_arg():
        if func.L_PAREN():  # function called with parenthesis
            args = func.func_multi_arg().getText().split(",")
        else:
            args = func.func_multi_arg_noparens().split(",")

        args = list(map(lambda arg: latex2sympy(arg, VARIABLE_VALUES), args))
        name = func.func_normal_multi_arg().start.text[1:]

        if name == "operatorname":
            operatorname = func.func_normal_multi_arg().func_operator_name.getText()
            if operatorname in ["gcd", "lcm"]:
                expr = handle_gcd_lcm(operatorname, args)
            elif operatorname == 'zeros':
                expr = sympy.zeros(*args)
            elif operatorname == 'ones':
                expr = sympy.ones(*args)
            elif operatorname == 'diag':
                expr = sympy.diag(*args)
            elif operatorname == 'hstack':
                expr = sympy.Matrix.hstack(*args)
            elif operatorname == 'vstack':
                expr = sympy.Matrix.vstack(*args)
            elif operatorname in ['orth', 'ortho', 'orthogonal', 'orthogonalize']:
                if len(args) == 1:
                    arg = args[0]
                    expr = sympy.matrices.GramSchmidt([arg.col(i) for i in range(arg.cols)], True)
                else:
                    expr = sympy.matrices.GramSchmidt(args, True)
        elif name in ["gcd", "lcm"]:
            expr = handle_gcd_lcm(name, args)
        elif name in ["max", "min"]:
            name = name[0].upper() + name[1:]
            expr = getattr(sympy.functions, name)(*args, evaluate=False)

        func_pow = None
        should_pow = True
        if func.supexpr():
            if func.supexpr().expr():
                func_pow = convert_expr(func.supexpr().expr())
            else:
                func_pow = convert_atom(func.supexpr().atom())

        if func_pow and should_pow:
            expr = sympy.Pow(expr, func_pow, evaluate=False)

        return expr
    elif func.atom_expr_no_supexpr():
        # define a function
        f = sympy.Function(func.atom_expr_no_supexpr().getText())
        # args
        args = func.func_common_args().getText().split(",")
        if args[-1] == '':
            args = args[:-1]
        args = [latex2sympy(arg, VARIABLE_VALUES) for arg in args]
        # supexpr
        if func.supexpr():
            if func.supexpr().expr():
                expr = convert_expr(func.supexpr().expr())
            else:
                expr = convert_atom(func.supexpr().atom())
            return sympy.Pow(f(*args), expr, evaluate=False)
        else:
            return f(*args)
    elif func.FUNC_INT():
        return handle_integral(func)
    elif func.FUNC_SQRT():
        expr = convert_expr(func.base)
        if func.root:
            r = convert_expr(func.root)
            return sympy.Pow(expr, 1 / r, evaluate=False)
        else:
            return sympy.Pow(expr, sympy.S.Half, evaluate=False)
    elif func.FUNC_SUM():
        return handle_sum_or_prod(func, "summation")
    elif func.FUNC_PROD():
        return handle_sum_or_prod(func, "product")
    elif func.FUNC_LIM():
        return handle_limit(func)
    elif func.EXP_E():
        return handle_exp(func)


def convert_func_arg(arg):
    if hasattr(arg, 'expr'):
        return convert_expr(arg.expr())
    else:
        return convert_mp(arg.mp_nofunc())


def handle_integral(func):
    if func.additive():
        integrand = convert_add(func.additive())
    elif func.frac():
        integrand = convert_frac(func.frac())
    else:
        integrand = 1

    int_var = None
    if func.DIFFERENTIAL():
        int_var = get_differential_var(func.DIFFERENTIAL())
    else:
        for sym in integrand.atoms(sympy.Symbol):
            s = str(sym)
            if len(s) > 1 and s[0] == 'd':
                if s[1] == '\\':
                    int_var = sympy.Symbol(s[2:], real=is_real)
                else:
                    int_var = sympy.Symbol(s[1:], real=is_real)
                int_sym = sym
        if int_var:
            integrand = integrand.subs(int_sym, 1)
        else:
            # Assume dx by default
            int_var = sympy.Symbol('x', real=is_real)

    if func.subexpr():
        if func.subexpr().atom():
            lower = convert_atom(func.subexpr().atom())
        else:
            lower = convert_expr(func.subexpr().expr())
        if func.supexpr().atom():
            upper = convert_atom(func.supexpr().atom())
        else:
            upper = convert_expr(func.supexpr().expr())
        return sympy.Integral(integrand, (int_var, lower, upper))
    else:
        return sympy.Integral(integrand, int_var)


def handle_sum_or_prod(func, name):
    val = convert_mp(func.mp())
    iter_var = convert_expr(func.subeq().equality().expr(0))
    start = convert_expr(func.subeq().equality().expr(1))
    if func.supexpr().expr():  # ^{expr}
        end = convert_expr(func.supexpr().expr())
    else:  # ^atom
        end = convert_atom(func.supexpr().atom())

    if name == "summation":
        return sympy.Sum(val, (iter_var, start, end))
    elif name == "product":
        return sympy.Product(val, (iter_var, start, end))


def handle_limit(func):
    sub = func.limit_sub()
    if sub.LETTER_NO_E():
        var = sympy.Symbol(sub.LETTER_NO_E().getText(), real=is_real)
    elif sub.GREEK_CMD():
        var = sympy.Symbol(sub.GREEK_CMD().getText()[1:].strip(), real=is_real)
    elif sub.OTHER_SYMBOL_CMD():
        var = sympy.Symbol(sub.OTHER_SYMBOL_CMD().getText().strip(), real=is_real)
    else:
        var = sympy.Symbol('x', real=is_real)
    if sub.SUB():
        direction = "-"
    else:
        direction = "+"
    approaching = convert_expr(sub.expr())
    content = convert_mp(func.mp())

    return sympy.Limit(content, var, approaching, direction)


def handle_exp(func):
    if func.supexpr():
        if func.supexpr().expr():  # ^{expr}
            exp_arg = convert_expr(func.supexpr().expr())
        else:  # ^atom
            exp_arg = convert_atom(func.supexpr().atom())
    else:
        exp_arg = 1
    return sympy.exp(exp_arg)


def handle_gcd_lcm(f, args):
    """
    Return the result of gcd() or lcm(), as UnevaluatedExpr

    f: str - name of function ("gcd" or "lcm")
    args: List[Expr] - list of function arguments
    """

    args = tuple(map(sympy.nsimplify, args))

    # gcd() and lcm() don't support evaluate=False
    return sympy.UnevaluatedExpr(getattr(sympy, f)(args))


def handle_floor(expr):
    """
    Apply floor() then return the floored expression.

    expr: Expr - sympy expression as an argument to floor()
    """
    return sympy.functions.floor(expr, evaluate=False)


def handle_ceil(expr):
    """
    Apply ceil() then return the ceil-ed expression.

    expr: Expr - sympy expression as an argument to ceil()
    """
    return sympy.functions.ceiling(expr, evaluate=False)


def get_differential_var(d):
    text = get_differential_var_str(d.getText())
    return sympy.Symbol(text, real=is_real)


def get_differential_var_str(text):
    for i in range(1, len(text)):
        c = text[i]
        if not (c == " " or c == "\r" or c == "\n" or c == "\t"):
            idx = i
            break
    text = text[idx:]
    if text[0] == "\\":
        text = text[1:]
    return text


def latex(tex):
    global frac_type
    result = sympy.latex(tex)
    result = result.replace(r'\frac', frac_type, -1).replace(r'\dfrac', frac_type, -1).replace(r'\tfrac', frac_type, -1)
    result = result.replace(r'\left[\begin{matrix}', r'\begin{bmatrix}', -1).replace(r'\end{matrix}\right]', r'\end{bmatrix}', -1)
    result = result.replace(r'\left', r'', -1).replace(r'\right', r'', -1)
    result = result.replace(r' )', r')', -1)
    result = result.replace(r'\log', r'\ln', -1)
    return result


def latex2latex(tex):
    result = latex2sympy(tex)
    # if result is a list or tuple or dict
    if isinstance(result, list) or isinstance(result, tuple) or isinstance(result, dict):
        return latex(result)
    else:
        return latex(simplify(result.subs(variances).doit().doit()))


# Set image value
latex2latex('i=I')
latex2latex('j=I')
# set Identity(i)
for i in range(1, 10):
    lh = sympy.Symbol(r'\bm{I}_' + str(i), real=False)
    lh_m = sympy.MatrixSymbol(r'\bm{I}_' + str(i), i, i)
    rh = sympy.Identity(i).as_mutable()
    variances[lh] = rh
    variances[lh_m] = rh
    var[str(lh)] = rh

if __name__ == '__main__':
    # latex2latex(r'A_1=\begin{bmatrix}1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8\end{bmatrix}')
    # latex2latex(r'b_1=\begin{bmatrix}1 \\ 2 \\ 3 \\ 4\end{bmatrix}')
    # tex = r"(x+2)|_{x=y+1}"
    # tex = r"\operatorname{zeros}(3)"
    tex = r"\frac{\mathrm{d}}{\mathrm{d}x}(x^{2}+x)"
    # print("latex2latex:", latex2latex(tex))
    math = latex2sympy(tex)
    # math = math.subs(variances)
    print("latex:", tex)
    # print("var:", variances)
    print("raw_math:", math)
    # print("math:", latex(math.doit()))
    # print("math_type:", type(math.doit()))
    # print("shape:", (math.doit()).shape)
    print("cal:", latex2latex(tex))
