我设计了一个具有枚举模板专业化的结构,如下所示:
template<DataType type>
struct TypeTrait;
template<>
struct TypeTrait<DATA_TYPE_INT8> {
static constexpr uint32_t size = sizeof(int8_t);
};
template<>
struct TypeTrait<DATA_TYPE_INT16> {
static constexpr uint32_t size = sizeof(int16_t);
};
template<>
struct TypeTrait<DATA_TYPE_FP16> {
static constexpr uint32_t size = sizeof(uint16_t);
};
template<>
struct TypeTrait<DATA_TYPE_UINT8> {
static constexpr uint32_t size = sizeof(uint8_t);
};
template<>
struct TypeTrait<DATA_TYPE_UINT16> {
static constexpr uint32_t size = sizeof(uint16_t);
};
template<>
struct TypeTrait<DATA_TYPE_INT32> {
static constexpr uint32_t size = sizeof(int32_t);
};
template<>
struct TypeTrait<DATA_TYPE_UINT32> {
static constexpr uint32_t size = sizeof(uint32_t);
};
template<>
struct TypeTrait<DATA_TYPE_FP32> {
static constexpr uint32_t size = sizeof(float);
};
枚举数据类型定义如下:
enum DataType {
DATA_TYPE_INT8 = 0,
DATA_TYPE_INT16 = 1,
DATA_TYPE_FP16 = 2,
DATA_TYPE_UINT8 = 3,
DATA_TYPE_UINT16 = 4,
DATA_TYPE_INT32 = 5,
DATA_TYPE_UINT32 = 6,
DATA_TYPE_FP32 = 7,
DATA_TYPE_UNKOWN
};
我想将 DataType 变量传递给 struct TypeTrait,如下所示:
class Test {
public:
...
void Convert() {
...
uint32_t size = TypeTrait<type_>::size;
...
}
private:
DataType type_;
};
当我这样做时,编译程序时出现问题:
main.cc: In member function ‘void Test::Convert()’:
main.cc:63:35: error: use of ‘this’ in a constant expression
63 | uint32_t size = TypeTrait<type_>::size;
| ^~~~~
main.cc:63:40: error: use of ‘this’ in a constant expression
63 | uint32_t size = TypeTrait<type_>::size;
| ^
main.cc:63:35: note: in template argument for type ‘DataType’
63 | uint32_t size = TypeTrait<type_>::size;
| ^~~~~ ^
我尝试了很多方法,比如将
type_
转换为 const 值,如下所示:
const DataType dataType = type_;
uint32_t size = TypeTrait<dataType>::size;
然后就出现了这个问题。
main.cc: In member function ‘void Test::Convert()’:
main.cc:63:39: error: the value of ‘type’ is not usable in a constant expression
63 | uint32_t size = TypeTrait<type>::size;
| ^
main.cc:62:24: note: ‘type’ was not initialized with a constant expression
62 | const DataType type = GetType();
| ^~~~
main.cc:63:39: note: in template argument for type ‘DataType’
63 | uint32_t size = TypeTrait<type>::size;
|
我知道如果我像这样传递枚举元素,程序不会有问题。
uint32_t size = TypeTrait<DataType::DATA_TYPE_UINT32>::size;
我不知道如何解决这个问题。所以我必须使用 switch case 来处理这个问题,这违背了我的意愿。 我只想删除代码中的开关盒。 要重构的代码:
switch (dataType_) {
case DATA_TYPE_INT8:
byteSize = elemCnt * sizeof(int8_t);
break;
case DATA_TYPE_INT16:
byteSize = elemCnt * sizeof(int16_t);
break;
case DATA_TYPE_FP16:
byteSize = elemCnt * sizeof(uint16_t);
break;
case DATA_TYPE_UINT8:
byteSize = elemCnt * sizeof(uint8_t);
break;
case DATA_TYPE_UINT16:
byteSize = elemCnt * sizeof(uint16_t);
break;
case DATA_TYPE_INT32:
byteSize = elemCnt * sizeof(int32_t);
break;
case DATA_TYPE_UINT32:
byteSize = elemCnt * sizeof(uint32_t);
break;
case DATA_TYPE_FP32:
byteSize = elemCnt * sizeof(float);
break;
}
这是一种无需切换即可在运行时获取
TypeTrait<type>::size
的方法(需要 C++17):
uint32_t datatypeSize(DataType type) {
return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
return ((static_cast<std::size_t>(type) == Is ? TypeTrait<static_cast<DataType>(Is)>::size : 0) + ...);
}(std::make_index_sequence<DATA_TYPE_UNKOWN>{});
}
另一个,使用
std::array
(也来自 Jarod42 的评论 - C++17):
uint32_t datatypeSize(DataType type) {
return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
return std::array{TypeTrait<static_cast<DataType>(Is)>::size...}[type];
}(std::make_index_sequence<DATA_TYPE_UNKOWN>{});
}
在不使用模板专门化的情况下添加另一个答案。
我觉得这样的实现更加方便简洁
#define DefineHelper(XX)\
XX(DATA_TYPE_INT8, sizeof(int8_t), "DATA_TYPE_FP32")\
XX(DATA_TYPE_INT16, sizeof(int16_t), "DATA_TYPE_FP32")\
XX(DATA_TYPE_FP16, sizeof(uint16_t), "DATA_TYPE_FP32")\
XX(DATA_TYPE_UINT8, sizeof(uint8_t), "DATA_TYPE_FP32")\
XX(DATA_TYPE_UINT16, sizeof(uint16_t), "DATA_TYPE_FP32")\
XX(DATA_TYPE_INT32, sizeof(int32_t), "DATA_TYPE_FP32")
int32_t GetDataTypeSize(DataType e){
#define TypeSize(e, n, _) case e: return n;
switch(e) {
DefineHelper(TypeSize)
default:
return 0;
}
#undef TypeSize
}
const char* GetDataTypeStr(DataType e) {
#define TypeStr(e, _, s) case e: return s;
switch(e)
{
DefineHelper(TypeStr)
default:
return "unknowntype";
}
#undef TypeStr
}
int main(int argc, const char* argv[]) {
DataType type = static_cast<DataType>(atoi(argv[1]));
printf("typesize:%d desc:%s\n",
GetDataTypeSize(type), GetDataTypeStr(type));
return 0;
}