Skip to content

Commit

Permalink
Add Shader:getBufferFormat.
Browse files Browse the repository at this point in the history
Returns a table with the same setup as the data format table used with love.graphics.newBuffer.
  • Loading branch information
slime73 committed Jun 16, 2024
1 parent 12bbadb commit 78f8cfc
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 0 deletions.
121 changes: 121 additions & 0 deletions src/modules/graphics/Shader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,14 @@ void Shader::getLocalThreadgroupSize(int *x, int *y, int *z)
*z = reflection.localThreadgroupSize[2];
}

const std::vector<Buffer::DataDeclaration> *Shader::getBufferFormat(const std::string &name) const
{
auto it = reflection.bufferFormats.find(name);
if (it != reflection.bufferFormats.end())
return &it->second;
return nullptr;
}

bool Shader::validate(StrongRef<ShaderStage> stages[], std::string& err)
{
Reflection reflection;
Expand All @@ -946,6 +954,71 @@ static DataBaseType getBaseType(glslang::TBasicType basictype)
}
}

static DataFormat getDataFormat(glslang::TBasicType basictype, int components, int rows, int columns, bool matrix)
{
if (matrix)
{
if (basictype != glslang::EbtFloat)
return DATAFORMAT_MAX_ENUM;

if (rows == 2 && columns == 2)
return DATAFORMAT_FLOAT_MAT2X2;
else if (rows == 2 && columns == 3)
return DATAFORMAT_FLOAT_MAT2X3;
else if (rows == 2 && columns == 4)
return DATAFORMAT_FLOAT_MAT2X4;
else if (rows == 3 && columns == 2)
return DATAFORMAT_FLOAT_MAT3X2;
else if (rows == 3 && columns == 3)
return DATAFORMAT_FLOAT_MAT3X3;
else if (rows == 3 && columns == 4)
return DATAFORMAT_FLOAT_MAT3X4;
else if (rows == 4 && columns == 2)
return DATAFORMAT_FLOAT_MAT4X2;
else if (rows == 4 && columns == 3)
return DATAFORMAT_FLOAT_MAT4X3;
else if (rows == 4 && columns == 4)
return DATAFORMAT_FLOAT_MAT4X4;
else
return DATAFORMAT_MAX_ENUM;
}
else if (basictype == glslang::EbtFloat)
{
if (components == 1)
return DATAFORMAT_FLOAT;
else if (components == 2)
return DATAFORMAT_FLOAT_VEC2;
else if (components == 3)
return DATAFORMAT_FLOAT_VEC2;
else if (components == 4)
return DATAFORMAT_FLOAT_VEC2;
}
else if (basictype == glslang::EbtInt)
{
if (components == 1)
return DATAFORMAT_INT32;
else if (components == 2)
return DATAFORMAT_INT32_VEC2;
else if (components == 3)
return DATAFORMAT_INT32_VEC2;
else if (components == 4)
return DATAFORMAT_INT32_VEC2;
}
else if (basictype == glslang::EbtUint)
{
if (components == 1)
return DATAFORMAT_UINT32;
else if (components == 2)
return DATAFORMAT_UINT32_VEC2;
else if (components == 3)
return DATAFORMAT_UINT32_VEC2;
else if (components == 4)
return DATAFORMAT_UINT32_VEC2;
}

return DATAFORMAT_MAX_ENUM;
}

static PixelFormat getPixelFormat(glslang::TLayoutFormat format)
{
using namespace glslang;
Expand Down Expand Up @@ -1037,6 +1110,48 @@ static T convertData(const glslang::TConstUnion &data)
}
}

