我正在为 x86-64 实现我自己的光纤库。其部分原因是缺乏跨平台的标准上下文切换(GCC/Linux 有 makecontext,它采用 void *s 作为可变参数,而 Windows 有其 Fiber API,它采用 1 void * arg)以及 API 设计和实现方面的学习练习。在我的 API 中,协程函数需要 2 个参数:协程上下文和 void * 参数,所以我正在学习它是如何工作的。我将从调用 API 开始,即 C。
struct win64_mcontext {
U64 rdi, rsi, rbx, rbp, r12, r13, r14, r15;
U64 rax, rsp, rip;
U64 rcx, rdx, r8, r9;
};
struct coroutine {
struct win64_mcontext caller;
struct win64_mcontext callee;
U32 state;
};
void coprepare(struct coroutine **co,
void *stack, U64 stack_size, cofunc_t func)
{
*co = malloc(sizeof **co); /* TODO: replace with something cheaper */
_coprepare(&(*co)->caller, &(*co)->callee, stack, stack_size, func);
}
void coenter(struct coroutine *co, void *enter_arg)
{
_coenter(&co->caller, &co->callee, enter_arg);
}
void coyield(struct coroutine *co, void *yield_arg)
{
_coyield(&co->callee, &co->caller, yield_arg);
}
int coresume(struct coroutine *co)
{
_coresume(&co->caller, &co->callee);
return 0; /* punt this for now */
}
这是驱动整个事情的组件。 _coenter、_coyield 和 _coresume 均作为
jmp __cotransfer
实现
;;; _coprepare(struct win64_mcontext *old, struct win64_mcontext *new,
;;; void *stack, U64 stack_size,
;;; cofunc_t func);
;;; RCX -> old
;;; RDX -> new
;;; R8 -> stack
;;; R9 -> stack_size
;;; RSP + ? -> func
_coprepare proc
;; save non-volatile GPRs in 'old'
mov [RCX + OFF_RSI], RSI
mov [RCX + OFF_RDI], RDI
mov [RCX + OFF_RBP], RBP
mov [RCX + OFF_RBX], RBX
mov [RCX + OFF_R12], R12
mov [RCX + OFF_R13], R13
mov [RCX + OFF_R14], R14
mov [RCX + OFF_R15], R15
;; save stack frame info in 'old'
mov R10, RSP
mov R11, OFFSET _coyield
mov [RCX + OFF_RSP], R10
mov [RCX + OFF_RIP], R11
;; init non-volatile GPRs in 'new'
lea R10, [R8 + R9] ; new RSP, = stack + stack_size
lea R11, [RBP - 32] ; load func
xor EAX, EAX
mov [RDX + OFF_RSI], RAX
mov [RDX + OFF_RDI], RAX
mov [RDX + OFF_RBX], RAX
mov [RDX + OFF_RBP], R10
mov [RDX + OFF_R12], RAX
mov [RDX + OFF_R13], RAX
mov [RDX + OFF_R14], RAX
mov [RDX + OFF_R15], RAX
mov [RDX + OFF_RSP], R10
mov [RDX + OFF_RIP], R11
ret
_coprepare endp
;;; __cotransfer(struct win64_context *old, struct win64_mcontext *new, void *trans_arg);
;;; RCX : old
;;; RDX : new
;;; R8 : trans_arg
__cotransfer proc
;; save non-volatile GPRs
mov [RCX + OFF_RSI], RSI
mov [RCX + OFF_RDI], RDI
mov [RCX + OFF_RBX], RBX
mov [RCX + OFF_RBP], RBP
mov [RCX + OFF_R12], R12
mov [RCX + OFF_R13], R13
mov [RCX + OFF_R14], R14
mov [RCX + OFF_R15], R15
;; save argument GPRs
mov [RCX + OFF_RCX], RCX
mov [RCX + OFF_RDX], RDX
mov [RCX + OFF_R8], R8
mov [RCX + OFF_R9], R9
;; save stack frame info
lea R10, [RSP - 8] ; save SP, exclude IP
lea R11, [RSP] ; save IP
mov [RCX + OFF_RSP], R10
mov [RCX + OFF_RIP], R11
;; switch stacks
mov RAX, RSP
mov RSP, [RDX + OFF_RSP]
mov [RCX + OFF_RSP], RAX
;; load non-volatile GPRs
mov RSI, [RDX + OFF_RSI]
mov RDI, [RDX + OFF_RDI]
mov RBX, [RDX + OFF_RBX]
mov RBP, [RDX + OFF_RBP]
mov R12, [RDX + OFF_R12]
mov R13, [RDX + OFF_R13]
mov R14, [RDX + OFF_R14]
mov R15, [RDX + OFF_R15]
;; load argument registers
mov R10, RCX
mov R11, RDX
mov RCX, [R11 + OFF_RCX]
mov RDX, [R11 + OFF_RDX]
mov R8, [R11 + OFF_R8]
mov R9, [R11 + OFF_R9]
; push new return address
mov RAX, [R11 + OFF_RIP]
push RAX
ret ; jump to new return address
__cotransfer endp
我错过了什么吗?它总是在 __cotransfer 的某个地方崩溃。我不知道在调试过程中我最终在哪里,所以我一定做错了什么,比如破坏 BP、IP 或 SP。我丢失了堆栈,因为我切换了它,而 MSVC 无法弄清楚我们现在在哪里。我很迷茫,我需要有此类事情经验的人的帮助。
根据设计,我们需要 5 个例程
在 Fiber 上下文中,我们需要保存 Fiber 的当前堆栈指针和指向其分配堆栈的指针(当我们决定删除 Fiber 时,释放它)。从 Windows 视图 - 我们还必须为每个 Fiber 拥有自己的
NT_TIB
结构,并在切换 Fiber 上下文时切换 StackBase
、StackLimit
等。否则,异常句柄和堆栈中的附加分配位置将不起作用(将保留内存转换为提交内存并移动保护页)。因此 NT_TIB
还需要保存在 Fiber 上下文中。 Fiber 的寄存器我们可以直接保存在堆栈中。
Windows 的最小实现(当然这里存在现成的实现)可以如下所示:
c/c++部分:
typedef struct _INITIAL_TEB
{
PVOID OldStackBase;
PVOID OldStackLimit;
PVOID StackBase;
PVOID StackLimit;
PVOID StackAllocationBase;
} INITIAL_TEB, *PINITIAL_TEB;
extern "C"
NTSYSAPI
NTSTATUS
NTAPI RtlFreeUserStack ( _In_ PVOID AllocationBase );
extern "C"
NTSYSAPI
NTSTATUS
NTAPI
RtlCreateUserStack (
_In_opt_ SIZE_T CommittedStackSize,
_In_opt_ SIZE_T MaximumStackSize,
_In_opt_ ULONG_PTR ZeroBits,
_In_ SIZE_T PageSize,
_In_ ULONG_PTR ReserveAlignment,
_Out_ PINITIAL_TEB InitialTeb);
struct FIBER_CONTEXT
{
NT_TIB Tib;
PVOID StackPointer;
PVOID StackAllocationBase;
};
extern "C"
{
DWORD _G_DeallocationStack_ofs;
void __cdecl FiberStart();
void __fastcall SwitchToContext(FIBER_CONTEXT* ctx);
}
ULONG Get_DeallocationStack_offset(ULONG n = 0x1000)
{
if (PNT_TIB FakeTeb = (PNT_TIB)LocalAlloc(LMEM_FIXED|LMEM_ZEROINIT, n * sizeof(PVOID)))
{
PNT_TIB Tib = (PNT_TIB)NtCurrentTeb();
PVOID StackBase = Tib->StackBase;
Tib->Self = FakeTeb;
void** ppv = (void**)FakeTeb + n;
FakeTeb->StackBase = StackBase;
ULONG_PTR Low, Hi;
do
{
*--ppv = FakeTeb;
GetCurrentThreadStackLimits(&Low, &Hi);
if ((void*)Hi != StackBase)
{
break;
}
if ((void*)Low == FakeTeb)
{
_G_DeallocationStack_ofs = (n - 1)* sizeof(PVOID);
break;
}
} while (--n);
Tib->Self = Tib;
LocalFree(FakeTeb);
}
return _G_DeallocationStack_ofs;
}
FIBER_CONTEXT* MyConvertThreadToFiber()
{
if (!_G_DeallocationStack_ofs && !Get_DeallocationStack_offset())
{
return 0;
}
if (FIBER_CONTEXT* ctx = new FIBER_CONTEXT)
{
((NT_TIB*)NtCurrentTeb())->FiberData = ctx;
return ctx;
}
return 0;
}
void MyConvertFiberToThread()
{
if (FIBER_CONTEXT* ctx = (FIBER_CONTEXT*)((NT_TIB*)NtCurrentTeb())->FiberData)
{
delete ctx;
((NT_TIB*)NtCurrentTeb())->FiberData = 0;
}
}
FIBER_CONTEXT* WINAPI MyCreateFiber(
__in SIZE_T dwStackSize,
__in PFIBER_START_ROUTINE lpStartAddress,
__in_opt PVOID lpParameter
)
{
INITIAL_TEB InitialTeb;
NTSTATUS status = RtlCreateUserStack(0, dwStackSize, 0, 0x1000, 0x10000, &InitialTeb);
if (0 <= status)
{
if (FIBER_CONTEXT* ctx = new FIBER_CONTEXT)
{
ctx->StackAllocationBase = InitialTeb.StackAllocationBase;
NT_TIB* Tib = ((NT_TIB*)NtCurrentTeb());
ctx->Tib.ArbitraryUserPointer = 0;
ctx->Tib.ExceptionList = 0;
ctx->Tib.FiberData = ctx;
ctx->Tib.StackBase = InitialTeb.StackBase;
ctx->Tib.StackLimit = InitialTeb.StackLimit;
ctx->Tib.SubSystemTib = Tib->SubSystemTib;
ctx->Tib.Self = Tib->Self;
void** StackBase = (void**)InitialTeb.StackBase;
ctx->StackPointer = StackBase - (4 + 1 + 8);
StackBase[-3] = lpStartAddress;
StackBase[-4] = lpParameter;
StackBase[-5] = FiberStart;
return ctx;
}
RtlFreeUserStack(InitialTeb.StackAllocationBase);
}
return 0;
}
VOID WINAPI MyDeleteFiber(FIBER_CONTEXT* ctx)
{
RtlFreeUserStack(ctx->StackAllocationBase);
delete ctx;
}
asm(针对 x64)实现部分:
NT_TIB STRUCT
ExceptionList DQ ?
StackBase DQ ?
StackLimit DQ ?
SubSystemTib DQ ?
FiberData DQ ?
ArbitraryUserPointer DQ ?
Self DQ ?
NT_TIB ENDS
FIBER_CONTEXT STRUCT
Tib NT_TIB <?>
StackPointer DQ ?
StackAllocationBase DQ ?
FIBER_CONTEXT ENDS
extern __imp_ExitThread:QWORD
extern _G_DeallocationStack_ofs: DWORD
.code
FiberStart proc
mov rcx,[rsp]
call qword ptr [rsp + 8]
mov ecx,eax
call [__imp_ExitThread]
FiberStart endp
SwitchToContext proc
push r15
push r14
push r13
push r12
push rsi
push rdi
push rbx
push rbp
mov rax,gs:[NT_TIB.Self] ; rax -> NT_TIB
mov rdx,[rax + NT_TIB.FiberData] ; current fiber data
mov [rdx + FIBER_CONTEXT.StackPointer],rsp ; save current rsp
mov rsp,[rcx + FIBER_CONTEXT.StackPointer] ; set new rsp
mov rbp,[rcx + FIBER_CONTEXT.StackAllocationBase]
mov ebx,_G_DeallocationStack_ofs
mov [rax + rbx], rbp
; save NT_TIB
lea rdi,[rdx + FIBER_CONTEXT.Tib]
mov rsi,rax
mov rdx,rcx
mov rcx, SIZEOF NT_TIB / SIZEOF QWORD
rep movsq
; set NT_TIB
mov rdi,rax
lea rsi,[rdx + FIBER_CONTEXT.Tib]
mov rcx, SIZEOF NT_TIB / SIZEOF QWORD
rep movsq
pop rbp
pop rbx
pop rdi
pop rsi
pop r12
pop r13
pop r14
pop r15
ret
SwitchToContext endp
END
以及使用示例:
struct FCTX
{
FIBER_CONTEXT* MainFiber, *WorkFiber;
PCSTR sz;
};
void WINAPI FiberProc(FCTX* ctx)
{
for (;;)
{
DbgPrint("%s\n", ctx->sz);
SwitchToContext(ctx->MainFiber);
}
}
void test()
{
FCTX ctx;
if (ctx.MainFiber = MyConvertThreadToFiber())
{
if (ctx.WorkFiber = MyCreateFiber(0, (PFIBER_START_ROUTINE)FiberProc, &ctx))
{
ctx.sz = "task #1";
SwitchToContext(ctx.WorkFiber);
ctx.sz = "task #2";
SwitchToContext(ctx.WorkFiber);
MyDeleteFiber(ctx.WorkFiber);
}
MyConvertFiberToThread();
}
}