TE pytorch Tensor模块结构
TransformerEngine的Tensor模块主要用于处理各种量化张量的操作和计算。该模块的目录结构如下:
1
2
3
4
5
6
7
8
9
10
11
├── __init__.py
├── _internal
│ ├── __init__.py
│ ├── float8_blockwise_tensor_base.py
│ ├── float8_tensor_base.py
│ └── mxfp8_tensor_base.py
├── float8_blockwise_tensor.py
├── float8_tensor.py
├── mxfp8_tensor.py
├── quantized_tensor.py
└── utils.py
目录中描述了以下几种张量类型:
- float8_blockwise_tensor.py:采用blockwise量化策略的float8张量。
- float8_tensor.py:采用全张量量化策略的float8张量。
- mxfp8_tensor.py:采用mxfp8量化策略的float8张量。
- quantized_tensor.py:量化张量及量化器基类。
量化张量类型区分为XTensor
和XTensorBase
两种,其中XTensorBase
实现在internal
目录下。
张量基类与量化器基类
QuantizedTensorBase
量化张量基类QuantizedTensorBase
定义了量化张量的基本接口和属性。主要包含以下几个抽象方法:
update_usage
方法用于根据提供的行列使用情况更新FP8Tensor。参数rowwise_usage
和columnwise_usage
分别表示是否需要行向量或列向量的使用数据。也用于在量化张量中生成或删除行/列向量的使用数据。1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
def update_usage( self, rowwise_usage: Optional[bool] = None, columnwise_usage: Optional[bool] = None, ): r""" Generate or remove quantized data based on provided usage. Parameters ---------- rowwise_usage : Optional[bool[, default = `None` Whether to create or keep the data needed for using the tensor in rowwise fashion (e.g. as B argument in TN GEMM). Leaving it as `None` preserves the original value in the tensor. columnwise_usage : Optional[bool], default = `None` Whether to create or keep the data needed for using the tensor in columnwise fashion (e.g. as A argument in TN GEMM). Leaving it as `None` preserves the original value in the tensor. """ raise NotImplementedError( f"{self.__class__.__name__} class does not implement update_usage function" )
prepare_for_saving
方法用于前向保存FP8Tensor,返回保存的数据列表和当前张量基类的状态。restore_from_saved
方法用于从保存的数据列表恢复FP8Tensor。1 2 3 4 5 6 7 8 9 10 11 12 13
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]: """Prepare the tensor base for saving for backward""" raise NotImplementedError( f"{self.__class__.__name__} class does not implement prepare_for_saving function" ) def restore_from_saved( self, tensors: list[Optional[torch.Tensor]] ) -> list[Optional[torch.Tensor]]: """Restore the tensor base data from the saved tensors list""" raise NotImplementedError( f"{self.__class__.__name__} class does not implement restore_from_saved function" )
update_quantizer
方法用于更新量化器。一般适用于从ckpt中加载FP8权重与当前FP8策略不一致的情况。1 2 3 4 5 6 7
def update_quantizer(self, quantizer: Quantizer): """Update quantizer for the tensor""" if self._quantizer is None: raise RuntimeError("To be updated, quantizer must be set") if self._quantizer is not quantizer: warnings.warn("Quantizer is being updated, this may affect model behavior") self._quantizer = quantizer
QuantizedTensor
量化张量QuantizedTensor
继承torch.Tensor
,实现了FP8Tensor的基本操作。
- 核心函数,反量化和量化操作,根据不同量化策略的张量类型实现不同的反量化和量化方法。
1 2 3 4 5 6 7 8 9 10 11
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Convert quantized data to standard PyTorch tensor""" raise NotImplementedError( f"{self.__class__.__name__} class does not implement dequantize function" ) def quantize_(self, tensor: torch.Tensor) -> QuantizedTensor: """Update quantized data in-place""" raise NotImplementedError( f"{self.__class__.__name__} class does not implement quantize_ function" )
Quantizer
量化器基类Quantizer
定义了量化操作的基本接口和属性。
- 初始化,
rowwise
和columnwise
表示FP8Tensor是否需要行向量或列向量的使用数据。若均为True,则表示FP8Tensor需要同时使用行向量和列向量。其中列向量一般用于反向传播获取转置结果。
internal
为True时表示量化器的量化结果是QuantizedTensorBase
类型,否则为QuantizedTensor
类型。1 2 3 4
def __init__(self, *, rowwise: bool, columnwise: bool) -> None: self.rowwise_usage = rowwise self.columnwise_usage = columnwise self.internal = False
update_quantized
方法用于将源张量量化结果更新到目标FP8Tensor中。参数noop_flag
用于指示是否需要跳过量化操作。1 2 3 4 5 6 7 8
def update_quantized( self, src: torch.Tensor, dst: QuantizedTensor, *, noop_flag: Optional[torch.Tensor] = None, ) -> QuantizedTensor: """Quantize tensor in-place"""
- 量化实现,
quantize
方法调用量化张量,核心函数通过_QuantizeFunc
实现,调用csrc里的C++量化实现tex.quantize
。multi_quantize
方法是多张量版本,目前没看到用。__call__
方法用于调用quantize
方法。1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
def quantize( self, tensor: torch.Tensor, *, out: Optional[QuantizedTensor] = None, dtype: Optional[torch.dtype] = None, # pylint: disable=unused-argument # used by override ) -> QuantizedTensor: """Quantize tensor""" if out is not None: return self.update_quantized(tensor, out) if (not self.internal) and torch.is_grad_enabled(): return _QuantizeFunc.apply(tensor, self) return _QuantizeFunc.forward(None, tensor, self) def multi_quantize(self, list_of_tensors): """Quantize multiple tensors""" list_of_output_tensors = [] for tensor in list_of_tensors: list_of_output_tensors.append(self.quantize(tensor)) return list_of_output_tensors def __call__(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor""" return self.quantize(tensor)