static bool AddFieldsToFormat(std::vector<Buffer::DataDeclaration> &format, int level, const glslang::TType *type, int arraylength, const std::string &basename, std::string &err)
{
if (type->isStruct())
{
auto fields = type->getStruct();

for (int i = 0; i < std::max(arraylength, 1); i++)
{
std::string name = basename;
if (level > 0)
{
name += type->getFieldName().c_str();
if (arraylength > 0)
name += "[" + std::to_string(i) + "]";
name += ".";
}
for (size_t fieldi = 0; fieldi < fields->size(); fieldi++)
{
const glslang::TType *fieldtype = (*fields)[fieldi].type;
int fieldlength = fieldtype->isSizedArray() ? fieldtype->getCumulativeArraySize() : 0;

if (!AddFieldsToFormat(format, level + 1, fieldtype, fieldlength, name, err))
return false;
}
}
}
else
{
DataFormat dataformat = getDataFormat(type->getBasicType(), type->getVectorSize(), type->getMatrixRows(), type->getMatrixCols(), type->isMatrix());
if (dataformat == DATAFORMAT_MAX_ENUM)
{
err = "Shader validation error:\n";
return false;
}

std::string name = basename.empty() ? type->getFieldName().c_str() : basename + type->getFieldName().c_str();
format.emplace_back(name.c_str(), dataformat, arraylength);
}

return true;
}

bool Shader::validateInternal(StrongRef<ShaderStage> stages[], std::string &err, Reflection &reflection)
{
glslang::TProgram program;
Expand Down Expand Up @@ -1295,6 +1410,12 @@ bool Shader::validateInternal(StrongRef<ShaderStage> stages[], std::string &err,
u.access = (Access)(ACCESS_READ | ACCESS_WRITE);

reflection.storageBuffers[u.name] = u;

std::vector<Buffer::DataDeclaration> format;
if (!AddFieldsToFormat(format, 0, elementtype, 0, "", err))
return false;

reflection.bufferFormats[u.name] = format;
}
else
{
Expand Down
5 changes: 5 additions & 0 deletions src/modules/graphics/Shader.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "Texture.h"
#include "ShaderStage.h"
#include "Resource.h"
#include "Buffer.h"

// STL
#include <string>
Expand Down Expand Up @@ -266,6 +267,8 @@ class Shader : public Object, public Resource

void getLocalThreadgroupSize(int *x, int *y, int *z);

const std::vector<Buffer::DataDeclaration> *getBufferFormat(const std::string &name) const;

static SourceInfo getSourceInfo(const std::string &src);
static std::string createShaderStageCode(Graphics *gfx, ShaderStageType stage, const std::string &code, const CompileOptions &options, const SourceInfo &info, bool gles, bool checksystemfeatures);

Expand Down Expand Up @@ -296,6 +299,8 @@ class Shader : public Object, public Resource

std::map<std::string, std::vector<LocalUniformValue>> localUniformInitializerValues;

std::map<std::string, std::vector<Buffer::DataDeclaration>> bufferFormats;

int textureCount;
int bufferCount;

Expand Down
36 changes: 36 additions & 0 deletions src/modules/graphics/wrap_Shader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,41 @@ int w_Shader_getLocalThreadgroupSize(lua_State* L)
return 3;
}

int w_Shader_getBufferFormat(lua_State *L)
{
Shader *shader = luax_checkshader(L, 1);
const char *name = luaL_checkstring(L, 2);
const std::vector<Buffer::DataDeclaration> *format = shader->getBufferFormat(name);
if (name != nullptr)
{
lua_createtable(L, (int)format->size(), 0);

for (size_t i = 0; i < format->size(); i++)
{
const Buffer::DataDeclaration &member = (*format)[i];

lua_createtable(L, 0, 3);

lua_pushstring(L, member.name.c_str());
lua_setfield(L, -2, "name");

const char* formatstr = "unknown";
getConstant(member.format, formatstr);
lua_pushstring(L, formatstr);
lua_setfield(L, -2, "format");

lua_pushinteger(L, member.arrayLength);
lua_setfield(L, -2, "arraylength");

lua_rawseti(L, -2, i + 1);
}

return 1;
}

return luaL_error(L, "Buffer '%s' does not exist in the Shader.", name);
}

int w_Shader_getDebugName(lua_State *L)
{
Shader *shader = luax_checkshader(L, 1);
Expand All @@ -533,6 +568,7 @@ static const luaL_Reg w_Shader_functions[] =
{ "hasUniform", w_Shader_hasUniform },
{ "hasStage", w_Shader_hasStage },
{ "getLocalThreadgroupSize", w_Shader_getLocalThreadgroupSize },
{ "getBufferFormat", w_Shader_getBufferFormat },
{ "getDebugName", w_Shader_getDebugName },
{ 0, 0 }
};
Expand Down

0 comments on commit 78f8cfc

Please sign in to comment.