From 88196ee50830204f6d9cd538ed79c12510a0a94a Mon Sep 17 00:00:00 2001 From: Isaac Marovitz Date: Tue, 25 Mar 2025 19:39:05 -0400 Subject: [PATCH] MSL cast down workaround Explicitly adds single component swizzle for narrowing vector to scalar conversion Signed-off-by: Isaac Marovitz --- XenosRecomp/shader_recompiler.cpp | 308 ++++++++++++++++++++---------- 1 file changed, 210 insertions(+), 98 deletions(-) diff --git a/XenosRecomp/shader_recompiler.cpp b/XenosRecomp/shader_recompiler.cpp index 77d65ce..a385e4b 100644 --- a/XenosRecomp/shader_recompiler.cpp +++ b/XenosRecomp/shader_recompiler.cpp @@ -417,6 +417,12 @@ void ShaderRecompiler::recompile(const AluInstruction& instr) SCALAR_CONSTANT_1 }; + struct OperationResult + { + std::string expression; + size_t componentCount; + }; + auto op = [&](size_t operand) { size_t reg = 0; @@ -521,16 +527,16 @@ void ShaderRecompiler::recompile(const AluInstruction& instr) } } - std::string result; + OperationResult opResult {}; if (negate) - result += '-'; + opResult.expression += '-'; if (abs) - result += "abs("; + opResult.expression += "abs("; - result += regFormatted; - result += '.'; + opResult.expression += regFormatted; + opResult.expression += '.'; switch (operand) { @@ -562,8 +568,10 @@ void ShaderRecompiler::recompile(const AluInstruction& instr) for (size_t i = 0; i < 4; i++) { - if ((mask >> i) & 0x1) - result += SWIZZLES[((swizzle >> (i * 2)) + i) & 0x3]; + if ((mask >> i) & 0x1) { + opResult.componentCount++; + opResult.expression += SWIZZLES[((swizzle >> (i * 2)) + i) & 0x3]; + } } break; @@ -571,41 +579,43 @@ void ShaderRecompiler::recompile(const AluInstruction& instr) case SCALAR_0: case SCALAR_CONSTANT_0: - result += SWIZZLES[((swizzle >> 6) + 3) & 0x3]; + opResult.componentCount = 1; + opResult.expression += SWIZZLES[((swizzle >> 6) + 3) & 0x3]; break; case SCALAR_1: case SCALAR_CONSTANT_1: - result += SWIZZLES[swizzle & 0x3]; + opResult.componentCount = 1; + opResult.expression += SWIZZLES[swizzle & 0x3]; break; } if (abs) - result += ")"; + opResult.expression += ")"; - return result; + return opResult; }; switch (instr.vectorOpcode) { case AluVectorOpcode::KillEq: indent(); - println("clip(any({} == {}) ? -1 : 1);", op(VECTOR_0), op(VECTOR_1)); + println("clip(any({} == {}) ? -1 : 1);", op(VECTOR_0).expression, op(VECTOR_1).expression); break; case AluVectorOpcode::KillGt: indent(); - println("clip(any({} > {}) ? -1 : 1);", op(VECTOR_0), op(VECTOR_1)); + println("clip(any({} > {}) ? -1 : 1);", op(VECTOR_0).expression, op(VECTOR_1).expression); break; case AluVectorOpcode::KillGe: indent(); - println("clip(any({} >= {}) ? -1 : 1);", op(VECTOR_0), op(VECTOR_1)); + println("clip(any({} >= {}) ? -1 : 1);", op(VECTOR_0).expression, op(VECTOR_1).expression); break; case AluVectorOpcode::KillNe: indent(); - println("clip(any({} != {}) ? -1 : 1);", op(VECTOR_0), op(VECTOR_1)); + println("clip(any({} != {}) ? -1 : 1);", op(VECTOR_0).expression, op(VECTOR_1).expression); break; } @@ -674,7 +684,7 @@ void ShaderRecompiler::recompile(const AluInstruction& instr) if (instr.vectorOpcode >= AluVectorOpcode::SetpEqPush && instr.vectorOpcode <= AluVectorOpcode::SetpGePush) { indent(); - print("p0 = {} == 0.0 && {} ", op(VECTOR_0), op(VECTOR_1)); + print("p0 = {} == 0.0 && {} ", op(VECTOR_0).expression, op(VECTOR_1).expression); switch (instr.vectorOpcode) { @@ -697,7 +707,7 @@ void ShaderRecompiler::recompile(const AluInstruction& instr) else if (instr.vectorOpcode >= AluVectorOpcode::MaxA) { indent(); - println("a0 = (int)clamp(floor(({}).w + 0.5), -256.0, 255.0);", op(VECTOR_0)); + println("a0 = (int)clamp(floor(({}).w + 0.5), -256.0, 255.0);", op(VECTOR_0).expression); } uint32_t vectorWriteMask = instr.vectorWriteMask; @@ -732,87 +742,169 @@ void ShaderRecompiler::recompile(const AluInstruction& instr) out += " = "; + bool scalarRegWrite = vectorWriteSize <= 1; + if (vectorWriteSize > 1) print("(float{})(", vectorWriteSize); else - out += "(float)("; + out += "(float)(("; if (instr.vectorSaturate) out += "saturate("; + size_t operationResultComponentCount; + switch (instr.vectorOpcode) { case AluVectorOpcode::Add: - print("{} + {}", op(VECTOR_0), op(VECTOR_1)); - break; + { + auto v0 = op(VECTOR_0); + auto v1 = op(VECTOR_1); + operationResultComponentCount = std::max(v0.componentCount, v1.componentCount); + + print("{} + {}", v0.expression, v1.expression); + break; + } case AluVectorOpcode::Mul: - print("{} * {}", op(VECTOR_0), op(VECTOR_1)); - break; + { + auto v0 = op(VECTOR_0); + auto v1 = op(VECTOR_1); + operationResultComponentCount = std::max(v0.componentCount, v1.componentCount); + + print("{} * {}", v0.expression, v1.expression); + break; + } case AluVectorOpcode::Max: case AluVectorOpcode::MaxA: - print("max({}, {})", op(VECTOR_0), op(VECTOR_1)); - break; + { + auto v0 = op(VECTOR_0); + auto v1 = op(VECTOR_1); + operationResultComponentCount = v0.componentCount; + + print("max({}, {})", v0.expression, v1.expression); + break; + } case AluVectorOpcode::Min: - print("min({}, {})", op(VECTOR_0), op(VECTOR_1)); - break; + { + auto v0 = op(VECTOR_0); + auto v1 = op(VECTOR_1); + operationResultComponentCount = v0.componentCount; + + print("min({}, {})", v0.expression, v1.expression); + break; + } case AluVectorOpcode::Seq: - print("{} == {}", op(VECTOR_0), op(VECTOR_1)); + operationResultComponentCount = 1; + print("{} == {}", op(VECTOR_0).expression, op(VECTOR_1).expression); break; case AluVectorOpcode::Sgt: - print("{} > {}", op(VECTOR_0), op(VECTOR_1)); + operationResultComponentCount = 1; + print("{} > {}", op(VECTOR_0).expression, op(VECTOR_1).expression); break; case AluVectorOpcode::Sge: - print("{} >= {}", op(VECTOR_0), op(VECTOR_1)); + operationResultComponentCount = 1; + print("{} >= {}", op(VECTOR_0).expression, op(VECTOR_1).expression); break; case AluVectorOpcode::Sne: - print("{} != {}", op(VECTOR_0), op(VECTOR_1)); + operationResultComponentCount = 1; + print("{} != {}", op(VECTOR_0).expression, op(VECTOR_1).expression); break; case AluVectorOpcode::Frc: - print("frac({})", op(VECTOR_0)); - break; + { + auto v0 = op(VECTOR_0); + operationResultComponentCount = v0.componentCount; + + print("frac({})", v0.expression); + break; + } case AluVectorOpcode::Trunc: - print("trunc({})", op(VECTOR_0)); - break; + { + auto v0 = op(VECTOR_0); + operationResultComponentCount = v0.componentCount; + + print("trunc({})", v0.expression); + break; + } case AluVectorOpcode::Floor: - print("floor({})", op(VECTOR_0)); - break; + { + auto v0 = op(VECTOR_0); + operationResultComponentCount = v0.componentCount; + + print("floor({})", v0.expression); + break; + } case AluVectorOpcode::Mad: - print("{} * {} + {}", op(VECTOR_0), op(VECTOR_1), op(VECTOR_2)); - break; + { + auto v0 = op(VECTOR_0); + auto v1 = op(VECTOR_1); + auto v2 = op(VECTOR_2); + operationResultComponentCount = std::max(std::max(v0.componentCount, v1.componentCount), v2.componentCount); + + print("{} * {} + {}", v0.expression, v1.expression, v2.expression); + break; + } case AluVectorOpcode::CndEq: - print("selectWrapper({} == 0.0, {}, {})", op(VECTOR_0), op(VECTOR_1), op(VECTOR_2)); - break; + { + auto v0 = op(VECTOR_0); + auto v1 = op(VECTOR_1); + auto v2 = op(VECTOR_2); + operationResultComponentCount = std::max(v1.componentCount, v2.componentCount); + + print("selectWrapper({} == 0.0, {}, {})", v0.expression, v1.expression, v2.expression); + break; + } case AluVectorOpcode::CndGe: - print("selectWrapper({} >= 0.0, {}, {})", op(VECTOR_0), op(VECTOR_1), op(VECTOR_2)); - break; + { + auto v0 = op(VECTOR_0); + auto v1 = op(VECTOR_1); + auto v2 = op(VECTOR_2); + operationResultComponentCount = std::max(v1.componentCount, v2.componentCount); + + print("selectWrapper({} >= 0.0, {}, {})", v0.expression, v1.expression, v2.expression); + break; + } case AluVectorOpcode::CndGt: - print("selectWrapper({} > 0.0, {}, {})", op(VECTOR_0), op(VECTOR_1), op(VECTOR_2)); - break; + { + auto v0 = op(VECTOR_0); + auto v1 = op(VECTOR_1); + auto v2 = op(VECTOR_2); + operationResultComponentCount = std::max(v1.componentCount, v2.componentCount); + + print("selectWrapper({} > 0.0, {}, {})", v0.expression, v1.expression, v2.expression); + break; + } case AluVectorOpcode::Dp4: case AluVectorOpcode::Dp3: - print("dot({}, {})", op(VECTOR_0), op(VECTOR_1)); + operationResultComponentCount = 1; + print("dot({}, {})", op(VECTOR_0).expression, op(VECTOR_1).expression); break; case AluVectorOpcode::Dp2Add: - print("dot({}, {}) + {}", op(VECTOR_0), op(VECTOR_1), op(VECTOR_2)); - break; + { + auto v2 = op(VECTOR_2); + operationResultComponentCount = v2.componentCount; + + print("dot({}, {}) + {}", op(VECTOR_0).expression, op(VECTOR_1).expression, v2.expression); + break; + } case AluVectorOpcode::Cube: + operationResultComponentCount = 4; println("\n#ifdef __air__"); indent(); print("cube(r{}, &cubeMapData)", instr.src1Register); @@ -823,41 +915,61 @@ void ShaderRecompiler::recompile(const AluInstruction& instr) break; case AluVectorOpcode::Max4: - print("max4({})", op(VECTOR_0)); + operationResultComponentCount = 4; + print("max4({})", op(VECTOR_0).expression); break; case AluVectorOpcode::SetpEqPush: case AluVectorOpcode::SetpNePush: case AluVectorOpcode::SetpGtPush: case AluVectorOpcode::SetpGePush: - print("p0 ? 0.0 : {} + 1.0", op(VECTOR_0)); - break; + { + auto v0 = op(VECTOR_0); + operationResultComponentCount = v0.componentCount; + + print("p0 ? 0.0 : {} + 1.0", v0.expression); + break; + } case AluVectorOpcode::KillEq: - print("any({} == {})", op(VECTOR_0), op(VECTOR_1)); + operationResultComponentCount = 1; + print("any({} == {})", op(VECTOR_0).expression, op(VECTOR_1).expression); break; case AluVectorOpcode::KillGt: - print("any({} > {})", op(VECTOR_0), op(VECTOR_1)); + operationResultComponentCount = 1; + print("any({} > {})", op(VECTOR_0).expression, op(VECTOR_1).expression); break; case AluVectorOpcode::KillGe: - print("any({} >= {})", op(VECTOR_0), op(VECTOR_1)); + operationResultComponentCount = 1; + print("any({} >= {})", op(VECTOR_0).expression, op(VECTOR_1).expression); break; case AluVectorOpcode::KillNe: - print("any({} != {})", op(VECTOR_0), op(VECTOR_1)); + operationResultComponentCount = 1; + print("any({} != {})", op(VECTOR_0).expression, op(VECTOR_1).expression); break; case AluVectorOpcode::Dst: - print("dst({}, {})", op(VECTOR_0), op(VECTOR_1)); + operationResultComponentCount = 4; + print("dst({}, {})", op(VECTOR_0).expression, op(VECTOR_1).expression); break; } + out += ")"; + + if (scalarRegWrite) { + if (operationResultComponentCount > 1) + out += ".x"; + + out += ")"; + } + if (instr.vectorSaturate) out += ')'; - out += ");\n"; + out += ";\n"; } if (instr.scalarOpcode != AluScalarOpcode::RetainPrev) @@ -870,27 +982,27 @@ void ShaderRecompiler::recompile(const AluInstruction& instr) switch (instr.scalarOpcode) { case AluScalarOpcode::SetpEq: - print("{} == 0.0", op(SCALAR_0)); + print("{} == 0.0", op(SCALAR_0).expression); break; case AluScalarOpcode::SetpNe: - print("{} != 0.0", op(SCALAR_0)); + print("{} != 0.0", op(SCALAR_0).expression); break; case AluScalarOpcode::SetpGt: - print("{} > 0.0", op(SCALAR_0)); + print("{} > 0.0", op(SCALAR_0).expression); break; case AluScalarOpcode::SetpGe: - print("{} >= 0.0", op(SCALAR_0)); + print("{} >= 0.0", op(SCALAR_0).expression); break; case AluScalarOpcode::SetpInv: - print("{} == 1.0", op(SCALAR_0)); + print("{} == 1.0", op(SCALAR_0).expression); break; case AluScalarOpcode::SetpPop: - print("{} - 1.0 <= 0.0", op(SCALAR_0)); + print("{} - 1.0 <= 0.0", op(SCALAR_0).expression); break; case AluScalarOpcode::SetpClr: @@ -898,7 +1010,7 @@ void ShaderRecompiler::recompile(const AluInstruction& instr) break; case AluScalarOpcode::SetpRstr: - print("{} == 0.0", op(SCALAR_0)); + print("{} == 0.0", op(SCALAR_0).expression); break; } @@ -913,87 +1025,87 @@ void ShaderRecompiler::recompile(const AluInstruction& instr) switch (instr.scalarOpcode) { case AluScalarOpcode::Adds: - print("{} + {}", op(SCALAR_0), op(SCALAR_1)); + print("{} + {}", op(SCALAR_0).expression, op(SCALAR_1).expression); break; case AluScalarOpcode::AddsPrev: - print("{} + ps", op(SCALAR_0)); + print("{} + ps", op(SCALAR_0).expression); break; case AluScalarOpcode::Muls: - print("{} * {}", op(SCALAR_0), op(SCALAR_1)); + print("{} * {}", op(SCALAR_0).expression, op(SCALAR_1).expression); break; case AluScalarOpcode::MulsPrev: case AluScalarOpcode::MulsPrev2: - print("{} * ps", op(SCALAR_0)); + print("{} * ps", op(SCALAR_0).expression); break; case AluScalarOpcode::Maxs: case AluScalarOpcode::MaxAs: case AluScalarOpcode::MaxAsf: - print("max({}, {})", op(SCALAR_0), op(SCALAR_1)); + print("max({}, {})", op(SCALAR_0).expression, op(SCALAR_1).expression); break; case AluScalarOpcode::Mins: - print("min({}, {})", op(SCALAR_0), op(SCALAR_1)); + print("min({}, {})", op(SCALAR_0).expression, op(SCALAR_1).expression); break; case AluScalarOpcode::Seqs: - print("{} == 0.0", op(SCALAR_0)); + print("{} == 0.0", op(SCALAR_0).expression); break; case AluScalarOpcode::Sgts: - print("{} > 0.0", op(SCALAR_0)); + print("{} > 0.0", op(SCALAR_0).expression); break; case AluScalarOpcode::Sges: - print("{} >= 0.0", op(SCALAR_0)); + print("{} >= 0.0", op(SCALAR_0).expression); break; case AluScalarOpcode::Snes: - print("{} != 0.0", op(SCALAR_0)); + print("{} != 0.0", op(SCALAR_0).expression); break; case AluScalarOpcode::Frcs: - print("frac({})", op(SCALAR_0)); + print("frac({})", op(SCALAR_0).expression); break; case AluScalarOpcode::Truncs: - print("trunc({})", op(SCALAR_0)); + print("trunc({})", op(SCALAR_0).expression); break; case AluScalarOpcode::Floors: - print("floor({})", op(SCALAR_0)); + print("floor({})", op(SCALAR_0).expression); break; case AluScalarOpcode::Exp: - print("exp2({})", op(SCALAR_0)); + print("exp2({})", op(SCALAR_0).expression); break; case AluScalarOpcode::Logc: case AluScalarOpcode::Log: - print("clamp(log2({}), FLT_MIN, FLT_MAX)", op(SCALAR_0)); + print("clamp(log2({}), FLT_MIN, FLT_MAX)", op(SCALAR_0).expression); break; case AluScalarOpcode::Rcpc: case AluScalarOpcode::Rcpf: case AluScalarOpcode::Rcp: - print("clamp(rcp({}), FLT_MIN, FLT_MAX)", op(SCALAR_0)); + print("clamp(rcp({}), FLT_MIN, FLT_MAX)", op(SCALAR_0).expression); break; case AluScalarOpcode::Rsqc: case AluScalarOpcode::Rsqf: case AluScalarOpcode::Rsq: - print("clamp(rsqrt({}), FLT_MIN, FLT_MAX)", op(SCALAR_0)); + print("clamp(rsqrt({}), FLT_MIN, FLT_MAX)", op(SCALAR_0).expression); break; case AluScalarOpcode::Subs: - print("{} - {}", op(SCALAR_0), op(SCALAR_1)); + print("{} - {}", op(SCALAR_0).expression, op(SCALAR_1).expression); break; case AluScalarOpcode::SubsPrev: - print("{} - ps", op(SCALAR_0)); + print("{} - ps", op(SCALAR_0).expression); break; case AluScalarOpcode::SetpEq: @@ -1004,11 +1116,11 @@ void ShaderRecompiler::recompile(const AluInstruction& instr) break; case AluScalarOpcode::SetpInv: - print("{0} == 0.0 ? 1.0 : {0}", op(SCALAR_0)); + print("{0} == 0.0 ? 1.0 : {0}", op(SCALAR_0).expression); break; case AluScalarOpcode::SetpPop: - print("p0 ? 0.0 : ({} - 1.0)", op(SCALAR_0)); + print("p0 ? 0.0 : ({} - 1.0)", op(SCALAR_0).expression); break; case AluScalarOpcode::SetpClr: @@ -1016,54 +1128,54 @@ void ShaderRecompiler::recompile(const AluInstruction& instr) break; case AluScalarOpcode::SetpRstr: - print("p0 ? 0.0 : {}", op(SCALAR_0)); + print("p0 ? 0.0 : {}", op(SCALAR_0).expression); break; case AluScalarOpcode::KillsEq: - print("{} == 0.0", op(SCALAR_0)); + print("{} == 0.0", op(SCALAR_0).expression); break; case AluScalarOpcode::KillsGt: - print("{} > 0.0", op(SCALAR_0)); + print("{} > 0.0", op(SCALAR_0).expression); break; case AluScalarOpcode::KillsGe: - print("{} >= 0.0", op(SCALAR_0)); + print("{} >= 0.0", op(SCALAR_0).expression); break; case AluScalarOpcode::KillsNe: - print("{} != 0.0", op(SCALAR_0)); + print("{} != 0.0", op(SCALAR_0).expression); break; case AluScalarOpcode::KillsOne: - print("{} == 1.0", op(SCALAR_0)); + print("{} == 1.0", op(SCALAR_0).expression); break; case AluScalarOpcode::Sqrt: - print("sqrt({})", op(SCALAR_0)); + print("sqrt({})", op(SCALAR_0).expression); break; case AluScalarOpcode::Mulsc0: case AluScalarOpcode::Mulsc1: - print("{} * {}", op(SCALAR_CONSTANT_0), op(SCALAR_CONSTANT_1)); + print("{} * {}", op(SCALAR_CONSTANT_0).expression, op(SCALAR_CONSTANT_1).expression); break; case AluScalarOpcode::Addsc0: case AluScalarOpcode::Addsc1: - print("{} + {}", op(SCALAR_CONSTANT_0), op(SCALAR_CONSTANT_1)); + print("{} + {}", op(SCALAR_CONSTANT_0).expression, op(SCALAR_CONSTANT_1).expression); break; case AluScalarOpcode::Subsc0: case AluScalarOpcode::Subsc1: - print("{} - {}", op(SCALAR_CONSTANT_0), op(SCALAR_CONSTANT_1)); + print("{} - {}", op(SCALAR_CONSTANT_0).expression, op(SCALAR_CONSTANT_1).expression); break; case AluScalarOpcode::Sin: - print("sin({})", op(SCALAR_0)); + print("sin({})", op(SCALAR_0).expression); break; case AluScalarOpcode::Cos: - print("cos({})", op(SCALAR_0)); + print("cos({})", op(SCALAR_0).expression); break; } @@ -1076,11 +1188,11 @@ void ShaderRecompiler::recompile(const AluInstruction& instr) { case AluScalarOpcode::MaxAs: indent(); - println("a0 = (int)clamp(floor({} + 0.5), -256.0, 255.0);", op(SCALAR_0)); + println("a0 = (int)clamp(floor({} + 0.5), -256.0, 255.0);", op(SCALAR_0).expression); break; case AluScalarOpcode::MaxAsf: indent(); - println("a0 = (int)clamp(floor({}), -256.0, 255.0);", op(SCALAR_0)); + println("a0 = (int)clamp(floor({}), -256.0, 255.0);", op(SCALAR_0).expression); break; } }