Skip to content

Commit

Permalink
more shader input
Browse files Browse the repository at this point in the history
  • Loading branch information
iaomw committed Dec 3, 2024
1 parent 6e80b9e commit 629cd71
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 12 deletions.
43 changes: 42 additions & 1 deletion zeno/src/nodes/mtl/ShaderAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,51 @@
#include <zeno/types/PrimitiveObject.h>
#include <zeno/types/NumericObject.h>
#include <zeno/utils/string.h>
#include <magic_enum.hpp>
#include <algorithm>

namespace zeno {

enum struct SurfaceAttr {
pos, clr, nrm, uv, tang, bitang, NoL, LoV, N, T, L, V, H, reflectance, fresnel,
worldNrm, worldTan, worldBTn,
camFront, camUp, camRight
};

enum struct InstAttr {
instIdx, instPos, instNrm, instUv, instClr, instTang
};

enum struct VolumeAttr {};

enum struct RayAttr {
rayLength, isBackFace, isShadowRay,
};

static std::string dataTypeDefaultString() {
auto name = magic_enum::enum_name(SurfaceAttr::pos);
return std::string(name);
}

static std::string dataTypeListString() {
auto list0 = magic_enum::enum_names<SurfaceAttr>();
auto list1 = magic_enum::enum_names<InstAttr>();
auto list2 = magic_enum::enum_names<RayAttr>();

std::string result;

auto concat = [&](const auto &list) {
for (auto& ele : list) {
result += " ";
result += ele;
}
};

concat(list0); concat(list1); concat(list2);

result += " prd.rndf() attrs.localPosLazy() attrs.uniformPosLazy()";
return result;
}

struct ShaderInputAttr : ShaderNodeClone<ShaderInputAttr> {
virtual int determineType(EmissionPass *em) override {
Expand All @@ -33,7 +74,7 @@ struct ShaderInputAttr : ShaderNodeClone<ShaderInputAttr> {

ZENDEFNODE(ShaderInputAttr, {
{
{"enum pos clr nrm uv tang bitang NoL LoV N T L V H reflectance fresnel instPos instNrm instUv instClr instTang prd.rndf() attrs.localPosLazy() attrs.uniformPosLazy() rayLength isShadowRay worldNrm worldTan worldBTn camFront camUp camRight", "attr", "pos"},
{"enum" + dataTypeListString(), "attr", dataTypeDefaultString()},
{"enum float vec2 vec3 vec4 bool", "type", "vec3"},
},
{
Expand Down
3 changes: 3 additions & 0 deletions zenovis/xinxinoptix/CallableDefault.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@ extern "C" __device__ MatOutput __direct_callable__evalmat(cudaTextureObject_t z
auto att_uv = attrs.uv;
auto att_nrm = attrs.nrm;
auto att_tang = attrs.tang;

auto att_instIdx = attrs.instIdx;
auto att_instPos = attrs.instPos;
auto att_instNrm = attrs.instNrm;
auto att_instUv = attrs.instUv;
auto att_instClr = attrs.instClr;
auto att_instTang = attrs.instTang;
auto att_rayLength = attrs.rayLength;

auto att_isBackFace = attrs.isBackFace ? 1.0f:0.0f;
auto att_isShadowRay = attrs.isShadowRay ? 1.0f:0.0f;

vec3 b = normalize(cross(attrs.T, attrs.N));
Expand Down
17 changes: 13 additions & 4 deletions zenovis/xinxinoptix/DeflMatShader.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,19 @@ __inline__ __device__ void cihouSphereInstanceAux(MatInput& attrs) {
assert(lut != nullptr);

auto tmp = lut[optixGetInstanceId()];
auto auxBuffer = reinterpret_cast<float3*>(tmp);
auto auxBuffer = reinterpret_cast<float*>(tmp);
assert(auxBuffer != nullptr);

auto aux = auxBuffer + optixGetPrimitiveIndex() * 4;

attrs.clr = {};
attrs.tang = {};

attrs.instIdx = *(uint*)aux;
attrs.instPos = {}; //rt_data->instPos[inst_idx2];
attrs.instNrm = {}; //rt_data->instNrm[inst_idx2];
attrs.instUv = {}; //rt_data->instUv[inst_idx2];
attrs.instClr = auxBuffer[optixGetPrimitiveIndex()];
attrs.instClr = *(float3*)(aux+1);
attrs.instTang = {}; //rt_data->instTang[inst_idx2];
}
}
Expand Down Expand Up @@ -113,6 +117,7 @@ extern "C" __global__ void __anyhit__shadow_cutout()
attrs.nrm = N;
attrs.uv = sphereUV(_normal_object_, false);

attrs.instPos = _center_object_;
cihouSphereInstanceAux(attrs);

#else
Expand Down Expand Up @@ -183,6 +188,7 @@ extern "C" __global__ void __anyhit__shadow_cutout()
attrs.tang = optixTransformVectorFromObjectToWorldSpace(attrs.tang);
attrs.rayLength = optixGetRayTmax();

attrs.instIdx = params.instIdx[inst_idx];
attrs.instPos = decodeHalf( rt_data->instPos[inst_idx] );
attrs.instNrm = decodeHalf( rt_data->instNrm[inst_idx] );
attrs.instUv = decodeHalf( rt_data->instUv[inst_idx] );
Expand Down Expand Up @@ -341,8 +347,9 @@ extern "C" __global__ void __closesthit__radiance()
float3 P = ray_orig + optixGetRayTmax() * ray_dir;

HitGroupData* rt_data = (HitGroupData*)optixGetSbtDataPointer();
MatInput attrs{};
float estimation = 0;

MatInput attrs {};
attrs.isBackFace = optixIsBackFaceHit();

#if (_P_TYPE_==2)

Expand Down Expand Up @@ -420,6 +427,7 @@ extern "C" __global__ void __closesthit__radiance()
attrs.nrm = N;
attrs.uv = sphereUV(objNorm, false);

attrs.instPos = sphere_center;
cihouSphereInstanceAux(attrs);

#else
Expand Down Expand Up @@ -492,6 +500,7 @@ extern "C" __global__ void __closesthit__radiance()
attrs.tang = normalize(interp(barys, tan0, tan1, tan2));
attrs.tang = optixTransformNormalFromObjectToWorldSpace(attrs.tang);

attrs.instIdx = params.instIdx[inst_idx];
attrs.instPos = decodeHalf( rt_data->instPos[inst_idx] );
attrs.instNrm = decodeHalf( rt_data->instNrm[inst_idx] );
attrs.instUv = decodeHalf( rt_data->instUv[inst_idx] );
Expand Down
7 changes: 7 additions & 0 deletions zenovis/xinxinoptix/IOMat.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

#include "zxxglslvec.h"

#ifndef uint
#define uint unsigned int
#endif

struct MatOutput {
vec3 basecolor;
float roughness;
Expand Down Expand Up @@ -59,6 +63,8 @@ struct MatInput {
vec3 uv;
vec3 clr;
vec3 tang;

uint instIdx;
vec3 instPos;
vec3 instNrm;
vec3 instUv;
Expand All @@ -68,6 +74,7 @@ struct MatInput {
float LoV;

float rayLength;
bool isBackFace;
bool isShadowRay;

vec3 reflectance;
Expand Down
31 changes: 24 additions & 7 deletions zenovis/xinxinoptix/optixPathTracer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ struct PathTracerState
raii<CUdeviceptr> _meshAux;
raii<CUdeviceptr> _instToMesh;

raii<CUdeviceptr> d_instIdx;
raii<CUdeviceptr> d_instPos;
raii<CUdeviceptr> d_instNrm;
raii<CUdeviceptr> d_instUv;
Expand Down Expand Up @@ -785,6 +786,8 @@ static void buildMeshIAS(PathTracerState& state, int rayTypeCount, std::vector<s
0,1,0,0,
0,0,1,0};


uint defaultInstIdx = 0u;
float3 defaultInstPos = {0, 0, 0};
float3 defaultInstNrm = {0, 1, 0};
float3 defaultInstUv = {0, 0, 0};
Expand Down Expand Up @@ -813,6 +816,8 @@ static void buildMeshIAS(PathTracerState& state, int rayTypeCount, std::vector<s
#else
using AuxType = float3;
#endif

std::vector<uint> instIdx(num_instances);
std::vector<AuxType> instPos(num_instances);
std::vector<AuxType> instNrm(num_instances);
std::vector<AuxType> instUv(num_instances);
Expand Down Expand Up @@ -892,7 +897,6 @@ static void buildMeshIAS(PathTracerState& state, int rayTypeCount, std::vector<s
buildMeshAccel(state, mesh);
}


for (std::size_t k = 0; k < instMats.size(); ++k)
{
const auto &instMat = instMats[k];
Expand All @@ -908,6 +912,7 @@ static void buildMeshIAS(PathTracerState& state, int rayTypeCount, std::vector<s
instance.traversableHandle = mesh->gas_handle;
memcpy(instance.transform, instMat3r4c, sizeof(float) * 12);

instIdx[instanceID] = k;
instPos[instanceID] = toHalf(instAttrs.pos[k]);
instNrm[instanceID] = toHalf(instAttrs.nrm[k]);
instUv[instanceID] = toHalf(instAttrs.uv[k]);
Expand Down Expand Up @@ -940,6 +945,15 @@ static void buildMeshIAS(PathTracerState& state, int rayTypeCount, std::vector<s

state.params.instToMesh = (void*)state._instToMesh.handle;

state.d_instIdx.resize(sizeof(instIdx[0]) * instIdx.size(), 0);
CUDA_CHECK( cudaMemcpy(
reinterpret_cast<void*>( (CUdeviceptr)state.d_instIdx ),
instIdx.data(),
sizeof(instIdx[0]) * instIdx.size(),
cudaMemcpyHostToDevice
) );
state.params.instIdx = (uint*)state.d_instIdx.handle;

state.d_instPos.resize(sizeof(instPos[0]) * instPos.size(), 0);
CUDA_CHECK( cudaMemcpy(
reinterpret_cast<void*>( (CUdeviceptr)state.d_instPos ),
Expand Down Expand Up @@ -2777,15 +2791,18 @@ void UpdateInst()
auto sia = std::make_shared<SphereInstanceAgent>(sphereInstanceBase);

sia->radius_list = std::vector<float>(element_count, sphereInstanceBase.radius);
sia->aux_list = std::vector<float>(element_count * 3, 0);

for (size_t i=0; i<element_count; ++i) {
sia->radius_list[i] *= instTrs.tang[3*i +0];
const uint aux_size = 4;
sia->aux_list = std::vector<float>(element_count * aux_size, 0);

for (uint i=0; i<element_count; ++i) {
sia->radius_list[i] *= instTrs.tang[3*i];

sia->aux_list[i*3+0] = instTrs.clr[3*i +0];
sia->aux_list[i*3+1] = instTrs.clr[3*i +1];
sia->aux_list[i*3+2] = instTrs.clr[3*i +2];
sia->aux_list[i*aux_size+0] = reinterpret_cast<float&>(i); // instIdx

sia->aux_list[i*aux_size+1] = instTrs.clr[3*i+0];
sia->aux_list[i*aux_size+2] = instTrs.clr[3*i+1];
sia->aux_list[i*aux_size+3] = instTrs.clr[3*i+2];
}

sia->center_list.resize(element_count);
Expand Down
2 changes: 2 additions & 0 deletions zenovis/xinxinoptix/optixPathTracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ struct Params
uint32_t hairInstOffset;
void* hairAux;

uint* instIdx;

void* dlights_ptr;
void* plights_ptr;

Expand Down

0 comments on commit 629cd71

Please sign in to comment.