Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: [SPIR-V] Add support for vk::buffer_ref attribute #5089

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tools/clang/include/clang/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#define LLVM_CLANG_AST_ASTCONTEXT_H

#include "clang/AST/ASTTypeTraits.h"
#include "clang/AST/Attr.h"
#include "clang/AST/CanonicalType.h"
#include "clang/AST/CommentCommandTraits.h"
#include "clang/AST/Decl.h"
Expand Down Expand Up @@ -2490,6 +2491,17 @@ class ASTContext : public RefCountedBase<ASTContext> {
};

llvm::StringMap<SectionInfo> SectionInfos;

// Buffer Reference Utilities
public:
bool IsBufferRef(const ValueDecl *D) const;
bool IsBufferRefDecl(const Decl *D) const;
bool IsBufferRefTypeDef(QualType T) const;
clang::VKBufferRefAttr *GetBufferRefAttr(const Expr *base) const;
clang::VKBufferRefAttr *GetBufferRefAttr(const ValueDecl *decl) const;
uint32_t BufferRefByteSize() const;
uint32_t BufferRefByteAlign() const;
clang::CanQualType BufferRefProxyType() const;
};

/// \brief Utility function for constructing a nullary selector.
Expand Down
8 changes: 8 additions & 0 deletions tools/clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,14 @@ def VKLocation : InheritableAttr {
let Documentation = [Undocumented];
}

def VKBufferRef : InheritableAttr {
let Spellings = [CXX11<"vk", "buffer_ref">];
let Subjects = SubjectList<[Function, Var, ParmVar, Field, TypedefName], ErrorDiag>;
let Args = [IntArgument<"Number">];
let LangOpts = [SPIRV];
let Documentation = [Undocumented];
}

