Skip to content

Commit

Permalink
[IR][CodeGen] Add new SizedCapabilityType for use in CodeGen
Browse files Browse the repository at this point in the history
EVT::getTypeForEVT currently returns a PointerType for MVT::cN, but this
has a couple of issues. The first issue is we have to hard-code the
address space, though that's not such a big deal given we do that
elsewhere too. The bigger issue is that, when we later pass that to
MVT::getVT or EVT::getEVT, it doesn't know what the right size is, so
returns MVT::cPTR instead, which is not a true value type, and is
supposed to only be used by TableGen. This has been seen to confuse
TargetLoweringBase::getTypeConversion, as when presented with a vector
of capability pointers it can end up trying to recreate a smaller vector
of the same type, but this trips up various assertions for the MVT::cPTR
as both the IR methods and the code here are expecting to be dealing
with actual value types.

Borrowing the idea of TypedPointerType (DXILPointerTyID) a bit,
introduce a new IR type, SizedCapabilityType, to represent a fixed-size
capability during CodeGen, which allows lossless roundtripping from MVT
to Type and back.

This fixes building cheritest, which has crashed since the introduction
of cPTR due to cPTR not being a value type, mirroring iPTR, unlike the
old iFATPTRAny which was its own weird beast, but wouldn't have tripped
up these assertions. It probably didn't do the most sensible things
though.

Fixes: 7aa7f2e ("[CodeGen] Rework MVT representation of capabilities and add type inference")
  • Loading branch information
jrtc27 committed Jun 4, 2024
1 parent 9e20010 commit 5e65240
Show file tree
Hide file tree
Showing 13 changed files with 121 additions and 12 deletions.
2 changes: 2 additions & 0 deletions llvm/include/llvm/IR/DataLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,8 @@ inline TypeSize DataLayout::getTypeSizeInBits(Type *Ty) const {
getTypeSizeInBits(VTy->getElementType()).getFixedSize();
return TypeSize(MinBits, EltCnt.isScalable());
}
case Type::SizedCapabilityTyID:
return TypeSize::Fixed(cast<SizedCapabilityType>(Ty)->getBitWidth());
default:
llvm_unreachable("DataLayout::getTypeSizeInBits(): Unsupported type");
}
Expand Down
35 changes: 35 additions & 0 deletions llvm/include/llvm/IR/DerivedTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,41 @@ class PointerType : public Type {
}
};

/// Class to represent fixed-size capability types during CodeGen.
class SizedCapabilityType : public Type {
friend class LLVMContextImpl;

protected:
explicit SizedCapabilityType(LLVMContext &C, unsigned NumBits)
: Type(C, SizedCapabilityTyID) {
setSubclassData(NumBits);
}

public:
/// This enum is just used to hold constants we need for SizedCapabilityType.
enum {
MIN_CAP_BITS = 64, ///< Minimum number of bits that can be specified
MAX_CAP_BITS = 256 ///< Maximum number of bits that can be specified
///< Note that bit width is stored in the Type classes SubclassData field
///< which has 24 bits.
};

/// This static method is the primary way of constructing an IntegerType.
/// If an IntegerType with the same NumBits value was previously instantiated,
/// that instance will be returned. Otherwise a new one will be created. Only
/// one instance with a given NumBits value is ever created.
/// Get or create an IntegerType instance.
static SizedCapabilityType *get(LLVMContext &C, unsigned NumBits);

/// Get the number of bits in this IntegerType
unsigned getBitWidth() const { return getSubclassData(); }

/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *T) {
return T->getTypeID() == SizedCapabilityTyID;
}
};

