Skip to content

Commit

Permalink
Add HLSL generator
Browse files Browse the repository at this point in the history
  • Loading branch information
alichraghi committed Aug 30, 2024
1 parent 719dc53 commit 5577b4b
Show file tree
Hide file tree
Showing 3 changed files with 1,384 additions and 0 deletions.
256 changes: 256 additions & 0 deletions tools/hlsl_generator/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# TODO: OVERLOADS! Currently, we are generating multiple functions for the same function with different types.
# e.g. `groupNonUniformIAdd` and `groupNonUniformFAdd` can be simplifed to a single function named `groupNonUniformAdd`
# with multiple overloads. as an extra point, we can drop the requirement for templates and generate the type

import json
import io
from enum import Enum
from argparse import ArgumentParser
import os
from typing import NamedTuple
from typing import Optional

head = """#ifdef __HLSL_VERSION
#include "spirv/unified1/spirv.hpp"
#include "spirv/unified1/GLSL.std.450.h"
#endif
#include "nbl/builtin/hlsl/type_traits.hlsl"
namespace nbl
{
namespace hlsl
{
#ifdef __HLSL_VERSION
namespace spirv
{
//! General Decls
template<uint32_t StorageClass, typename T>
using pointer_t = vk::SpirvOpaqueType<spv::OpTypePointer, vk::Literal<vk::integral_constant<uint32_t, StorageClass>>, T>;
// The holy operation that makes addrof possible
template<uint32_t StorageClass, typename T>
[[vk::ext_instruction(spv::OpCopyObject)]]
pointer_t<StorageClass, T> copyObject([[vk::ext_reference]] T value);
//! Std 450 Extended set operations
template<typename SquareMatrix>
[[vk::ext_instruction(GLSLstd450MatrixInverse)]]
SquareMatrix matrixInverse(NBL_CONST_REF_ARG(SquareMatrix) mat);
// Add specializations if you need to emit a `ext_capability` (this means that the instruction needs to forward through an `impl::` struct and so on)
template<typename T, typename U>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpBitcast)]]
enable_if_t<is_spirv_type_v<T> && is_spirv_type_v<U>, T> bitcast(U);
template<typename T>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpBitcast)]]
uint64_t bitcast(pointer_t<spv::StorageClassPhysicalStorageBuffer,T>);
template<typename T>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpBitcast)]]
pointer_t<spv::StorageClassPhysicalStorageBuffer,T> bitcast(uint64_t);
template<class T, class U>
[[vk::ext_instruction(spv::OpBitcast)]]
T bitcast(U);
"""

foot = """}
#endif
}
}
#endif
"""

def gen(grammer_path, metadata_path, output_path):
grammer_raw = open(grammer_path, "r").read()
grammer = json.loads(grammer_raw)
del grammer_raw

metadata_raw = open(metadata_path, "r").read()
metadata = json.loads(metadata_raw)
del metadata_raw

output = open(output_path, "w", buffering=1024**2)

builtins = [x for x in grammer["operand_kinds"] if x["kind"] == "BuiltIn"][0]["enumerants"]
execution_modes = [x for x in grammer["operand_kinds"] if x["kind"] == "ExecutionMode"][0]["enumerants"]
group_operations = [x for x in grammer["operand_kinds"] if x["kind"] == "GroupOperation"][0]["enumerants"]

with output as writer:
writer.write(head)

writer.write("\n//! Builtins\n")
for name in metadata["builtins"].keys():
# Validate
builtin_exist = False
for b in builtins:
if b["enumerant"] == name: builtin_exist = True

if (builtin_exist):
bm = metadata["builtins"][name]
is_mutable = "const" in bm.keys() and bm["mutable"]
writer.write("[[vk::ext_builtin_input(spv::BuiltIn" + name + ")]]\n")
writer.write("static " + ("" if is_mutable else "const ") + bm["type"] + " " + name + ";\n")
else:
raise Exception("Invalid builtin " + name)

writer.write("\n//! Execution Modes\nnamespace execution_mode\n{")
for em in execution_modes:
name = em["enumerant"]
name_l = name[0].lower() + name[1:]
writer.write("\n\tvoid " + name_l + "()\n\t{\n\t\tvk::ext_execution_mode(spv::ExecutionMode" + name + ");\n\t}\n")
writer.write("}\n")

