Skip to content

Commit

Permalink
Shader reflection, store info inside compiled mshdr files
Browse files Browse the repository at this point in the history
  • Loading branch information
xezno committed Dec 20, 2024
1 parent 8e94cd9 commit 87fb92e
Show file tree
Hide file tree
Showing 9 changed files with 314 additions and 39 deletions.
40 changes: 36 additions & 4 deletions Source/Mocha.Common/Resources/ShaderInfo.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,40 @@
namespace Mocha.Common;
using Mocha.Glue;
using System.Runtime.InteropServices;

namespace Mocha.Common;

public enum ShaderReflectionType
{
Unknown,

Buffer,
Texture,
Sampler
}

[StructLayout( LayoutKind.Sequential )]
public struct ShaderReflectionBinding
{
public int Set { get; set; }
public int Binding { get; set; }
public ShaderReflectionType Type { get; set; }
public string Name { get; set; }
}

public struct ShaderReflectionInfo
{
public UtilArray Bindings { get; set; }
}

public struct ShaderStageInfo
{
public int[] Data { get; set; }
public ShaderReflectionInfo Reflection { get; set; }
}

public struct ShaderInfo
{
public int[] VertexShaderData { get; set; }
public int[] FragmentShaderData { get; set; }
public int[] ComputeShaderData { get; set; }
public ShaderStageInfo Vertex { get; set; }
public ShaderStageInfo Fragment { get; set; }
public ShaderStageInfo Compute { get; set; }
}
16 changes: 8 additions & 8 deletions Source/Mocha.Engine/Render/Assets/Material.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ public Material( string path )

