forked from KhronosGroup/SPIRV-Headers
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
719dc53
commit 5577b4b
Showing
3 changed files
with
1,384 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
# TODO: OVERLOADS! Currently, we are generating multiple functions for the same function with different types. | ||
# e.g. `groupNonUniformIAdd` and `groupNonUniformFAdd` can be simplifed to a single function named `groupNonUniformAdd` | ||
# with multiple overloads. as an extra point, we can drop the requirement for templates and generate the type | ||
|
||
import json | ||
import io | ||
from enum import Enum | ||
from argparse import ArgumentParser | ||
import os | ||
from typing import NamedTuple | ||
from typing import Optional | ||
|
||
head = """#ifdef __HLSL_VERSION | ||
#include "spirv/unified1/spirv.hpp" | ||
#include "spirv/unified1/GLSL.std.450.h" | ||
#endif | ||
#include "nbl/builtin/hlsl/type_traits.hlsl" | ||
namespace nbl | ||
{ | ||
namespace hlsl | ||
{ | ||
#ifdef __HLSL_VERSION | ||
namespace spirv | ||
{ | ||
//! General Decls | ||
template<uint32_t StorageClass, typename T> | ||
using pointer_t = vk::SpirvOpaqueType<spv::OpTypePointer, vk::Literal<vk::integral_constant<uint32_t, StorageClass>>, T>; | ||
// The holy operation that makes addrof possible | ||
template<uint32_t StorageClass, typename T> | ||
[[vk::ext_instruction(spv::OpCopyObject)]] | ||
pointer_t<StorageClass, T> copyObject([[vk::ext_reference]] T value); | ||
//! Std 450 Extended set operations | ||
template<typename SquareMatrix> | ||
[[vk::ext_instruction(GLSLstd450MatrixInverse)]] | ||
SquareMatrix matrixInverse(NBL_CONST_REF_ARG(SquareMatrix) mat); | ||
// Add specializations if you need to emit a `ext_capability` (this means that the instruction needs to forward through an `impl::` struct and so on) | ||
template<typename T, typename U> | ||
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]] | ||
[[vk::ext_instruction(spv::OpBitcast)]] | ||
enable_if_t<is_spirv_type_v<T> && is_spirv_type_v<U>, T> bitcast(U); | ||
template<typename T> | ||
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]] | ||
[[vk::ext_instruction(spv::OpBitcast)]] | ||
uint64_t bitcast(pointer_t<spv::StorageClassPhysicalStorageBuffer,T>); | ||
template<typename T> | ||
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]] | ||
[[vk::ext_instruction(spv::OpBitcast)]] | ||
pointer_t<spv::StorageClassPhysicalStorageBuffer,T> bitcast(uint64_t); | ||
template<class T, class U> | ||
[[vk::ext_instruction(spv::OpBitcast)]] | ||
T bitcast(U); | ||
""" | ||
|
||
foot = """} | ||
#endif | ||
} | ||
} | ||
#endif | ||
""" | ||
|
||
def gen(grammer_path, metadata_path, output_path): | ||
grammer_raw = open(grammer_path, "r").read() | ||
grammer = json.loads(grammer_raw) | ||
del grammer_raw | ||
|
||
metadata_raw = open(metadata_path, "r").read() | ||
metadata = json.loads(metadata_raw) | ||
del metadata_raw | ||
|
||
output = open(output_path, "w", buffering=1024**2) | ||
|
||
builtins = [x for x in grammer["operand_kinds"] if x["kind"] == "BuiltIn"][0]["enumerants"] | ||
execution_modes = [x for x in grammer["operand_kinds"] if x["kind"] == "ExecutionMode"][0]["enumerants"] | ||
group_operations = [x for x in grammer["operand_kinds"] if x["kind"] == "GroupOperation"][0]["enumerants"] | ||
|
||
with output as writer: | ||
writer.write(head) | ||
|
||
writer.write("\n//! Builtins\n") | ||
for name in metadata["builtins"].keys(): | ||
# Validate | ||
builtin_exist = False | ||
for b in builtins: | ||
if b["enumerant"] == name: builtin_exist = True | ||
|
||
if (builtin_exist): | ||
bm = metadata["builtins"][name] | ||
is_mutable = "const" in bm.keys() and bm["mutable"] | ||
writer.write("[[vk::ext_builtin_input(spv::BuiltIn" + name + ")]]\n") | ||
writer.write("static " + ("" if is_mutable else "const ") + bm["type"] + " " + name + ";\n") | ||
else: | ||
raise Exception("Invalid builtin " + name) | ||
|
||
writer.write("\n//! Execution Modes\nnamespace execution_mode\n{") | ||
for em in execution_modes: | ||
name = em["enumerant"] | ||
name_l = name[0].lower() + name[1:] | ||
writer.write("\n\tvoid " + name_l + "()\n\t{\n\t\tvk::ext_execution_mode(spv::ExecutionMode" + name + ");\n\t}\n") | ||
writer.write("}\n") | ||
|
||
writer.write("\n//! Group Operations\nnamespace group_operation\n{\n") | ||
for go in group_operations: | ||
name = go["enumerant"] | ||
value = go["value"] | ||
writer.write("\tstatic const uint32_t " + name + " = " + str(value) + ";\n") | ||
writer.write("}\n") | ||
|
||
writer.write("\n//! Instructions\n") | ||
for instruction in grammer["instructions"]: | ||
match instruction["class"]: | ||
case "Atomic": | ||
match instruction["opname"]: | ||
# integers operate on 2s complement so same op for signed and unsigned | ||
case "OpAtomicIAdd" | "OpAtomicISub" | "OpAtomicIIncrement" | "OpAtomicIDecrement" | "OpAtomicAnd" | "OpAtomicOr" | "OpAtomicXor": | ||
processInst(writer, instruction, InstOptions({"uint32_t", "int32_t"})) | ||
processInst(writer, instruction, InstOptions({"uint32_t", "int32_t"}, Shape.PTR_TEMPLATE)) | ||
processInst(writer, instruction, InstOptions({"uint64_t", "int64_t"})) | ||
processInst(writer, instruction, InstOptions({"uint64_t", "int64_t"}, Shape.PTR_TEMPLATE)) | ||
case "OpAtomicUMin" | "OpAtomicUMax": | ||
processInst(writer, instruction, InstOptions({"uint32_t"})) | ||
processInst(writer, instruction, InstOptions({"uint32_t"}, Shape.PTR_TEMPLATE)) | ||
case "OpAtomicSMin" | "OpAtomicSMax": | ||
processInst(writer, instruction, InstOptions({"int32_t"})) | ||
processInst(writer, instruction, InstOptions({"int32_t"}, Shape.PTR_TEMPLATE)) | ||
case "OpAtomicFMinEXT" | "OpAtomicFMaxEXT" | "OpAtomicFAddEXT": | ||
processInst(writer, instruction, InstOptions({"float"})) | ||
processInst(writer, instruction, InstOptions({"float"}, Shape.PTR_TEMPLATE)) | ||
case _: | ||
processInst(writer, instruction, InstOptions()) | ||
processInst(writer, instruction, InstOptions({}, Shape.PTR_TEMPLATE)) | ||
case "Memory": | ||
processInst(writer, instruction, InstOptions({}, Shape.PTR_TEMPLATE)) | ||
processInst(writer, instruction, InstOptions({}, Shape.PSB_RT)) | ||
case "Barrier": | ||
processInst(writer, instruction, InstOptions()) | ||
case "Bit": | ||
match instruction["opname"]: | ||
case "OpBitFieldUExtract": | ||
processInst(writer, instruction, InstOptions({"Unsigned"})) | ||
case "OpBitFieldSExtract": | ||
processInst(writer, instruction, InstOptions({"Signed"})) | ||
case "OpBitFieldInsert": | ||
processInst(writer, instruction, InstOptions({"Signed", "Unsigned"})) | ||
case "Reserved": | ||
match instruction["opname"]: | ||
case "OpBeginInvocationInterlockEXT" | "OpEndInvocationInterlockEXT": | ||
processInst(writer, instruction, InstOptions()) | ||
case "Non-Uniform": | ||
processInst(writer, instruction, InstOptions()) | ||
case _: continue # TODO | ||
|
||
writer.write(foot) | ||
|
||
class Shape(Enum): | ||
DEFAULT = 0, | ||
PTR_TEMPLATE = 1, # TODO: this is a DXC Workaround | ||
PSB_RT = 2, # PhysicalStorageBuffer Result Type | ||
|
||
class InstOptions(NamedTuple): | ||
allowed_types: list = {} | ||
shape: Shape = Shape.DEFAULT | ||
|
||
def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions): | ||
name = instruction["opname"] | ||
|
||
# Attributes | ||
templates = ["typename T"] | ||
conds = [] | ||
result_ty = "void" | ||
args = [] | ||
|
||
if options.shape == Shape.PTR_TEMPLATE: | ||
templates.append("typename P") | ||
|
||
if options.shape == Shape.PTR_TEMPLATE: | ||
conds.append("is_spirv_type_v<P>") | ||
if len(options.allowed_types) > 0: | ||
allowed_types_conds = [] | ||
for at in options.allowed_types: | ||
if at == "Signed": | ||
allowed_types_conds.append("is_signed_v<T>") | ||
elif at == "Unsigned": | ||
allowed_types_conds.append("is_unsigned_v<T>") | ||
else: | ||
allowed_types_conds.append("is_same_v<T, " + at + ">") | ||
conds.append("(" + " || ".join(allowed_types_conds) + ")") | ||
|
||
if "operands" in instruction: | ||
for operand in instruction["operands"]: | ||
op_name = operand["name"].strip("'") if "name" in operand else None | ||
op_name = op_name[0].lower() + op_name[1:] if (op_name != None) else "" | ||
match operand["kind"]: | ||
case "IdResultType" | "IdResult": | ||
result_ty = "T" | ||
case "IdRef": | ||
match operand["name"]: | ||
case "'Pointer'": | ||
if options.shape == Shape.PTR_TEMPLATE: | ||
args.append("P " + op_name) | ||
elif options.shape == Shape.PSB_RT: | ||
args.append("pointer_t<spv::StorageClassPhysicalStorageBuffer, T> " + op_name) | ||
else: | ||
args.append("[[vk::ext_reference]] T " + op_name) | ||
case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'": | ||
args.append("T " + op_name) | ||
case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'": | ||
args.append("uint32_t " + op_name) | ||
case "'Predicate'": args.append("bool " + op_name) | ||
case "'ClusterSize'": | ||
if "quantifier" in operand and operand["quantifier"] == "?": continue # TODO: overload | ||
else: return # TODO | ||
case _: return # TODO | ||
case "IdScope": args.append("uint32_t " + op_name.lower() + "Scope") | ||
case "IdMemorySemantics": args.append(" uint32_t " + op_name) | ||
case "GroupOperation": args.append("[[vk::ext_literal]] uint32_t " + op_name) | ||
case "MemoryAccess": | ||
writeInst(writer, templates, name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t memoryAccess"]) | ||
writeInst(writer, templates, name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"]) | ||
writeInst(writer, templates + ["uint32_t alignment"], name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"]) | ||
case _: return # TODO | ||
|
||
writeInst(writer, templates, name, conds, result_ty, args) | ||
|
||
|
||
def writeInst(writer: io.TextIOWrapper, templates, name, conds, result_ty, args): | ||
fn_name = name[2].lower() + name[3:] | ||
writer.write("template<" + ", ".join(templates) + ">\n[[vk::ext_instruction(spv::" + name + ")]]\n") | ||
if len(conds) > 0: | ||
writer.write("enable_if_t<" + " && ".join(conds) + ", " + result_ty + ">") | ||
else: | ||
writer.write(result_ty) | ||
writer.write(" " + fn_name + "(" + ", ".join(args) + ");\n\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
script_dir_path = os.path.abspath(os.path.dirname(__file__)) | ||
|
||
parser = ArgumentParser(description="Generate HLSL from SPIR-V instructions") | ||
parser.add_argument("output", type=str, help="HLSL output file") | ||
parser.add_argument("--grammer", required=False, type=str, help="Input SPIR-V grammer JSON file", default=os.path.join(script_dir_path, "../../include/spirv/unified1/spirv.core.grammar.json")) | ||
parser.add_argument("--metadata", required=False, type=str, help="Input SPIR-V Instructions/BuiltIns type mapping/attributes/etc", default=os.path.join(script_dir_path, "metadata.json")) | ||
args = parser.parse_args() | ||
|
||
gen(args.grammer, args.metadata, args.output) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
{ | ||
"builtins": { | ||
"HelperInvocation": { | ||
"type": "bool", | ||
"mutable": true | ||
}, | ||
"Position": { | ||
"type": "float32_t4" | ||
}, | ||
"VertexIndex": { | ||
"type": "uint32_t", | ||
"mutable": true | ||
}, | ||
"InstanceIndex": { | ||
"type": "uint32_t", | ||
"mutable": true | ||
}, | ||
"NumWorkgroups": { | ||
"type": "uint32_t3", | ||
"mutable": true | ||
}, | ||
"WorkgroupId": { | ||
"type": "uint32_t3", | ||
"mutable": true | ||
}, | ||
"LocalInvocationId": { | ||
"type": "uint32_t3", | ||
"mutable": true | ||
}, | ||
"GlobalInvocationId": { | ||
"type": "uint32_t3", | ||
"mutable": true | ||
}, | ||
"LocalInvocationIndex": { | ||
"type": "uint32_t", | ||
"mutable": true | ||
}, | ||
"SubgroupEqMask": { | ||
"type": "uint32_t4" | ||
}, | ||
"SubgroupGeMask": { | ||
"type": "uint32_t4" | ||
}, | ||
"SubgroupGtMask": { | ||
"type": "uint32_t4" | ||
}, | ||
"SubgroupLeMask": { | ||
"type": "uint32_t4" | ||
}, | ||
"SubgroupLtMask": { | ||
"type": "uint32_t4" | ||
}, | ||
"SubgroupSize": { | ||
"type": "uint32_t" | ||
}, | ||
"NumSubgroups": { | ||
"type": "uint32_t" | ||
}, | ||
"SubgroupId": { | ||
"type": "uint32_t" | ||
}, | ||
"SubgroupLocalInvocationId": { | ||
"type": "uint32_t" | ||
} | ||
} | ||
} |
Oops, something went wrong.