|
|
#include <ATen/ATen.h> |
|
|
#include "compat.h" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ |
|
|
switch(TYPE) \ |
|
|
{ \ |
|
|
case at::ScalarType::Float: \ |
|
|
{ \ |
|
|
using scalar_t_##LEVEL = float; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::Half: \ |
|
|
{ \ |
|
|
using scalar_t_##LEVEL = at::Half; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
default: \ |
|
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ |
|
|
} |
|
|
|
|
|
|
|
|
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \ |
|
|
switch(TYPE) \ |
|
|
{ \ |
|
|
case at::ScalarType::Float: \ |
|
|
{ \ |
|
|
using scalar_t_##LEVEL = float; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::Half: \ |
|
|
{ \ |
|
|
using scalar_t_##LEVEL = at::Half; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::BFloat16: \ |
|
|
{ \ |
|
|
using scalar_t_##LEVEL = at::BFloat16; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
default: \ |
|
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ |
|
|
} |
|
|
|
|
|
|
|
|
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ |
|
|
switch(TYPE) \ |
|
|
{ \ |
|
|
case at::ScalarType::Float: \ |
|
|
{ \ |
|
|
using scalar_t_##LEVEL = float; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::Half: \ |
|
|
{ \ |
|
|
using scalar_t_##LEVEL = at::Half; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::Byte: \ |
|
|
{ \ |
|
|
using scalar_t_##LEVEL = uint8_t; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
default: \ |
|
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ |
|
|
} |
|
|
|
|
|
|
|
|
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ |
|
|
switch(TYPE) \ |
|
|
{ \ |
|
|
case at::ScalarType::Double: \ |
|
|
{ \ |
|
|
using scalar_t_##LEVEL = double; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::Float: \ |
|
|
{ \ |
|
|
using scalar_t_##LEVEL = float; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::Half: \ |
|
|
{ \ |
|
|
using scalar_t_##LEVEL = at::Half; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
default: \ |
|
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ |
|
|
} |
|
|
|
|
|
|
|
|
#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \ |
|
|
switch(TYPE) \ |
|
|
{ \ |
|
|
case at::ScalarType::Double: \ |
|
|
{ \ |
|
|
using scalar_t_##LEVEL = double; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::Float: \ |
|
|
{ \ |
|
|
using scalar_t_##LEVEL = float; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::Half: \ |
|
|
{ \ |
|
|
using scalar_t_##LEVEL = at::Half; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::BFloat16: \ |
|
|
{ \ |
|
|
using scalar_t_##LEVEL = at::BFloat16; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
default: \ |
|
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ |
|
|
} |
|
|
|
|
|
|
|
|
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ |
|
|
switch(TYPE) \ |
|
|
{ \ |
|
|
case at::ScalarType::Double: \ |
|
|
{ \ |
|
|
using scalar_t_##LEVEL = double; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::Float: \ |
|
|
{ \ |
|
|
using scalar_t_##LEVEL = float; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
default: \ |
|
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ |
|
|
} |
|
|
|
|
|
|
|
|
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ |
|
|
switch(TYPE) \ |
|
|
{ \ |
|
|
case at::ScalarType::Half: \ |
|
|
{ \ |
|
|
using scalar_t = at::Half; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::BFloat16: \ |
|
|
{ \ |
|
|
using scalar_t = at::BFloat16; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
default: \ |
|
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ |
|
|
} |
|
|
|
|
|
|
|
|
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ |
|
|
switch(TYPEIN) \ |
|
|
{ \ |
|
|
case at::ScalarType::Float: \ |
|
|
{ \ |
|
|
using scalar_t_in = float; \ |
|
|
switch(TYPEOUT) \ |
|
|
{ \ |
|
|
case at::ScalarType::Float: \ |
|
|
{ \ |
|
|
using scalar_t_out = float; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::Half: \ |
|
|
{ \ |
|
|
using scalar_t_out = at::Half; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::BFloat16: \ |
|
|
{ \ |
|
|
using scalar_t_out = at::BFloat16; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
default: \ |
|
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ |
|
|
} \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::Half: \ |
|
|
{ \ |
|
|
using scalar_t_in = at::Half; \ |
|
|
using scalar_t_out = at::Half; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::BFloat16: \ |
|
|
{ \ |
|
|
using scalar_t_in = at::BFloat16; \ |
|
|
using scalar_t_out = at::BFloat16; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
default: \ |
|
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ |
|
|
} |
|
|
|
|
|
|
|
|
#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ |
|
|
switch(TYPEIN) \ |
|
|
{ \ |
|
|
case at::ScalarType::Double: \ |
|
|
{ \ |
|
|
using scalar_t_in = double; \ |
|
|
switch(TYPEOUT) \ |
|
|
{ \ |
|
|
case at::ScalarType::Double: \ |
|
|
{ \ |
|
|
using scalar_t_out = double; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::Float: \ |
|
|
{ \ |
|
|
using scalar_t_out = float; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::Half: \ |
|
|
{ \ |
|
|
using scalar_t_out = at::Half; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::BFloat16: \ |
|
|
{ \ |
|
|
using scalar_t_out = at::BFloat16; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
default: \ |
|
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ |
|
|
} \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::Float: \ |
|
|
{ \ |
|
|
using scalar_t_in = float; \ |
|
|
switch(TYPEOUT) \ |
|
|
{ \ |
|
|
case at::ScalarType::Float: \ |
|
|
{ \ |
|
|
using scalar_t_out = float; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::Half: \ |
|
|
{ \ |
|
|
using scalar_t_out = at::Half; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::BFloat16: \ |
|
|
{ \ |
|
|
using scalar_t_out = at::BFloat16; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
default: \ |
|
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ |
|
|
} \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::Half: \ |
|
|
{ \ |
|
|
using scalar_t_in = at::Half; \ |
|
|
using scalar_t_out = at::Half; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
case at::ScalarType::BFloat16: \ |
|
|
{ \ |
|
|
using scalar_t_in = at::BFloat16; \ |
|
|
using scalar_t_out = at::BFloat16; \ |
|
|
__VA_ARGS__; \ |
|
|
break; \ |
|
|
} \ |
|
|
default: \ |
|
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ |
|
|
} |
|
|
|
|
|
|
|
|
template<typename T> |
|
|
__device__ __forceinline__ T reduce_block_into_lanes |
|
|
(T *x, |
|
|
T val, |
|
|
int lanes=1, |
|
|
bool share_result=false) |
|
|
{ |
|
|
int tid = threadIdx.x + threadIdx.y*blockDim.x; |
|
|
int blockSize = blockDim.x*blockDim.y; |
|
|
|
|
|
if(blockSize >= 64) |
|
|
{ |
|
|
x[tid] = val; |
|
|
__syncthreads(); |
|
|
} |
|
|
|
|
|
#pragma unroll |
|
|
for(int i = (blockSize >> 1); i >= 64; i >>= 1) |
|
|
{ |
|
|
if(tid < i) |
|
|
x[tid] = x[tid] + x[tid+i]; |
|
|
__syncthreads(); |
|
|
} |
|
|
|
|
|
T final; |
|
|
|
|
|
if(tid < 32) |
|
|
{ |
|
|
if(blockSize >= 64) |
|
|
final = x[tid] + x[tid+32]; |
|
|
else |
|
|
final = val; |
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
for(int i = 16; i >= lanes; i >>= 1) |
|
|
final = final + __shfl_down_sync(0xffffffff, final, i); |
|
|
} |
|
|
|
|
|
if(share_result) |
|
|
{ |
|
|
if(tid < lanes) |
|
|
x[tid] = final; |
|
|
|
|
|
__syncthreads(); |
|
|
} |
|
|
|
|
|
return final; |
|
|
} |
|
|
|
|
|
template<typename T> |
|
|
__device__ __forceinline__ T reduce_block_into_lanes_max_op |
|
|
(T *x, |
|
|
T val, |
|
|
int lanes=1, |
|
|
bool share_result=false) |
|
|
{ |
|
|
int tid = threadIdx.x + threadIdx.y*blockDim.x; |
|
|
int blockSize = blockDim.x*blockDim.y; |
|
|
|
|
|
if(blockSize >= 64) |
|
|
{ |
|
|
x[tid] = val; |
|
|
__syncthreads(); |
|
|
} |
|
|
|
|
|
#pragma unroll |
|
|
for(int i = (blockSize >> 1); i >= 64; i >>= 1) |
|
|
{ |
|
|
if(tid < i) |
|
|
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid+i])); |
|
|
__syncthreads(); |
|
|
} |
|
|
|
|
|
T final; |
|
|
|
|
|
if(tid < 32) |
|
|
{ |
|
|
if(blockSize >= 64) |
|
|
final = fmaxf(fabsf(x[tid]), fabsf(x[tid+32])); |
|
|
else |
|
|
final = val; |
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
for(int i = 16; i >= lanes; i >>= 1) |
|
|
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); |
|
|
} |
|
|
|
|
|
if(share_result) |
|
|
{ |
|
|
if(tid < lanes) |
|
|
x[tid] = final; |
|
|
|
|
|
__syncthreads(); |
|
|
} |
|
|
|
|
|
return final; |
|
|
} |
|
|
|