Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- ///////////////////////////////////////////////////////////////////////////////////////////////////
- //
- // TraversalBvh8.hlsli
- //
- // Implements ray traversal through multi-level BVH-8 (CWBVH) acceleration structure.
- //
- ///////////////////////////////////////////////////////////////////////////////////////////////////
- #ifndef __TRAVERSAL_BVH8__HLSLI__
- #define __TRAVERSAL_BVH8__HLSLI__
- #include "../Raytracer.hlsli"
- // Definition to compact funciton parameters into TraceRayCompute funciton (BVH-8 variant)
- #define TRACE_RAY_PARAMS RWStructuredBuffer<GeometryNode> Geometries,\
- RWStructuredBuffer<InstanceNode> Instances,\
- RWStructuredBuffer<MemoryNode> ASTreeNodes,\
- RWStructuredBuffer<BVH8Node> ASTreeData,\
- RWStructuredBuffer<MemoryNode> ASIndexNodes,\
- RWStructuredBuffer<uint> ASIndexData,\
- RWStructuredBuffer<MemoryNode> WoopNodes,\
- RWStructuredBuffer<float4> WoopData
- // Definition to compact argument passing into TraceRayCompute function (BVH-8 variant)
- #define TRACE_RAY_ARGS Geometries,\
- Instances,\
- ASTreeNodes,\
- ASTreeData,\
- ASIndexNodes,\
- ASIndexData,\
- WoopNodes,\
- WoopData
- /// <summary>
- /// Get octant of ray direction
- /// </summary>
- /// <param name="rayDirection">Ray direction</param>
- /// <returns>Octant index encoded in 3 bits</returns>
- uint GetOctant(float4 rayDirection)
- {
- // Get inverse of ray octant, encoded in 3 bits
- return (rayDirection.x < 0.0f ? 0 : 0x04040404) |
- (rayDirection.y < 0.0f ? 0 : 0x02020202) |
- (rayDirection.z < 0.0f ? 0 : 0x01010101);
- }
- /// <summary>
- /// Extract n-th byte from x
- /// </summary>
- /// <param name="x">Input value</param>
- /// <param name="n">Byte index</param>
- /// <returns>N-th byte value of input value</returns>
- uint ExtractByte(uint x, uint n)
- {
- return (x >> (n * 8)) & 0xFF;
- }
- /// <summary>
- /// Intersect ray with BVH-8 in compact wide storage
- /// </summary>
- /// <param name="origin">Ray origin</param>
- /// <param name="direction">Ray direction</param>
- /// <param name="octantInverse">Inverse of ray octant</param>
- /// <param name="maxDistance">Maximum distance to intersect</param>
- /// <param name="node0">Holds origin point on local grid in first 12 bytes, exponents for axes in 3 bytes, mask in last byte (determining leaf/interior node)</param>
- /// <param name="node1">Holds base child index (4-bytes), base triangle index (4-bytes), meta information (8-bytes)</param>
- /// <param name="node2">Holds quantized AABBs - Min X (8-bytes), Max X (8-bytes)</param>
- /// <param name="node3">Holds quantized AABBs - Min Y (8-bytes), Max Y (8-bytes)</param>
- /// <param name="node4">Holds quantized AABBs - Min Z (8-bytes), Max Z (8-bytes)</param>
- /// <returns>Hit mask</returns>
- uint IntersectNode(float4 origin, float4 direction, uint octantInverse, float maxDistance, float4 node0, float4 node1, float4 node2, float4 node3, float4 node4)
- {
- // Get base local point for children
- float3 p = node0.xyz;
- // Get exponents for axes
- uint emask = asuint(node0.w);
- uint eX = ExtractByte(emask, 0);
- uint eY = ExtractByte(emask, 1);
- uint eZ = ExtractByte(emask, 2);
- // Get adjusted direction by axes for intersection
- float3 adjDirection = float3(
- asfloat(eX << 23) / direction.x,
- asfloat(eY << 23) / direction.y,
- asfloat(eZ << 23) / direction.z
- );
- // Get adjusted origin for intersection
- float3 adjOrigin = (p - origin.xyz) / direction.xyz;
- // Resulting hitmask
- uint hitMask = 0;
- // Loop through data
- [unroll]
- for (int i = 0; i < 2; i++)
- {
- // Meta infromation
- uint meta4 = asuint(i == 0 ? node1.z : node1.w);
- // Extract bit indices and child bits
- uint isInner4 = (meta4 & (meta4 << 1)) & 0x10101010;
- uint innerMask4 = (((isInner4 << 3) >> 7) & 0x01010101) * 0xff;
- uint bitIndex4 = (meta4 ^ (octantInverse & innerMask4)) & 0x1F1F1F1F;
- uint childBits4 = (meta4 >> 5) & 0x07070707;
- // Extract quantized min/max of AABBs
- uint qLoX = asuint(i == 0 ? node2.x : node2.y);
- uint qHiX = asuint(i == 0 ? node2.z : node2.w);
- uint qLoY = asuint(i == 0 ? node3.x : node3.y);
- uint qHiY = asuint(i == 0 ? node3.z : node3.w);
- uint qLoZ = asuint(i == 0 ? node4.x : node4.y);
- uint qHiZ = asuint(i == 0 ? node4.z : node4.w);
- // Get per-axis min/max per direction of ray
- uint xMin = direction.x < 0.0f ? qHiX : qLoX;
- uint xMax = direction.x < 0.0f ? qLoX : qHiX;
- uint yMin = direction.y < 0.0f ? qHiY : qLoY;
- uint yMax = direction.y < 0.0f ? qLoY : qHiY;
- uint zMin = direction.z < 0.0f ? qHiZ : qLoZ;
- uint zMax = direction.z < 0.0f ? qLoZ : qHiZ;
- // Loop through all 4 AABBs in current iteration (2-iters = 8 AABBs in total)
- [unroll]
- for (int j = 0; j < 4; j++)
- {
- // Get quantized min value per axis for given AABB
- float3 tmin3 = float3(
- float(ExtractByte(xMin, j)),
- float(ExtractByte(yMin, j)),
- float(ExtractByte(zMin, j)));
- // Get quantized max value per axis for given AABB
- float3 tmax3 = float3(
- float(ExtractByte(xMax, j)),
- float(ExtractByte(yMax, j)),
- float(ExtractByte(zMax, j)));
- // Use adjusted origin and direction to calculate min/max values
- tmin3 = mad(tmin3, adjDirection, adjOrigin);
- tmax3 = mad(tmax3, adjDirection, adjOrigin);
- // Calculate entry and exist distances along ray
- float tmin = max(max(tmin3.x, tmin3.y), max(tmin3.z, 0.1f));
- float tmax = min(min(tmax3.x, tmax3.y), min(tmax3.z, maxDistance));
- // Check whether intersection happens
- bool intersection = tmin <= tmax;
- // In case of intersection, store in hitmask
- [branch]
- if (intersection)
- {
- uint childBits = ExtractByte(childBits4, j);
- uint bitIndex = ExtractByte(bitIndex4, j);
- hitMask |= childBits << bitIndex;
- }
- }
- }
- return hitMask;
- }
- /// <summary>
- /// Performs ray traversal through acceleration structure for single ray.
- ///
- /// This function performs traversal through compressed wide Bounding Volume Hierarchy
- /// (BVH-8/CWBVH). Result of this funciton is represented by barycentric coordinates, primitive ID,
- /// geometry ID and distance along the ray to hitpoint.
- /// </summary>
- /// <param name="r">Ray to trace.</param>
- /// <param name="Geometries">Buffer of GeometryNode - holds all definition for geometries in the scene</param>
- /// <param name="Instances">Buffer of InstanceNode - holds all geometry instances definitions in the scene</param>
- /// <param name="ASTreeNodes">Buffer of memory nodes - each representing single BVH node data definition/offsets</param>
- /// <param name="ASTreeData">Buffer of BVH nodes - BVH node data</param>
- /// <param name="ASIndexNodes">Buffer of memory nodes - each representing single BVH index data definition/offsets</param>
- /// <param name="ASIndexData">Buffer of BVH indices - BVH index data</param>
- /// <param name="WoopNodes">Buffer of memory nodes - each representing definition/offsets into data buffer holding woop triangle data</param>
- /// <param name="WoopData">Buffer of woop triangle data - Woop triangle geometry data</param>
- /// <returns>
- /// 4-component vector, where:
- /// 1st component has packed U/V barycentric coordinates (see PackFloat2/UnpackFloat2)
- /// 2nd component distance along the ray to hit
- /// 3rd component primitive ID (unsigned int)
- /// 4th component geometry ID (unsigned int)
- /// </returns>
- float4 TraceRayCompute(Ray r, TRACE_RAY_PARAMS)
- {
- float4 o = r.Origin;
- float4 d = r.Direction;
- float4 inv = r.Inverse;
- uint octInv4 = GetOctant(d);
- uint2 currentGroup = uint2(0, 0x80000000);
- uint2 triangleGroup = uint2(0, 0);
- uint2 stack[BVH_STACK_SIZE];
- uint stack_ptr = 0;
- int meshbvh_stack_ptr = -1;
- float tmin = 0.0f;
- float tmax = 10000.0f;
- float bU = 0.0f;
- float bV = 0.0f;
- uint prim_id = 0;
- uint geometryID = 0;
- bool hit = false;
- InstanceNode instance = Instances[0];
- // Traversal (use for for testing)
- [loop]
- for (int i = 0; i < 1024; i++)
- {
- // Test whether we're in interior node
- [branch]
- if (currentGroup.y & 0xff000000)
- {
- // Get next child index (consume bit)
- uint childIndexOffset = firstbithigh(currentGroup.y);
- uint slotIndex = (childIndexOffset - 24) ^ (octInv4 & 0xff);
- uint relativeIndex = countbits(currentGroup.y & ~(0xffffffff << slotIndex));
- uint childNodeIndex = currentGroup.x + relativeIndex;
- currentGroup.y &= ~(1 << childIndexOffset);
- if (currentGroup.y & 0xff000000)
- {
- stack[stack_ptr] = currentGroup;
- stack_ptr++;
- }
- // Perform intersection test against all 8 children
- uint hitMask = IntersectNode(o,
- d,
- octInv4,
- tmax,
- ASTreeData[childNodeIndex].Node0,
- ASTreeData[childNodeIndex].Node1,
- ASTreeData[childNodeIndex].Node2,
- ASTreeData[childNodeIndex].Node3,
- ASTreeData[childNodeIndex].Node4);
- // Update masks from hit results
- currentGroup.y = (hitMask & 0xff000000) | ((asuint(ASTreeData[childNodeIndex].Node0.w) >> 24) & 0xff);
- triangleGroup.y = (hitMask & 0x00ffffff);
- currentGroup.x = asuint(ASTreeData[childNodeIndex].Node1.x);
- triangleGroup.x = asuint(ASTreeData[childNodeIndex].Node1.y);
- }
- else
- {
- // Leaf node - current node group holds triangle group
- triangleGroup = currentGroup;
- currentGroup = uint2(0, 0);
- }
- // We are in leaf node
- if (triangleGroup.y != 0)
- {
- // We're searching top-level BVH (TLAS), enter bottom-level BVH (BLAS)
- if (meshbvh_stack_ptr == -1)
- {
- uint blas_offset = firstbithigh(triangleGroup.y);
- triangleGroup.y &= ~(1 << blas_offset);
- uint index_offset = ASIndexNodes[triangleGroup.x + blas_offset].Offset / 4;
- uint instance_index = ASIndexData[triangleGroup.x + blas_offset];
- instance = Instances[instance_index];
- if (triangleGroup.y != 0)
- {
- stack[stack_ptr] = triangleGroup;
- stack_ptr++;
- }
- if (currentGroup.y & 0xff000000)
- {
- stack[stack_ptr] = currentGroup;
- stack_ptr++;
- }
- meshbvh_stack_ptr = stack_ptr;
- o = mul(r.Origin, instance.TransformInverse);
- d = mul(r.Direction, instance.TransformInverse);
- inv = rcp(d);
- octInv4 = GetOctant(d);
- currentGroup.x = ASTreeNodes[Geometries[instance.GeometryNode].BVHNode + 1].Offset / 80;
- currentGroup.y = 0x80000000;
- }
- // We're already in bottom-level BVH (BLAS)
- else
- {
- while (triangleGroup.y != 0)
- {
- // Obtain next triangle from triangle group in BLAS node record
- uint triangleIndex = firstbithigh(triangleGroup.y);
- triangleGroup.y &= ~(1 << triangleIndex);
- GeometryNode geom = Geometries[instance.GeometryNode];
- MemoryNode wbo = WoopNodes[geom.WoopBufferNode];
- uint index_offset = ASIndexNodes[triangleGroup.x + triangleIndex].Offset / 4;
- // Don't trash cache by reading index through it
- uint tri_idx = ASIndexData[triangleGroup.x + triangleIndex] * 3;
- // Fetch data for Woop's triangle
- float4 r = WoopData[wbo.Offset / 16 + tri_idx + 0];
- float4 p = WoopData[wbo.Offset / 16 + tri_idx + 1];
- float4 q = WoopData[wbo.Offset / 16 + tri_idx + 2];
- // Perform intersection
- float o_z = r.w - o.x * r.x - o.y * r.y - o.z * r.z;
- float i_z = 1.0f / (d.x * r.x + d.y * r.y + d.z * r.z);
- float t = o_z * i_z;
- if (t > tmin && t < tmax)
- {
- float o_x = p.w + o.x * p.x + o.y * p.y + o.z * p.z;
- float d_x = d.x * p.x + d.y * p.y + d.z * p.z;
- float u = o_x + t * d_x;
- if (u >= 0.0f && u <= 1.0f)
- {
- float o_y = q.w + o.x * q.x + o.y * q.y + o.z * q.z;
- float d_y = d.x * q.x + d.y * q.y + d.z * q.z;
- float v = o_y + t * d_y;
- if (v >= 0.0f && u + v <= 1.0f)
- {
- tmax = t;
- bU = u;
- bV = v;
- hit = true;
- geometryID = instance.GeometryNode;
- prim_id = tri_idx;
- }
- }
- }
- }
- }
- }
- // Pop stack if any item still in it, end traversal otherwise
- if ((currentGroup.y & 0xff000000) == 0)
- {
- // Entrypoint has been reached, terminate traversal
- if (stack_ptr == 0)
- {
- break;
- }
- // If we're in BLAS and we're on entrypoint, then reset the ray as the traversal will
- // continue in TLAS
- if (stack_ptr == meshbvh_stack_ptr)
- {
- meshbvh_stack_ptr = -1;
- o = r.Origin;
- d = r.Direction;
- inv = r.Inverse;
- octInv4 = GetOctant(d);
- }
- // Pop from stack
- stack_ptr--;
- currentGroup = stack[stack_ptr];
- }
- }
- return float4(PackFloat2(bU, bV), tmax, asfloat(prim_id), asfloat(geometryID));
- }
- #endif
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement