x86-64 / Windows 下正确的上下文切换

问题描述 投票:0回答:1

我正在为 x86-64 实现我自己的光纤库。其部分原因是缺乏跨平台的标准上下文切换(GCC/Linux 有 makecontext,它采用 void *s 作为可变参数,而 Windows 有其 Fiber API,它采用 1 void * arg)以及 API 设计和实现方面的学习练习。在我的 API 中,协程函数需要 2 个参数:协程上下文和 void * 参数,所以我正在学习它是如何工作的。我将从调用 API 开始,即 C。

struct win64_mcontext {
  U64 rdi, rsi, rbx, rbp, r12, r13, r14, r15;
  U64 rax, rsp, rip;
  U64 rcx, rdx, r8, r9;
};

struct coroutine {
  struct win64_mcontext caller;
  struct win64_mcontext callee;
  U32 state;
};

void coprepare(struct coroutine **co,
           void *stack, U64 stack_size, cofunc_t func)
{
  *co = malloc(sizeof **co); /* TODO: replace with something cheaper */
  _coprepare(&(*co)->caller, &(*co)->callee, stack, stack_size, func);
}

void coenter(struct coroutine *co, void *enter_arg)
{
   _coenter(&co->caller, &co->callee, enter_arg);
}

void coyield(struct coroutine *co, void *yield_arg)
{
  _coyield(&co->callee, &co->caller, yield_arg);
}

int  coresume(struct coroutine *co)
{
  _coresume(&co->caller, &co->callee);
  return 0; /* punt this for now */
}

这是驱动整个事情的组件。 _coenter、_coyield 和 _coresume 均作为

jmp __cotransfer

实现
;;; _coprepare(struct win64_mcontext *old, struct win64_mcontext *new,
;;;            void *stack, U64 stack_size,
;;;            cofunc_t func);
;;; RCX     -> old
;;; RDX     -> new
;;; R8      -> stack
;;; R9      -> stack_size
;;; RSP + ? -> func
_coprepare proc
    ;; save non-volatile GPRs in 'old'
    mov [RCX + OFF_RSI], RSI
    mov [RCX + OFF_RDI], RDI
    mov [RCX + OFF_RBP], RBP
    mov [RCX + OFF_RBX], RBX
    mov [RCX + OFF_R12], R12
    mov [RCX + OFF_R13], R13
    mov [RCX + OFF_R14], R14
    mov [RCX + OFF_R15], R15

    ;; save stack frame info in 'old'
    mov R10, RSP
    mov R11, OFFSET _coyield

    mov [RCX + OFF_RSP], R10
    mov [RCX + OFF_RIP], R11

    ;; init non-volatile GPRs in 'new'
    lea R10, [R8 + R9]       ; new RSP, = stack + stack_size
    lea R11, [RBP - 32]  ; load func

    xor EAX, EAX
    mov [RDX + OFF_RSI], RAX
    mov [RDX + OFF_RDI], RAX
    mov [RDX + OFF_RBX], RAX
    mov [RDX + OFF_RBP], R10
    mov [RDX + OFF_R12], RAX
    mov [RDX + OFF_R13], RAX
    mov [RDX + OFF_R14], RAX
    mov [RDX + OFF_R15], RAX

    mov [RDX + OFF_RSP], R10
    mov [RDX + OFF_RIP], R11

    ret
_coprepare endp

;;; __cotransfer(struct win64_context *old, struct win64_mcontext *new, void *trans_arg);
;;; RCX : old
;;; RDX : new
;;; R8  : trans_arg
__cotransfer proc
    ;; save non-volatile GPRs
    mov [RCX + OFF_RSI], RSI
    mov [RCX + OFF_RDI], RDI
    mov [RCX + OFF_RBX], RBX
    mov [RCX + OFF_RBP], RBP
    mov [RCX + OFF_R12], R12
    mov [RCX + OFF_R13], R13
    mov [RCX + OFF_R14], R14
    mov [RCX + OFF_R15], R15

    ;; save argument GPRs
    mov [RCX + OFF_RCX], RCX
    mov [RCX + OFF_RDX], RDX
    mov [RCX + OFF_R8], R8
    mov [RCX + OFF_R9], R9

    ;; save stack frame info
    lea R10, [RSP - 8]  ; save SP, exclude IP
    lea R11, [RSP]      ; save IP

    mov [RCX + OFF_RSP], R10
    mov [RCX + OFF_RIP], R11

    ;; switch stacks
    mov RAX, RSP
    mov RSP, [RDX + OFF_RSP]
    mov [RCX + OFF_RSP], RAX

    ;; load non-volatile GPRs
    mov RSI, [RDX + OFF_RSI]
    mov RDI, [RDX + OFF_RDI]
    mov RBX, [RDX + OFF_RBX]
    mov RBP, [RDX + OFF_RBP]
    mov R12, [RDX + OFF_R12]
    mov R13, [RDX + OFF_R13]
    mov R14, [RDX + OFF_R14]
    mov R15, [RDX + OFF_R15]

    ;; load argument registers
    mov R10, RCX
    mov R11, RDX

    mov RCX, [R11 + OFF_RCX]
    mov RDX, [R11 + OFF_RDX]
    mov R8,  [R11 + OFF_R8]
    mov R9,  [R11 + OFF_R9]

    ; push new return address
    mov RAX, [R11 + OFF_RIP]
    push RAX        
    ret ; jump to new return address