NativeMaterial = new(
Path,
shaderFormat.Data.VertexShaderData.ToInterop(),
shaderFormat.Data.FragmentShaderData.ToInterop(),
shaderFormat.Data.Vertex.Data.ToInterop(),
shaderFormat.Data.Fragment.Data.ToInterop(),
Vertex.VertexAttributes.ToInterop(),
textures.ToInterop(),
SamplerType.Point,
Expand All @@ -85,8 +85,8 @@ public Material( string path )
var shaderFormat = Serializer.Deserialize<MochaFile<ShaderInfo>>( shaderFileBytes );

NativeMaterial.SetShaderData(
shaderFormat.Data.VertexShaderData.ToInterop(),
shaderFormat.Data.FragmentShaderData.ToInterop()
shaderFormat.Data.Vertex.Data.ToInterop(),
shaderFormat.Data.Fragment.Data.ToInterop()
);

NativeMaterial.Reload();
Expand Down Expand Up @@ -122,8 +122,8 @@ public Material( string shaderPath, VertexAttribute[] vertexAttributes, Texture?

NativeMaterial = new(
Path,
shaderFormat.Data.VertexShaderData.ToInterop(),
shaderFormat.Data.FragmentShaderData.ToInterop(),
shaderFormat.Data.Vertex.Data.ToInterop(),
shaderFormat.Data.Fragment.Data.ToInterop(),
vertexAttributes.ToInterop(),
textures.ToInterop(),
sampler,
Expand All @@ -144,8 +144,8 @@ public static Material FromShader( string shaderPath, VertexAttribute[] vertexAt

material.NativeMaterial = new(
material.Path,
shaderFormat.Data.VertexShaderData.ToInterop(),
shaderFormat.Data.FragmentShaderData.ToInterop(),
shaderFormat.Data.Vertex.Data.ToInterop(),
shaderFormat.Data.Fragment.Data.ToInterop(),
vertexAttributes.ToInterop()
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1162,23 +1162,27 @@ void VulkanRenderContext::CreateFullScreenTri()

// Vertex
{
if ( !ShaderCompiler::Instance().Compile( SHADER_TYPE_VERTEX, g_fullScreenTriVertexShader.c_str(), vertexShaderBits ) )
ShaderCompilerResult result;
if ( !ShaderCompiler::Instance().Compile( SHADER_TYPE_VERTEX, g_fullScreenTriVertexShader.c_str(), result ) )
{
ErrorMessage( "Fullscreen triangle vertex shader failed to compile." );
abort();
}
vertexShaderBits = result.ShaderData.GetData<unsigned int>();

pipelineInfo.shaderInfo.vertexShaderData = vertexShaderBits;
}

// Fragment
{
ShaderCompilerResult result;
if ( !ShaderCompiler::Instance().Compile(
SHADER_TYPE_FRAGMENT, g_fullScreenTriFragmentShader.c_str(), fragmentShaderBits ) )
SHADER_TYPE_FRAGMENT, g_fullScreenTriFragmentShader.c_str(), result ) )
{
ErrorMessage( "Fullscreen triangle fragment shader failed to compile." );
abort();
}
fragmentShaderBits = result.ShaderData.GetData<unsigned int>();

pipelineInfo.shaderInfo.fragmentShaderData = fragmentShaderBits;
}
Expand Down
10 changes: 10 additions & 0 deletions Source/Mocha.Host/Rendering/baserendercontext.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ struct PipelineInfo_t
bool renderToSwapchain = false;
};

// todo: remove me
struct RenderPushConstants
{
glm::vec4 data = glm::vec4{ 1.0f };
Expand All @@ -202,6 +203,15 @@ struct RenderPushConstants
glm::vec4 vLightInfoWS[4] = {};
};

// These need to be aligned
struct ViewConstants
{
glm::mat4 render_matrix = {}; // view/projection
glm::vec4 vCameraPosWS = {}; // camera pos
glm::vec4 vLightInfoWS[4] = {}; // light data
glm::vec4 data = {}; // misc data
};

struct GPUInfo
{
const char* gpuName = "Unnamed";
Expand Down
12 changes: 11 additions & 1 deletion Source/Mocha.Host/Rendering/rendermanager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,16 @@ void RenderManager::Shutdown()

void SceneMeshPass::Execute()
{
for ( auto& sceneMesh : m_meshes )
// Make a copy here so that nothing can affect what we draw while we're drawing it
std::vector<std::shared_ptr<SceneMesh>> meshes = std::vector<std::shared_ptr<SceneMesh>>( m_meshes );

/*std::vector<glm::mat4> objectMatrices = {};
for ( auto& sceneMesh : meshes )
{
objectMatrices.push_back( sceneMesh->m_transform.GetModelMatrix() );
}*/

for ( auto& sceneMesh : meshes )
{
bool materialWasDirty = false;

Expand Down Expand Up @@ -193,6 +202,7 @@ void SceneMeshPass::AddMesh( std::shared_ptr<SceneMesh> sceneMesh )
m_meshes.push_back( sceneMesh );
}

// todo: remove
void SceneMeshPass::SetConstants( RenderPushConstants constants )
{
m_constants = constants;
Expand Down
166 changes: 160 additions & 6 deletions Source/Mocha.Host/Rendering/shadercompiler.cpp
Original file line number Diff line number Diff line change
@@ -1,20 +1,123 @@
#include "shadercompiler.h"

#include "baserendercontext.h"
#include <vector>
#include "spdlog/spdlog.h"

#include <vector>

using namespace slang;

ShaderCompiler::ShaderCompiler()
{
createGlobalSession( m_globalSession.writeRef() );
}

ShaderCompiler::~ShaderCompiler() { }
ShaderCompiler::~ShaderCompiler()
{
}

const char* KindToString( slang::TypeReflection::Kind kind )
{
switch ( kind )
{
case slang::TypeReflection::Kind::None:
return "None";
case slang::TypeReflection::Kind::Struct:
return "Struct";
case slang::TypeReflection::Kind::Array:
return "Array";
case slang::TypeReflection::Kind::Matrix:
return "Matrix";
case slang::TypeReflection::Kind::Vector:
return "Vector";
case slang::TypeReflection::Kind::Scalar:
return "Scalar";
case slang::TypeReflection::Kind::ConstantBuffer:
return "ConstantBuffer";
case slang::TypeReflection::Kind::Resource:
return "Resource";
case slang::TypeReflection::Kind::SamplerState:
return "SamplerState";
case slang::TypeReflection::Kind::TextureBuffer:
return "TextureBuffer";
case slang::TypeReflection::Kind::ShaderStorageBuffer:
return "ShaderStorageBuffer";
case slang::TypeReflection::Kind::ParameterBlock:
return "ParameterBlock";
case slang::TypeReflection::Kind::GenericTypeParameter:
return "GenericTypeParameter";
case slang::TypeReflection::Kind::Interface:
return "Interface";
case slang::TypeReflection::Kind::OutputStream:
return "OutputStream";
case slang::TypeReflection::Kind::Specialized:
return "Specialized";
case slang::TypeReflection::Kind::Feedback:
return "Feedback";
case slang::TypeReflection::Kind::Pointer:
return "Pointer";
case slang::TypeReflection::Kind::DynamicResource:
return "DynamicResource";
default:
return "Unknown";
}
}

const char* BindingTypeToString( slang::BindingType type )
{
switch ( type )
{
case slang::BindingType::Unknown:
return "Unknown";
case slang::BindingType::Sampler:
return "Sampler";
case slang::BindingType::Texture:
return "Texture";
case slang::BindingType::ConstantBuffer:
return "ConstantBuffer";
case slang::BindingType::ParameterBlock:
return "ParameterBlock";
case slang::BindingType::TypedBuffer:
return "TypedBuffer";
case slang::BindingType::RawBuffer:
return "RawBuffer";
case slang::BindingType::CombinedTextureSampler:
return "CombinedTextureSampler";
case slang::BindingType::InputRenderTarget:
return "InputRenderTarget";
case slang::BindingType::InlineUniformData:
return "InlineUniformData";
case slang::BindingType::RayTracingAccelerationStructure:
return "RayTracingAccelerationStructure";
case slang::BindingType::VaryingInput:
return "VaryingInput";
case slang::BindingType::VaryingOutput:
return "VaryingOutput";
case slang::BindingType::ExistentialValue:
return "ExistentialValue";
case slang::BindingType::PushConstant:
return "PushConstant";
case slang::BindingType::MutableFlag:
return "MutableFlag";
case slang::BindingType::MutableTexture:
return "MutableTexture";
case slang::BindingType::MutableTypedBuffer:
return "MutableTypedBuffer";
case slang::BindingType::MutableRawBuffer:
return "MutableRawBuffer";
case slang::BindingType::BaseMask:
return "BaseMask";
case slang::BindingType::ExtMask:
return "ExtMask";
default:
return "Unknown";
}
}

bool ShaderCompiler::Compile( const ShaderType shaderType, const char* pShader, std::vector<uint32_t>& outSpirv )
bool ShaderCompiler::Compile( const ShaderType shaderType, const char* pShader, ShaderCompilerResult& outResult )
{
outResult = ShaderCompilerResult();

Slang::ComPtr<ISession> session;

TargetDesc targetDesc{};
Expand All @@ -34,8 +137,7 @@ bool ShaderCompiler::Compile( const ShaderType shaderType, const char* pShader,
m_globalSession->createSession( sessionDesc, session.writeRef() );

Slang::ComPtr<IBlob> diagnostics;
IModule* module =
session->loadModuleFromSourceString( "Shader", "Shader.slang", pShader, diagnostics.writeRef() );
IModule* module = session->loadModuleFromSourceString( "Shader", "Shader.slang", pShader, diagnostics.writeRef() );

if ( diagnostics )
spdlog::error( "Shader compiler: {}", ( const char* )diagnostics->getBufferPointer() );
Expand Down Expand Up @@ -68,6 +170,58 @@ bool ShaderCompiler::Compile( const ShaderType shaderType, const char* pShader,

const uint32_t* data = static_cast<const uint32_t*>( kernelBlob->getBufferPointer() );
size_t wordCount = kernelBlob->getBufferSize() / sizeof( uint32_t );
outSpirv = std::vector<uint32_t>( data, data + wordCount );

outResult.ShaderData = UtilArray::FromVector( std::vector<uint32_t>( data, data + wordCount ) );

std::vector<ShaderReflectionBinding> reflectionBindings = {};
ShaderReflectionInfo shaderReflectionInfo = {};

{
slang::ProgramLayout* layout = program->getLayout( targetIndex );
auto globalScope = layout->getGlobalParamsVarLayout();
auto globalTypeLayout = globalScope->getTypeLayout();

int paramCount = globalTypeLayout->getFieldCount();
for ( int i = 0; i < paramCount; i++ )
{
auto param = globalTypeLayout->getFieldByIndex( i );
auto type = globalTypeLayout->getBindingRangeType( i );

// get binding info
size_t binding = param->getOffset( SLANG_PARAMETER_CATEGORY_DESCRIPTOR_TABLE_SLOT );
size_t set = param->getBindingSpace( SLANG_PARAMETER_CATEGORY_DESCRIPTOR_TABLE_SLOT );

SlangResourceShape shape = param->getType()->getResourceShape();

// get param name/type
const char* name = param->getName();
slang::TypeReflection::Kind kind = param->getType()->getKind();

spdlog::info( "[{}, {}] {} {}, {}", binding, set, KindToString( kind ), name, BindingTypeToString( type ) );

auto mochaReflectionType = SHADER_REFLECTION_TYPE_UNKNOWN;
switch ( type )
{
case slang::BindingType::Unknown:
mochaReflectionType = SHADER_REFLECTION_TYPE_UNKNOWN;
break;
case slang::BindingType::Texture:
mochaReflectionType = SHADER_REFLECTION_TYPE_TEXTURE;
break;
case slang::BindingType::Sampler:
mochaReflectionType = SHADER_REFLECTION_TYPE_SAMPLER;
break;
case slang::BindingType::ConstantBuffer:
mochaReflectionType = SHADER_REFLECTION_TYPE_BUFFER;
break;
}

reflectionBindings.push_back( ShaderReflectionBinding{
.Set = ( int )set, .Binding = ( int )binding, .Type = mochaReflectionType, .Name = name } );
}
}

outResult.ReflectionData = ShaderReflectionInfo { .Bindings = UtilArray::FromVector( reflectionBindings ) };

return true;
}
Loading

0 comments on commit 87fb92e

Please sign in to comment.