danieldk HF Staff commited on
Commit
8304685
·
1 Parent(s): 79667ee

Remove source

Browse files
README.md CHANGED
@@ -1,10 +1,13 @@
1
  ---
2
  tags:
3
- - kernel
4
  ---
5
 
6
  ![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/activation)
7
 
8
  ## Activation
9
 
10
- Activation kernels from [vLLM](https://github.com/vllm-project/vllm/blob/main/csrc/activation_kernels.cu).
 
 
 
 
1
  ---
2
  tags:
3
+ - kernel
4
  ---
5
 
6
  ![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/activation)
7
 
8
  ## Activation
9
 
10
+ Activation kernels from [vLLM](https://github.com/vllm-project/vllm/blob/main/csrc/activation_kernels.cu).
11
+
12
+ Kernel source: https://github.com/huggingface/kernels-community/tree/main/activation
13
+
activation/activation_kernels.cu DELETED
@@ -1,244 +0,0 @@
1
- #include <ATen/cuda/CUDAContext.h>
2
- #include <torch/all.h>
3
- #include <c10/cuda/CUDAGuard.h>
4
-
5
- #include <cmath>
6
-
7
- #include "cuda_compat.h"
8
- #include "dispatch_utils.h"
9
-
10
- namespace vllm {
11
-
12
- template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
13
- bool act_first>
14
- __device__ __forceinline__ scalar_t compute(const scalar_t& x,
15
- const scalar_t& y) {
16
- return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
17
- }
18
- // Activation and gating kernel template.
19
-
20
- template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
21
- bool act_first>
22
- __global__ void act_and_mul_kernel(
23
- scalar_t* __restrict__ out, // [..., d]
24
- const scalar_t* __restrict__ input, // [..., 2, d]
25
- const int d) {
26
- const int64_t token_idx = blockIdx.x;
27
- for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
28
- const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
29
- const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
30
- out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
31
- }
32
- }
33
-
34
- template <typename T>
35
- __device__ __forceinline__ T silu_kernel(const T& x) {
36
- // x * sigmoid(x)
37
- return (T)(((float)x) / (1.0f + expf((float)-x)));
38
- }
39
-
40
- template <typename T>
41
- __device__ __forceinline__ T gelu_kernel(const T& x) {
42
- // Equivalent to PyTorch GELU with 'none' approximation.
43
- // Refer to:
44
- // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
45
- const float f = (float)x;
46
- constexpr float ALPHA = M_SQRT1_2;
47
- return (T)(f * 0.5f * (1.0f + erf(f * ALPHA)));
48
- }
49
-
50
- template <typename T>
51
- __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
52
- // Equivalent to PyTorch GELU with 'tanh' approximation.
53
- // Refer to:
54
- // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
55
- const float f = (float)x;
56
- constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
57
- constexpr float KAPPA = 0.044715;
58
- float x_cube = f * f * f;
59
- float inner = BETA * (f + KAPPA * x_cube);
60
- return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
61
- }
62
-
63
- } // namespace vllm
64
-
65
- // Launch activation and gating kernel.
66
- // Use ACT_FIRST (bool) indicating whether to apply the activation function
67
- // first.
68
- #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \
69
- int d = input.size(-1) / 2; \
70
- int64_t num_tokens = input.numel() / input.size(-1); \
71
- dim3 grid(num_tokens); \
72
- dim3 block(std::min(d, 1024)); \
73
- if (num_tokens == 0) { \
74
- return; \
75
- } \
76
- const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
77
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
78
- VLLM_DISPATCH_FLOATING_TYPES( \
79
- input.scalar_type(), "act_and_mul_kernel", [&] { \
80
- vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, ACT_FIRST> \
81
- <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
82
- input.data_ptr<scalar_t>(), d); \
83
- });
84
-
85
- void silu_and_mul(torch::Tensor& out, // [..., d]
86
- torch::Tensor& input) // [..., 2 * d]
87
- {
88
- LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true);
89
- }
90
-
91
- void mul_and_silu(torch::Tensor& out, // [..., d]
92
- torch::Tensor& input) // [..., 2 * d]
93
- {
94
- // The difference between mul_and_silu and silu_and_mul is that mul_and_silu
95
- // applies the silu to the latter half of the input.
96
- LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false);
97
- }
98
-
99
- void gelu_and_mul(torch::Tensor& out, // [..., d]
100
- torch::Tensor& input) // [..., 2 * d]
101
- {
102
- LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true);
103
- }
104
-
105
- void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
106
- torch::Tensor& input) // [..., 2 * d]
107
- {
108
- LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true);
109
- }
110
-
111
- namespace vllm {
112
-
113
- template <typename T>
114
- __device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) {
115
- const float f = (float)x;
116
- return (T)(f > threshold ? f : 0.0f);
117
- }
118
-
119
- template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
120
- __global__ void act_and_mul_kernel_with_param(
121
- scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
122
- const float param) {
123
- const int64_t token_idx = blockIdx.x;
124
- for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
125
- const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
126
- const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
127
- out[token_idx * d + idx] = ACT_FN(x, param) * y;
128
- }
129
- }
130
-
131
- } // namespace vllm
132
-
133
- #define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
134
- int d = input.size(-1) / 2; \
135
- int64_t num_tokens = input.numel() / input.size(-1); \
136
- dim3 grid(num_tokens); \
137
- dim3 block(std::min(d, 1024)); \
138
- const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
139
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
140
- VLLM_DISPATCH_FLOATING_TYPES( \
141
- input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
142
- vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
143
- <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
144
- input.data_ptr<scalar_t>(), d, \
145
- PARAM); \
146
- });
147
-
148
- void fatrelu_and_mul(torch::Tensor& out, // [..., d],
149
- torch::Tensor& input, // [..., 2 * d]
150
- double threshold) {
151
- LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
152
- }
153
- namespace vllm {
154
-
155
- // Element-wise activation kernel template.
156
- template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
157
- __global__ void activation_kernel(
158
- scalar_t* __restrict__ out, // [..., d]
159
- const scalar_t* __restrict__ input, // [..., d]
160
- const int d) {
161
- const int64_t token_idx = blockIdx.x;
162
- for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
163
- const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
164
- out[token_idx * d + idx] = ACT_FN(x);
165
- }
166
- }
167
-
168
- } // namespace vllm
169
-
170
- // Launch element-wise activation kernel.
171
- #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
172
- int d = input.size(-1); \
173
- int64_t num_tokens = input.numel() / d; \
174
- dim3 grid(num_tokens); \
175
- dim3 block(std::min(d, 1024)); \
176
- const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
177
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
178
- VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
179
- vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
180
- <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
181
- input.data_ptr<scalar_t>(), d); \
182
- });
183
-
184
- namespace vllm {
185
-
186
-
187
- template <typename T>
188
- __device__ __forceinline__ T gelu_new_kernel(const T& x) {
189
- const float x3 = (float)(x * x * x);
190
- const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
191
- return ((T)0.5) * x * (((T)1.0) + t);
192
- }
193
-
194
- template <typename T>
195
- __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
196
- const float f = (float)x;
197
- const T t =
198
- (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
199
- return ((T)0.5) * x * (((T)1.0) + t);
200
- }
201
-
202
- template <typename T>
203
- __device__ __forceinline__ T gelu_quick_kernel(const T& x) {
204
- // x * sigmoid(1.702 * x)
205
- return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x)));
206
- }
207
-
208
- } // namespace vllm
209
-
210
- void gelu_new(torch::Tensor& out, // [..., d]
211
- torch::Tensor& input) // [..., d]
212
- {
213
- LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
214
- }
215
-
216
- void gelu_fast(torch::Tensor& out, // [..., d]
217
- torch::Tensor& input) // [..., d]
218
- {
219
- LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
220
- }
221
-
222
- void gelu_quick(torch::Tensor& out, // [..., d]
223
- torch::Tensor& input) // [..., d]
224
- {
225
- LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
226
- }
227
-
228
- void gelu(torch::Tensor& out, // [..., d]
229
- torch::Tensor& input) // [..., d]
230
- {
231
- LAUNCH_ACTIVATION_KERNEL(vllm::gelu_kernel);
232
- }
233
-
234
- void gelu_tanh(torch::Tensor& out, // [..., d]
235
- torch::Tensor& input) // [..., d]
236
- {
237
- LAUNCH_ACTIVATION_KERNEL(vllm::gelu_tanh_kernel);
238
- }
239
-
240
- void silu(torch::Tensor& out, // [..., d]
241
- torch::Tensor& input) // [..., d]
242
- {
243
- LAUNCH_ACTIVATION_KERNEL(vllm::silu_kernel);
244
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
activation/cuda_compat.h DELETED
@@ -1,49 +0,0 @@
1
- #pragma once
2
-
3
- #ifdef USE_ROCM
4
- #include <hip/hip_runtime.h>
5
- #endif
6
-
7
- #if defined(USE_ROCM) && defined(__GFX9__)
8
- #define WARP_SIZE 64
9
- #else
10
- #define WARP_SIZE 32
11
- #endif
12
-
13
- #ifndef USE_ROCM
14
- #define VLLM_LDG(arg) __ldg(arg)
15
- #else
16
- #define VLLM_LDG(arg) *(arg)
17
- #endif
18
-
19
- #ifndef USE_ROCM
20
- #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
21
- __shfl_xor_sync(uint32_t(-1), var, lane_mask)
22
- #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
23
- __shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
24
- #else
25
- #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
26
- #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
27
- __shfl_xor(var, lane_mask, width)
28
- #endif
29
-
30
- #ifndef USE_ROCM
31
- #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
32
- #else
33
- #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
34
- #endif
35
-
36
- #ifndef USE_ROCM
37
- #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
38
- __shfl_down_sync(uint32_t(-1), var, lane_delta)
39
- #else
40
- #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
41
- #endif
42
-
43
- #ifndef USE_ROCM
44
- #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
45
- cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
46
- #else
47
- #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
48
- hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
49
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
activation/dispatch_utils.h DELETED
@@ -1,83 +0,0 @@
1
- /*
2
- * Adapted from
3
- * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
4
- */
5
- #pragma once
6
-
7
- #include <torch/all.h>
8
-
9
- // Need a special dispatch case macro since we will nest the FP8 dispatch.
10
- // Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
11
- #define AT_DISPATCH_FP8_CASE(enum_type, ...) \
12
- AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
13
-
14
- #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
15
- AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
16
- AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
17
- AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
18
-
19
- #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
20
- AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
21
-
22
- // ROCm devices might use either fn or fnuz, so set up dispatch table for both.
23
- // A host-based check at runtime will create a preferred FP8 type for ROCm
24
- // such that the correct kernel is dispatched.
25
- #ifdef USE_ROCM
26
- #define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
27
- AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
28
- AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
29
-
30
- #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
31
- AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
32
- AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
33
- AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
34
- #else
35
- #define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
36
- AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
37
-
38
- #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
39
- AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
40
- AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
41
- #endif
42
-
43
- // When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
44
- // See AT_DISPATCH_FP8_CASE above.
45
- #define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
46
- AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
47
-
48
- #define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
49
- AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
50
-
51
- #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
52
- AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
53
- AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
54
- AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
55
- AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
56
-
57
- #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
58
- AT_DISPATCH_SWITCH(TYPE, NAME, \
59
- VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
60
-
61
- #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
62
- AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
63
- AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
64
- AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
65
- AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
66
- AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
67
-
68
- #define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
69
- AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
70
- AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
71
- AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
72
- AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
73
- AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
74
- AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
75
- AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
76
- AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
77
-
78
- #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
79
- AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
80
-
81
- #define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
82
- AT_DISPATCH_SWITCH( \
83
- TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build.toml DELETED
@@ -1,18 +0,0 @@
1
- [general]
2
- name = "activation"
3
- universal = false
4
-
5
- [torch]
6
- src = [
7
- "torch-ext/torch_binding.cpp",
8
- "torch-ext/torch_binding.h",
9
- ]
10
-
11
- [kernel.activation]
12
- backend = "cuda"
13
- depends = ["torch"]
14
- src = [
15
- "activation/activation_kernels.cu",
16
- "activation/cuda_compat.h",
17
- "activation/dispatch_utils.h",
18
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.lock DELETED
@@ -1,168 +0,0 @@
1
- {
2
- "nodes": {
3
- "flake-compat": {
4
- "locked": {
5
- "lastModified": 1747046372,
6
- "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
- "owner": "edolstra",
8
- "repo": "flake-compat",
9
- "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
- "type": "github"
11
- },
12
- "original": {
13
- "owner": "edolstra",
14
- "repo": "flake-compat",
15
- "type": "github"
16
- }
17
- },
18
- "flake-compat_2": {
19
- "locked": {
20
- "lastModified": 1747046372,
21
- "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
22
- "owner": "edolstra",
23
- "repo": "flake-compat",
24
- "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
25
- "type": "github"
26
- },
27
- "original": {
28
- "owner": "edolstra",
29
- "repo": "flake-compat",
30
- "type": "github"
31
- }
32
- },
33
- "flake-utils": {
34
- "inputs": {
35
- "systems": "systems"
36
- },
37
- "locked": {
38
- "lastModified": 1731533236,
39
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
- "owner": "numtide",
41
- "repo": "flake-utils",
42
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
- "type": "github"
44
- },
45
- "original": {
46
- "owner": "numtide",
47
- "repo": "flake-utils",
48
- "type": "github"
49
- }
50
- },
51
- "flake-utils_2": {
52
- "inputs": {
53
- "systems": "systems_2"
54
- },
55
- "locked": {
56
- "lastModified": 1731533236,
57
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
- "owner": "numtide",
59
- "repo": "flake-utils",
60
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
- "type": "github"
62
- },
63
- "original": {
64
- "owner": "numtide",
65
- "repo": "flake-utils",
66
- "type": "github"
67
- }
68
- },
69
- "hf-nix": {
70
- "inputs": {
71
- "flake-compat": "flake-compat_2",
72
- "flake-utils": "flake-utils_2",
73
- "nixpkgs": "nixpkgs"
74
- },
75
- "locked": {
76
- "lastModified": 1759493343,
77
- "narHash": "sha256-8fhl0gwMAnOkQbogPIVq+Fha+Yeq52FaRXfwF+F9Q+k=",
78
- "owner": "huggingface",
79
- "repo": "hf-nix",
80
- "rev": "b1fc3a18b52447a0f24bc6884418edc5e66082b9",
81
- "type": "github"
82
- },
83
- "original": {
84
- "owner": "huggingface",
85
- "repo": "hf-nix",
86
- "type": "github"
87
- }
88
- },
89
- "kernel-builder": {
90
- "inputs": {
91
- "flake-compat": "flake-compat",
92
- "flake-utils": "flake-utils",
93
- "hf-nix": "hf-nix",
94
- "nixpkgs": [
95
- "kernel-builder",
96
- "hf-nix",
97
- "nixpkgs"
98
- ]
99
- },
100
- "locked": {
101
- "lastModified": 1759516823,
102
- "narHash": "sha256-UJVvZHtS9c64Dm4iZRaOKWB+VHI7jzcazGH57KXWeg8=",
103
- "owner": "huggingface",
104
- "repo": "kernel-builder",
105
- "rev": "e13610a05f67b7296be9ead89ad172a0a088a1c3",
106
- "type": "github"
107
- },
108
- "original": {
109
- "owner": "huggingface",
110
- "repo": "kernel-builder",
111
- "type": "github"
112
- }
113
- },
114
- "nixpkgs": {
115
- "locked": {
116
- "lastModified": 1755963616,
117
- "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=",
118
- "owner": "nixos",
119
- "repo": "nixpkgs",
120
- "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4",
121
- "type": "github"
122
- },
123
- "original": {
124
- "owner": "nixos",
125
- "ref": "nixos-unstable-small",
126
- "repo": "nixpkgs",
127
- "type": "github"
128
- }
129
- },
130
- "root": {
131
- "inputs": {
132
- "kernel-builder": "kernel-builder"
133
- }
134
- },
135
- "systems": {
136
- "locked": {
137
- "lastModified": 1681028828,
138
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
- "owner": "nix-systems",
140
- "repo": "default",
141
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
- "type": "github"
143
- },
144
- "original": {
145
- "owner": "nix-systems",
146
- "repo": "default",
147
- "type": "github"
148
- }
149
- },
150
- "systems_2": {
151
- "locked": {
152
- "lastModified": 1681028828,
153
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
- "owner": "nix-systems",
155
- "repo": "default",
156
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
- "type": "github"
158
- },
159
- "original": {
160
- "owner": "nix-systems",
161
- "repo": "default",
162
- "type": "github"
163
- }
164
- }
165
- },
166
- "root": "root",
167
- "version": 7
168
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.nix DELETED
@@ -1,17 +0,0 @@
1
- {
2
- description = "Flake for activation kernels";
3
-
4
- inputs = {
5
- kernel-builder.url = "github:huggingface/kernel-builder";
6
- };
7
-
8
- outputs =
9
- {
10
- self,
11
- kernel-builder,
12
- }:
13
- kernel-builder.lib.genFlakeOutputs {
14
- inherit self;
15
- path = ./.;
16
- };
17
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/__init__.py DELETED
File without changes
tests/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (142 Bytes)
 
tests/kernels/__init__.py DELETED
File without changes
tests/kernels/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (150 Bytes)
 
tests/kernels/__pycache__/allclose_default.cpython-312.pyc DELETED
Binary file (842 Bytes)
 
tests/kernels/__pycache__/test_activation.cpython-312-pytest-8.4.2.pyc DELETED
Binary file (11.7 kB)
 
tests/kernels/__pycache__/utils.cpython-312.pyc DELETED
Binary file (2.75 kB)
 
tests/kernels/allclose_default.py DELETED
@@ -1,14 +0,0 @@
1
- import torch
2
-
3
- # Reference default values of atol and rtol are from
4
- # https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
5
- default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
6
- default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6}
7
-
8
-
9
- def get_default_atol(output) -> float:
10
- return default_atol[output.dtype]
11
-
12
-
13
- def get_default_rtol(output) -> float:
14
- return default_rtol[output.dtype]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/kernels/test_activation.py DELETED
@@ -1,206 +0,0 @@
1
- # SPDX-License-Identifier: Apache-2.0
2
- # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
-
4
- import math
5
- import random
6
- from typing import Type
7
-
8
- import activation
9
- import pytest
10
- import torch
11
- import torch.nn.functional as F
12
-
13
- from .utils import opcheck
14
- from .allclose_default import get_default_atol, get_default_rtol
15
-
16
- DTYPES = [torch.half, torch.bfloat16, torch.float]
17
- NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
18
- D = [512, 13824] # Arbitrary values for testing
19
- SEEDS = [0]
20
- CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
21
-
22
-
23
- def gelu_fast(x: torch.Tensor) -> torch.Tensor:
24
- return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
25
-
26
-
27
- def gelu_new(x: torch.Tensor) -> torch.Tensor:
28
- c = math.sqrt(2.0 / math.pi)
29
- return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
30
-
31
-
32
- def gelu_quick(x: torch.Tensor) -> torch.Tensor:
33
- return x * torch.sigmoid(1.702 * x)
34
-
35
-
36
- def fatrelu_and_mul(x: torch.Tensor, threshold: float) -> torch.Tensor:
37
- d = x.shape[-1] // 2
38
- x1 = x[..., :d]
39
- x2 = x[..., d:]
40
- x1 = F.threshold(x1, threshold, 0.0)
41
- return x1 * x2
42
-
43
-
44
- def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
45
- d = x.shape[-1] // 2
46
- return F.silu(x[..., :d]) * x[..., d:]
47
-
48
-
49
- def mul_and_silu(x: torch.Tensor) -> torch.Tensor:
50
- d = x.shape[-1] // 2
51
- return x[..., :d] * F.silu(x[..., d:])
52
-
53
-
54
- def gelu_and_mul(x: torch.Tensor, approximate: str) -> torch.Tensor:
55
- d = x.shape[-1] // 2
56
- return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
57
-
58
- def gelu(x: torch.Tensor) -> torch.Tensor:
59
- return F.gelu(x)
60
-
61
- def gelu_tanh(x: torch.Tensor) -> torch.Tensor:
62
- return F.gelu(x, approximate="tanh")
63
-
64
- def silu(x: torch.Tensor) -> torch.Tensor:
65
- return F.silu(x)
66
-
67
- @pytest.mark.parametrize(
68
- "activation_name", ["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"]
69
- )
70
- @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
71
- @pytest.mark.parametrize("d", D)
72
- @pytest.mark.parametrize("dtype", DTYPES)
73
- @pytest.mark.parametrize("seed", SEEDS)
74
- @pytest.mark.parametrize("device", CUDA_DEVICES)
75
- @torch.inference_mode()
76
- def test_act_and_mul(
77
- activation_name: str,
78
- num_tokens: int,
79
- d: int,
80
- dtype: torch.dtype,
81
- seed: int,
82
- device: str,
83
- ) -> None:
84
- random.seed(seed)
85
- torch.manual_seed(seed)
86
- torch.set_default_device(device)
87
- x = torch.randn(num_tokens, 2 * d, dtype=dtype)
88
- if activation_name == "silu_and_mul":
89
- torch_fn = silu_and_mul
90
- fn = activation.silu_and_mul
91
- op = activation.ops.silu_and_mul
92
- layer = activation.layers.SiluAndMul()
93
- elif activation_name == "mul_and_silu":
94
- torch_fn = mul_and_silu
95
- fn = activation.mul_and_silu
96
- op = activation.ops.mul_and_silu
97
- layer = activation.layers.MulAndSilu()
98
- elif activation_name == "gelu":
99
- torch_fn = lambda x: gelu_and_mul(x, "none")
100
- fn = activation.gelu_and_mul
101
- op = activation.ops.gelu_and_mul
102
- layer = activation.layers.GeluAndMul()
103
- elif activation_name == "gelu_tanh":
104
- torch_fn = lambda x: gelu_and_mul(x, "tanh")
105
- fn = activation.gelu_tanh_and_mul
106
- op = activation.ops.gelu_tanh_and_mul
107
- layer = activation.layers.GeluTanhAndMul()
108
- elif activation_name == "fatrelu":
109
- threshold = random.uniform(0, 1)
110
- torch_fn = lambda x: fatrelu_and_mul(x, threshold)
111
- fn = lambda out, x: activation.fatrelu_and_mul(out, x, threshold)
112
- op = activation.ops.fatrelu_and_mul
113
- layer = activation.layers.FatreluAndMul(threshold)
114
-
115
- out_shape = x.shape[:-1] + (x.shape[-1] // 2,)
116
- out = torch.empty(out_shape, dtype=x.dtype, device=x.device)
117
- out = fn(out, x)
118
- mod_out = layer(x)
119
- ref_out = torch_fn(x)
120
-
121
- # The SiLU, GELU and FatReLU implementations are equivalent to the native
122
- # PyTorch implementations, so we can do exact comparison.
123
- torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
124
- torch.testing.assert_close(mod_out, ref_out, atol=0.0, rtol=0.0)
125
-
126
- d = x.shape[-1] // 2
127
- output_shape = x.shape[:-1] + (d,)
128
- out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
129
- if activation_name == "fatrelu":
130
- opcheck(op, (out, x, threshold))
131
- else:
132
- opcheck(op, (out, x))
133
-
134
-
135
- @pytest.mark.parametrize(
136
- "activation_fns",
137
- [
138
- (
139
- gelu_fast,
140
- activation.gelu_fast,
141
- activation.ops.gelu_fast,
142
- activation.layers.FastGELU,
143
- ),
144
- (
145
- gelu_new,
146
- activation.gelu_new,
147
- activation.ops.gelu_new,
148
- activation.layers.NewGELU,
149
- ),
150
- (
151
- gelu_quick,
152
- activation.gelu_quick,
153
- activation.ops.gelu_quick,
154
- activation.layers.QuickGELU,
155
- ),
156
- (
157
- gelu_tanh,
158
- activation.gelu_tanh,
159
- activation.ops.gelu_tanh,
160
- activation.layers.GeluTanh,
161
- ),
162
- (
163
- silu,
164
- activation.silu,
165
- activation.ops.silu,
166
- activation.layers.Silu,
167
- ),
168
- (
169
- gelu,
170
- activation.gelu,
171
- activation.ops.gelu,
172
- activation.layers.Gelu
173
- ),
174
- ],
175
- )
176
- @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
177
- @pytest.mark.parametrize("d", D)
178
- @pytest.mark.parametrize("dtype", DTYPES)
179
- @pytest.mark.parametrize("seed", SEEDS)
180
- @pytest.mark.parametrize("device", CUDA_DEVICES)
181
- @torch.inference_mode()
182
- def test_activation(
183
- activation_fns,
184
- num_tokens: int,
185
- d: int,
186
- dtype: torch.dtype,
187
- seed: int,
188
- device: str,
189
- ) -> None:
190
- torch.manual_seed(seed)
191
- torch.set_default_device(device)
192
- x = torch.randn(num_tokens, d, dtype=dtype)
193
- torch_fn, fn, op, cls = activation_fns
194
- layer = cls()
195
- out = fn(torch.empty_like(x), x)
196
- layer_out = layer(x)
197
- ref_out = torch_fn(x)
198
- torch.testing.assert_close(
199
- out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out)
200
- )
201
- torch.testing.assert_close(
202
- out, layer_out, atol=get_default_atol(out), rtol=get_default_rtol(out)
203
- )
204
-
205
- out = torch.empty_like(x)
206
- opcheck(op, (out, x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/kernels/utils.py DELETED
@@ -1,73 +0,0 @@
1
- """Kernel test utils"""
2
-
3
- import itertools
4
- import random
5
- import unittest
6
- from numbers import Number
7
- from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
8
-
9
- import pytest
10
- import torch
11
- from torch._prims_common import TensorLikeType
12
-
13
- # For now, disable "test_aot_dispatch_dynamic" since there are some
14
- # bugs related to this test in PyTorch 2.4.
15
- DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
16
- "test_schema",
17
- "test_autograd_registration",
18
- "test_faketensor",
19
- )
20
-
21
- ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
22
- "test_schema",
23
- "test_autograd_registration",
24
- "test_faketensor",
25
- "test_aot_dispatch_dynamic",
26
- )
27
-
28
-
29
- # Copied/modified from torch._refs.__init__.py
30
- def fp8_allclose(
31
- a: TensorLikeType,
32
- b: TensorLikeType,
33
- rtol: float = 1e-05,
34
- atol: float = 1e-08,
35
- equal_nan: bool = False,
36
- ) -> bool:
37
- """
38
- Reference implementation of torch.allclose
39
- """
40
- torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
41
-
42
- return bool(
43
- torch.all(
44
- torch.isclose(
45
- a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
46
- )
47
- ).item()
48
- )
49
-
50
-
51
- # A special version of op check that has a restricted default set of test_utils
52
- # and a patched version of allclose that supports fp8 types.
53
- def opcheck(
54
- op: Union[
55
- torch._ops.OpOverload,
56
- torch._ops.OpOverloadPacket,
57
- torch._library.custom_ops.CustomOpDef,
58
- ],
59
- args: Tuple[Any, ...],
60
- kwargs: Optional[Dict[str, Any]] = None,
61
- *,
62
- test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
63
- raise_exception: bool = True,
64
- cond: bool = True
65
- ) -> Dict[str, str]:
66
- with unittest.mock.patch("torch.allclose", new=fp8_allclose):
67
- return (
68
- torch.library.opcheck(
69
- op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
70
- )
71
- if cond
72
- else {}
73
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/activation/__init__.py DELETED
@@ -1,75 +0,0 @@
1
- import torch
2
-
3
- from ._ops import ops
4
-
5
- from . import layers
6
-
7
-
8
- def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
9
- ops.silu_and_mul(out, x)
10
- return out
11
-
12
-
13
- def mul_and_silu(out: torch.Tensor, x: torch.Tensor) -> None:
14
- ops.mul_and_silu(out, x)
15
- return out
16
-
17
-
18
- def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
19
- ops.gelu_and_mul(out, x)
20
- return out
21
-
22
-
23
- def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
24
- ops.gelu_tanh_and_mul(out, x)
25
- return out
26
-
27
-
28
- def fatrelu_and_mul(out: torch.Tensor, x: torch.Tensor, threshold: float = 0.0) -> None:
29
- ops.fatrelu_and_mul(out, x, threshold)
30
- return out
31
-
32
-
33
- def gelu(out: torch.Tensor, x: torch.Tensor) -> None:
34
- ops.gelu(out, x)
35
- return out
36
-
37
- def silu(out: torch.Tensor, x: torch.Tensor) -> None:
38
- ops.silu(out, x)
39
- return out
40
-
41
-
42
- def gelu_tanh(out: torch.Tensor, x: torch.Tensor) -> None:
43
- ops.gelu_tanh(out, x)
44
- return out
45
-
46
-
47
- def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
48
- ops.gelu_fast(out, x)
49
- return out
50
-
51
-
52
- def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
53
- ops.gelu_new(out, x)
54
- return out
55
-
56
-
57
- def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
58
- ops.gelu_quick(out, x)
59
- return out
60
-
61
-
62
- __all__ = [
63
- "silu_and_mul",
64
- "mul_and_silu",
65
- "gelu_and_mul",
66
- "gelu_tanh_and_mul",
67
- "fatrelu_and_mul",
68
- "gelu_fast",
69
- "gelu_new",
70
- "gelu_quick",
71
- "gelu_tanh",
72
- "silu",
73
- "gelu",
74
- "layers",
75
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/activation/layers.py DELETED
@@ -1,179 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from ._ops import ops
5
-
6
-
7
- class SiluAndMul(nn.Module):
8
- """An activation function for SwiGLU.
9
-
10
- The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
11
-
12
- Shapes:
13
- x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
14
- return: (num_tokens, d) or (batch_size, seq_len, d)
15
- """
16
-
17
- can_torch_compile: bool = True
18
-
19
- def forward(self, x: torch.Tensor):
20
- d = x.shape[-1] // 2
21
- output_shape = x.shape[:-1] + (d,)
22
- out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
23
- ops.silu_and_mul(out, x)
24
- return out
25
-
26
- class Silu(nn.Module):
27
- """An activation function for SiLU.
28
-
29
- The function computes x -> silu(x).
30
-
31
- Shapes:
32
- x: (num_tokens, d) or (batch_size, seq_len, d)
33
- return: (num_tokens, d) or (batch_size, seq_len, d)
34
- """
35
-
36
- can_torch_compile: bool = True
37
-
38
- def forward(self, x: torch.Tensor):
39
- out = torch.empty_like(x)
40
- ops.silu(out, x)
41
- return out
42
-
43
- class Gelu(nn.Module):
44
- """An activation function for GELU.
45
-
46
- The function computes x -> gelu(x).
47
-
48
- Shapes:
49
- x: (num_tokens, d) or (batch_size, seq_len, d)
50
- return: (num_tokens, d) or (batch_size, seq_len, d)
51
- """
52
-
53
- can_torch_compile: bool = True
54
-
55
- def forward(self, x: torch.Tensor):
56
- out = torch.empty_like(x)
57
- ops.gelu(out, x)
58
- return out
59
-
60
- class GeluTanh(nn.Module):
61
- """An activation function for GELU with `tanh` approximation.
62
-
63
- The function computes x -> gelu_tanh(x).
64
-
65
- Shapes:
66
- x: (num_tokens, d) or (batch_size, seq_len, d)
67
- return: (num_tokens, d) or (batch_size, seq_len, d)
68
- """
69
-
70
- can_torch_compile: bool = True
71
-
72
- def forward(self, x: torch.Tensor):
73
- out = torch.empty_like(x)
74
- ops.gelu_tanh(out, x)
75
- return out
76
-
77
-
78
- class MulAndSilu(nn.Module):
79
- """An activation function for SwiGLU.
80
-
81
- The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
82
-
83
- Shapes:
84
- x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
85
- return: (num_tokens, d) or (batch_size, seq_len, d)
86
- """
87
-
88
- can_torch_compile: bool = True
89
-
90
- def forward(self, x: torch.Tensor) -> torch.Tensor:
91
- d = x.shape[-1] // 2
92
- output_shape = x.shape[:-1] + (d,)
93
- out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
94
- ops.mul_and_silu(out, x)
95
- return out
96
-
97
-
98
- class GeluAndMul(nn.Module):
99
- """An activation function for GeGLU.
100
-
101
- The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
102
-
103
- Shapes:
104
- x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
105
- return: (batch_size, seq_len, d) or (num_tokens, d)
106
- """
107
-
108
- can_torch_compile: bool = True
109
-
110
- def forward(self, x: torch.Tensor):
111
- d = x.shape[-1] // 2
112
- output_shape = x.shape[:-1] + (d,)
113
- out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
114
- ops.gelu_and_mul(out, x)
115
- return out
116
-
117
-
118
- class GeluTanhAndMul(nn.Module):
119
- can_torch_compile: bool = True
120
-
121
- def forward(self, x: torch.Tensor):
122
- d = x.shape[-1] // 2
123
- output_shape = x.shape[:-1] + (d,)
124
- out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
125
- ops.gelu_tanh_and_mul(out, x)
126
- return out
127
-
128
-
129
- class FatreluAndMul(nn.Module):
130
- """An activation function for FATReLU.
131
-
132
- The function computes x -> FATReLU(x[:d]) * x[d:] where
133
- d = x.shape[-1] // 2.
134
- This is used in openbmb/MiniCPM-S-1B-sft.
135
-
136
- Shapes:
137
- x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
138
- return: (num_tokens, d) or (batch_size, seq_len, d)
139
- """
140
-
141
- can_torch_compile: bool = True
142
-
143
- def __init__(self, threshold: float = 0.0):
144
- super().__init__()
145
- self.threshold = threshold
146
-
147
- def forward(self, x: torch.Tensor):
148
- d = x.shape[-1] // 2
149
- output_shape = x.shape[:-1] + (d,)
150
- out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
151
- ops.fatrelu_and_mul(out, x, self.threshold)
152
- return out
153
-
154
-
155
- class FastGELU(nn.Module):
156
- can_torch_compile: bool = True
157
-
158
- def forward(self, x: torch.Tensor) -> torch.Tensor:
159
- out = torch.empty_like(x)
160
- ops.gelu_fast(out, x)
161
- return out
162
-
163
-
164
- class NewGELU(nn.Module):
165
- can_torch_compile: bool = True
166
-
167
- def forward(self, x: torch.Tensor) -> torch.Tensor:
168
- out = torch.empty_like(x)
169
- ops.gelu_new(out, x)
170
- return out
171
-
172
-
173
- class QuickGELU(nn.Module):
174
- can_torch_compile: bool = True
175
-
176
- def forward(self, x: torch.Tensor) -> torch.Tensor:
177
- out = torch.empty_like(x)
178
- ops.gelu_quick(out, x)
179
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/torch_binding.cpp DELETED
@@ -1,52 +0,0 @@
1
- #include <torch/library.h>
2
-
3
- #include "registration.h"
4
- #include "torch_binding.h"
5
-
6
- TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
- // Activation ops
8
- // Activation function used in SwiGLU.
9
- ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
10
- ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
11
-
12
- ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
13
- ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
14
-
15
- // Activation function used in GeGLU with `none` approximation.
16
- ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
17
- ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
18
-
19
- // Activation function used in GeGLU with `tanh` approximation.
20
- ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
21
- ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
22
-
23
- // FATReLU implementation.
24
- ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
25
- ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul);
26
-
27
- // GELU implementation used in GPT-2.
28
- ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
29
- ops.impl("gelu_new", torch::kCUDA, &gelu_new);
30
-
31
- // Approximate GELU implementation.
32
- ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
33
- ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
34
-
35
- // Quick GELU implementation.
36
- ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
37
- ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
38
-
39
- // GELU with `tanh` approximation.
40
- ops.def("gelu_tanh(Tensor! out, Tensor input) -> ()");
41
- ops.impl("gelu_tanh", torch::kCUDA, &gelu_tanh);
42
-
43
- // SiLU implementation.
44
- ops.def("silu(Tensor! out, Tensor input) -> ()");
45
- ops.impl("silu", torch::kCUDA, &silu);
46
-
47
- // GELU with none approximation.
48
- ops.def("gelu(Tensor! out, Tensor input) -> ()");
49
- ops.impl("gelu", torch::kCUDA, &gelu);
50
- }
51
-
52
- REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/torch_binding.h DELETED
@@ -1,26 +0,0 @@
1
- #pragma once
2
-
3
- #include <torch/torch.h>
4
-
5
- void silu_and_mul(torch::Tensor &out, torch::Tensor &input);
6
-
7
- void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
8
-
9
- void gelu_and_mul(torch::Tensor &out, torch::Tensor &input);
10
-
11
- void gelu_tanh_and_mul(torch::Tensor &out, torch::Tensor &input);
12
-
13
- void fatrelu_and_mul(torch::Tensor &out, torch::Tensor &input,
14
- double threshold);
15
-
16
- void gelu_new(torch::Tensor &out, torch::Tensor &input);
17
-
18
- void gelu_fast(torch::Tensor &out, torch::Tensor &input);
19
-
20
- void gelu_quick(torch::Tensor &out, torch::Tensor &input);
21
-
22
- void gelu_tanh(torch::Tensor &out, torch::Tensor &input);
23
-
24
- void silu(torch::Tensor &out, torch::Tensor &input);
25
-
26
- void gelu(torch::Tensor &out, torch::Tensor &input);