__cotransfer endp

我错过了什么吗?它总是在 __cotransfer 的某个地方崩溃。我不知道在调试过程中我最终在哪里,所以我一定做错了什么,比如破坏 BP、IP 或 SP。我丢失了堆栈,因为我切换了它,而 MSVC 无法弄清楚我们现在在哪里。我很迷茫,我需要有此类事情经验的人的帮助。

c windows assembly x86-64
1个回答
2
投票

根据设计,我们需要 5 个例程

  • 用于将当前线程转换为纤程(为纤程分配纤程上下文) 当前线程)并将纤程转换回线程(释放此上下文)
  • 用于创建和删除新的 Fiber 上下文
  • 以及用于切换所选光纤上下文的单个例程

在 Fiber 上下文中,我们需要保存 Fiber 的当前堆栈指针和指向其分配堆栈的指针(当我们决定删除 Fiber 时,释放它)。从 Windows 视图 - 我们还必须为每个 Fiber 拥有自己的

NT_TIB
结构,并在切换 Fiber 上下文时切换
StackBase
StackLimit
等。否则,异常句柄和堆栈中的附加分配位置将不起作用(将保留内存转换为提交内存并移动保护页)。因此
NT_TIB
还需要保存在 Fiber 上下文中。 Fiber 的寄存器我们可以直接保存在堆栈中。

Windows 的最小实现(当然这里存在现成的实现)可以如下所示:

c/c++部分:

typedef struct _INITIAL_TEB
{
    PVOID OldStackBase;
    PVOID OldStackLimit;
    PVOID StackBase;
    PVOID StackLimit;
    PVOID StackAllocationBase;
} INITIAL_TEB, *PINITIAL_TEB;

extern "C"
NTSYSAPI 
NTSTATUS 
NTAPI RtlFreeUserStack  (   _In_ PVOID      AllocationBase  );

extern "C"
NTSYSAPI 
NTSTATUS 
NTAPI   
RtlCreateUserStack (
                    _In_opt_ SIZE_T CommittedStackSize, 
                    _In_opt_ SIZE_T MaximumStackSize, 
                    _In_opt_ ULONG_PTR ZeroBits, 
                    _In_ SIZE_T PageSize, 
                    _In_ ULONG_PTR ReserveAlignment, 
                    _Out_ PINITIAL_TEB InitialTeb);

struct FIBER_CONTEXT
{
    NT_TIB Tib;
    PVOID StackPointer;
    PVOID StackAllocationBase;
};

extern "C"
{
    DWORD _G_DeallocationStack_ofs;
    void __cdecl FiberStart();
    void __fastcall SwitchToContext(FIBER_CONTEXT* ctx);
}

ULONG Get_DeallocationStack_offset(ULONG n = 0x1000)
{
    if (PNT_TIB FakeTeb = (PNT_TIB)LocalAlloc(LMEM_FIXED|LMEM_ZEROINIT, n * sizeof(PVOID)))
    {
        PNT_TIB Tib = (PNT_TIB)NtCurrentTeb();
        PVOID StackBase = Tib->StackBase;

        Tib->Self = FakeTeb;

        void** ppv = (void**)FakeTeb + n;
        FakeTeb->StackBase = StackBase;

        ULONG_PTR Low, Hi;
        do 
        {
            *--ppv = FakeTeb;
            GetCurrentThreadStackLimits(&Low, &Hi);

            if ((void*)Hi != StackBase)
            {
                break;
            }

            if ((void*)Low == FakeTeb)
            {
                _G_DeallocationStack_ofs = (n - 1)* sizeof(PVOID);
                break;
            }
        } while (--n);

        Tib->Self = Tib;

        LocalFree(FakeTeb);
    }

    return _G_DeallocationStack_ofs;
}

FIBER_CONTEXT* MyConvertThreadToFiber()
{
    if (!_G_DeallocationStack_ofs && !Get_DeallocationStack_offset())
    {
        return 0;
    }

    if (FIBER_CONTEXT* ctx = new FIBER_CONTEXT)
    {
        ((NT_TIB*)NtCurrentTeb())->FiberData = ctx;
        return ctx;
    }

    return 0;
}

void MyConvertFiberToThread()
{
    if (FIBER_CONTEXT* ctx = (FIBER_CONTEXT*)((NT_TIB*)NtCurrentTeb())->FiberData)
    {
        delete ctx;
        ((NT_TIB*)NtCurrentTeb())->FiberData = 0;
    }
}

