Skip to content

Commit

Permalink
InvalidationTracker: Invalidate code across all threads
Browse files Browse the repository at this point in the history
When thread management was moved to the frontend, invalidation moved
from being a global operation to per-thread but the WOW64 backend wasn't
updated to account for this. Now for any invalidation event loop over
all threads tracked by the frontend and invalidate the appropriate
range.
  • Loading branch information
bylaws committed Apr 18, 2024
1 parent d92580b commit 7ea1d42
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 24 deletions.
35 changes: 29 additions & 6 deletions Source/Windows/Common/InvalidationTracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,28 @@

#include <FEXCore/Utils/LogManager.h>
#include <FEXCore/Utils/TypeDefines.h>
#include <FEXCore/Utils/SignalScopeGuards.h>
#include <FEXCore/Core/Context.h>
#include <FEXCore/Debug/InternalThreadState.h>
#include "InvalidationTracker.h"
#include <windef.h>
#include <winternl.h>

namespace FEX::Windows {
InvalidationTracker::InvalidationTracker(FEXCore::Context::Context& CTX, const std::unordered_map<DWORD, FEXCore::Core::InternalThreadState*>& Threads)
: CTX {CTX}
, Threads {Threads} {}

void InvalidationTracker::HandleMemoryProtectionNotification(FEXCore::Core::InternalThreadState* Thread, uint64_t Address, uint64_t Size,
ULONG Prot) {
const auto AlignedBase = Address & FEXCore::Utils::FEX_PAGE_MASK;
const auto AlignedSize = (Address - AlignedBase + Size + FEXCore::Utils::FEX_PAGE_SIZE - 1) & FEXCore::Utils::FEX_PAGE_MASK;

if (Prot & (PAGE_EXECUTE | PAGE_EXECUTE_READ | PAGE_EXECUTE_READWRITE)) {
Thread->CTX->InvalidateGuestCodeRange(Thread, AlignedBase, AlignedSize);
std::scoped_lock Lock(CTX.GetCodeInvalidationMutex());
for (auto Thread : Threads) {
CTX.InvalidateGuestCodeRange(Thread.second, AlignedBase, AlignedSize);
}
}

if (Prot & PAGE_EXECUTE_READWRITE) {
Expand All @@ -28,15 +36,21 @@ void InvalidationTracker::HandleMemoryProtectionNotification(FEXCore::Core::Inte
}
}

void InvalidationTracker::InvalidateContainingSection(FEXCore::Core::InternalThreadState* Thread, uint64_t Address, bool Free) {
void InvalidationTracker::InvalidateContainingSection(uint64_t Address, bool Free) {
MEMORY_BASIC_INFORMATION Info;
if (NtQueryVirtualMemory(NtCurrentProcess(), reinterpret_cast<void*>(Address), MemoryBasicInformation, &Info, sizeof(Info), nullptr)) {
return;
}

const auto SectionBase = reinterpret_cast<uint64_t>(Info.AllocationBase);
const auto SectionSize = reinterpret_cast<uint64_t>(Info.BaseAddress) + Info.RegionSize - reinterpret_cast<uint64_t>(Info.AllocationBase);
Thread->CTX->InvalidateGuestCodeRange(Thread, SectionBase, SectionSize);

{
std::scoped_lock Lock(CTX.GetCodeInvalidationMutex());
for (auto Thread : Threads) {
CTX.InvalidateGuestCodeRange(Thread.second, SectionBase, SectionSize);
}
}

if (Free) {
std::scoped_lock Lock(RWXIntervalsLock);
Expand All @@ -47,7 +61,13 @@ void InvalidationTracker::InvalidateContainingSection(FEXCore::Core::InternalThr
void InvalidationTracker::InvalidateAlignedInterval(FEXCore::Core::InternalThreadState* Thread, uint64_t Address, uint64_t Size, bool Free) {
const auto AlignedBase = Address & FEXCore::Utils::FEX_PAGE_MASK;
const auto AlignedSize = (Address - AlignedBase + Size + FEXCore::Utils::FEX_PAGE_SIZE - 1) & FEXCore::Utils::FEX_PAGE_MASK;
Thread->CTX->InvalidateGuestCodeRange(Thread, AlignedBase, AlignedSize);

{
std::scoped_lock Lock(CTX.GetCodeInvalidationMutex());
for (auto Thread : Threads) {
CTX.InvalidateGuestCodeRange(Thread.second, AlignedBase, AlignedSize);
}
}

if (Free) {
std::scoped_lock Lock(RWXIntervalsLock);
Expand Down Expand Up @@ -75,7 +95,7 @@ void InvalidationTracker::ReprotectRWXIntervals(uint64_t Address, uint64_t Size)
} while (Address < End);
}

bool InvalidationTracker::HandleRWXAccessViolation(FEXCore::Core::InternalThreadState* Thread, uint64_t FaultAddress) {
bool InvalidationTracker::HandleRWXAccessViolation(uint64_t FaultAddress) {
const bool NeedsInvalidate = [&](uint64_t Address) {
std::unique_lock Lock(RWXIntervalsLock);
const bool Enclosed = RWXIntervals.Query(Address).Enclosed;
Expand All @@ -93,7 +113,10 @@ bool InvalidationTracker::HandleRWXAccessViolation(FEXCore::Core::InternalThread

if (NeedsInvalidate) {
// RWXIntervalsLock cannot be held during invalidation
Thread->CTX->InvalidateGuestCodeRange(Thread, FaultAddress & FEXCore::Utils::FEX_PAGE_MASK, FEXCore::Utils::FEX_PAGE_SIZE);
std::scoped_lock Lock(CTX.GetCodeInvalidationMutex());
for (auto Thread : Threads) {
CTX.InvalidateGuestCodeRange(Thread.second, FaultAddress & FHU::FEX_PAGE_MASK, FHU::FEX_PAGE_SIZE);
}
return true;
}
return false;
Expand Down
16 changes: 12 additions & 4 deletions Source/Windows/Common/InvalidationTracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,33 @@

#include "IntervalList.h"
#include <mutex>
#include <unordered_map>

namespace FEXCore::Core {
struct InternalThreadState;
}

namespace FEXCore::Context {
class Context;
}

namespace FEX::Windows {
/**
* @brief Handles SMC and regular code invalidation
*/
class InvalidationTracker {
public:
void HandleMemoryProtectionNotification(FEXCore::Core::InternalThreadState* Thread, uint64_t Address, uint64_t Size, ULONG Prot);
void InvalidateContainingSection(FEXCore::Core::InternalThreadState* Thread, uint64_t Address, bool Free);
void InvalidateAlignedInterval(FEXCore::Core::InternalThreadState* Thread, uint64_t Address, uint64_t Size, bool Free);
InvalidationTracker(FEXCore::Context::Context& CTX, const std::unordered_map<DWORD, FEXCore::Core::InternalThreadState*>& Threads);
void HandleMemoryProtectionNotification(uint64_t Address, uint64_t Size, ULONG Prot);
void InvalidateContainingSection(uint64_t Address, bool Free);
void InvalidateAlignedInterval(uint64_t Address, uint64_t Size, bool Free);
void ReprotectRWXIntervals(uint64_t Address, uint64_t Size);
bool HandleRWXAccessViolation(FEXCore::Core::InternalThreadState* Thread, uint64_t FaultAddress);
bool HandleRWXAccessViolation(uint64_t FaultAddress);

private:
IntervalList<uint64_t> RWXIntervals;
std::mutex RWXIntervalsLock;
FEXCore::Context::Context& CTX;
const std::unordered_map<DWORD, FEXCore::Core::InternalThreadState*>& Threads;
};
} // namespace FEX::Windows
39 changes: 25 additions & 14 deletions Source/Windows/WOW64/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ fextl::unique_ptr<FEXCore::Context::Context> CTX;
fextl::unique_ptr<FEX::DummyHandlers::DummySignalDelegator> SignalDelegator;
fextl::unique_ptr<WowSyscallHandler> SyscallHandler;

FEX::Windows::InvalidationTracker InvalidationTracker;
std::optional<FEX::Windows::InvalidationTracker> InvalidationTracker;
std::optional<FEX::Windows::CPUFeatures> CPUFeatures;

std::mutex ThreadCreationMutex;
Expand Down Expand Up @@ -394,7 +394,7 @@ class WowSyscallHandler : public FEXCore::HLE::SyscallHandler, public FEXCore::A
}

void MarkGuestExecutableRange(FEXCore::Core::InternalThreadState* Thread, uint64_t Start, uint64_t Length) override {
InvalidationTracker.ReprotectRWXIntervals(Start, Length);
InvalidationTracker->ReprotectRWXIntervals(Start, Length);
}
};

Expand Down Expand Up @@ -423,6 +423,7 @@ void BTCpuProcessInit() {
CTX->SetSignalDelegator(SignalDelegator.get());
CTX->SetSyscallHandler(SyscallHandler.get());
CTX->InitCore();
InvalidationTracker.emplace(*CTX, Threads);
CPUFeatures.emplace(*CTX);
}

Expand Down Expand Up @@ -617,15 +618,22 @@ NTSTATUS BTCpuResetToConsistentState(EXCEPTION_POINTERS* Ptrs) {
if (Exception->ExceptionCode == EXCEPTION_ACCESS_VIOLATION) {
const auto FaultAddress = static_cast<uint64_t>(Exception->ExceptionInformation[1]);

if (InvalidationTracker.HandleRWXAccessViolation(GetTLS().ThreadState(), FaultAddress)) {
LogMan::Msg::DFmt("Handled self-modifying code: pc: {:X} fault: {:X}", Context->Pc, FaultAddress);
NtContinue(Context, FALSE);
}

if (Context::HandleSuspendInterrupt(Context, FaultAddress)) {
LogMan::Msg::DFmt("Resumed from suspend");
NtContinue(Context, FALSE);
}

bool HandledRWX = false;
if (GetTLS().ThreadState()) {
std::scoped_lock Lock(ThreadCreationMutex);
HandledRWX = InvalidationTracker->HandleRWXAccessViolation(FaultAddress);
}

if (HandledRWX) {
LogMan::Msg::DFmt("Handled self-modifying code: pc: {:X} fault: {:X}", Context->Pc, FaultAddress);
NtContinue(Context, FALSE);
}
}

if (!IsAddressInJit(Context->Pc)) {
Expand All @@ -647,29 +655,32 @@ NTSTATUS BTCpuResetToConsistentState(EXCEPTION_POINTERS* Ptrs) {
}

void BTCpuFlushInstructionCache2(const void* Address, SIZE_T Size) {
InvalidationTracker.InvalidateAlignedInterval(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), false);
std::scoped_lock Lock(ThreadCreationMutex);
InvalidationTracker->InvalidateAlignedInterval(reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), false);
}

void BTCpuNotifyMemoryAlloc(void* Address, SIZE_T Size, ULONG Type, ULONG Prot) {
InvalidationTracker.HandleMemoryProtectionNotification(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address),
static_cast<uint64_t>(Size), Prot);
std::scoped_lock Lock(ThreadCreationMutex);
InvalidationTracker->HandleMemoryProtectionNotification(reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), Prot);
}

void BTCpuNotifyMemoryProtect(void* Address, SIZE_T Size, ULONG NewProt) {
InvalidationTracker.HandleMemoryProtectionNotification(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address),
static_cast<uint64_t>(Size), NewProt);
std::scoped_lock Lock(ThreadCreationMutex);
InvalidationTracker->HandleMemoryProtectionNotification(reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), NewProt);
}

void BTCpuNotifyMemoryFree(void* Address, SIZE_T Size, ULONG FreeType) {
std::scoped_lock Lock(ThreadCreationMutex);
if (!Size) {
InvalidationTracker.InvalidateContainingSection(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address), true);
InvalidationTracker->InvalidateContainingSection(reinterpret_cast<uint64_t>(Address), true);
} else if (FreeType & MEM_DECOMMIT) {
InvalidationTracker.InvalidateAlignedInterval(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), true);
InvalidationTracker->InvalidateAlignedInterval(reinterpret_cast<uint64_t>(Address), static_cast<uint64_t>(Size), true);
}
}

void BTCpuNotifyUnmapViewOfSection(void* Address, ULONG Flags) {
InvalidationTracker.InvalidateContainingSection(GetTLS().ThreadState(), reinterpret_cast<uint64_t>(Address), true);
std::scoped_lock Lock(ThreadCreationMutex);
InvalidationTracker->InvalidateContainingSection(reinterpret_cast<uint64_t>(Address), true);
}

BOOLEAN WINAPI BTCpuIsProcessorFeaturePresent(UINT Feature) {
Expand Down

0 comments on commit 7ea1d42

Please sign in to comment.