from forensics.win32.mshtml import mshtml_types
from forensics.win32.mshtml import TP_FLAGS
from forensics.win32.mshtml import TAGDESC,TAGDESC_FLAGS,HASEND
from forensics.object2 import Object
from forensics.object import read_value, obj_size, get_obj_offset

class CDoc(Object):
    hasMembers = True
    name = "CDoc"
    
    def __new__(typ, *args, **kwargs):
        obj = object.__new__(typ)
        return obj

    def get_lookaside(self, key):
        idx = key % self.Lookaside.numElements
        entry = self.Lookaside.pHashTable[idx]
        return entry.hvalue
        
class CHtPvPv(Object):
    hasMembers = True
    name = "CHtPvPv"
    
    def __new__(typ, *args, **kwargs):
        obj = object.__new__(typ)
        return obj

    def getHashTable(self):
        tbl_start = read_value(self.vm, 'pointer', self.get_member_offset('pHashTable'))
        entry_size = obj_size(mshtml_types, '_HTENTRY')
        entries = []
        for i in range(self.numElements):
            entry = Object("_HTENTRY", tbl_start+(i*entry_size), self.vm, profile=self.profile)
            entries.append(entry)
        return entries
    pHashTable = property(fget=getHashTable)

class CMarkup(Object):
    hasMembers = True
    name = "CMarkup"

    TOP_ELEM = 5
    
    def __new__(typ, *args, **kwargs):
        obj = object.__new__(typ)
        return obj

    def elem_cache(self):
        elem_addr = self.get_lookaside(self.TOP_ELEM)
        return Object("CElemCache", elem_addr, self.vm, profile=self.profile)

    def get_top_element(self):
        return self.elem_cache().pHtml
    
    def get_lookaside(self, idx):
        return self.pSecurity.pDoc.get_lookaside(self.offset+(idx*4))

    def get_text(self, treepos):
        """Get the text for a TreePos"""
        cp = treepos.get_cp()
        cpb = cp*2
        text_len = treepos.get_textlen()
        text_blen = text_len*2
        
        # FIXME: should probably figure out the index. Using 0 for now.
        # Maybe this happens if cpb > txtElem.bufLen ?
        txtElem = self.txtArray.get(0, 'CTextArrayElem')

        if cp + text_len > txtElem.nChars:
            raise UnsupportedException(
                "FIXME: Hmm, looks like not all text fits in the first element"
            )
        
        if cpb >= txtElem.gapOff: # In the second half
            gap_size = txtElem.bufLen - txtElem.nChars*2
            txt = self.vm.read(txtElem.pBuff.v() + cpb + gap_size, text_blen)
        elif cpb + text_blen < txtElem.gapOff: # In the first half
            txt = self.vm.read(txtElem.pBuff.v() + cpb, text_blen)
        else: # Crosses gap
            # [ stuff ][ our text ][ gap ][ our text ][ stuff ]
            #          ^- cpb      ^      ^- gapOff+gap_size
            #                      `- gapOff 
            st_len = txtElem.gapOff - cpb
            ed_len = text_blen - st_len
            gap_size = txtElem.bufLen - txtElem.nChars*2
            
            #print "Buffer base: %#x" % txtElem.pBuff.v()
            #print "Slices: %x bytes at %#x" % (st_len, txtElem.pBuff.v() + cpb)
            #print "        %x bytes at %#x" % (ed_len, txtElem.pBuff.v() + txtElem.gapOff + gap_size)
            txt =  self.vm.read(txtElem.pBuff.v() + cpb, st_len)
            txt += self.vm.read(txtElem.pBuff.v() + txtElem.gapOff + gap_size, ed_len)

        if txt:
            txt = txt.decode('utf-16-le', 'backslashreplace')
        return txt