FIBER_CONTEXT* WINAPI MyCreateFiber(
                          __in      SIZE_T dwStackSize,
                          __in      PFIBER_START_ROUTINE lpStartAddress,
                          __in_opt  PVOID lpParameter
                          )
{
    INITIAL_TEB InitialTeb;
    NTSTATUS status = RtlCreateUserStack(0, dwStackSize, 0, 0x1000, 0x10000, &InitialTeb);
    
    if (0 <= status)
    {
        if (FIBER_CONTEXT* ctx = new FIBER_CONTEXT)
        {
            ctx->StackAllocationBase = InitialTeb.StackAllocationBase;
            NT_TIB* Tib = ((NT_TIB*)NtCurrentTeb());
            
            ctx->Tib.ArbitraryUserPointer = 0;
            ctx->Tib.ExceptionList = 0;
            ctx->Tib.FiberData = ctx;
            ctx->Tib.StackBase = InitialTeb.StackBase;
            ctx->Tib.StackLimit = InitialTeb.StackLimit;
            ctx->Tib.SubSystemTib = Tib->SubSystemTib;
            ctx->Tib.Self = Tib->Self;

            void** StackBase = (void**)InitialTeb.StackBase;
            ctx->StackPointer = StackBase - (4 + 1 + 8);
            StackBase[-3] = lpStartAddress;
            StackBase[-4] = lpParameter;
            StackBase[-5] = FiberStart;
            return ctx;
        }
        RtlFreeUserStack(InitialTeb.StackAllocationBase);
    }

    return 0;
}

VOID WINAPI MyDeleteFiber(FIBER_CONTEXT* ctx)
{
    RtlFreeUserStack(ctx->StackAllocationBase);
    delete ctx;
}
 

asm(针对 x64)实现部分:

NT_TIB STRUCT
    ExceptionList DQ ?
    StackBase DQ ?
    StackLimit DQ ?
    SubSystemTib DQ ?
    FiberData DQ ?
    ArbitraryUserPointer DQ ?
    Self DQ ?
NT_TIB ENDS

FIBER_CONTEXT STRUCT
    Tib NT_TIB <?>
    StackPointer DQ ?
    StackAllocationBase DQ ?
FIBER_CONTEXT ENDS

extern __imp_ExitThread:QWORD
extern _G_DeallocationStack_ofs: DWORD

.code

FiberStart proc
    mov     rcx,[rsp]
    call    qword ptr [rsp + 8]
    mov     ecx,eax
    call    [__imp_ExitThread]
FiberStart endp

SwitchToContext proc
    push    r15
    push    r14
    push    r13
    push    r12
    push    rsi
    push    rdi
    push    rbx
    push    rbp
    
    mov     rax,gs:[NT_TIB.Self]    ; rax -> NT_TIB
    
    mov     rdx,[rax + NT_TIB.FiberData]    ; current fiber data
    
    mov     [rdx + FIBER_CONTEXT.StackPointer],rsp  ; save current rsp
    mov     rsp,[rcx + FIBER_CONTEXT.StackPointer]  ; set new rsp
    
    mov     rbp,[rcx + FIBER_CONTEXT.StackAllocationBase]
    mov     ebx,_G_DeallocationStack_ofs
    mov     [rax + rbx], rbp

    ; save NT_TIB
    lea     rdi,[rdx + FIBER_CONTEXT.Tib]
    mov     rsi,rax
    mov     rdx,rcx
    mov     rcx, SIZEOF NT_TIB / SIZEOF QWORD
    rep     movsq

    ; set NT_TIB
    mov     rdi,rax
    lea     rsi,[rdx + FIBER_CONTEXT.Tib]
    mov     rcx, SIZEOF NT_TIB / SIZEOF QWORD
    rep     movsq
    
    pop     rbp
    pop     rbx
    pop     rdi
    pop     rsi
    pop     r12
    pop     r13
    pop     r14
    pop     r15
    ret
SwitchToContext endp

END

以及使用示例:

struct FCTX 
{
    FIBER_CONTEXT* MainFiber, *WorkFiber;
    PCSTR sz;
};

void WINAPI FiberProc(FCTX* ctx)
{
    for (;;)
    {
        DbgPrint("%s\n", ctx->sz);
        SwitchToContext(ctx->MainFiber);
    }
}

void test()
{
    FCTX ctx;
    if (ctx.MainFiber = MyConvertThreadToFiber())
    {
        if (ctx.WorkFiber = MyCreateFiber(0, (PFIBER_START_ROUTINE)FiberProc, &ctx))
        {
            ctx.sz = "task #1";
            SwitchToContext(ctx.WorkFiber);
            ctx.sz = "task #2";
            SwitchToContext(ctx.WorkFiber);
            MyDeleteFiber(ctx.WorkFiber);
        }
        MyConvertFiberToThread();
    }
}
© www.soinside.com 2019 - 2024. All rights reserved.