Advertisement
Zgragselus

RenderPassTrace

Nov 15th, 2023
874
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 6.91 KB | None | 0 0
  1. #include "Raytracer.hlsli"
  2.  
  3. cbuffer Params : register(b0)
  4. {
  5.     float4 ResolutionParams;
  6.     float4 BoundsMin;
  7.     float4 BoundsMax;
  8. }
  9.  
  10. RWStructuredBuffer<Ray> Rays : register(u0);
  11. RWTexture2D<float4> Output: register(u1);
  12. RWStructuredBuffer<float4> VertexData: register(u2);
  13. RWStructuredBuffer<MemoryNode> VertexNodes: register(u3);
  14. RWStructuredBuffer<uint> IndexData: register(u4);
  15. RWStructuredBuffer<MemoryNode> IndexNodes: register(u5);
  16. RWStructuredBuffer<GeometryNode> Geometries: register(u6);
  17. RWStructuredBuffer<InstanceNode> Instances: register(u7);
  18. RWStructuredBuffer<BVHNode> ASTreeData: register(u8);
  19. RWStructuredBuffer<MemoryNode> ASTreeNodes: register(u9);
  20. RWStructuredBuffer<uint> ASIndexData: register(u10);
  21. RWStructuredBuffer<MemoryNode> ASIndexNodes: register(u11);
  22. RWStructuredBuffer<float4> WoopData: register(u12);
  23. RWStructuredBuffer<MemoryNode> WoopNodes: register(u13);
  24.  
  25. bool IntersectRayTriangle(float4 origin, float4 direction, float4 v0, float4 v1, float4 v2, out float distance, out float u, out float v)
  26. {
  27.     float4 e1 = v1 - v0;
  28.     float4 e2 = v2 - v0;
  29.     float4 pvec = float4(cross(direction.xyz, e2.xyz), 0.0f);
  30.     float det = dot(e1, pvec);
  31.     if (det > -1e-6 && det < 1e-6)
  32.     {
  33.         return false;
  34.     }
  35.     float inv_det = 1.0f / det;
  36.  
  37.     float4 tvec = origin - v0;
  38.     u = dot(tvec, pvec) * inv_det;
  39.     if (u < 0.0f || u > 1.0f)
  40.     {
  41.         return false;
  42.     }
  43.  
  44.     float4 qvec = float4(cross(tvec.xyz, e1.xyz), 0.0f);
  45.     v = dot(direction, qvec) * inv_det;
  46.     if (v < 0.0f || u + v > 1.0f)
  47.     {
  48.         return false;
  49.     }
  50.  
  51.     distance = dot(e2, qvec) * inv_det;
  52.     return (distance > 0.0f);
  53. }
  54.  
  55. #define BVH_STACK_SIZE 64
  56.  
  57. [numthreads(32, 32, 1)]
  58. void RenderPass(uint3 GI : SV_GroupID, uint3 DTid : SV_DispatchThreadID, uint3 GTid : SV_GroupThreadID)
  59. {
  60.     uint width = asuint(ResolutionParams.x);
  61.     uint height = asuint(ResolutionParams.y);
  62.  
  63.     int id = DTid.y * width + DTid.x;
  64.    
  65.     Ray r = Rays[id];
  66.  
  67.     float4 o = r.Origin;
  68.     float4 d = r.Direction;
  69.     float4 inv = r.Inverse;
  70.     float4 oinv = o * inv;
  71.  
  72.     uint node_id = 0;
  73.     uint stack[BVH_STACK_SIZE];
  74.     uint stack_ptr = 0;
  75.     stack[stack_ptr] = 0xFFFFFFFF;
  76.     int meshbvh_stack_ptr = -1;
  77.  
  78.     float tmin = 0.0f;
  79.     float tmax = 10000.0f;
  80.     float bU = 0.0f;
  81.     float bV = 0.0f;
  82.     float dist = tmax;
  83.     bool hit = false;
  84.  
  85.     float4 temp = float4(0.0f, 0.0f, 0.0f, 1.0f);
  86.  
  87.     InstanceNode instance = Instances[0];
  88.  
  89.     int i;
  90.  
  91.     // Traversal (use for for testing)
  92.     [loop] for (i = 0; i < 1000; i++)
  93.     //while (node_id != 0xFFFFFFFF) // This is significantly slower on RDNA 2, not sure about others
  94.     {
  95.         temp.x += 0.001f;
  96.  
  97.         [branch] if (ASTreeData[node_id].PrimitiveCount == 0)
  98.         {
  99.             // Fetch children bounding boxes
  100.             float4 n0xy = ASTreeData[node_id].LXY;
  101.             float4 n1xy = ASTreeData[node_id].RXY;
  102.             float4 nz = ASTreeData[node_id].LRZ;
  103.  
  104.             // Test against child AABBs
  105.             float c0lox = n0xy.x * inv.x - oinv.x;
  106.             float c0hix = n0xy.y * inv.x - oinv.x;
  107.             float c0loy = n0xy.z * inv.y - oinv.y;
  108.             float c0hiy = n0xy.w * inv.y - oinv.y;
  109.             float c0loz = nz.x * inv.z - oinv.z;
  110.             float c0hiz = nz.y * inv.z - oinv.z;
  111.             float c1loz = nz.z * inv.z - oinv.z;
  112.             float c1hiz = nz.w * inv.z - oinv.z;
  113.             float c0min = max(max(min(c0lox, c0hix), min(c0loy, c0hiy)), max(min(c0loz, c0hiz), tmin));
  114.             float c0max = min(min(max(c0lox, c0hix), max(c0loy, c0hiy)), min(max(c0loz, c0hiz), tmax));
  115.             float c1lox = n1xy.x * inv.x - oinv.x;
  116.             float c1hix = n1xy.y * inv.x - oinv.x;
  117.             float c1loy = n1xy.z * inv.y - oinv.y;
  118.             float c1hiy = n1xy.w * inv.y - oinv.y;
  119.             float c1min = max(max(min(c1lox, c1hix), min(c1loy, c1hiy)), max(min(c1loz, c1hiz), tmin));
  120.             float c1max = min(min(max(c1lox, c1hix), max(c1loy, c1hiy)), min(max(c1loz, c1hiz), tmax));
  121.  
  122.             bool traverseChild0 = (c0max >= c0min);
  123.             bool traverseChild1 = (c1max >= c1min);
  124.  
  125.             // If no children was hit, get node from stack
  126.             if (!traverseChild0 && !traverseChild1)
  127.             {
  128.                 if (stack_ptr == meshbvh_stack_ptr)
  129.                 {
  130.                     meshbvh_stack_ptr = -1;
  131.                     o = r.Origin;
  132.                     d = r.Direction;
  133.                     inv = r.Inverse;
  134.                     oinv = o * inv;
  135.                 }
  136.  
  137.                 node_id = stack[stack_ptr];
  138.                 stack_ptr--;
  139.             }
  140.             else if (traverseChild0 || traverseChild1)
  141.             {
  142.                 uint first_child = node_id + 1;
  143.                 uint second_child = ASTreeData[node_id].PrimitiveOffset;
  144.                 node_id = (traverseChild0) ? first_child : second_child;
  145.  
  146.                 if (traverseChild0 && traverseChild1)
  147.                 {
  148.                     if (c1min < c0min)
  149.                     {
  150.                         node_id = second_child;
  151.                         stack_ptr++;
  152.                         stack[stack_ptr] = first_child;
  153.                     }
  154.                     else
  155.                     {
  156.                         stack_ptr++;
  157.                         stack[stack_ptr] = second_child;
  158.                     }
  159.                 }
  160.             }
  161.         }
  162.         else if (ASTreeData[node_id].PrimitiveCount == -1)
  163.         {
  164.             meshbvh_stack_ptr = stack_ptr;
  165.  
  166.             uint blas_offset = ASTreeData[node_id].PrimitiveOffset;
  167.             uint instance_index = ASIndexData[blas_offset];
  168.             instance = Instances[instance_index];
  169.  
  170.             node_id = ASTreeNodes[Geometries[instance.GeometryNode].BVHNode + 1].Offset / 64;
  171.  
  172.             o = mul(r.Origin, instance.TransformInverse);
  173.             d = mul(r.Direction, instance.TransformInverse);
  174.             inv = rcp(d);
  175.             oinv = o * inv;
  176.         }
  177.         else
  178.         {
  179.             if (ASTreeData[node_id].PrimitiveCount > 0)
  180.             {
  181.                 GeometryNode geom = Geometries[instance.GeometryNode];
  182.                 //MemoryNode vbo = VertexNodes[geom.VertexBufferNode];
  183.                 //MemoryNode ibo = IndexNodes[geom.IndexBufferNode];
  184.                 MemoryNode wbo = WoopNodes[geom.WoopBufferNode];
  185.  
  186.                 uint index_offset = ASIndexNodes[ASTreeData[node_id].PrimitiveOffset].Offset / 4;
  187.  
  188.                 for (uint j = 0; j < ASTreeData[node_id].PrimitiveCount; j++)
  189.                 {
  190.                     // Don't trash cache by reading index through it
  191.                     uint tri_idx = ASIndexData[ASTreeData[node_id].PrimitiveOffset + j] * 3;
  192.                     float4 r = WoopData[wbo.Offset / 16 + tri_idx + 0];
  193.                     float4 p = WoopData[wbo.Offset / 16 + tri_idx + 1];
  194.                     float4 q = WoopData[wbo.Offset / 16 + tri_idx + 2];
  195.  
  196.                     float o_z = r.w - o.x * r.x - o.y * r.y - o.z * r.z;
  197.                     float i_z = 1.0f / (d.x * r.x + d.y * r.y + d.z * r.z);
  198.                     float t = o_z * i_z;
  199.  
  200.                     if (t > tmin && t < tmax)
  201.                     {
  202.                         float o_x = p.w + o.x * p.x + o.y * p.y + o.z * p.z;
  203.                         float d_x = d.x * p.x + d.y * p.y + d.z * p.z;
  204.                         float u = o_x + t * d_x;
  205.  
  206.                         if (u >= 0.0f && u <= 1.0f)
  207.                         {
  208.                             float o_y = q.w + o.x * q.x + o.y * q.y + o.z * q.z;
  209.                             float d_y = d.x * q.x + d.y * q.y + d.z * q.z;
  210.                             float v = o_y + t * d_y;
  211.  
  212.                             if (v >= 0.0f && u + v <= 1.0f)
  213.                             {
  214.                                 tmax = t;
  215.                                 bU = u;
  216.                                 bV = v;
  217.                                 hit = true;
  218.  
  219.                                 //id = prims_ids[n];
  220.                             }
  221.                         }
  222.                     }
  223.                 }
  224.             }
  225.  
  226.             if (stack_ptr == meshbvh_stack_ptr)
  227.             {
  228.                 meshbvh_stack_ptr = -1;
  229.                 o = r.Origin;
  230.                 d = r.Direction;
  231.                 inv = r.Inverse;
  232.                 oinv = o * inv;
  233.             }
  234.  
  235.             node_id = stack[stack_ptr];
  236.             stack_ptr--;
  237.         }
  238.  
  239.         if (node_id == 0xFFFFFFFF)
  240.         {
  241.             break;
  242.         }
  243.     }
  244.  
  245.     //Output[DTid.xy] = float4(temp.x, temp.x * 0.1f, temp.x * 0.01f, 1.0f);
  246.    
  247.     Output[DTid.xy] = float4(bU, bV, temp.x, temp.w);
  248. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement