Home TransformerEngine FP8调用分析
Post
Cancel

TransformerEngine FP8调用分析

TE FP8管理结构

TransformerEngine的FP8控制通过类对象FP8GlobalStateMananger进行管理,关联FP8的Recipe和RecipeState两个模块。

FP8 Recipe

Recipe定义在transformer_engine/pytorch/fp8/recipe.py中,用以描述FP8的量化策略及其配置参数。不同的量化策略根据对应的配置参数对实现的细节进行管理,Recipe负责规定不同的量化策略需要的配置参数。

FP8 RecipeState

RecipeState负责将Recipe的配置参数实例化,其中包含一个抽象方法make_quantizers,用来生成对应量化策略下的量化器。 以Float8BlockScalingRecipeState为例,它继承自RecipeState,并实现了make_quantizers方法。该方法根据当前的模式(前向或后向)和量化器数量,生成相应的量化器列表,可以根据具体应用场景,如前向、反向过程,配置不同的量化策略。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    def make_quantizers(self) -> list:
        # TODO(ksivamani); Find better design for this, adding here to avoid circular import.
        from .tensor.float8_blockwise_tensor import Float8BlockQuantizer

        if self.mode == "forward":
            # The index convention (coming from base.py set_meta_tensor)
            # is somewhat awkward, and doesn't play nicely with QuantizeOp,
            # which is not associated with a GEMM.
            assert self.num_quantizers % 3 == 0  # x, w, output per gemm
            return list(
                itertools.chain.from_iterable(
                    [
                        [
                            Float8BlockQuantizer(
                                fp8_dtype=self.qx_dtype,
                                rowwise=True,
                                columnwise=True,
                                amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon,
                                force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale,
                                block_scaling_dim=self.recipe.x_block_scaling_dim,
                            ),
                            Float8BlockQuantizer(
                                fp8_dtype=self.qw_dtype,
                                rowwise=True,
                                columnwise=True,
                                amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon,
                                force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale,
                                block_scaling_dim=self.recipe.w_block_scaling_dim,
                            ),
                            Float8BlockQuantizer(
                                fp8_dtype=self.qx_dtype,
                                rowwise=True,
                                columnwise=True,
                                amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon,
                                force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale,
                                block_scaling_dim=self.recipe.x_block_scaling_dim,
                            ),
                        ]
                        for _ in range(self.num_quantizers // 3)
                    ]
                )
            )

        assert self.mode == "backward", f"Unexpected mode {self.mode}"
        assert self.num_quantizers % 2 == 0  # grad_output and grad_input per gemm
        return list(
            itertools.chain.from_iterable(
                [
                    [
                        Float8BlockQuantizer(
                            fp8_dtype=self.qgrad_dtype,
                            rowwise=True,
                            columnwise=True,
                            amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon,
                            force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale,
                            block_scaling_dim=self.recipe.grad_block_scaling_dim,
                        ),
                        Float8BlockQuantizer(
                            fp8_dtype=self.qgrad_dtype,
                            rowwise=True,
                            columnwise=True,
                            amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon,
                            force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale,
                            block_scaling_dim=self.recipe.grad_block_scaling_dim,
                        ),
                    ]
                    for _ in range(self.num_quantizers // 2)
                ]
            )
        )

FP8GlobalStateMananger

在pytorch框架里让模型实现fp8计算的方式有autocast和model_init两种上下文管理器方式:

  1. model_init方式使用案例如下,在模型初始化时,通过with fp8_model_init(enabled=True)开启fp8计算,可以通过preserve_high_precision_init_val参数保存模型权重的高精度初始值。
    1
    2
    3
    4
    5
    6
    7
    8
    
     with fp8_model_init(enabled=True):
         model = transformer_engine.pytorch.Linear(768, 768)
    
     # Preserving high precision initial value to initialize master weight
     with fp8_model_init(enabled=True, preserve_high_precision_init_val=True):
         model = transformer_engine.pytorch.Linear(768, 768)
     master_weight = model.weight.get_high_precision_init_val()
     model.weight.clear_high_precision_init_val()
    
  2. autocast方式使用案例如下,通过with autocast(enabled=True)开启fp8计算。
    1
    2
    
     with fp8_autocast(enabled=True):
         out = model(inp)
    

    上述两种方式的实现都依赖于FP8GlobalStateMananger类来管理FP8的状态和配置。
    以autocast方式为例,在autocast上下文管理器中,首先保存进入上下文前的FP8状态,然后通过FP8GlobalStateMananger.fp8_autocast_enter方法将recipe传入,来设置FP8量化策略。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    
    @contextmanager
    def fp8_autocast(
     enabled: bool = True,
     calibrating: bool = False,
     fp8_recipe: Optional[Recipe] = None,
     fp8_group: Optional[dist_group_type] = None,
     _graph: bool = False,
    ) -> None:
     """
     Context manager for FP8 usage.
    
     .. code-block:: python
    
         with fp8_autocast(enabled=True):
             out = model(inp)
    
     .. note::
    
         Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors
         with shapes where both dimensions are divisible by 16. In terms of the input to the full
         Transformer network, this typically requires padding sequence length to be multiple of 16.
    
     .. note::
    
         When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once
         inside a single `fp8_autocast` region. This is unsupported behavior because the amax
         reduction is handled during the exit of the `fp8_autocast` context. Calling the same
         module more than once inside an `fp8_autocast` region overrides the amax tensors
         before reduction can occur.
    
     Parameters
     ----------
     enabled: bool, default = `True`
              whether or not to enable fp8
     calibrating: bool, default = `False`
                  calibration mode allows collecting statistics such as amax and scale
                  data of fp8 tensors even when executing without fp8 enabled. This is
                  useful for saving an inference ready fp8 checkpoint while training
                  using a higher precision.
     fp8_recipe: recipe.Recipe, default = `None`
                 recipe used for FP8 training.
     fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
                distributed group over which amaxes for the fp8 tensors
                are reduced at the end of each training step.
     """
     fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
     FP8GlobalStateManager.fp8_autocast_enter(
         enabled=enabled,
         calibrating=calibrating,
         fp8_recipe=fp8_recipe,
         fp8_group=fp8_group,
         _graph=_graph,
     )
     try:
         yield
     finally:
         FP8GlobalStateManager.set_fp8_autocast_state(fp8_state)
         FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph)
    

    FP8GlobalStateMananger.fp8_autocast_enter方法主要设置FP8的状态,包括是否开启FP8、量化策略、分布式组等。此处的状态会在封装的模型运行时使用,包括权重量化及量化器生成等。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    
     @classmethod
     def fp8_autocast_enter(
         cls,
         enabled: bool = False,
         calibrating: bool = False,
         fp8_recipe: Optional[Recipe] = None,
         fp8_group: Optional[dist_group_type] = None,
         _graph: bool = False,
     ) -> None:
         """Set state and tracking variables for entry into FP8 region."""
    
         fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
         autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
         cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group)
    
         cls.FP8_ENABLED = enabled
         cls.FP8_CALIBRATION = calibrating
         cls.FP8_RECIPE = fp8_recipe
         cls.FP8_DISTRIBUTED_GROUP = fp8_group
         cls.FP8_GRAPH_CAPTURING = _graph
    
         if cls.FP8_AUTOCAST_DEPTH == 0:
             cls.IS_FIRST_FP8_MODULE = True
         cls.FP8_AUTOCAST_DEPTH += 1
    
         if enabled:
             fp8_available, reason_for_no_fp8 = cls.is_fp8_available()
             assert fp8_available, reason_for_no_fp8
             if isinstance(fp8_recipe, MXFP8BlockScaling):
                 mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available()
                 assert mxfp8_available, reason_for_no_mxfp8
    
This post is licensed under CC BY 4.0 by the author.