#include "shader.h" #include "shader_recompiler.h" #include "dxc_compiler.h" static std::unique_ptr readAllBytes(const char* filePath, size_t& fileSize) { FILE* file = fopen(filePath, "rb"); fseek(file, 0, SEEK_END); fileSize = ftell(file); fseek(file, 0, SEEK_SET); auto data = std::make_unique(fileSize); fread(data.get(), 1, fileSize, file); fclose(file); return data; } static void writeAllBytes(const char* filePath, const void* data, size_t dataSize) { FILE* file = fopen(filePath, "wb"); fwrite(data, 1, dataSize, file); fclose(file); } struct RecompiledShader { uint8_t* data = nullptr; IDxcBlob* dxil = nullptr; IDxcBlob* spirv = nullptr; }; int main(int argc, char** argv) { const char* input = #ifdef SHADER_RECOMP_INPUT SHADER_RECOMP_INPUT #else argv[1] #endif ; const char* output = #ifdef SHADER_RECOMP_OUTPUT SHADER_RECOMP_OUTPUT #else argv[2] #endif ; if (std::filesystem::is_directory(input)) { std::vector> files; std::map shaders; for (auto& file : std::filesystem::directory_iterator(input)) { size_t fileSize = 0; auto fileData = readAllBytes(file.path().string().c_str(), fileSize); bool foundAny = false; for (size_t i = 0; fileSize > sizeof(ShaderContainer) && i < fileSize - sizeof(ShaderContainer) - 1;) { auto shaderContainer = reinterpret_cast(fileData.get() + i); size_t dataSize = shaderContainer->virtualSize + shaderContainer->physicalSize; if ((shaderContainer->flags & 0xFFFFFF00) == 0x102A1100 && dataSize < (fileSize - i) && shaderContainer->field1C == 0 && shaderContainer->field20 == 0) { XXH64_hash_t hash = XXH3_64bits(shaderContainer, dataSize); auto shader = shaders.try_emplace(hash); if (shader.second) { shader.first->second.data = fileData.get() + i; foundAny = true; } i += dataSize; } else { i += sizeof(uint32_t); } } if (foundAny) files.emplace_back(std::move(fileData)); } std::atomic progress = 0; std::for_each(std::execution::par_unseq, shaders.begin(), shaders.end(), [&](auto& hashShaderPair) { auto& shader = hashShaderPair.second; thread_local ShaderRecompiler recompiler; recompiler = {}; recompiler.recompile(shader.data); thread_local DxcCompiler dxcCompiler; shader.dxil = dxcCompiler.compile(recompiler.out, recompiler.isPixelShader, false); shader.spirv = dxcCompiler.compile(recompiler.out, recompiler.isPixelShader, true); assert(shader.dxil != nullptr && shader.spirv != nullptr); assert(*(reinterpret_cast(shader.dxil->GetBufferPointer()) + 1) != 0 && "DXIL was not signed properly!"); size_t currentProgress = ++progress; if ((currentProgress % 10) == 0 || (currentProgress == shaders.size() - 1)) std::println("Recompiling shaders... {}%", currentProgress / float(shaders.size()) * 100.0f); }); std::println("Creating shader cache..."); StringBuffer f; f.println("#include \"shader_cache.h\""); f.println("ShaderCacheEntry g_shaderCacheEntries[] = {{"); std::vector dxil; std::vector spirv; for (auto& [hash, shader] : shaders) { f.println("\t{{ 0x{:X}, {}, {}, {}, {} }},", hash, dxil.size(), shader.dxil->GetBufferSize(), spirv.size(), shader.spirv->GetBufferSize()); dxil.insert(dxil.end(), reinterpret_cast(shader.dxil->GetBufferPointer()), reinterpret_cast(shader.dxil->GetBufferPointer()) + shader.dxil->GetBufferSize()); spirv.insert(spirv.end(), reinterpret_cast(shader.spirv->GetBufferPointer()), reinterpret_cast(shader.spirv->GetBufferPointer()) + shader.spirv->GetBufferSize()); } f.println("}};"); std::println("Compressing DXIL cache..."); int level = ZSTD_maxCLevel(); //level = ZSTD_defaultCLevel(); std::vector dxilCompressed(ZSTD_compressBound(dxil.size())); dxilCompressed.resize(ZSTD_compress(dxilCompressed.data(), dxilCompressed.size(), dxil.data(), dxil.size(), level)); f.print("uint8_t g_compressedDxilCache[] = {{"); for (auto data : dxilCompressed) f.print("{},", data); f.println("}};"); std::println("Compressing SPIRV cache..."); std::vector spirvCompressed(ZSTD_compressBound(spirv.size())); spirvCompressed.resize(ZSTD_compress(spirvCompressed.data(), spirvCompressed.size(), spirv.data(), spirv.size(), level)); f.print("uint8_t g_compressedSpirvCache[] = {{"); for (auto data : spirvCompressed) f.print("{},", data); f.println("}};"); f.println("size_t g_shaderCacheEntryCount = {};", shaders.size()); f.println("size_t g_dxilCacheCompressedSize = {};", dxilCompressed.size()); f.println("size_t g_dxilCacheDecompressedSize = {};", dxil.size()); f.println("size_t g_spirvCacheCompressedSize = {};", spirvCompressed.size()); f.println("size_t g_spirvCacheDecompressedSize = {};", spirv.size()); writeAllBytes(output, f.out.data(), f.out.size()); } else { ShaderRecompiler recompiler; size_t fileSize; recompiler.recompile(readAllBytes(input, fileSize).get()); writeAllBytes(output, recompiler.out.data(), recompiler.out.size()); } return 0; }