Advertisement
Zgragselus

BVH Traversal in HLSL

Apr 15th, 2023 (edited)
858
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 6.62 KB | None | 0 0
  1. #define BVH_STACK_SIZE 64
  2.  
  3. [numthreads(32, 32, 1)]
  4. void RenderPass(uint3 GI : SV_GroupID, uint3 DTid : SV_DispatchThreadID, uint3 GTid : SV_GroupThreadID)
  5. {
  6.     // This is a testing traversal that takes viewport and casts primary ray for each pixel (ray directions are generated in separate kernel)
  7.     uint width = asuint(ResolutionParams.x);
  8.     uint height = asuint(ResolutionParams.y);
  9.  
  10.     int id = DTid.y * width + DTid.x;
  11.    
  12.     // Get ray from buffer and initialize
  13.     Ray r = Rays[id];
  14.     float4 o = r.Origin;
  15.     float4 d = r.Direction;
  16.     float4 inv = r.Inverse;
  17.     float4 oinv = o * inv;
  18.  
  19.     // Initialize stack for BVH stack traversal (0xFFFFFFFF being the entrypoint sentinel)
  20.     uint node_id = 0;
  21.     uint stack[BVH_STACK_SIZE];
  22.     uint stack_ptr = 0;
  23.     stack[stack_ptr] = 0xFFFFFFFF;
  24.     // Entrypoint sentinel for BLAS
  25.     int meshbvh_stack_ptr = -1;
  26.  
  27.     // Traversal variables (holding results)
  28.     float tmin = 0.0f;
  29.     float tmax = 10000.0f;
  30.     float bU = 0.0f;
  31.     float bV = 0.0f;
  32.     float dist = tmax;
  33.     bool hit = false;
  34.     float4 temp = float4(0.0f, 0.0f, 0.0f, 1.0f);
  35.     InstanceNode instance = Instances[0];
  36.  
  37.     // So, here the craziness begins, natively this should be a WHILE loop waiting until we are on entrypoint sentinel,
  38.     // i.e. 0xFFFFFFFF - BUT - that is extremely slow, results easily get between 100-200ms for 1920x1080 rays. This being said
  39.     // just switching over to for loop without static parameter (i.e. for (;;)) yields the same slow results.
  40.     //
  41.     // BUUUUT! Switching to fixed number of steps (i.e. 1000) results in much faster rendering (<15ms), fun fact this time holds
  42.     // even when I increase resolution twice. There is never 1000 steps in the traversal at all.
  43.     //
  44.     // I could leave it as-is, but the problem is - occupancy. Using speculative while-while traversal or persistent threads
  45.     // is just much more efficient, but trying to do that on a for loop is again - slowing it down.
  46.     int i;
  47.     [loop] for (i = 0; i < 1000; i++)
  48.     //while (node_id != 0xFFFFFFFF) // This is significantly slower
  49.     {
  50.         temp.x += 0.1f;
  51.  
  52.         // Interior node hit
  53.         [branch] if (ASTreeData[node_id].PrimitiveCount == 0)
  54.         {
  55.             // Fetch children bounding boxes
  56.             float4 n0xy = ASTreeData[node_id].LXY;
  57.             float4 n1xy = ASTreeData[node_id].RXY;
  58.             float4 nz = ASTreeData[node_id].LRZ;
  59.  
  60.             // Test against both child AABBs
  61.             float c0lox = n0xy.x * inv.x - oinv.x;
  62.             float c0hix = n0xy.y * inv.x - oinv.x;
  63.             float c0loy = n0xy.z * inv.y - oinv.y;
  64.             float c0hiy = n0xy.w * inv.y - oinv.y;
  65.             float c0loz = nz.x * inv.z - oinv.z;
  66.             float c0hiz = nz.y * inv.z - oinv.z;
  67.             float c1loz = nz.z * inv.z - oinv.z;
  68.             float c1hiz = nz.w * inv.z - oinv.z;
  69.             float c0min = max(max(min(c0lox, c0hix), min(c0loy, c0hiy)), max(min(c0loz, c0hiz), tmin));
  70.             float c0max = min(min(max(c0lox, c0hix), max(c0loy, c0hiy)), min(max(c0loz, c0hiz), tmax));
  71.             float c1lox = n1xy.x * inv.x - oinv.x;
  72.             float c1hix = n1xy.y * inv.x - oinv.x;
  73.             float c1loy = n1xy.z * inv.y - oinv.y;
  74.             float c1hiy = n1xy.w * inv.y - oinv.y;
  75.             float c1min = max(max(min(c1lox, c1hix), min(c1loy, c1hiy)), max(min(c1loz, c1hiz), tmin));
  76.             float c1max = min(min(max(c1lox, c1hix), max(c1loy, c1hiy)), min(max(c1loz, c1hiz), tmax));
  77.  
  78.             // Which child AABBs were hit
  79.             bool traverseChild0 = (c0max >= c0min);
  80.             bool traverseChild1 = (c1max >= c1min);
  81.  
  82.             // If no children was hit, get node from stack
  83.             if (!traverseChild0 && !traverseChild1)
  84.             {
  85.                 // If we're on entrypoint sentinel of BLAS, get back into TLAS - reset ray
  86.                 if (stack_ptr == meshbvh_stack_ptr)
  87.                 {
  88.                     meshbvh_stack_ptr = -1;
  89.                     o = r.Origin;
  90.                     d = r.Direction;
  91.                     inv = r.Inverse;
  92.                     oinv = o * inv;
  93.                 }
  94.  
  95.                 node_id = stack[stack_ptr];
  96.                 stack_ptr--;
  97.             }
  98.             // One or more child nodes was hit - continue in first one, push second (further) on stack (if both were hit)
  99.             else if (traverseChild0 || traverseChild1)
  100.             {
  101.                 uint first_child = node_id + 1;
  102.                 uint second_child = ASTreeData[node_id].PrimitiveOffset;
  103.                 node_id = (traverseChild0) ? first_child : second_child;
  104.  
  105.                 if (traverseChild0 && traverseChild1)
  106.                 {
  107.                     if (c1min < c0min)
  108.                     {
  109.                         node_id = second_child;
  110.                         stack_ptr++;
  111.                         stack[stack_ptr] = first_child;
  112.                     }
  113.                     else
  114.                     {
  115.                         stack_ptr++;
  116.                         stack[stack_ptr] = second_child;
  117.                     }
  118.                 }
  119.             }
  120.         }
  121.         // Leaf node of TLAS
  122.         else if (ASTreeData[node_id].PrimitiveCount == -1)
  123.         {
  124.             // Store entrypoint sentinel for BLAS and continue traversal in BLAS
  125.             meshbvh_stack_ptr = stack_ptr;
  126.  
  127.             uint blas_offset = ASTreeData[node_id].PrimitiveOffset;
  128.             uint instance_index = ASIndexData[blas_offset];
  129.             instance = Instances[instance_index];
  130.  
  131.             node_id = ASTreeNodes[Geometries[instance.GeometryNode].BVHNode + 1].Offset / 64;
  132.  
  133.             o = mul(r.Origin, instance.TransformInverse);
  134.             d = mul(r.Direction, instance.TransformInverse);
  135.             inv = rcp(d);
  136.             oinv = o * inv;
  137.         }
  138.         // Leaf node of BLAS
  139.         else
  140.         {
  141.             // Intersect ALL triangles in the node
  142.             if (ASTreeData[node_id].PrimitiveCount > 0)
  143.             {
  144.                 GeometryNode geom = Geometries[instance.GeometryNode];
  145.                 MemoryNode wbo = WoopNodes[geom.WoopBufferNode];
  146.  
  147.                 uint index_offset = ASIndexNodes[ASTreeData[node_id].PrimitiveOffset].Offset / 4;
  148.  
  149.                 for (uint j = 0; j < ASTreeData[node_id].PrimitiveCount; j++)
  150.                 {
  151.                     // Don't trash cache by reading index through it
  152.                     uint tri_idx = ASIndexData[ASTreeData[node_id].PrimitiveOffset + j] * 3;
  153.                     float4 r = WoopData[wbo.Offset / 16 + tri_idx + 0];
  154.                     float4 p = WoopData[wbo.Offset / 16 + tri_idx + 1];
  155.                     float4 q = WoopData[wbo.Offset / 16 + tri_idx + 2];
  156.  
  157.                     float o_z = r.w - o.x * r.x - o.y * r.y - o.z * r.z;
  158.                     float i_z = 1.0f / (d.x * r.x + d.y * r.y + d.z * r.z);
  159.                     float t = o_z * i_z;
  160.  
  161.                     if (t > tmin && t < tmax)
  162.                     {
  163.                         float o_x = p.w + o.x * p.x + o.y * p.y + o.z * p.z;
  164.                         float d_x = d.x * p.x + d.y * p.y + d.z * p.z;
  165.                         float u = o_x + t * d_x;
  166.  
  167.                         if (u >= 0.0f && u <= 1.0f)
  168.                         {
  169.                             float o_y = q.w + o.x * q.x + o.y * q.y + o.z * q.z;
  170.                             float d_y = d.x * q.x + d.y * q.y + d.z * q.z;
  171.                             float v = o_y + t * d_y;
  172.  
  173.                             if (v >= 0.0f && u + v <= 1.0f)
  174.                             {
  175.                                 tmax = t;
  176.                                 bU = u;
  177.                                 bV = v;
  178.                                 hit = true;
  179.                             }
  180.                         }
  181.                     }
  182.                 }
  183.             }
  184.  
  185.             if (stack_ptr == meshbvh_stack_ptr)
  186.             {
  187.                 meshbvh_stack_ptr = -1;
  188.                 o = r.Origin;
  189.                 d = r.Direction;
  190.                 inv = r.Inverse;
  191.                 oinv = o * inv;
  192.             }
  193.  
  194.             node_id = stack[stack_ptr];
  195.             stack_ptr--;
  196.         }
  197.  
  198.         // Termination condition for for-loop approach
  199.         if (node_id == 0xFFFFFFFF)
  200.         {
  201.             break;
  202.         }
  203.     }
  204.    
  205.     Output[DTid.xy] = float4(bU, bV, temp.x, temp.w);
  206. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement