之前我们介绍过DX12_Mesh Shaders Render,但是基于MeshShader我们能做的还很多,比如实例化和剔除(视锥与遮挡),这也就直接解决了现在主流的GPU-Driven管线方法,是不是一举两得了(毕竟MS就是变种的CS嘛)。那么我们一步步来,先来说一下Mesh Shader实例化如何实现吧。
- 一种是将实例化数据放在与VertexBuffer绑定位同级的管线布局上并设置管线布局,之后调用DrawInstance即可;
- 另一种就是放到常量缓冲区上,调用DrawInstance后在Shader中使用SV_InstanceID/gl_InstanceIndex进行绘制即可。
其实Mesh Shader的实例化就是和第二种方式一样,使用实例化的数据直接在MS中生成对应的Meshlet数据使用PS接上即可,当然了这种方式和传统API的实例化还是有区别的:
- 效率比传统实例化快(MS->PS > VS->PS,更不说把臃肿的TS与GS加上了),原因是Mesh Shader的数据量更适合硬件并行计算,充分发挥了GPU算力
- 更加灵活,可拓展完全GPU-Driven的算法实现
void D3D12MeshletInstancing::RegenerateInstances()
m_updateInstances = true;
const float radius = m_model.GetBoundingSphere().Radius;
const float padding = 0.0f;
const float spacing = (1.0f + padding) * radius * 2.0f;
const uint32_t width = m_instanceLevel * 2 + 1;
const float extents = spacing * m_instanceLevel;
m_instanceCount = width * width * width;
const uint32_t instanceBufferSize = (uint32_t)GetAlignedSize(m_instanceCount * sizeof(Instance));
// 实例化数量改变时重新创建默认堆数据
if (!m_instanceBuffer || m_instanceBuffer->GetDesc().Width < instanceBufferSize)
const CD3DX12_HEAP_PROPERTIES instanceBufferDefaultHeapProps(D3D12_HEAP_TYPE_DEFAULT);
const CD3DX12_RESOURCE_DESC instanceBufferDesc = CD3DX12_RESOURCE_DESC::Buffer(instanceBufferSize);
// 创建Buffer(常变数据,所以放共享显存中,最后析构再UnMap)
const CD3DX12_HEAP_PROPERTIES instanceBufferUploadHeapProps(D3D12_HEAP_TYPE_UPLOAD);
// 创建上传堆
m_instanceUpload->Map(0, nullptr, reinterpret_cast<void**>(&m_instanceData));
// CPU更新实例化数据
for (uint32_t i = 0; i < m_instanceCount; ++i)
XMVECTOR index = XMVectorSet(float(i % width), float((i / width) % width), float(i / (width * width)), 0);
XMVECTOR location = index * spacing - XMVectorReplicate(extents);
XMMATRIX world = XMMatrixTranslationFromVector(location);
auto& inst = m_instanceData[i];
XMStoreFloat4x4(&inst.World, XMMatrixTranspose(world));
XMStoreFloat4x4(&inst.WorldInvTranspose, XMMatrixTranspose(XMMatrixInverse(nullptr, XMMatrixTranspose(world))));
// 仅实例化场景变更时更新
if (m_updateInstances)
const auto toCopyBarrier = CD3DX12_RESOURCE_BARRIER::Transition(m_instanceBuffer.Get(), D3D12_RESOURCE_STATE_GENERIC_READ, D3D12_RESOURCE_STATE_COPY_DEST);
m_commandList->ResourceBarrier(1, &toCopyBarrier);
m_commandList->CopyResource(m_instanceBuffer.Get(), m_instanceUpload.Get());
const auto toGenericBarrier = CD3DX12_RESOURCE_BARRIER::Transition(m_instanceBuffer.Get(), D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_GENERIC_READ);
m_commandList->ResourceBarrier(1, &toGenericBarrier);
m_updateInstances = false;
二、Instance Mesh Shader实现
#define ROOT_SIG "CBV(b0), \
RootConstants(b1, num32bitconstants=2), \
RootConstants(b2, num32bitconstants=3), \
SRV(t0), \
SRV(t1), \
SRV(t2), \
SRV(t3), \
struct Constants
float4x4 World;
float4x4 WorldView;
float4x4 WorldViewProj;
uint DrawMeshlets;
struct Instance
float4x4 World;
float4x4 WorldInvTranspose;
struct DrawParams
uint InstanceCount;
uint InstanceOffset;
struct MeshInfo
uint IndexBytes;
uint MeshletCount;
uint MeshletOffset;
struct Vertex
float3 Position;
float3 Normal;
struct VertexOut
float4 PositionHS : SV_Position;
float3 PositionVS : POSITION0;
float3 Normal : NORMAL0;
uint MeshletIndex : COLOR0;
struct Meshlet
uint VertCount;
uint VertOffset;
uint PrimCount;
uint PrimOffset;
ConstantBuffer<Constants> Globals : register(b0);
ConstantBuffer<DrawParams> DrawParams : register(b1);
ConstantBuffer<MeshInfo> MeshInfo : register(b2);
StructuredBuffer<Vertex> Vertices : register(t0);
StructuredBuffer<Meshlet> Meshlets : register(t1);
ByteAddressBuffer UniqueVertexIndices : register(t2);
StructuredBuffer<uint> PrimitiveIndices : register(t3);
StructuredBuffer<Instance> Instances : register(t4);
// Data Loaders
uint3 UnpackPrimitive(uint primitive)
// 从32位的uint数据中解压三角形(10 bit)
return uint3(primitive & 0x3FF, (primitive >> 10) & 0x3FF, (primitive >> 20) & 0x3FF);
uint3 GetPrimitive(Meshlet m, uint index)
return UnpackPrimitive(PrimitiveIndices[m.PrimOffset + index]);
uint GetVertexIndex(Meshlet m, uint localIndex)
localIndex = m.VertOffset + localIndex;
if (MeshInfo.IndexBytes == 4) // 32-bit Vertex Indices
return UniqueVertexIndices.Load(localIndex * 4);
else // 16-bit Vertex Indices
// Byte address must be 4-byte aligned.
uint wordOffset = (localIndex & 0x1);
uint byteOffset = (localIndex / 2) * 4;
// Grab the pair of 16-bit indices, shift & mask off proper 16-bits.
uint indexPair = UniqueVertexIndices.Load(byteOffset);
uint index = (indexPair >> (wordOffset * 16)) & 0xffff;
return index;
VertexOut GetVertexAttributes(uint meshletIndex, uint vertexIndex)
Vertex v = Vertices[vertexIndex];
VertexOut vout;
vout.PositionVS = mul(float4(v.Position, 1), Globals.WorldView).xyz;
vout.PositionHS = mul(float4(v.Position, 1), Globals.WorldViewProj);
vout.Normal = mul(float4(v.Normal, 0), Globals.World).xyz;
vout.MeshletIndex = meshletIndex;
return vout;
[NumThreads(128, 1, 1)]
void main(
uint gtid : SV_GroupThreadID,
uint gid : SV_GroupID,
out indices uint3 tris[126],
out vertices VertexOut verts[64]
uint meshletIndex = gid / DrawParams.InstanceCount;
Meshlet m = Meshlets[meshletIndex];
// 实例数确定:一般情况下每个线程组只有一个实例
uint startInstance = gid % DrawParams.InstanceCount;
uint instanceCount = 1;
// 最后一个Meshlet单独处理- 由一个线程组提交的多个实例
if (meshletIndex == MeshInfo.MeshletCount - 1)
const uint instancesPerGroup = min(MAX_VERTS / m.VertCount, MAX_PRIMS / m.PrimCount);
// 确定这个组中有多少个实例
uint unpackedGroupCount = (MeshInfo.MeshletCount - 1) * DrawParams.InstanceCount;
uint packedIndex = gid - unpackedGroupCount;
startInstance = packedIndex * instancesPerGroup;
instanceCount = min(DrawParams.InstanceCount - startInstance, instancesPerGroup);
// 计算我们的需要输出的顶点与索引数
uint vertCount = m.VertCount * instanceCount;
uint primCount = m.PrimCount * instanceCount;
SetMeshOutputCounts(vertCount, primCount);
// 数据导出
if (gtid < vertCount)
uint readIndex = gtid % m.VertCount; // Wrap our reads for packed instancing.
uint instanceId = gtid / m.VertCount; // Instance index into this threadgroup's instances (only non-zero for packed threadgroups.)
uint vertexIndex = GetVertexIndex(m, readIndex);
uint instanceIndex = startInstance + instanceId;
verts[gtid] = GetVertexAttributes(meshletIndex, vertexIndex, instanceIndex);
if (gtid < primCount)
uint readIndex = gtid % m.PrimCount; // Wrap our reads for packed instancing.
uint instanceId = gtid / m.PrimCount; // Instance index within this threadgroup (only non-zero in last meshlet threadgroups.)
// Must offset the vertex indices to this thread's instanced verts
tris[gtid] = GetPrimitive(m, readIndex) + (m.VertCount * instanceId);
struct Constants
float4x4 World;
float4x4 WorldView;
float4x4 WorldViewProj;
uint DrawMeshlets;
struct VertexOut
float4 PositionHS : SV_Position;
float3 PositionVS : POSITION0;
float3 Normal : NORMAL0;
uint MeshletIndex : COLOR0;
ConstantBuffer<Constants> Globals : register(b0);
float4 main(VertexOut input) : SV_TARGET
float ambientIntensity = 0.1;
float3 lightColor = float3(1, 1, 1);
float3 lightDir = -normalize(float3(1, -1, 1));
float3 diffuseColor;
float shininess;
if (Globals.DrawMeshlets)
uint meshletIndex = input.MeshletIndex;
diffuseColor = float3(
float(meshletIndex & 1),
float(meshletIndex & 3) / 4,
float(meshletIndex & 7) / 8);
shininess = 16.0;
diffuseColor = 0.8;
shininess = 64.0;
float3 normal = normalize(input.Normal);
// Do some fancy Blinn-Phong shading!
float cosAngle = saturate(dot(normal, lightDir));
float3 viewDir = -normalize(input.PositionVS);
float3 halfAngle = normalize(lightDir + viewDir);
float blinnTerm = saturate(dot(normal, halfAngle));
blinnTerm = cosAngle != 0.0 ? blinnTerm : 0.0;
blinnTerm = pow(blinnTerm, shininess);
float3 finalColor = (cosAngle + blinnTerm + ambientIntensity) * diffuseColor;
return float4(finalColor, 1);