Home TransformerEngine Tensor模块分析
Post
Cancel

TransformerEngine Tensor模块分析

TE pytorch Tensor模块结构

TransformerEngine的Tensor模块主要用于处理各种量化张量的操作和计算。该模块的目录结构如下:

1
2
3
4
5
6
7
8
9
10
11
├── __init__.py
├── _internal
│   ├── __init__.py
│   ├── float8_blockwise_tensor_base.py
│   ├── float8_tensor_base.py
│   └── mxfp8_tensor_base.py
├── float8_blockwise_tensor.py
├── float8_tensor.py
├── mxfp8_tensor.py
├── quantized_tensor.py
└── utils.py

目录中描述了以下几种张量类型:

  • float8_blockwise_tensor.py:采用blockwise量化策略的float8张量。
  • float8_tensor.py:采用全张量量化策略的float8张量。
  • mxfp8_tensor.py:采用mxfp8量化策略的float8张量。
  • quantized_tensor.py:量化张量及量化器基类。

量化张量类型区分为XTensorXTensorBase两种,其中XTensorBase实现在internal目录下。

张量基类与量化器基类

QuantizedTensorBase

量化张量基类QuantizedTensorBase定义了量化张量的基本接口和属性。主要包含以下几个抽象方法:

  1. update_usage方法用于根据提供的行列使用情况更新FP8Tensor。参数rowwise_usagecolumnwise_usage分别表示是否需要行向量或列向量的使用数据。也用于在量化张量中生成或删除行/列向量的使用数据。
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    
     def update_usage(
         self,
         rowwise_usage: Optional[bool] = None,
         columnwise_usage: Optional[bool] = None,
     ):
         r"""
         Generate or remove quantized data based on provided usage.
    
         Parameters
         ----------
         rowwise_usage : Optional[bool[, default = `None`
                         Whether to create or keep the data needed for using the tensor
                         in rowwise fashion (e.g. as B argument in TN GEMM). Leaving it as `None`
                         preserves the original value in the tensor.
         columnwise_usage : Optional[bool], default = `None`
                            Whether to create or keep the data needed for using the tensor
                            in columnwise fashion (e.g. as A argument in TN GEMM). Leaving it as
                            `None` preserves the original value in the tensor.
    
         """
         raise NotImplementedError(
             f"{self.__class__.__name__} class does not implement update_usage function"
         )
    
  2. prepare_for_saving方法用于前向保存FP8Tensor,返回保存的数据列表和当前张量基类的状态。restore_from_saved方法用于从保存的数据列表恢复FP8Tensor。
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    
     def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]:
         """Prepare the tensor base for saving for backward"""
         raise NotImplementedError(
             f"{self.__class__.__name__} class does not implement prepare_for_saving function"
         )
    
     def restore_from_saved(
         self, tensors: list[Optional[torch.Tensor]]
     ) -> list[Optional[torch.Tensor]]:
         """Restore the tensor base data from the saved tensors list"""
         raise NotImplementedError(
             f"{self.__class__.__name__} class does not implement restore_from_saved function"
         )
    
  3. update_quantizer方法用于更新量化器。一般适用于从ckpt中加载FP8权重与当前FP8策略不一致的情况。
    1
    2
    3
    4
    5
    6
    7
    
     def update_quantizer(self, quantizer: Quantizer):
         """Update quantizer for the tensor"""
         if self._quantizer is None:
             raise RuntimeError("To be updated, quantizer must be set")
         if self._quantizer is not quantizer:
             warnings.warn("Quantizer is being updated, this may affect model behavior")
             self._quantizer = quantizer
    

QuantizedTensor

量化张量QuantizedTensor继承torch.Tensor,实现了FP8Tensor的基本操作。

  1. 核心函数,反量化和量化操作,根据不同量化策略的张量类型实现不同的反量化和量化方法。
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    
     def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
         """Convert quantized data to standard PyTorch tensor"""
         raise NotImplementedError(
             f"{self.__class__.__name__} class does not implement dequantize function"
         )
    
     def quantize_(self, tensor: torch.Tensor) -> QuantizedTensor:
         """Update quantized data in-place"""
         raise NotImplementedError(
             f"{self.__class__.__name__} class does not implement quantize_ function"
         )
    

Quantizer

量化器基类Quantizer定义了量化操作的基本接口和属性。

  1. 初始化,rowwisecolumnwise表示FP8Tensor是否需要行向量或列向量的使用数据。若均为True,则表示FP8Tensor需要同时使用行向量和列向量。其中列向量一般用于反向传播获取转置结果。
    internal为True时表示量化器的量化结果是QuantizedTensorBase类型,否则为QuantizedTensor类型。
    1
    2
    3
    4
    
     def __init__(self, *, rowwise: bool, columnwise: bool) -> None:
         self.rowwise_usage = rowwise
         self.columnwise_usage = columnwise
         self.internal = False
    
  2. update_quantized方法用于将源张量量化结果更新到目标FP8Tensor中。参数noop_flag用于指示是否需要跳过量化操作。
    1
    2
    3
    4
    5
    6
    7
    8
    
     def update_quantized(
         self,
         src: torch.Tensor,
         dst: QuantizedTensor,
         *,
         noop_flag: Optional[torch.Tensor] = None,
     ) -> QuantizedTensor:
         """Quantize tensor in-place"""
    
  3. 量化实现,quantize方法调用量化张量,核心函数通过_QuantizeFunc实现,调用csrc里的C++量化实现tex.quantizemulti_quantize方法是多张量版本,目前没看到用。__call__方法用于调用quantize方法。
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    
     def quantize(
         self,
         tensor: torch.Tensor,
         *,
         out: Optional[QuantizedTensor] = None,
         dtype: Optional[torch.dtype] = None,  # pylint: disable=unused-argument # used by override
     ) -> QuantizedTensor:
         """Quantize tensor"""
         if out is not None:
             return self.update_quantized(tensor, out)
         if (not self.internal) and torch.is_grad_enabled():
             return _QuantizeFunc.apply(tensor, self)
         return _QuantizeFunc.forward(None, tensor, self)
    
     def multi_quantize(self, list_of_tensors):
         """Quantize multiple tensors"""
         list_of_output_tensors = []
         for tensor in list_of_tensors:
             list_of_output_tensors.append(self.quantize(tensor))
         return list_of_output_tensors
    
     def __call__(self, tensor: torch.Tensor) -> QuantizedTensor:
         """Quantize tensor"""
         return self.quantize(tensor)
    
This post is licensed under CC BY 4.0 by the author.