def VKIndex : InheritableAttr {
let Spellings = [CXX11<"vk", "index">];
let Subjects = SubjectList<[Function, ParmVar, Field], ErrorDiag>;
Expand Down
5 changes: 3 additions & 2 deletions tools/clang/include/clang/SPIRV/AstTypeProbe.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ bool isAKindOfStructuredOrByteBuffer(QualType type);
/// \brief Returns true if the given type is the HLSL (RW)StructuredBuffer,
/// (RW)ByteAddressBuffer, {Append|Consume}StructuredBuffer, or a struct
/// containing one of the above.
bool isOrContainsAKindOfStructuredOrByteBuffer(QualType type);
bool isOrContainsAKindOfStructuredOrByteBuffer(const ASTContext &astContext,
QualType type);

/// \brief Returns true if the given type is the HLSL Buffer type.
bool isBuffer(QualType type);
Expand Down Expand Up @@ -285,7 +286,7 @@ bool isOpaqueArrayType(QualType type);
/// (in a recursive away).
///
/// Note: legalization specific code
bool isOpaqueStructType(QualType type);
bool isOpaqueStructType(const ASTContext &astContext, QualType type);

/// \brief Returns true if the given type can use relaxed precision
/// decoration. Integer and float types with lower than 32 bits can be
Expand Down
8 changes: 6 additions & 2 deletions tools/clang/include/clang/SPIRV/SpirvBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class SpirvBuilder {
/// this method for the variable itself.
SpirvVariable *addFnVar(QualType valueType, SourceLocation,
llvm::StringRef name = "", bool isPrecise = false,
SpirvInstruction *init = nullptr);
SpirvInstruction *init = nullptr, bool isRef = false);

/// \brief Ends building of the current function. All basic blocks constructed
/// from the beginning or after ending the previous function will be collected
Expand Down Expand Up @@ -208,7 +208,8 @@ class SpirvBuilder {
SpirvAccessChain *
createAccessChain(QualType resultType, SpirvInstruction *base,
llvm::ArrayRef<SpirvInstruction *> indexes,
SourceLocation loc, SourceRange range = {});
SourceLocation loc, SourceRange range = {},
uint32_t bufRefAlign = 0);
SpirvAccessChain *
createAccessChain(const SpirvType *resultType, SpirvInstruction *base,
llvm::ArrayRef<SpirvInstruction *> indexes,
Expand Down Expand Up @@ -690,6 +691,9 @@ class SpirvBuilder {
/// \brief Decorates the given target with noperspective
void decorateNoPerspective(SpirvInstruction *target, SourceLocation);

/// \brief Decorates the given target with aliased pointer
void decorateAliasedPointer(SpirvInstruction *target, SourceLocation);

/// \brief Decorates the given target with sample
void decorateSample(SpirvInstruction *target, SourceLocation);

Expand Down
73 changes: 73 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,21 @@ class SpirvContext {
const SpirvPointerType *getPointerType(const SpirvType *pointee,
spv::StorageClass);

const SpirvFwdPointerType *getFwdPointerType(unsigned pointee_id,
spv::StorageClass);

const SpirvFwdPointerType *getFwdPointerForDecl(const DeclContext *decl);

void registerFwdPointerForDecl(const SpirvFwdPointerType *fwd_ptr,
const DeclContext *decl);

void registerFwdPointerForPointer(const SpirvFwdPointerType *fwd_ptr,
const SpirvPointerType *ptr);
const SpirvFwdPointerType *getFwdPointerForPointer(
const SpirvPointerType *ptr);
const SpirvPointerType *getPointerForFwdPointer(
const SpirvFwdPointerType *fwd_ptr);

FunctionType *getFunctionType(const SpirvType *ret,
llvm::ArrayRef<const SpirvType *> param);

Expand Down Expand Up @@ -398,11 +413,50 @@ class SpirvContext {
const DeclContext *decl) {
assert(spvTy != nullptr && decl != nullptr);
spvStructTypeToDecl[spvTy] = decl;
registeredDecls.insert(decl);
}
const DeclContext *getStructDeclForSpirvType(const SpirvType *spvTy) {
return spvStructTypeToDecl[spvTy];
}

// Remember and query if struct_ty is part of buffer reference
void registerRefdStructType(const StructType* struct_ty) {
refdStructTypes.insert(struct_ty);
}
bool isRefdStructType(const StructType *struct_ty) {
return refdStructTypes.find(struct_ty) != refdStructTypes.end();
}

// Has decl been registered with a spv type yet?
bool isRegisteredStructDecl(const DeclContext* decl) {
return registeredDecls.find(decl) != registeredDecls.end();
}

// Is decl lowering in progress?
bool isInProgressStructDecl(const DeclContext *decl) {
return inProgressDecls.find(decl) != inProgressDecls.end();
}

// Set decl in-progress status
void setInProgressStructDecl(const DeclContext *decl,
const bool inProgress) {
if (inProgress)
inProgressDecls.insert(decl);
else
inProgressDecls.erase(decl);
return;
}

// Return forward pointee_id associated with decl
unsigned getPointeeIdForStructDecl(const DeclContext* decl) {
auto pid_itr = pointeeIdForDecl.find(decl);
if (pid_itr != pointeeIdForDecl.end())
return pid_itr->second;
auto pid = nextPointeeId++;
pointeeIdForDecl[decl] = pid;
return pid;
}

/// Function to add/get the mapping from a FunctionDecl to its DebugFunction.
void registerDebugFunctionForDecl(const FunctionDecl *decl,
SpirvDebugFunction *fn) {
Expand Down Expand Up @@ -455,6 +509,9 @@ class SpirvContext {
using SCToPtrTyMap =
llvm::DenseMap<spv::StorageClass, const SpirvPointerType *,
StorageClassDenseMapInfo>;
using SCToFwdPtrTyMap =
llvm::DenseMap<spv::StorageClass, const SpirvFwdPointerType *,
StorageClassDenseMapInfo>;

// Vector/matrix types for each possible element count.
// Type at index is for vector of index components. Index 0/1 is unused.
Expand All @@ -471,6 +528,19 @@ class SpirvContext {
llvm::SmallVector<const StructType *, 8> structTypes;
llvm::SmallVector<const HybridStructType *, 8> hybridStructTypes;
llvm::DenseMap<const SpirvType *, SCToPtrTyMap> pointerTypes;
llvm::DenseMap<unsigned, SCToFwdPtrTyMap> fwdPointerTypes;
llvm::DenseMap<const DeclContext *, const SpirvFwdPointerType *>
fwdPointerForDecl;
llvm::DenseSet<const DeclContext *> registeredDecls;
llvm::DenseSet<const DeclContext *> inProgressDecls;
llvm::DenseSet<const StructType *> refdStructTypes;
llvm::DenseMap<const DeclContext *, unsigned> pointeeIdForDecl;
llvm::DenseMap<const SpirvPointerType *, const SpirvFwdPointerType *>
fwdPointerForPointer;
llvm::DenseMap<const SpirvFwdPointerType *, const SpirvPointerType *>
pointerForFwdPointer;
llvm::DenseMap<const StructType *, const SpirvPointerType *>
pointerForStruct;
llvm::SmallVector<const HybridPointerType *, 8> hybridPointerTypes;
llvm::DenseSet<FunctionType *, FunctionTypeMapInfo> functionTypes;
llvm::DenseMap<unsigned, SpirvIntrinsicType*> spirvIntrinsicTypes;
Expand All @@ -491,6 +561,9 @@ class SpirvContext {
llvm::StringMap<RichDebugInfo> debugInfo;
SpirvDebugInstruction *currentLexicalScope;

// Ids for StructDecls pointee
unsigned nextPointeeId;

// Mapping from SPIR-V type to debug type instruction.
// The purpose is not to generate several DebugType* instructions for the same
// type if the type is used for several variables.
Expand Down
5 changes: 4 additions & 1 deletion tools/clang/include/clang/SPIRV/SpirvInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,8 @@ class SpirvAccessChain : public SpirvInstruction {
SpirvAccessChain(QualType resultType, SourceLocation loc,
SpirvInstruction *base,
llvm::ArrayRef<SpirvInstruction *> indexVec,
SourceRange range = {});
SourceRange range = {},
uint32_t bufRefAlign = 0);

DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvAccessChain)

Expand All @@ -861,10 +862,12 @@ class SpirvAccessChain : public SpirvInstruction {

SpirvInstruction *getBase() const { return base; }
llvm::ArrayRef<SpirvInstruction *> getIndexes() const { return indices; }
uint32_t getBufferRefAlign() const { return bufferRefAlign; }

private:
SpirvInstruction *base;
llvm::SmallVector<SpirvInstruction *, 4> indices;
uint32_t bufferRefAlign;
};

/// \brief Atomic instructions.
Expand Down
26 changes: 25 additions & 1 deletion tools/clang/include/clang/SPIRV/SpirvType.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class SpirvType {
TK_RuntimeArray,
TK_Struct,
TK_Pointer,
TK_FwdPointer,
TK_Function,
TK_AccelerationStructureNV,
TK_RayQueryKHR,
Expand Down Expand Up @@ -384,6 +385,25 @@ class SpirvPointerType : public SpirvType {
spv::StorageClass storageClass;
};

/// Represents a SPIR-V forward pointer type.
class SpirvFwdPointerType : public SpirvType {
public:
SpirvFwdPointerType(unsigned pointee_id, spv::StorageClass sc)
: SpirvType(TK_FwdPointer), pointeeId(pointee_id), storageClass(sc) {}

static bool classof(const SpirvType *t) { return t->getKind() == TK_FwdPointer; }

spv::StorageClass getStorageClass() const { return storageClass; }

bool operator==(const SpirvFwdPointerType &that) const {
return pointeeId == that.pointeeId && storageClass == that.storageClass;
}

private:
unsigned pointeeId;
spv::StorageClass storageClass;
};

/// Represents a SPIR-V function type. None of the parameters nor the return
/// type is allowed to be a hybrid type.
class FunctionType : public SpirvType {
Expand Down Expand Up @@ -480,12 +500,14 @@ class HybridStructType : public HybridType {
public:
FieldInfo(QualType astType_, llvm::StringRef name_ = "",
clang::VKOffsetAttr *offset = nullptr,
bool is_buffer_ref = false,
hlsl::ConstantPacking *packOffset = nullptr,
const hlsl::RegisterAssignment *regC = nullptr,
bool precise = false,
llvm::Optional<BitfieldInfo> bitfield = llvm::None)
: astType(astType_), name(name_), vkOffsetAttr(offset),
packOffsetAttr(packOffset), registerC(regC), isPrecise(precise),
isBufferRef(is_buffer_ref), packOffsetAttr(packOffset),
registerC(regC), isPrecise(precise),
bitfield(std::move(bitfield)) {}

// The field's type.
Expand All @@ -494,6 +516,8 @@ class HybridStructType : public HybridType {
std::string name;
// vk::offset attributes associated with this field.
clang::VKOffsetAttr *vkOffsetAttr;
// vk::buffer_ref attribute is associated with this field.
bool isBufferRef;
// :packoffset() annotations associated with this field.
hlsl::ConstantPacking *packOffsetAttr;
// :register(c#) annotations associated with this field.
Expand Down
2 changes: 2 additions & 0 deletions tools/clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -1923,6 +1923,8 @@ class Sema {
AccessSpecifier AS,
AttributeList *MSPropertyAttr);

bool IsBufferRefDecltr(Declarator *D);

FieldDecl *CheckFieldDecl(DeclarationName Name, QualType T,
TypeSourceInfo *TInfo,
RecordDecl *Record, SourceLocation Loc,
Expand Down
61 changes: 58 additions & 3 deletions tools/clang/lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include "clang/AST/ASTContext.h"
#include "CXXABI.h"
#include "clang/AST/ASTMutationListener.h"
#include "clang/AST/Attr.h"
#include "clang/AST/CharUnits.h"
#include "clang/AST/Comment.h"
#include "clang/AST/CommentCommandTraits.h"
Expand Down Expand Up @@ -3217,8 +3216,8 @@ ASTContext::getTypedefType(const TypedefNameDecl *Decl,

if (Canonical.isNull())
Canonical = getCanonicalType(Decl->getUnderlyingType());
TypedefType *newType = new(*this, TypeAlignment)
TypedefType(Type::Typedef, Decl, Canonical);
TypedefType *newType =
new (*this, TypeAlignment) TypedefType(Type::Typedef, Decl, Canonical);
Decl->TypeForDecl = newType;
Types.push_back(newType);
return QualType(newType, 0);
Expand Down Expand Up @@ -8906,3 +8905,59 @@ clang::LazyGenerationalUpdatePtr<
clang::LazyGenerationalUpdatePtr<
const Decl *, Decl *, &ExternalASTSource::CompleteRedeclChain>::makeValue(
const clang::ASTContext &Ctx, Decl *Value);

// Buffer Reference Utilities

bool
ASTContext::IsBufferRefDecl(const Decl *D) const {
return D->hasAttr<VKBufferRefAttr>();
}

bool
ASTContext::IsBufferRefTypeDef(QualType T) const {
auto typedef_type = T->getAs<TypedefType>();
if (!typedef_type)
return false;
auto decl = typedef_type->getDecl();
return decl->hasAttr<VKBufferRefAttr>();
}

bool ASTContext::IsBufferRef(const ValueDecl *D) const {
return IsBufferRefDecl(D) || IsBufferRefTypeDef(D->getType());
}

VKBufferRefAttr*
ASTContext::GetBufferRefAttr(const Expr *base) const {
VKBufferRefAttr *attr = nullptr;
if (const auto *arg = dyn_cast<DeclRefExpr>(base))
if (const auto *varDecl = dyn_cast<VarDecl>(arg->getDecl()))
attr = varDecl->getAttr<VKBufferRefAttr>();
if (attr == nullptr) {
auto typedef_type = base->getType()->getAs<TypedefType>();
if (typedef_type) {
auto decl = typedef_type->getDecl();
attr = decl->getAttr<VKBufferRefAttr>();
}
}
return attr;
}

VKBufferRefAttr *ASTContext::GetBufferRefAttr(const ValueDecl *decl) const {
auto attr = decl->getAttr<VKBufferRefAttr>();
if (attr == nullptr) {
const auto fieldType = decl->getType();
const auto typedef_type = fieldType->getAs<TypedefType>();
if (typedef_type) {
auto decl = typedef_type->getDecl();
attr = decl->getAttr<VKBufferRefAttr>();
}
}
return attr;
}

uint32_t ASTContext::BufferRefByteSize() const { return 8u; }
uint32_t ASTContext::BufferRefByteAlign() const { return 8u; }

clang::CanQualType ASTContext::BufferRefProxyType() const {
return UnsignedLongLongTy;
}
5 changes: 5 additions & 0 deletions tools/clang/lib/AST/RecordLayoutBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2312,6 +2312,11 @@ MicrosoftRecordLayoutBuilder::getAdjustedElementInfo(
// Get the alignment of the field type's natural alignment, ignore any
// alignment attributes.
ElementInfo Info;
if (Context.IsBufferRef(FD)) {
std::tie(Info.Size, Info.Alignment) = Context.getTypeInfoInChars(
Context.BufferRefProxyType());
return Info;
}
std::tie(Info.Size, Info.Alignment) =
Context.getTypeInfoInChars(FD->getType()->getUnqualifiedDesugaredType());
// Respect align attributes on the field.
Expand Down
Loading