오늘 새벽에 나왔는데 여기에는 아직 없는 것 같아서 논문에 써져 있는 코드를 그대로 타이핑해서 올리는 레흐


## 3. PyTorch Implementation of BitNet b1.58
"""
여기에는 LLaMA LLM 아키텍쳐로부터 BitNet b1.58로 바꾸는 2단계가 있습니다:
1. Replace all nn.Linear in attention and SwiGLU with BitLinear(Figure 3);
2. Remove RMSNorm before attention and SwiGLU because BitLinear has built-in RMSNorm(Figure 4).
"""
#----------------------------------------
### 1st Stage(Figure 3)
def activation_quant(x):
"""
Per-token quantization to 8 bits. No grouping is needed for quantization.
Args:
x: an activation tensor with shape [n, d]
Returns:
y: a quantized activation tensor with shape [n, d]
"""
scale = 127.0 / x.abs().max(dim = -1, keepdim = True).values.clamp_(min = 1e-5)
y = (x * scale).round().clamp_(-128, 127) / scale
return y
def weight_quant(w):
"""
Per-tensor quantization to 1.58 bits. No grouping is needed for quantization.
Args:
w: a weight tensor with shape [d, k]
Returns:
u: a quantized weight with shape [d, k]
"""
scale = 1.0 / w.abs().mean().clamp_(min = 1e-5)
u = (w * scale).round().clamp_(-1, 1) / scale
return u
class BitLinear(nn.Linear):
"""
This is only for training, and kernel optimization is needed for efficiency.
"""
def forward(self, x):
"""
Args:
x: an input tensor with shape [n, d]
Returns:
y: an output tensor with shape [n, d]
"""
w = self.weight
x_norm = RMSNorm(x)
x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
w_quant = w + (weight_quant(w) - w).detach()
y = F.linear(x_quant, w_quant)
return y
#----------------------------------------
### 2nd Stage(Figure 4)
"""
Pytorch code for the BitLinear component in training BitNet b1.58. It requires additional efforts to improve the training efficiency, such as kernel fusion.
As for the inference, there are some changes for efficiency.
1. The model weights are offline quantized to 1.58 bits.
2. The standard F.linear operation is replaced with a customized low-bit kernel.
3. The scaling factors for both weight quantization and activation quantization are applied after the F.linear operation.
4. There is no need to implement the Straight-Through Estimator (STE) trick.
5. The RMSNorm operation can be fused with the activation quantization.
"""
def activation_norm_quant(x):
"""
RMSNorm & Per-token quantization to 8 bits. It can be implemented as a fused kernel.
Args:
x: an activation tensor with shape [n, d]
Returns:
y: a quantized activation tensor with shape [n, d]
scale: a scalar for dequantization with shape [1]
"""
x = RMSNorm(x)
scale = 127.0 / x.abs().max(dim = -1, keepdim = True).values.clamp_(min = 1e-5)
y = (x * scale).round().clamp_(-128, 127)
return y, scale
class BitLinear(nn.Linear):
"""
This is only for inference. The weights should been quantized in advance.
"""
def forward(self, x):
"""
Args:
x: an input tensor with shape [n, d]
Returns:
y: an output tensor with shape [n, d]
"""
w = self.weight # a 1.58-bit weight tensor with shape [d, k]
w_scale = self.weight_scale # a full-precision weight scale tensor with shape [1]
x_quant, x_scale = activation_norm_quant(x)
y = gemm_lowbit_kernel(x_quant, w) / w_scale / x_scale
return y
#----------------------------------------
## 4. PyTorch Implementation of BitNet b1 (Figure 5)
"""
Similarly, we provide the implementation of the original 1-bit BitNet by replacing the function of weight_quant(w) as in Figure 3.
"""
def weight_quant(w):
"""
Per-tensor quantization to 1 bits. No grouping is needed for quantization.
Args:
w: a weight tensor with shape [d, k]
Returns:
u: a quantized weight with shape [d, k]
"""
scale = w.abs().mean()
e = w.mean()
u = (w - e).sign() * scale
return u