writer.write("\n//! Group Operations\nnamespace group_operation\n{\n")
for go in group_operations:
name = go["enumerant"]
value = go["value"]
writer.write("\tstatic const uint32_t " + name + " = " + str(value) + ";\n")
writer.write("}\n")

writer.write("\n//! Instructions\n")
for instruction in grammer["instructions"]:
match instruction["class"]:
case "Atomic":
match instruction["opname"]:
# integers operate on 2s complement so same op for signed and unsigned
case "OpAtomicIAdd" | "OpAtomicISub" | "OpAtomicIIncrement" | "OpAtomicIDecrement" | "OpAtomicAnd" | "OpAtomicOr" | "OpAtomicXor":
processInst(writer, instruction, InstOptions({"uint32_t", "int32_t"}))
processInst(writer, instruction, InstOptions({"uint32_t", "int32_t"}, Shape.PTR_TEMPLATE))
processInst(writer, instruction, InstOptions({"uint64_t", "int64_t"}))
processInst(writer, instruction, InstOptions({"uint64_t", "int64_t"}, Shape.PTR_TEMPLATE))
case "OpAtomicUMin" | "OpAtomicUMax":
processInst(writer, instruction, InstOptions({"uint32_t"}))
processInst(writer, instruction, InstOptions({"uint32_t"}, Shape.PTR_TEMPLATE))
case "OpAtomicSMin" | "OpAtomicSMax":
processInst(writer, instruction, InstOptions({"int32_t"}))
processInst(writer, instruction, InstOptions({"int32_t"}, Shape.PTR_TEMPLATE))
case "OpAtomicFMinEXT" | "OpAtomicFMaxEXT" | "OpAtomicFAddEXT":
processInst(writer, instruction, InstOptions({"float"}))
processInst(writer, instruction, InstOptions({"float"}, Shape.PTR_TEMPLATE))
case _:
processInst(writer, instruction, InstOptions())
processInst(writer, instruction, InstOptions({}, Shape.PTR_TEMPLATE))
case "Memory":
processInst(writer, instruction, InstOptions({}, Shape.PTR_TEMPLATE))
processInst(writer, instruction, InstOptions({}, Shape.PSB_RT))
case "Barrier":
processInst(writer, instruction, InstOptions())
case "Bit":
match instruction["opname"]:
case "OpBitFieldUExtract":
processInst(writer, instruction, InstOptions({"Unsigned"}))
case "OpBitFieldSExtract":
processInst(writer, instruction, InstOptions({"Signed"}))
case "OpBitFieldInsert":
processInst(writer, instruction, InstOptions({"Signed", "Unsigned"}))
case "Reserved":
match instruction["opname"]:
case "OpBeginInvocationInterlockEXT" | "OpEndInvocationInterlockEXT":
processInst(writer, instruction, InstOptions())
case "Non-Uniform":
processInst(writer, instruction, InstOptions())
case _: continue # TODO

writer.write(foot)

class Shape(Enum):
DEFAULT = 0,
PTR_TEMPLATE = 1, # TODO: this is a DXC Workaround
PSB_RT = 2, # PhysicalStorageBuffer Result Type

class InstOptions(NamedTuple):
allowed_types: list = {}
shape: Shape = Shape.DEFAULT

def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
name = instruction["opname"]

# Attributes
templates = ["typename T"]
conds = []
result_ty = "void"
args = []

if options.shape == Shape.PTR_TEMPLATE:
templates.append("typename P")

if options.shape == Shape.PTR_TEMPLATE:
conds.append("is_spirv_type_v<P>")
if len(options.allowed_types) > 0:
allowed_types_conds = []
for at in options.allowed_types:
if at == "Signed":
allowed_types_conds.append("is_signed_v<T>")
elif at == "Unsigned":
allowed_types_conds.append("is_unsigned_v<T>")
else:
allowed_types_conds.append("is_same_v<T, " + at + ">")
conds.append("(" + " || ".join(allowed_types_conds) + ")")