Type *Type::getExtendedType() const {
assert(
isIntOrIntVectorTy() &&
Expand Down
8 changes: 6 additions & 2 deletions llvm/include/llvm/IR/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class Type {
FixedVectorTyID, ///< Fixed width SIMD vector type
ScalableVectorTyID, ///< Scalable SIMD vector type
DXILPointerTyID, ///< DXIL typed pointer used by DirectX target
SizedCapabilityTyID,///< Fixed-size CHERI capability type
};

private:
Expand Down Expand Up @@ -225,6 +226,9 @@ class Type {
/// True if this is an instance of PointerType.
bool isPointerTy() const { return getTypeID() == PointerTyID; }

/// True if this is an instance of PointerType.
bool isSizedCapabilityTy() const { return getTypeID() == SizedCapabilityTyID; }

/// True if this is an instance of an opaque PointerType.
bool isOpaquePointerTy() const;

Expand Down Expand Up @@ -272,8 +276,8 @@ class Type {
bool isSized(SmallPtrSetImpl<Type*> *Visited = nullptr) const {
// If it's a primitive, it is always sized.
if (getTypeID() == IntegerTyID || isFloatingPointTy() ||
getTypeID() == PointerTyID || getTypeID() == X86_MMXTyID ||
getTypeID() == X86_AMXTyID)
getTypeID() == PointerTyID || getTypeID() == SizedCapabilityTyID ||
getTypeID() == X86_MMXTyID || getTypeID() == X86_AMXTyID)
return true;
// If it is not something that can have a size (e.g. a function or label),
// it doesn't have a size.
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,8 @@ void ModuleBitcodeWriter::writeTypeTable() {
}
case Type::DXILPointerTyID:
llvm_unreachable("DXIL pointers cannot be added to IR modules");
case Type::SizedCapabilityTyID:
llvm_unreachable("Fixed-size capabilities cannot be added to IR modules");
}

// Emit the finished record.
Expand Down
12 changes: 7 additions & 5 deletions llvm/lib/CodeGen/ValueTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,10 +535,11 @@ Type *EVT::getTypeForEVT(LLVMContext &Context) const {
return ScalableVectorType::get(Type::getDoubleTy(Context), 8);
case MVT::Metadata: return Type::getMetadataTy(Context);
case MVT::c64:
return SizedCapabilityType::get(Context, 64);
case MVT::c128:
return SizedCapabilityType::get(Context, 128);
case MVT::c256:
// XXX: Hard-coded AS
return PointerType::get(Type::getInt8Ty(Context), 200);
return SizedCapabilityType::get(Context, 256);
}
// clang-format on
}
Expand All @@ -555,6 +556,8 @@ MVT MVT::getVT(Type *Ty, bool HandleUnknown){
return MVT::isVoid;
case Type::IntegerTyID:
return getIntegerVT(cast<IntegerType>(Ty)->getBitWidth());
case Type::SizedCapabilityTyID:
return getCapabilityVT(cast<SizedCapabilityType>(Ty)->getBitWidth());
case Type::HalfTyID: return MVT(MVT::f16);
case Type::BFloatTyID: return MVT(MVT::bf16);
case Type::FloatTyID: return MVT(MVT::f32);
Expand All @@ -565,9 +568,8 @@ MVT MVT::getVT(Type *Ty, bool HandleUnknown){
case Type::FP128TyID: return MVT(MVT::f128);
case Type::PPC_FP128TyID: return MVT(MVT::ppcf128);
case Type::PointerTyID: {
// FIXME: require a DataLayout here!
if (isCheriPointer(Ty, nullptr))
return MVT(MVT::cPTR);
// NB: Removing this upstream, so ensure we haven't made it worse
assert(!isCheriPointer(Ty, nullptr));
return MVT(MVT::iPTR);
}
case Type::FixedVectorTyID:
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/IR/AsmWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,9 @@ void TypePrinting::print(Type *Ty, raw_ostream &OS) {
// extra dependencies we just print the pointer's address here.
OS << "dxil-ptr (" << Ty << ")";
return;
case Type::SizedCapabilityTyID:
OS << 'c' << cast<SizedCapabilityType>(Ty)->getBitWidth();
return;
}
llvm_unreachable("Invalid TypeID");
}
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/IR/Core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,8 @@ LLVMTypeKind LLVMGetTypeKind(LLVMTypeRef Ty) {
return LLVMScalableVectorTypeKind;
case Type::DXILPointerTyID:
llvm_unreachable("DXIL pointers are unsupported via the C API");
case Type::SizedCapabilityTyID:
llvm_unreachable("Fixed-size capabilities are unsupported via the C API");
}
llvm_unreachable("Unhandled TypeID.");
}
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/IR/DataLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,8 @@ Align DataLayout::getAlignment(Type *Ty, bool abi_or_pref) const {
}
case Type::X86_AMXTyID:
return Align(64);
case Type::SizedCapabilityTyID:
return Align(cast<SizedCapabilityType>(Ty)->getBitWidth() / 8);
default:
llvm_unreachable("Bad type for getAlignment!!!");
}
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/IR/LLVMContextImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1484,6 +1484,7 @@ class LLVMContextImpl {
DenseMap<std::pair<Type *, ElementCount>, VectorType *> VectorTypes;
DenseMap<Type *, PointerType *> PointerTypes; // Pointers in AddrSpace = 0
DenseMap<std::pair<Type *, unsigned>, PointerType *> ASPointerTypes;
DenseMap<unsigned, SizedCapabilityType *> SizedCapabilityTypes;

/// ValueHandles - This map keeps track of all of the value handles that are
/// watching a Value*. The Value::HasValueHandle bit is used to know
Expand Down
30 changes: 25 additions & 5 deletions llvm/lib/IR/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ TypeSize Type::getPrimitiveSizeInBits() const {
assert(!ETS.isScalable() && "Vector type should have fixed-width elements");
return {ETS.getFixedSize() * EC.getKnownMinValue(), EC.isScalable()};
}
case Type::SizedCapabilityTyID:
return TypeSize::Fixed(cast<SizedCapabilityType>(this)->getBitWidth());
default: return TypeSize::Fixed(0);
}
}
Expand Down Expand Up @@ -674,7 +676,7 @@ VectorType *VectorType::get(Type *ElementType, ElementCount EC) {

bool VectorType::isValidElementType(Type *ElemTy) {
return ElemTy->isIntegerTy() || ElemTy->isFloatingPointTy() ||
ElemTy->isPointerTy();
ElemTy->isPointerTy() || ElemTy->isSizedCapabilityTy();
}

//===----------------------------------------------------------------------===//
Expand All @@ -684,8 +686,9 @@ bool VectorType::isValidElementType(Type *ElemTy) {
FixedVectorType *FixedVectorType::get(Type *ElementType, unsigned NumElts) {
assert(NumElts > 0 && "#Elements of a VectorType must be greater than 0");
assert(isValidElementType(ElementType) && "Element type of a VectorType must "
"be an integer, floating point, or "
"pointer type.");
"be an integer, floating point, "
"pointer, or sized capability "
"type.");

auto EC = ElementCount::getFixed(NumElts);

Expand All @@ -706,8 +709,9 @@ ScalableVectorType *ScalableVectorType::get(Type *ElementType,
unsigned MinNumElts) {
assert(MinNumElts > 0 && "#Elements of a VectorType must be greater than 0");
assert(isValidElementType(ElementType) && "Element type of a VectorType must "
"be an integer, floating point, or "
"pointer type.");
"be an integer, floating point, "
"pointer, or sized capability "
"type.");

auto EC = ElementCount::getScalable(MinNumElts);

Expand Down Expand Up @@ -784,3 +788,19 @@ bool PointerType::isValidElementType(Type *ElemTy) {
bool PointerType::isLoadableOrStorableType(Type *ElemTy) {
return isValidElementType(ElemTy) && !ElemTy->isFunctionTy();
}

//===----------------------------------------------------------------------===//
// SizedCapabilityType Implementation
//===----------------------------------------------------------------------===//

SizedCapabilityType *SizedCapabilityType::get(LLVMContext &C, unsigned NumBits) {
assert(NumBits >= MIN_CAP_BITS && "bitwidth too small");
assert(NumBits <= MAX_CAP_BITS && "bitwidth too large");

SizedCapabilityType *&Entry = C.pImpl->SizedCapabilityTypes[NumBits];

if (!Entry)
Entry = new (C.pImpl->Alloc) SizedCapabilityType(C, NumBits);

return Entry;
}
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,7 @@ void DXILBitcodeWriter::writeTypeTable() {
case Type::BFloatTyID:
case Type::X86_AMXTyID:
case Type::TokenTyID:
case Type::SizedCapabilityTyID:
llvm_unreachable("These should never be used!!!");
break;
case Type::VoidTyID:
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ unsigned HexagonTargetObjectFile::getSmallestAddressableSize(const Type *Ty,
case Type::X86_AMXTyID:
case Type::TokenTyID:
case Type::DXILPointerTyID:
case Type::SizedCapabilityTyID:
return 0;
}

Expand Down
34 changes: 34 additions & 0 deletions llvm/test/Transforms/SLPVectorizer/cheri-crash-cPTR-element.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -mattr=+cheri128 -mcpu=beri -passes="function(slp-vectorizer)" -S < %s | FileCheck %s

target datalayout = "E-m:e-pf200:128:128:128:64-i8:8:32-i16:16:32-i64:64-n32:64-S128"
target triple = "mips64c128-unknown-freebsd"

;; This previously crashed in TargetLoweringBase due to creating a vector which
;; gave an element type of MVT::cPTR when queried rather than MVT::c128.
define ptr @cmemcpy(ptr addrspace(200) %0, ptr addrspace(200) %1, i1 %tobool) {
; CHECK-LABEL: @cmemcpy(
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 [[TOBOOL:%.*]], label [[IF_END12:%.*]], label [[DO_BODY45:%.*]]
; CHECK: if.end12:
; CHECK-NEXT: br label [[DO_BODY45]]
; CHECK: do.body45:
; CHECK-NEXT: [[SRC2_0:%.*]] = phi ptr addrspace(200) [ [[TMP0:%.*]], [[IF_END12]] ], [ [[TMP1:%.*]], [[ENTRY:%.*]] ]
; CHECK-NEXT: [[DST1_0:%.*]] = phi ptr addrspace(200) [ [[TMP1]], [[IF_END12]] ], [ [[TMP0]], [[ENTRY]] ]
; CHECK-NEXT: [[TMP2:%.*]] = load i64, ptr addrspace(200) [[SRC2_0]], align 8
; CHECK-NEXT: store i64 [[TMP2]], ptr addrspace(200) [[DST1_0]], align 8
; CHECK-NEXT: ret ptr null
;
entry:
br i1 %tobool, label %if.end12, label %do.body45

if.end12:
br label %do.body45

do.body45:
%src2.0 = phi ptr addrspace(200) [ %0, %if.end12 ], [ %1, %entry ]
%dst1.0 = phi ptr addrspace(200) [ %1, %if.end12 ], [ %0, %entry ]
%2 = load i64, ptr addrspace(200) %src2.0, align 8
store i64 %2, ptr addrspace(200) %dst1.0, align 8
ret ptr null
}

0 comments on commit 5e65240

Please sign in to comment.