from doctype import Node

class indexType:
    def __init__(self, value):
        self.value = value
    def get(self):
        return self.value

class Doer:
    def __init__(self, ARG=None, DATA=None, VAR=None, arg1=None, arg2=None):
        self.ARG = ARG
        self.DATA = DATA
        self.VAR = VAR
        self.arg1 = arg1
        self.arg2 = arg2

cmp__gt = 1
collect__attributes__cat = 2
collect__attributes__graad = 3
collect__attributes__index = 4
collect__attributes__lemma = 5
collect__attributes__pt = 6
collect__child__node = 7
collect__descendant__or__self__node = 8
collect__descendant__or__self__type__node = 9
equal__is = 10
function__count__1__args = 11
function__not__1__args = 12

nTRUE  = [True]
nFALSE = []

class dAnd(Doer):
    def do(self, subdoc, doc):
        if len(self.arg1.do(subdoc, doc)) == 0:
            return nFALSE
        if len(self.arg2.do(subdoc, doc)) == 0:
            return nFALSE
        return nTRUE

class dArg(Doer):
    def do(self, subdoc, doc):
        # TODO: waarom flatten?
        result = flatten(self.arg1.do(subdoc, doc))
        if self.arg2 is not None:
            result.extend(flatten(self.arg2.do(subdoc, doc)))
        return result

class dCmp(Doer):
    def do(self, subdoc, doc):
        arg1 = self.arg1.do(subdoc, doc)
        arg2 = self.arg2.do(subdoc, doc)
        result = []
        if self.ARG == cmp__gt:
            for a1 in arg1:
                for a2 in arg2:
                    if a1 > a2:
                        result.append(True)
        else:
            raise TypeError("dCmp: unknown ARG")
        return result

class dCollect(Doer):
    def do(self, subdoc, doc):
        lists = []
        result1 = []
        for r in self.arg1.do(subdoc, doc):
            if self.ARG == collect__attributes__cat:
                if r.cat != "":
                    result1.append(r.cat)
            elif self.ARG == collect__attributes__graad:
                if r.graad != "":
                    result1.append(r.graad)
            elif self.ARG == collect__attributes__index:
                if r.index > 0:
                    result1.append(r.index)
            elif self.ARG == collect__attributes__lemma:
                if r.lemma != "":
                    result1.append(r.lemma)
            elif self.ARG == collect__attributes__pt:
                if r.pt != "":
                    result1.append(r.pt)
            elif self.ARG == collect__child__node:
                lists.append(r.axChildren)
            elif self.ARG == collect__descendant__or__self__type__node or self.ARG == collect__descendant__or__self__node:
                lists.append(r.axDescendantsOrSelf)
            else:
                raise Exception("dCollect: unknown ARG")

        if self.arg2 is None:
            for lst in lists:
                result1.extend(lst)
            return result1

        if len(result1) > 0:
            result1.extend(lists)
            lists = result1

        result2 = []
        for lst in lists:
            for e in lst:
                for r2 in self.arg2.do([e], doc):
                    if isinstance(r2, indexType):
                        if r2.value == 1:
                            result2.append(lst[0])
                        elif r2.value == -1:
                            result2.append(lst[-1])
                        else:
                            raise Exception("Collect: Missing case for plain index")
                    else:
                        result2.append(e)
        return result2

class dElem(Doer):
    def do(self, subdoc, doc):
        if self.arg1 is None:
            return self.DATA
        # TODO: arg1 negeren: klopt dat? zie bijvoorbeeld: foo[@a + 10 = @b]
        #
        # Waarvoor dient arg1 dan?
        #
        #   SORT
        #     COLLECT  'child' 'name' 'node' foo
        #       NODE
        #       PREDICATE
        #         EQUAL =
        #           PLUS +
        #             COLLECT  'attributes' 'name' 'node' a
        #               NODE
        #             ELEM Object is a number : 10
        #               COLLECT  'attributes' 'name' 'node' a
        #                 NODE
        #           COLLECT  'attributes' 'name' 'node' b
        #             NODE
        return self.DATA

class dEqual(Doer):
    def do(self, subdoc, doc):
        result = []
        if self.ARG == equal__is:
            a1 = self.arg1.do(subdoc, doc)
            a2 = self.arg2.do(subdoc, doc)
            for aa1 in a1:
                for aa2 in a2:
                    if aa1 == aa2:
                        result.append(True)
            return result
        raise Exception("Equal: Missing case")

class dFunction(Doer):
    def do(self, subdoc, doc):
        r = []
        if self.arg1 is not None:
            r = self.arg1.do(subdoc, doc)

        if self.ARG == function__count__1__args:
            return [len(r)]
        if self.ARG == function__not__1__args:
            if len(r) == 0:
                return nTRUE
            return nFALSE
        raise Exception("Function: Missing case")

class dNode(Doer):
    def do(self, subdoc, doc):
        return subdoc

class dOr(Doer):
    def do(self, subdoc, doc):
        if len(self.arg1.do(subdoc, doc)) > 0:
            return nTRUE
        if len(self.arg2.do(subdoc, doc)) > 0:
            return nTRUE
        return nFALSE

class dPredicate(Doer):
    def do(self, subdoc, doc):
        result = self.arg1.do(subdoc, doc)
        if self.arg2 is None or len(result) == 0:
            return result
        idx = self.arg2.do(result, doc)[0]
        if idx.value == 1:
            return [result[0]]
        if idx.value == -1:
            return [result[-1]]
        raise Exception("Predicate arg2: Missing case for index {}".format(idx.value))

class dRoot(Doer):
    def do(self, subdoc, doc):
        return [doc.node]

class dSort(Doer):
    def do(self, subdoc, doc):
        result = self.arg1.do(subdoc, doc)
        if len(result) < 2:
            return result

        if isinstance(result[0], list):
            res = []
            for r in result:
                res.extend(r)
            result = res
            if len(result) < 2:
                return result

        if isinstance(result[0], Node):
            result.sort(key=lambda k: k.id)
            i = 1
            while i < len(result):
                if result[i] == result[i-1]:
                    result.pop(i)
                else:
                    i += 1
        elif isinstance(result[0], bool):
            result = result[:1]
        elif isinstance(result[0], str) or isinstance(result[0], int):
            result.sort()
            i = 1
            while i < len(result):
                if result[i] == result[i-1]:
                    result.pop(i)
                else:
                    i += 1
        else:
            raise Exception("Sort: Missing case for type {}".format(type(result[0])))

        return result


class dVariable(Doer):
    def do(self, subdoc, doc):
        if isinstance(self.VAR, list):
            return self.VAR
        return [self.VAR]

class xPath(Doer):
    def do(self, doc):
        return self.arg1.do([], doc)

################################################################

def test(doc, xpath):
    return len(xpath.do(doc)) > 0

def find(doc, xpath):
    return xpath.do(doc)

def flatten(aa):
    if aa is None:
        return []
    if not isinstance(aa, list) and not isinstance(aa, tuple):
        return [aa]
    result = []
    for a in aa:
        if isinstance(a, list) or isinstance(a, tuple):
            result.extend(flatten(a))
        elif a is not None:
            result.append(a)
    return result