class CTreePos(Object):
    hasMembers = True
    name = "CTreePos"

    def __new__(typ, *args, **kwargs):
        obj = object.__new__(typ)
        return obj

    def get_treenode(self):
        """Gets the containing CTreeNode"""
        if self.flags & TP_FLAGS['TP_BEGIN']:
            off, _ = get_obj_offset(mshtml_types, ['CTreeNode', 'tpBegin'])
        elif self.flags & TP_FLAGS['TP_END']:
            off, _ = get_obj_offset(mshtml_types, ['CTreeNode', 'tpEnd'])
        else:
            return None

        return Object("CTreeNode", self.offset - off, self.vm, profile=self.profile)

    def next_treepos(self):
        """
        Get the next CTreeNode in DOM order.

        Implements the algorithm in mshtml!CTreeNode::NextTreePos.
        """
        
        child = self.pChild
        if child.v() == 0 or (child.flags & TP_FLAGS['TP_FIRST'] and
                              child.flags & TP_FLAGS['TP_LAST']):
            #print "child == NULL or has both TP_FIRST & TP_LAST"
            cur = self
            while not cur.flags & TP_FLAGS['TP_FIRST']:
                #print "%s is TP_FIRST, continuing to %s" % (cur, cur.pNext)
                cur = cur.pNext
                if cur.v() == 0:
                    #print "Reached a point where cur.pNext was NULL, returning NULL"
                    return None
            if not cur.flags & TP_FLAGS['TP_LAST']:
                #print "%s is TP_LAST, proceeding to %s" % (cur, cur.pNext)
                cur = cur.pNext
            ret = cur.pNext
            #print "ret is %s" % ret
            if ret.pNext.v() != 0:
                return ret
            else:
                #print "ret.pNext is NULL, returning NULL"
                return None
        else:
            #print "child != NULL and only one flag"
            if child.flags & TP_FLAGS['TP_FIRST']:
                #print "child has TP_FIRST, child = child.pNext => %s = %s" % (child, child.pNext)
                child = child.pNext
            cur = child.pChild
            #print "cur = %s" % child.pChild
            if cur.v() == 0: return child
            while cur.flags & TP_FLAGS['TP_FIRST']:
                #print "%s is TP_FIRST, continuing to %s" % (cur, cur.pChild)
                child = cur
                cur = child.pChild
                if cur.v() == 0:
                    #print "cur = child.pChild was NULL, returning child = %s" % child
                    return child
            return child

    def get_parent(self):
        if self.flags & TP_FLAGS['TP_LAST']:
            return self.pNext
        else:
            return self.pNext.pNext
    
    def get_cp(self):
        """
        Get the offset into the markup text buffer where
        this TreePos's text lives.

        Implementation of mshtml!CTreePos::GetCp.
        """

        isFirst = self.flags & TP_FLAGS['TP_FIRST']
        nch = self.nchLeft
        cur = self.get_parent()

        while cur.v() != 0:
            if not isFirst:
                if not cur.flags & (TP_FLAGS['TP_BEGIN'] |
                                    TP_FLAGS['TP_END']):
                    if cur.flags & TP_FLAGS['TP_TEXT']:
                        nch += cur.get_textlen()
                else:
                    if cur.flags & TP_FLAGS['TP_TXTINFO']:
                        nch += 1
                nch += cur.nchLeft
            isFirst = cur.flags & TP_FLAGS['TP_FIRST']
            cur = cur.get_parent()

        return nch

    def get_textlen(self):
        """
        Return the text length of this TreePos.

        Returns the number of characters in this TreePos. The
        TreePos *must* be a text node (check the TP_TEXT flag).
        """

        return self.txtFlags & 0x3FFFFFF

    def has_elem(self):
        return bool(self.get_treenode())    

    def is_text(self):
        return bool(self.flags & TP_FLAGS['TP_TEXT'])

    def is_ptr(self):
        return bool(self.flags & TP_FLAGS['TP_PTR'])

    def elem_tag(self):
        tn = self.get_treenode()
        if not tn: return "[disconnected]"
        return tn.pElement.tagstr(self.flags & TP_FLAGS['TP_END'])

class CElement(Object):
    hasMembers = True
    name = "CElement"
    
    def __new__(typ, *args, **kwargs):
        obj = object.__new__(typ)
        return obj

    def name(self):
        return TAGDESC[self.tagDescIdx][0]

    def tagstr(self, end=False):
        name = self.name()
        if not name:
            return ""
        else:
            return "<%s%s>" % ("/" if end else "", self.name())

    def get_flags(self):
        return TAGDESC[self.tagDescIdx][1]

    # Gets the CHtmlParseClass pointer
    # This apparently is set dynamically, so we'll
    # need to get it from the in-memory g_atagdesc :(
    def get_hpc(self):
        return TAGDESC[self.tagDescIdx][1]

    def is_block(self):
        return bool(self.get_flags() & TAGDESC_FLAGS['BLOCK_TAG'])

    def has_end(self):
        val = HASEND[self.tagDescIdx]
        if self.name() == "META": print val
        if not val:
            return False
        if self.get_flags() & TAGDESC_FLAGS['HASNOEND']:
            return False
        return True
    
    def needs_nbsp(self):
        return self.get_flags() & TAGDESC_FLAGS['NEEDSNBSP']

class CArrayBase(Object):
    hasMembers = True
    name = "CArrayBase"
    
    def __new__(typ, *args, **kwargs):
        obj = object.__new__(typ)
        return obj

    def get(self, i, typ):
        """Get the ith element as an object of type typ."""
        addr = self.pStart.v() + (self.elemSize * i)
        return Object(typ, addr, self.vm, profile=self.profile)
