Refactor SPIRV constant usage to allow for better driver optimization.

This commit is contained in:
Skyth 2024-10-25 17:31:15 +03:00
parent 85d1948655
commit 25b4f604af
2 changed files with 118 additions and 79 deletions

View file

@ -1,6 +1,8 @@
#define FLT_MIN asfloat(0xff7fffff) #define FLT_MIN asfloat(0xff7fffff)
#define FLT_MAX asfloat(0x7f7fffff) #define FLT_MAX asfloat(0x7f7fffff)
#define INPUT_LAYOUT_FLAG_HAS_R11G11B10_NORMAL (1 << 0)
#ifdef __spirv__ #ifdef __spirv__
struct PushConstants struct PushConstants
@ -12,32 +14,25 @@ struct PushConstants
[[vk::push_constant]] ConstantBuffer<PushConstants> g_PushConstants; [[vk::push_constant]] ConstantBuffer<PushConstants> g_PushConstants;
#define CONSTANT_BUFFER(NAME, REGISTER) struct NAME #define g_AlphaTestMode vk::RawBufferLoad<uint>(g_PushConstants.SharedConstants + 128)
#define PACK_OFFSET(REGISTER) #define g_AlphaThreshold vk::RawBufferLoad<float>(g_PushConstants.SharedConstants + 132)
#define g_Booleans vk::RawBufferLoad<uint>(g_PushConstants.SharedConstants + 136)
#define GET_CONSTANT(NAME) constants.NAME #define g_SwappedTexcoords vk::RawBufferLoad<uint>(g_PushConstants.SharedConstants + 140)
#define GET_SHARED_CONSTANT(NAME) sharedConstants.NAME #define g_InputLayoutFlags vk::RawBufferLoad<uint>(g_PushConstants.SharedConstants + 144)
#define g_EnableGIBicubicFiltering vk::RawBufferLoad<bool>(g_PushConstants.SharedConstants + 148)
#else #else
#define CONSTANT_BUFFER(NAME, REGISTER) cbuffer NAME : register(REGISTER, space4) #define DEFINE_SHARED_CONSTANTS() \
#define PACK_OFFSET(REGISTER) : packoffset(REGISTER) uint g_AlphaTestMode : packoffset(c8.x); \
float g_AlphaThreshold : packoffset(c8.y); \
#define GET_CONSTANT(NAME) NAME uint g_Booleans : packoffset(c8.z); \
#define GET_SHARED_CONSTANT(NAME) NAME uint g_SwappedTexcoords : packoffset(c8.w); \
uint g_InputLayoutFlags : packoffset(c9.x); \
bool g_EnableGIBicubicFiltering : packoffset(c9.y)
#endif #endif
#define INPUT_LAYOUT_FLAG_HAS_R11G11B10_NORMAL (1 << 0)
#define SHARED_CONSTANTS \
[[vk::offset(128)]] uint g_AlphaTestMode PACK_OFFSET(c8.x); \
[[vk::offset(132)]] float g_AlphaThreshold PACK_OFFSET(c8.y); \
[[vk::offset(136)]] uint g_Booleans PACK_OFFSET(c8.z); \
[[vk::offset(140)]] uint g_SwappedTexcoords PACK_OFFSET(c8.w); \
[[vk::offset(144)]] uint g_InputLayoutFlags PACK_OFFSET(c9.x); \
[[vk::offset(148)]] bool g_EnableGIBicubicFiltering PACK_OFFSET(c9.y)
Texture2D<float4> g_Texture2DDescriptorHeap[] : register(t0, space0); Texture2D<float4> g_Texture2DDescriptorHeap[] : register(t0, space0);
Texture3D<float4> g_Texture3DDescriptorHeap[] : register(t0, space1); Texture3D<float4> g_Texture3DDescriptorHeap[] : register(t0, space1);
TextureCube<float4> g_TextureCubeDescriptorHeap[] : register(t0, space2); TextureCube<float4> g_TextureCubeDescriptorHeap[] : register(t0, space2);

View file