if "operands" in instruction:
for operand in instruction["operands"]:
op_name = operand["name"].strip("'") if "name" in operand else None
op_name = op_name[0].lower() + op_name[1:] if (op_name != None) else ""
match operand["kind"]:
case "IdResultType" | "IdResult":
result_ty = "T"
case "IdRef":
match operand["name"]:
case "'Pointer'":
if options.shape == Shape.PTR_TEMPLATE:
args.append("P " + op_name)
elif options.shape == Shape.PSB_RT:
args.append("pointer_t<spv::StorageClassPhysicalStorageBuffer, T> " + op_name)
else:
args.append("[[vk::ext_reference]] T " + op_name)
case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'":
args.append("T " + op_name)
case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'":
args.append("uint32_t " + op_name)
case "'Predicate'": args.append("bool " + op_name)
case "'ClusterSize'":
if "quantifier" in operand and operand["quantifier"] == "?": continue # TODO: overload
else: return # TODO
case _: return # TODO
case "IdScope": args.append("uint32_t " + op_name.lower() + "Scope")
case "IdMemorySemantics": args.append(" uint32_t " + op_name)
case "GroupOperation": args.append("[[vk::ext_literal]] uint32_t " + op_name)
case "MemoryAccess":
writeInst(writer, templates, name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
writeInst(writer, templates, name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
writeInst(writer, templates + ["uint32_t alignment"], name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"])
case _: return # TODO

writeInst(writer, templates, name, conds, result_ty, args)


def writeInst(writer: io.TextIOWrapper, templates, name, conds, result_ty, args):
fn_name = name[2].lower() + name[3:]
writer.write("template<" + ", ".join(templates) + ">\n[[vk::ext_instruction(spv::" + name + ")]]\n")
if len(conds) > 0:
writer.write("enable_if_t<" + " && ".join(conds) + ", " + result_ty + ">")
else:
writer.write(result_ty)
writer.write(" " + fn_name + "(" + ", ".join(args) + ");\n\n")


if __name__ == "__main__":
script_dir_path = os.path.abspath(os.path.dirname(__file__))

parser = ArgumentParser(description="Generate HLSL from SPIR-V instructions")
parser.add_argument("output", type=str, help="HLSL output file")
parser.add_argument("--grammer", required=False, type=str, help="Input SPIR-V grammer JSON file", default=os.path.join(script_dir_path, "../../include/spirv/unified1/spirv.core.grammar.json"))
parser.add_argument("--metadata", required=False, type=str, help="Input SPIR-V Instructions/BuiltIns type mapping/attributes/etc", default=os.path.join(script_dir_path, "metadata.json"))
args = parser.parse_args()

gen(args.grammer, args.metadata, args.output)

66 changes: 66 additions & 0 deletions tools/hlsl_generator/metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
{
"builtins": {
"HelperInvocation": {
"type": "bool",
"mutable": true
},
"Position": {
"type": "float32_t4"
},
"VertexIndex": {
"type": "uint32_t",
"mutable": true
},
"InstanceIndex": {
"type": "uint32_t",
"mutable": true
},
"NumWorkgroups": {
"type": "uint32_t3",
"mutable": true
},
"WorkgroupId": {
"type": "uint32_t3",
"mutable": true
},
"LocalInvocationId": {
"type": "uint32_t3",
"mutable": true
},
"GlobalInvocationId": {
"type": "uint32_t3",
"mutable": true
},
"LocalInvocationIndex": {
"type": "uint32_t",
"mutable": true
},
"SubgroupEqMask": {
"type": "uint32_t4"
},
"SubgroupGeMask": {
"type": "uint32_t4"
},
"SubgroupGtMask": {
"type": "uint32_t4"
},
"SubgroupLeMask": {
"type": "uint32_t4"
},
"SubgroupLtMask": {
"type": "uint32_t4"
},
"SubgroupSize": {
"type": "uint32_t"
},
"NumSubgroups": {
"type": "uint32_t"
},
"SubgroupId": {
"type": "uint32_t"
},
"SubgroupLocalInvocationId": {
"type": "uint32_t"
}
}
}
Loading

0 comments on commit 5577b4b

Please sign in to comment.