Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WOW64 backend code invalidation fixes #3589

Merged
merged 4 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions Source/Windows/Common/InvalidationTracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,27 @@

#include <FEXCore/Utils/LogManager.h>
#include <FEXCore/Utils/TypeDefines.h>
#include <FEXCore/Utils/SignalScopeGuards.h>
#include <FEXCore/Core/Context.h>
#include <FEXCore/Debug/InternalThreadState.h>
#include "InvalidationTracker.h"
#include <windef.h>
#include <winternl.h>

namespace FEX::Windows {
void InvalidationTracker::HandleMemoryProtectionNotification(FEXCore::Core::InternalThreadState* Thread, uint64_t Address, uint64_t Size,
ULONG Prot) {
InvalidationTracker::InvalidationTracker(FEXCore::Context::Context& CTX, const std::unordered_map<DWORD, FEXCore::Core::InternalThreadState*>& Threads)
: CTX {CTX}
, Threads {Threads} {}

void InvalidationTracker::HandleMemoryProtectionNotification(uint64_t Address, uint64_t Size, ULONG Prot) {
const auto AlignedBase = Address & FEXCore::Utils::FEX_PAGE_MASK;
const auto AlignedSize = (Address - AlignedBase + Size + FEXCore::Utils::FEX_PAGE_SIZE - 1) & FEXCore::Utils::FEX_PAGE_MASK;

if (Prot & (PAGE_EXECUTE | PAGE_EXECUTE_READ | PAGE_EXECUTE_READWRITE)) {
Thread->CTX->InvalidateGuestCodeRange(Thread, AlignedBase, AlignedSize);
std::scoped_lock Lock(CTX.GetCodeInvalidationMutex());
for (auto Thread : Threads) {
CTX.InvalidateGuestCodeRange(Thread.second, AlignedBase, AlignedSize);
}
}

if (Prot & PAGE_EXECUTE_READWRITE) {
Expand All @@ -28,26 +35,43 @@ void InvalidationTracker::HandleMemoryProtectionNotification(FEXCore::Core::Inte
}
}

void InvalidationTracker::InvalidateContainingSection(FEXCore::Core::InternalThreadState* Thread, uint64_t Address, bool Free) {
void InvalidationTracker::InvalidateContainingSection(uint64_t Address, bool Free) {
MEMORY_BASIC_INFORMATION Info;
if (NtQueryVirtualMemory(NtCurrentProcess(), reinterpret_cast<void*>(Address), MemoryBasicInformation, &Info, sizeof(Info), nullptr)) {
return;
}

const auto SectionBase = reinterpret_cast<uint64_t>(Info.AllocationBase);
const auto SectionSize = reinterpret_cast<uint64_t>(Info.BaseAddress) + Info.RegionSize - reinterpret_cast<uint64_t>(Info.AllocationBase);
Thread->CTX->InvalidateGuestCodeRange(Thread, SectionBase, SectionSize);
auto SectionSize = reinterpret_cast<uint64_t>(Info.BaseAddress) + Info.RegionSize - SectionBase;

while (!NtQueryVirtualMemory(NtCurrentProcess(), reinterpret_cast<void*>(SectionBase + SectionSize), MemoryBasicInformation, &Info,
sizeof(Info), nullptr) &&
reinterpret_cast<uint64_t>(Info.AllocationBase) == SectionBase) {
SectionSize += Info.RegionSize;
}
{
std::scoped_lock Lock(CTX.GetCodeInvalidationMutex());
for (auto Thread : Threads) {
CTX.InvalidateGuestCodeRange(Thread.second, SectionBase, SectionSize);
}
}

if (Free) {
std::scoped_lock Lock(RWXIntervalsLock);
RWXIntervals.Remove({SectionBase, SectionBase + SectionSize});
}
}

void InvalidationTracker::InvalidateAlignedInterval(FEXCore::Core::InternalThreadState* Thread, uint64_t Address, uint64_t Size, bool Free) {
void InvalidationTracker::InvalidateAlignedInterval(uint64_t Address, uint64_t Size, bool Free) {
const auto AlignedBase = Address & FEXCore::Utils::FEX_PAGE_MASK;
const auto AlignedSize = (Address - AlignedBase + Size + FEXCore::Utils::FEX_PAGE_SIZE - 1) & FEXCore::Utils::FEX_PAGE_MASK;
Thread->CTX->InvalidateGuestCodeRange(Thread, AlignedBase, AlignedSize);

{
std::scoped_lock Lock(CTX.GetCodeInvalidationMutex());
for (auto Thread : Threads) {
CTX.InvalidateGuestCodeRange(Thread.second, AlignedBase, AlignedSize);
}
}

if (Free) {
std::scoped_lock Lock(RWXIntervalsLock);
Expand Down Expand Up @@ -75,7 +99,7 @@ void InvalidationTracker::ReprotectRWXIntervals(uint64_t Address, uint64_t Size)
} while (Address < End);
}

bool InvalidationTracker::HandleRWXAccessViolation(FEXCore::Core::InternalThreadState* Thread, uint64_t FaultAddress) {
bool InvalidationTracker::HandleRWXAccessViolation(uint64_t FaultAddress) {
const bool NeedsInvalidate = [&](uint64_t Address) {
std::unique_lock Lock(RWXIntervalsLock);
const bool Enclosed = RWXIntervals.Query(Address).Enclosed;
Expand All @@ -93,7 +117,10 @@ bool InvalidationTracker::HandleRWXAccessViolation(FEXCore::Core::InternalThread

if (NeedsInvalidate) {
// RWXIntervalsLock cannot be held during invalidation
Thread->CTX->InvalidateGuestCodeRange(Thread, FaultAddress & FEXCore::Utils::FEX_PAGE_MASK, FEXCore::Utils::FEX_PAGE_SIZE);
std::scoped_lock Lock(CTX.GetCodeInvalidationMutex());
for (auto Thread : Threads) {
CTX.InvalidateGuestCodeRange(Thread.second, FaultAddress & FEXCore::Utils::FEX_PAGE_MASK, FEXCore::Utils::FEX_PAGE_SIZE);
}
return true;
}
return false;
Expand Down
17 changes: 12 additions & 5 deletions Source/Windows/Common/InvalidationTracker.h
Original file line number Diff line number Diff line change
@@ -1,28 +1,35 @@
// SPDX-License-Identifier: MIT
// FIXME TODO put in cpp
#pragma once

#include "IntervalList.h"
#include <mutex>
#include <unordered_map>

namespace FEXCore::Core {
struct InternalThreadState;
}

namespace FEXCore::Context {
class Context;
}

namespace FEX::Windows {
/**
* @brief Handles SMC and regular code invalidation
*/
class InvalidationTracker {
public:
void HandleMemoryProtectionNotification(FEXCore::Core::InternalThreadState* Thread, uint64_t Address, uint64_t Size, ULONG Prot);
void InvalidateContainingSection(FEXCore::Core::InternalThreadState* Thread, uint64_t Address, bool Free);
void InvalidateAlignedInterval(FEXCore::Core::InternalThreadState* Thread, uint64_t Address, uint64_t Size, bool Free);
InvalidationTracker(FEXCore::Context::Context& CTX, const std::unordered_map<DWORD, FEXCore::Core::InternalThreadState*>& Threads);
void HandleMemoryProtectionNotification(uint64_t Address, uint64_t Size, ULONG Prot);
void InvalidateContainingSection(uint64_t Address, bool Free);
void InvalidateAlignedInterval(uint64_t Address, uint64_t Size, bool Free);
void ReprotectRWXIntervals(uint64_t Address, uint64_t Size);
bool HandleRWXAccessViolation(FEXCore::Core::InternalThreadState* Thread, uint64_t FaultAddress);
bool HandleRWXAccessViolation(uint64_t FaultAddress);

private:
IntervalList<uint64_t> RWXIntervals;
std::mutex RWXIntervalsLock;
FEXCore::Context::Context& CTX;
const std::unordered_map<DWORD, FEXCore::Core::InternalThreadState*>& Threads;
};
} // namespace FEX::Windows
2 changes: 0 additions & 2 deletions Source/Windows/Defs/ntdll.def
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,5 @@ EXPORTS
NtSuspendThread
NtGetContextThread
NtContinue
NtQueryVirtualMemory
NtProtectVirtualMemory
__wine_dbg_output
__wine_unix_call
61 changes: 37 additions & 24 deletions Source/Windows/WOW64/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ desc: Implements the WOW64 BT module API using FEXCore
#include <atomic>
#include <mutex>
#include <utility>
#include <unordered_set>
#include <unordered_map>
#include <ntstatus.h>
#include <windef.h>
#include <winternl.h>
Expand Down Expand Up @@ -94,11 +94,12 @@ fextl::unique_ptr<FEXCore::Context::Context> CTX;
fextl::unique_ptr<FEX::DummyHandlers::DummySignalDelegator> SignalDelegator;
fextl::unique_ptr<WowSyscallHandler> SyscallHandler;

FEX::Windows::InvalidationTracker InvalidationTracker;
std::optional<FEX::Windows::InvalidationTracker> InvalidationTracker;
std::optional<FEX::Windows::CPUFeatures> CPUFeatures;

std::mutex ThreadSuspendLock;
std::unordered_set<DWORD> InitializedWOWThreads; // Set of TIDs, `ThreadSuspendLock` must be locked when accessing
std::mutex ThreadCreationMutex;
// Map of TIDs to their FEX thread state, `ThreadCreationMutex` must be locked when accessing
std::unordered_map<DWORD, FEXCore::Core::InternalThreadState*> Threads;

std::pair<NTSTATUS, TLS> GetThreadTLS(HANDLE Thread) {
THREAD_BASIC_INFORMATION Info;
Expand Down Expand Up @@ -393,7 +394,7 @@ class WowSyscallHandler : public FEXCore::HLE::SyscallHandler, public FEXCore::A
}

void MarkGuestExecutableRange(FEXCore::Core::InternalThreadState* Thread, uint64_t Start, uint64_t Length) override {
InvalidationTracker.ReprotectRWXIntervals(Start, Length);
InvalidationTracker->ReprotectRWXIntervals(Start, Length);
}
};

Expand Down Expand Up @@ -422,14 +423,16 @@ void BTCpuProcessInit() {
CTX->SetSignalDelegator(SignalDelegator.get());
CTX->SetSyscallHandler(SyscallHandler.get());
CTX->InitCore();
InvalidationTracker.emplace(*CTX, Threads);
CPUFeatures.emplace(*CTX);
}

NTSTATUS BTCpuThreadInit() {
GetTLS().ThreadState() = CTX->CreateThread(0, 0);
auto* Thread = CTX->CreateThread(0, 0);
GetTLS().ThreadState() = Thread;

std::scoped_lock Lock(ThreadSuspendLock);
InitializedWOWThreads.emplace(GetCurrentThreadId());
std::scoped_lock Lock(ThreadCreationMutex);
Threads.emplace(GetCurrentThreadId(), Thread);
return STATUS_SUCCESS;
}

Expand All @@ -446,8 +449,8 @@ NTSTATUS BTCpuThreadTerm(HANDLE Thread) {
}

const auto ThreadTID = reinterpret_cast<uint64_t>(Info.ClientId.UniqueThread);
std::scoped_lock Lock(ThreadSuspendLock);
InitializedWOWThreads.erase(ThreadTID);
std::scoped_lock Lock(ThreadCreationMutex);
Threads.erase(ThreadTID);
}

CTX->DestroyThread(TLS.ThreadState());
Expand Down Expand Up @@ -550,10 +553,10 @@ NTSTATUS BTCpuSuspendLocalThread(HANDLE Thread, ULONG* Count) {
return Err;
}

std::scoped_lock Lock(ThreadSuspendLock);
std::scoped_lock Lock(ThreadCreationMutex);

// If the thread hasn't yet been initialized, suspend it without special handling as it wont yet have entered the JIT
if (!InitializedWOWThreads.contains(ThreadTID)) {
if (!Threads.contains(ThreadTID)) {
return NtSuspendThread(Thread, Count);
}

Expand Down Expand Up @@ -615,15 +618,22 @@ NTSTATUS BTCpuResetToConsistentState(EXCEPTION_POINTERS* Ptrs) {
if (Exception->ExceptionCode == EXCEPTION_ACCESS_VIOLATION) {
const auto FaultAddress = static_cast<uint64_t>(Exception->ExceptionInformation[1]);

if (InvalidationTracker.HandleRWXAccessViolation(GetTLS().ThreadState(), FaultAddress)) {
LogMan::Msg::DFmt("Handled self-modifying code: pc: {:X} fault: {:X}", Context->Pc, FaultAddress);
NtContinue(Context, FALSE);
}

if (Context::HandleSuspendInterrupt(Context, FaultAddress)) {
LogMan::Msg::DFmt("Resumed from suspend");
NtContinue(Context, FALSE);
}

bool HandledRWX = false;
if (GetTLS().ThreadState()) {
std::scoped_lock Lock(ThreadCreationMutex);
HandledRWX = InvalidationTracker->HandleRWXAccessViolation(FaultAddress);
}

if (HandledRWX) {
LogMan::Msg::DFmt("Handled self-modifying code: pc: {:X} fault: {:X}", Context->Pc, FaultAddress);
NtContinue(Context, FALSE);
}
}

if (!IsAddressInJit(Context->Pc)) {
Expand All @@ -645,29 +655,32 @@ NTSTATUS BTCpuResetToConsistentState(EXCEPTION_POINTERS* Ptrs) {
}

void BTCpuFlushInstructionCache2(const void* Address, SIZE_T Size) {
InvalidationTracker.InvalidateAlignedInterval(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), false);
std::scoped_lock Lock(ThreadCreationMutex);
InvalidationTracker->InvalidateAlignedInterval(reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), false);
}

void BTCpuNotifyMemoryAlloc(void* Address, SIZE_T Size, ULONG Type, ULONG Prot) {
InvalidationTracker.HandleMemoryProtectionNotification(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address),
static_cast<uint64_t>(Size), Prot);
std::scoped_lock Lock(ThreadCreationMutex);
InvalidationTracker->HandleMemoryProtectionNotification(reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), Prot);
}

void BTCpuNotifyMemoryProtect(void* Address, SIZE_T Size, ULONG NewProt) {
InvalidationTracker.HandleMemoryProtectionNotification(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address),
static_cast<uint64_t>(Size), NewProt);
std::scoped_lock Lock(ThreadCreationMutex);
InvalidationTracker->HandleMemoryProtectionNotification(reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), NewProt);
}

void BTCpuNotifyMemoryFree(void* Address, SIZE_T Size, ULONG FreeType) {
std::scoped_lock Lock(ThreadCreationMutex);
if (!Size) {
InvalidationTracker.InvalidateContainingSection(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address), true);
InvalidationTracker->InvalidateContainingSection(reinterpret_cast<uint64_t>(Address), true);
} else if (FreeType & MEM_DECOMMIT) {
InvalidationTracker.InvalidateAlignedInterval(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), true);
InvalidationTracker->InvalidateAlignedInterval(reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), true);
}
}

void BTCpuNotifyUnmapViewOfSection(void* Address, ULONG Flags) {
InvalidationTracker.InvalidateContainingSection(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address), true);
std::scoped_lock Lock(ThreadCreationMutex);
InvalidationTracker->InvalidateContainingSection(reinterpret_cast<uint64_t>(Address), true);
}

BOOLEAN WINAPI BTCpuIsProcessorFeaturePresent(UINT Feature) {
Expand Down
Loading