from forensics.object2 import Object
from forensics.object import *
from forensics.win32.lists import list_entry
from forensics.win32.tasks import process_list
from distorm import Decode, Decode32Bits

ssdt_types = {
  '_SERVICE_DESCRIPTOR_TABLE' : [ 0x40, {
    'Descriptors' : [0x0, ['array', 4, ['_SERVICE_DESCRIPTOR_ENTRY']]],
} ],
  '_SERVICE_DESCRIPTOR_ENTRY' : [ 0x10, {
    'KiServiceTable' : [0x0, ['pointer', ['void']]],
    'CounterBaseTable' : [0x4, ['pointer', ['unsigned long']]],
    'ServiceLimit' : [0x8, ['long']],
    'ArgumentTable' : [0xc, ['pointer', ['unsigned char']]],
} ],
}

NtUserRegisterWindowMessage_idx = 0x11ED

def get_threads(proc, types):
    return list_entry(proc.vm, types, proc.profile, 
                      proc.ThreadListHead.v(), "_ETHREAD",
                      fieldname="ThreadListEntry")

def find_gui_thread(addr_space, types, symtab, prof):
    pslist = process_list(addr_space, types, symtab)
    for p in pslist:
        proc = Object("_EPROCESS", p, addr_space, profile=prof)
        if not proc.vm: continue
        win = None
        for thrd in get_threads(proc, types):
            if thrd.Tcb.Win32Thread.v() != 0:
                return thrd

def find_RegisterWindowMessage(thrd):
    ServiceTable = thrd.Tcb.ServiceTable.v()
    # Start using thrd.vm here so we can read win32k session memory
    entry = (NtUserRegisterWindowMessage_idx & 0xF000) >> 12
    idx = NtUserRegisterWindowMessage_idx & 0x0FFF
    w32table_base = read_obj(thrd.vm, ssdt_types,
                        ['_SERVICE_DESCRIPTOR_TABLE',
                         'Descriptors', entry, 'KiServiceTable'],
                        ServiceTable)
    fn_off = w32table_base + (idx * 4)
    return read_value(thrd.vm, 'pointer', fn_off)

# Finds the second-to-last call
# call    _UserAddAtom@8  ; UserAddAtom(x,x)
# movzx   eax, ax
# call    __SEH_epilog
# retn    4
def find_UserAddAtom(addr_space, reg_addr):
    fn_data = addr_space.read(reg_addr, 0x1000)
    asm = Decode(reg_addr, fn_data, Decode32Bits)
    dis = [ a[2] for a in asm ]
    for i,d in enumerate(dis):
        if d.startswith("RET"): break

    calls = 0
    for j in range(i, -1, -1):
        if dis[j].startswith("CALL"):
            if calls == 0:
                calls += 1
            else:
                return int(dis[j].split()[1], 0)

# Finds the fifth PUSH
# mov     edi, edi
# push    ebp
# mov     ebp, esp
# push    ecx
# and     [ebp+var_4], 0
# lea     eax, [ebp+var_4]
# push    eax
# push    [ebp+arg_0]
# push    _UserAtomTableHandle
def find_UserAtomTableHandle(addr_space, uaa_addr):
    fn_data = addr_space.read(uaa_addr, 0x1000)
    asm = Decode(uaa_addr, fn_data, Decode32Bits)
    dis = [ a[2] for a in asm ]
    pushes = 0
    for d in dis:
        if d.startswith("PUSH"):
            if pushes == 4:
                return int(d.split()[2][1:-1], 0)
            else:
                pushes += 1

def find_atomtable(addr_space, types, symtab, prof):
    thrd = find_gui_thread(addr_space, types, symtab, prof)
    reg_addr = find_RegisterWindowMessage(thrd)
    if not reg_addr: return None
    uaa_addr = find_UserAddAtom(thrd.vm, reg_addr)
    if not uaa_addr: return None
    return find_UserAtomTableHandle(thrd.vm, uaa_addr)