@ -176,11 +176,11 @@ void ShaderRecompiler::recompile(const VertexFetchInstruction& instr, uint32_t a
case DeclUsage::Normal: case DeclUsage::Normal:
case DeclUsage::Tangent: case DeclUsage::Tangent:
case DeclUsage::Binormal: case DeclUsage::Binormal:
print("tfetchR11G11B10(GET_SHARED_CONSTANT(g_InputLayoutFlags), "); print("tfetchR11G11B10(g_InputLayoutFlags, ");
break; break;
case DeclUsage::TexCoord: case DeclUsage::TexCoord:
print("tfetchTexcoord(GET_SHARED_CONSTANT(g_SwappedTexcoords), "); print("tfetchTexcoord(g_SwappedTexcoords, ");
break; break;
} }
@ -254,7 +254,7 @@ void ShaderRecompiler::recompile(const TextureFetchInstruction& instr, bool bicu
if (instr.constIndex == 0 && instr.dimension == TextureDimension::Texture2D) if (instr.constIndex == 0 && instr.dimension == TextureDimension::Texture2D)
{ {
indent(); indent();
print("pixelCoord = getPixelCoord(GET_SHARED_CONSTANT({}_ResourceDescriptorIndex), ", constNamePtr); print("pixelCoord = getPixelCoord({}_ResourceDescriptorIndex, ", constNamePtr);
printSrcRegister(2); printSrcRegister(2);
out += ");\n"; out += ");\n";
} }
@ -298,7 +298,7 @@ void ShaderRecompiler::recompile(const TextureFetchInstruction& instr, bool bicu
if (bicubic) if (bicubic)
out += "Bicubic"; out += "Bicubic";
print("(GET_SHARED_CONSTANT({0}_ResourceDescriptorIndex), GET_SHARED_CONSTANT({0}_SamplerDescriptorIndex), ", constNamePtr); print("({0}_ResourceDescriptorIndex, {0}_SamplerDescriptorIndex, ", constNamePtr);
printSrcRegister(componentCount); printSrcRegister(componentCount);
switch (instr.dimension) switch (instr.dimension)
@ -428,13 +428,13 @@ void ShaderRecompiler::recompile(const AluInstruction& instr)
const char* constantName = reinterpret_cast<const char*>(constantTableData + findResult->second->name); const char* constantName = reinterpret_cast<const char*>(constantTableData + findResult->second->name);
if (findResult->second->registerCount > 1) if (findResult->second->registerCount > 1)
{ {
regFormatted = std::format("GET_CONSTANT({})[{}{}]", constantName, regFormatted = std::format("{}({}{})", constantName,
reg - findResult->second->registerIndex, instr.const0Relative ? (instr.constAddressRegisterRelative ? " + a0" : " + aL") : ""); reg - findResult->second->registerIndex, instr.const0Relative ? (instr.constAddressRegisterRelative ? " + a0" : " + aL") : "");
} }
else else
{ {
assert(!instr.const0Relative && !instr.const1Relative); assert(!instr.const0Relative && !instr.const1Relative);
regFormatted = std::format("GET_CONSTANT({})", constantName); regFormatted = constantName;
} }
} }
else else
@ -1045,8 +1045,7 @@ void ShaderRecompiler::recompile(const uint8_t* shaderData)
const auto constantTableContainer = reinterpret_cast<const ConstantTableContainer*>(shaderData + shaderContainer->constantTableOffset); const auto constantTableContainer = reinterpret_cast<const ConstantTableContainer*>(shaderData + shaderContainer->constantTableOffset);
constantTableData = reinterpret_cast<const uint8_t*>(&constantTableContainer->constantTable); constantTableData = reinterpret_cast<const uint8_t*>(&constantTableContainer->constantTable);
println("CONSTANT_BUFFER(Constants, b{})", isPixelShader ? 1 : 0); out += "#ifdef __spirv__\n\n";
out += "{\n";
bool isMetaInstancer = false; bool isMetaInstancer = false;
bool hasIndexCount = false; bool hasIndexCount = false;
@ -1056,10 +1055,6 @@ void ShaderRecompiler::recompile(const uint8_t* shaderData)
const auto constantInfo = reinterpret_cast<const ConstantInfo*>( const auto constantInfo = reinterpret_cast<const ConstantInfo*>(
constantTableData + constantTableContainer->constantTable.constantInfo + i * sizeof(ConstantInfo)); constantTableData + constantTableContainer->constantTable.constantInfo + i * sizeof(ConstantInfo));
assert(constantInfo->registerSet != RegisterSet::Int4);
if (constantInfo->registerSet == RegisterSet::Float4)
{
const char* constantName = reinterpret_cast<const char*>(constantTableData + constantInfo->name); const char* constantName = reinterpret_cast<const char*>(constantTableData + constantInfo->name);
if (!isPixelShader) if (!isPixelShader)
@ -1070,21 +1065,47 @@ void ShaderRecompiler::recompile(const uint8_t* shaderData)
hasIndexCount = true; hasIndexCount = true;
} }
print("\t[[vk::offset({})]] float4 {}", constantInfo->registerIndex * 16, constantName); switch (constantInfo->registerSet)
{
case RegisterSet::Float4:
{
const char* shaderName = isPixelShader ? "Pixel" : "Vertex";
if (constantInfo->registerCount > 1) if (constantInfo->registerCount > 1)
print("[{}]", constantInfo->registerCount.get()); {
println("#define {}(INDEX) vk::RawBufferLoad<float4>(g_PushConstants.{}ShaderConstants + ({} + INDEX) * 16, 0x10)",
println(" PACK_OFFSET(c{});", constantInfo->registerIndex.get()); constantName, shaderName, constantInfo->registerIndex.get());
}
else
{
println("#define {} vk::RawBufferLoad<float4>(g_PushConstants.{}ShaderConstants + {}, 0x10)",
constantName, shaderName, constantInfo->registerIndex * 16);
}
for (uint16_t j = 0; j < constantInfo->registerCount; j++) for (uint16_t j = 0; j < constantInfo->registerCount; j++)
float4Constants.emplace(constantInfo->registerIndex + j, constantInfo); float4Constants.emplace(constantInfo->registerIndex + j, constantInfo);
break;
}
case RegisterSet::Sampler:
{
println("#define {}_ResourceDescriptorIndex vk::RawBufferLoad<uint>(g_PushConstants.SharedConstants + {})",
constantName, constantInfo->registerIndex * 4);
println("#define {}_SamplerDescriptorIndex vk::RawBufferLoad<uint>(g_PushConstants.SharedConstants + {})",
constantName, 64 + constantInfo->registerIndex * 4);
samplers.emplace(constantInfo->registerIndex, constantName);
break;
}
} }
} }
out += "};\n\n"; out += "\n#else\n\n";
out += "CONSTANT_BUFFER(SharedConstants, b2)\n"; println("cbuffer {}ShaderConstants : register(b{}, space4)", isPixelShader ? "Pixel" : "Vertex", isPixelShader ? 1 : 0);
out += "{\n"; out += "{\n";
for (uint32_t i = 0; i < constantTableContainer->constantTable.constants; i++) for (uint32_t i = 0; i < constantTableContainer->constantTable.constants; i++)
@ -1092,36 +1113,64 @@ void ShaderRecompiler::recompile(const uint8_t* shaderData)
const auto constantInfo = reinterpret_cast<const ConstantInfo*>( const auto constantInfo = reinterpret_cast<const ConstantInfo*>(
constantTableData + constantTableContainer->constantTable.constantInfo + i * sizeof(ConstantInfo)); constantTableData + constantTableContainer->constantTable.constantInfo + i * sizeof(ConstantInfo));
if (constantInfo->registerSet == RegisterSet::Float4)
{
const char* constantName = reinterpret_cast<const char*>(constantTableData + constantInfo->name); const char* constantName = reinterpret_cast<const char*>(constantTableData + constantInfo->name);
assert(constantInfo->registerSet != RegisterSet::Int4); print("\tfloat4 {}", constantName);
switch (constantInfo->registerSet) if (constantInfo->registerCount > 1)
{ print("[{}]", constantInfo->registerCount.get());
case RegisterSet::Bool:
{
println("#define {} (1 << {})", constantName, constantInfo->registerIndex + (isPixelShader ? 16 : 0));
boolConstants.emplace(constantInfo->registerIndex, constantName);
break;
}
case RegisterSet::Sampler: println(" : packoffset(c{});", constantInfo->registerIndex.get());
{
println("\t[[vk::offset({})]] uint {}_ResourceDescriptorIndex PACK_OFFSET(c{}.{});",
constantInfo->registerIndex * 4, constantName, constantInfo->registerIndex / 4, SWIZZLES[constantInfo->registerIndex % 4]);
println("\t[[vk::offset({})]] uint {}_SamplerDescriptorIndex PACK_OFFSET(c{}.{});", if (constantInfo->registerCount > 1)
64 + constantInfo->registerIndex * 4, constantName, 4 + constantInfo->registerIndex / 4, SWIZZLES[constantInfo->registerIndex % 4]); println("#define {0}(INDEX) {0}[INDEX]", constantName);
samplers.emplace(constantInfo->registerIndex, constantName);
break;
}
} }
} }
out += "\tSHARED_CONSTANTS;\n";
out += "};\n\n"; out += "};\n\n";
out += "cbuffer SharedConstants : register(b2, space4)\n";
out += "{\n";
for (uint32_t i = 0; i < constantTableContainer->constantTable.constants; i++)
{
const auto constantInfo = reinterpret_cast<const ConstantInfo*>(
constantTableData + constantTableContainer->constantTable.constantInfo + i * sizeof(ConstantInfo));
if (constantInfo->registerSet == RegisterSet::Sampler)
{
const char* constantName = reinterpret_cast<const char*>(constantTableData + constantInfo->name);
println("\tuint {}_ResourceDescriptorIndex : packoffset(c{}.{});",
constantName, constantInfo->registerIndex / 4, SWIZZLES[constantInfo->registerIndex % 4]);
println("\tuint {}_SamplerDescriptorIndex : packoffset(c{}.{});",
constantName, 4 + constantInfo->registerIndex / 4, SWIZZLES[constantInfo->registerIndex % 4]);
}
}
out += "\tDEFINE_SHARED_CONSTANTS();\n";
out += "};\n\n";
out += "#endif\n";
for (uint32_t i = 0; i < constantTableContainer->constantTable.constants; i++)
{
const auto constantInfo = reinterpret_cast<const ConstantInfo*>(
constantTableData + constantTableContainer->constantTable.constantInfo + i * sizeof(ConstantInfo));
if (constantInfo->registerSet == RegisterSet::Bool)
{
const char* constantName = reinterpret_cast<const char*>(constantTableData + constantInfo->name);
println("\t#define {} (1 << {})", constantName, constantInfo->registerIndex + (isPixelShader ? 16 : 0));
boolConstants.emplace(constantInfo->registerIndex, constantName);
}
}
out += '\n';
const auto shader = reinterpret_cast<const Shader*>(shaderData + shaderContainer->shaderOffset); const auto shader = reinterpret_cast<const Shader*>(shaderData + shaderContainer->shaderOffset);
out += "void main(\n"; out += "void main(\n";
@ -1195,11 +1244,6 @@ void ShaderRecompiler::recompile(const uint8_t* shaderData)
out += ")\n"; out += ")\n";
out += "{\n"; out += "{\n";
out += "#ifdef __spirv__\n";
println("\tConstants constants = vk::RawBufferLoad<Constants>(g_PushConstants.{}ShaderConstants, 0x100);", isPixelShader ? "Pixel" : "Vertex");
out += "\tSharedConstants sharedConstants = vk::RawBufferLoad<SharedConstants>(g_PushConstants.SharedConstants, 0x100);\n";
out += "#endif\n\n";
if (shaderContainer->definitionTableOffset != NULL) if (shaderContainer->definitionTableOffset != NULL)
{ {
auto definitionTable = reinterpret_cast<const DefinitionTable*>(shaderData + shaderContainer->definitionTableOffset); auto definitionTable = reinterpret_cast<const DefinitionTable*>(shaderData + shaderContainer->definitionTableOffset);
@ -1293,7 +1337,7 @@ void ShaderRecompiler::recompile(const uint8_t* shaderData)
} }
else if (!isPixelShader && hasIndexCount && i == 0) else if (!isPixelShader && hasIndexCount && i == 0)
{ {
out += "float4(iVertexId + GET_CONSTANT(g_IndexCount).x * iInstanceId, 0.0, 0.0, 0.0);\n"; out += "float4(iVertexId + g_IndexCount.x * iInstanceId, 0.0, 0.0, 0.0);\n";
} }
else else
{ {
@ -1514,7 +1558,7 @@ void ShaderRecompiler::recompile(const uint8_t* shaderData)
{ {
auto findResult = boolConstants.find(cfInstr.condJmp.boolAddress); auto findResult = boolConstants.find(cfInstr.condJmp.boolAddress);
if (findResult != boolConstants.end()) if (findResult != boolConstants.end())
println("if ((GET_SHARED_CONSTANT(g_Booleans) & {}) {}= 0)", findResult->second, cfInstr.condJmp.condition ^ simpleControlFlow ? "!" : "="); println("if ((g_Booleans & {}) {}= 0)", findResult->second, cfInstr.condJmp.condition ^ simpleControlFlow ? "!" : "=");
else else
println("if (b{} {}= 0)", uint32_t(cfInstr.condJmp.boolAddress), cfInstr.condJmp.condition ^ simpleControlFlow ? "!" : "="); println("if (b{} {}= 0)", uint32_t(cfInstr.condJmp.boolAddress), cfInstr.condJmp.condition ^ simpleControlFlow ? "!" : "=");
} }
@ -1569,7 +1613,7 @@ void ShaderRecompiler::recompile(const uint8_t* shaderData)
if (textureFetch.constIndex == 10) // g_GISampler if (textureFetch.constIndex == 10) // g_GISampler
{ {
indent(); indent();
out += "[branch] if (GET_SHARED_CONSTANT(g_EnableGIBicubicFiltering))"; out += "[branch] if (g_EnableGIBicubicFiltering)";
indent(); indent();
out += '{'; out += '{';
@ -1611,24 +1655,24 @@ void ShaderRecompiler::recompile(const uint8_t* shaderData)
if (isPixelShader) if (isPixelShader)
{ {
indent(); indent();
out += "[branch] if (GET_SHARED_CONSTANT(g_AlphaTestMode) == 1)"; out += "[branch] if (g_AlphaTestMode == 1)";
indent(); indent();
out += '{'; out += '{';
indent(); indent();
out += "\tclip(oC0.w - GET_SHARED_CONSTANT(g_AlphaThreshold));\n"; out += "\tclip(oC0.w - g_AlphaThreshold);\n";
indent(); indent();
out += "}"; out += "}";
indent(); indent();
out += "else if (GET_SHARED_CONSTANT(g_AlphaTestMode) == 2)"; out += "else if (g_AlphaTestMode == 2)";
indent(); indent();
out += '{'; out += '{';
indent(); indent();
out += "\toC0.w *= 1.0 + computeMipLevel(pixelCoord) * 0.25;\n"; out += "\toC0.w *= 1.0 + computeMipLevel(pixelCoord) * 0.25;\n";
indent(); indent();
out += "\toC0.w = 0.5 + (oC0.w - GET_SHARED_CONSTANT(g_AlphaThreshold)) / max(fwidth(oC0.w), 1e-6);\n"; out += "\toC0.w = 0.5 + (oC0.w - g_AlphaThreshold) / max(fwidth(oC0.w), 1e-6);\n";
indent(); indent();
out += '}'; out += '}';