diff --git a/src/modules/graphics/Shader.cpp b/src/modules/graphics/Shader.cpp index 68fc16b70..4545d00da 100644 --- a/src/modules/graphics/Shader.cpp +++ b/src/modules/graphics/Shader.cpp @@ -928,6 +928,14 @@ void Shader::getLocalThreadgroupSize(int *x, int *y, int *z) *z = reflection.localThreadgroupSize[2]; } +const std::vector *Shader::getBufferFormat(const std::string &name) const +{ + auto it = reflection.bufferFormats.find(name); + if (it != reflection.bufferFormats.end()) + return &it->second; + return nullptr; +} + bool Shader::validate(StrongRef stages[], std::string& err) { Reflection reflection; @@ -946,6 +954,71 @@ static DataBaseType getBaseType(glslang::TBasicType basictype) } } +static DataFormat getDataFormat(glslang::TBasicType basictype, int components, int rows, int columns, bool matrix) +{ + if (matrix) + { + if (basictype != glslang::EbtFloat) + return DATAFORMAT_MAX_ENUM; + + if (rows == 2 && columns == 2) + return DATAFORMAT_FLOAT_MAT2X2; + else if (rows == 2 && columns == 3) + return DATAFORMAT_FLOAT_MAT2X3; + else if (rows == 2 && columns == 4) + return DATAFORMAT_FLOAT_MAT2X4; + else if (rows == 3 && columns == 2) + return DATAFORMAT_FLOAT_MAT3X2; + else if (rows == 3 && columns == 3) + return DATAFORMAT_FLOAT_MAT3X3; + else if (rows == 3 && columns == 4) + return DATAFORMAT_FLOAT_MAT3X4; + else if (rows == 4 && columns == 2) + return DATAFORMAT_FLOAT_MAT4X2; + else if (rows == 4 && columns == 3) + return DATAFORMAT_FLOAT_MAT4X3; + else if (rows == 4 && columns == 4) + return DATAFORMAT_FLOAT_MAT4X4; + else + return DATAFORMAT_MAX_ENUM; + } + else if (basictype == glslang::EbtFloat) + { + if (components == 1) + return DATAFORMAT_FLOAT; + else if (components == 2) + return DATAFORMAT_FLOAT_VEC2; + else if (components == 3) + return DATAFORMAT_FLOAT_VEC2; + else if (components == 4) + return DATAFORMAT_FLOAT_VEC2; + } + else if (basictype == glslang::EbtInt) + { + if (components == 1) + return DATAFORMAT_INT32; + else if (components == 2) + return DATAFORMAT_INT32_VEC2; + else if (components == 3) + return DATAFORMAT_INT32_VEC2; + else if (components == 4) + return DATAFORMAT_INT32_VEC2; + } + else if (basictype == glslang::EbtUint) + { + if (components == 1) + return DATAFORMAT_UINT32; + else if (components == 2) + return DATAFORMAT_UINT32_VEC2; + else if (components == 3) + return DATAFORMAT_UINT32_VEC2; + else if (components == 4) + return DATAFORMAT_UINT32_VEC2; + } + + return DATAFORMAT_MAX_ENUM; +} + static PixelFormat getPixelFormat(glslang::TLayoutFormat format) { using namespace glslang; @@ -1037,6 +1110,48 @@ static T convertData(const glslang::TConstUnion &data) } } +static bool AddFieldsToFormat(std::vector &format, int level, const glslang::TType *type, int arraylength, const std::string &basename, std::string &err) +{ + if (type->isStruct()) + { + auto fields = type->getStruct(); + + for (int i = 0; i < std::max(arraylength, 1); i++) + { + std::string name = basename; + if (level > 0) + { + name += type->getFieldName().c_str(); + if (arraylength > 0) + name += "[" + std::to_string(i) + "]"; + name += "."; + } + for (size_t fieldi = 0; fieldi < fields->size(); fieldi++) + { + const glslang::TType *fieldtype = (*fields)[fieldi].type; + int fieldlength = fieldtype->isSizedArray() ? fieldtype->getCumulativeArraySize() : 0; + + if (!AddFieldsToFormat(format, level + 1, fieldtype, fieldlength, name, err)) + return false; + } + } + } + else + { + DataFormat dataformat = getDataFormat(type->getBasicType(), type->getVectorSize(), type->getMatrixRows(), type->getMatrixCols(), type->isMatrix()); + if (dataformat == DATAFORMAT_MAX_ENUM) + { + err = "Shader validation error:\n"; + return false; + } + + std::string name = basename.empty() ? type->getFieldName().c_str() : basename + type->getFieldName().c_str(); + format.emplace_back(name.c_str(), dataformat, arraylength); + } + + return true; +} + bool Shader::validateInternal(StrongRef stages[], std::string &err, Reflection &reflection) { glslang::TProgram program; @@ -1295,6 +1410,12 @@ bool Shader::validateInternal(StrongRef stages[], std::string &err, u.access = (Access)(ACCESS_READ | ACCESS_WRITE); reflection.storageBuffers[u.name] = u; + + std::vector format; + if (!AddFieldsToFormat(format, 0, elementtype, 0, "", err)) + return false; + + reflection.bufferFormats[u.name] = format; } else { diff --git a/src/modules/graphics/Shader.h b/src/modules/graphics/Shader.h index be4e577b5..26dc414e9 100644 --- a/src/modules/graphics/Shader.h +++ b/src/modules/graphics/Shader.h @@ -26,6 +26,7 @@ #include "Texture.h" #include "ShaderStage.h" #include "Resource.h" +#include "Buffer.h" // STL #include @@ -266,6 +267,8 @@ class Shader : public Object, public Resource void getLocalThreadgroupSize(int *x, int *y, int *z); + const std::vector *getBufferFormat(const std::string &name) const; + static SourceInfo getSourceInfo(const std::string &src); static std::string createShaderStageCode(Graphics *gfx, ShaderStageType stage, const std::string &code, const CompileOptions &options, const SourceInfo &info, bool gles, bool checksystemfeatures); @@ -296,6 +299,8 @@ class Shader : public Object, public Resource std::map> localUniformInitializerValues; + std::map> bufferFormats; + int textureCount; int bufferCount; diff --git a/src/modules/graphics/wrap_Shader.cpp b/src/modules/graphics/wrap_Shader.cpp index 51de56b8c..5902640ea 100644 --- a/src/modules/graphics/wrap_Shader.cpp +++ b/src/modules/graphics/wrap_Shader.cpp @@ -514,6 +514,41 @@ int w_Shader_getLocalThreadgroupSize(lua_State* L) return 3; } +int w_Shader_getBufferFormat(lua_State *L) +{ + Shader *shader = luax_checkshader(L, 1); + const char *name = luaL_checkstring(L, 2); + const std::vector *format = shader->getBufferFormat(name); + if (name != nullptr) + { + lua_createtable(L, (int)format->size(), 0); + + for (size_t i = 0; i < format->size(); i++) + { + const Buffer::DataDeclaration &member = (*format)[i]; + + lua_createtable(L, 0, 3); + + lua_pushstring(L, member.name.c_str()); + lua_setfield(L, -2, "name"); + + const char* formatstr = "unknown"; + getConstant(member.format, formatstr); + lua_pushstring(L, formatstr); + lua_setfield(L, -2, "format"); + + lua_pushinteger(L, member.arrayLength); + lua_setfield(L, -2, "arraylength"); + + lua_rawseti(L, -2, i + 1); + } + + return 1; + } + + return luaL_error(L, "Buffer '%s' does not exist in the Shader.", name); +} + int w_Shader_getDebugName(lua_State *L) { Shader *shader = luax_checkshader(L, 1); @@ -533,6 +568,7 @@ static const luaL_Reg w_Shader_functions[] = { "hasUniform", w_Shader_hasUniform }, { "hasStage", w_Shader_hasStage }, { "getLocalThreadgroupSize", w_Shader_getLocalThreadgroupSize }, + { "getBufferFormat", w_Shader_getBufferFormat }, { "getDebugName", w_Shader_getDebugName }, { 0, 0 } };