Home TransformerEngine Module模块-TransformerEngineBaseModule
Post
Cancel

TransformerEngine Module模块-TransformerEngineBaseModule

TE Module模块结构

Module模块主要包括一个TEModule基类和transformer中的基本组件:

1
2
3
4
5
6
7
8
9
10
11
├── __init__.py
├── _common.py
├── base.py
├── fp8_padding.py
├── fp8_unpadding.py
├── grouped_linear.py
├── layernorm.py
├── layernorm_linear.py
├── layernorm_mlp.py
├── linear.py
└── rmsnorm.py

TEModule基类

核心基类TransformerEngineBaseModule,初始化包括大量管理FP8相关状态的成员变量,用来控制组件fp8量化或矩乘的行为。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class TransformerEngineBaseModule(torch.nn.Module, ABC):
    """Base TE module."""

    def __init__(self) -> None:
        super().__init__()
        assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
        self.name = None
        # module的fp8元数据和状态
        self.fp8_initialized = False
        self.fp8 = False
        self.fp8_calibration = False
        self.fp8_meta = {}
        self.fp8_meta["fp8_checkpoint"] = False
        self.fp8_meta["fp8_group"] = None
        self.fp8_meta_tensors_initialized = False
        # fp8量化器组
        self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}}
        self.tp_group = None
        self.tp_size = 1
        self.sequence_parallel = False
        self.param_init_meta = {}
        # 模型权重的fp8量化初始化
        self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
        self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val()
        self.fsdp_wrapped = False
        self.fsdp_group = None
        self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
        self.activation_dtype: Optional[torch.dtype] = None

fp8 recipeState初始化及量化器组初始化,区分前向和反向过程的量化器组。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
        """Init scales and amaxes for fwd | bwd."""
        fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"

        # Return early if recipe state matches recipe
        if self.fp8_meta_tensors_initialized:
            recipe_state = self.fp8_meta[fp8_meta_tensor_key]
            if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState):
                self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd)
                return
            if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
                return
            if recipe.float8_current_scaling() and isinstance(
                recipe_state, Float8CurrentScalingRecipeState
            ):
                return
            if recipe.float8_block_scaling() and isinstance(
                recipe_state, Float8BlockScalingRecipeState
            ):
                return

        # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
        # 2 (grad_output and grad_input) for bwd
        num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2

        # Initialize recipe state and quantizers
        recipe_state = RecipeState.create(
            recipe,
            mode=("forward" if fwd else "backward"),
            num_quantizers=num_fp8_tensors,
        )

        self.fp8_meta[fp8_meta_tensor_key] = recipe_state
        self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers()

fp8元数据初始化,包括model_init的权重fp8量化和计算过程的fp8量化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    # This routine is shared across FP8 and FP8_calibration paths so should not actually
    # assume FP8 execution.
    def init_fp8_metadata(self, num_gemms: int = 1) -> None:
        """Initialize fp8 related metadata and tensors during fprop."""
        _original_recipe = self.fp8_meta.get("recipe", None)

        self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
        self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
        self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
        fp8_enabled = self.fp8 or self.fp8_calibration
        self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration

        if self.fp8_parameters or fp8_enabled:
            if (
                self.fp8_initialized
                and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"]
            ):
                # FP8 init has already been run and recipe is the same, don't do anything.
                return
            self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
        else:
            # If fp8 isn't enabled, turn off and return.
            self.fp8_initialized = False
            return

        if self.fp8_parameters and not self.fp8_initialized:
            self.fp8_meta["num_gemms"] = num_gemms
            self.init_fp8_meta_tensors(self.fp8_meta["recipe"])

        if fp8_enabled:
            # Set FP8 and other FP8 metadata
            self.fp8_meta["num_gemms"] = num_gemms
            self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()

            # Set FP8_MAX per tensor according to recipe
            self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
            self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd

            # Allocate scales and amaxes
            self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
            self.fp8_initialized = True

            self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()

        _current_recipe = self.fp8_meta["recipe"]
        if _original_recipe is not None and not (
            issubclass(_current_recipe.__class__, _original_recipe.__class__)
            or issubclass(_original_recipe.__class__, _current_recipe.__class__)
        ):
            warnings.warn(
                f"Recipe type changed from {_original_recipe.__class__.__name__} "
                f"to {_current_recipe.__class__.__name__}. "
                "This may affect model behavior."
            )
            # Clear cached workspaces as they were created with the old recipe/quantizer type
            self._fp8_workspaces.clear()

前向装饰器,触发fp8元数据初始化,为fp8量化做准备。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    @contextmanager
    def prepare_forward(
        self,
        inp: torch.Tensor,
        num_gemms: int = 1,
        allow_non_contiguous: bool = False,
    ) -> Generator[torch.Tensor, None, None]:
        """Checks and prep for FWD.
        The context manager is needed because there isn't a way for a module to know
        if it's the last FP8 module in the forward autocast. It is useful
        to setup the forward aggregated amax reduction for every module
        just in case. The autocast exit will pick up the most recent one.
        """
        # Activation recomputation is used and this is the second forward phase.
        if self.fp8 and in_fp8_activation_recompute_phase():
            FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
        else:
            assert inp.is_cuda, "TransformerEngine needs CUDA."

            if self.tp_size > 1:
                assert self.tp_group_initialized, "TP group not initialized."

            self.set_activation_dtype(inp)
            self.init_fp8_metadata(num_gemms=num_gemms)
            self._check_weight_tensor_recipe_correspondence()

            if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed():
                assert self.fp8_meta["recipe"].reduce_amax, (
                    "Amax reduction across tensor parallel group is "
                    "necessary when using sequence parallelism with FP8."
                )

            if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
                FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)

            # Activation recomputation is used and this is the first forward phase.
            if self.fp8 and self.training and is_fp8_activation_recompute_enabled():
                FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)

        with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
            if not allow_non_contiguous and not inp.is_contiguous():
                inp = inp.contiguous()
            yield inp

        if self.fp8 and in_fp8_activation_recompute_phase():
            FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)

反向过程梯度的fp8量化预处理,量化后的梯度在计算dgrad和wgrad时使用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    @staticmethod
    def grad_output_preprocess(
        ctx,
        grad_output: torch.Tensor,
        row_parallel_mode: bool,
        quantizer: Optional[Quantizer],
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """Utility function for backward.
        Returns tuple in order (all optional/None based on training precion/recipe):
            R1: gathered `grad_output`.
            R2: bias gradient on R1.

        """
        grad_output = grad_output.reshape((-1, grad_output.shape[-1]))
        grad_output = grad_output.contiguous()
        gather_grad_output = row_parallel_mode and ctx.sequence_parallel

        # Non-FP8 case: bgrad is fused with wgrad for this case.
        if not ctx.fp8 and not ctx.debug:
            if gather_grad_output:
                if not ctx.ub_overlap_ag:  # Perform NCCL all-gather
                    grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
                else:  # Initialize Userbuffers all-gather
                    grad_output, _ = fill_userbuffers_buffer_for_all_gather(
                        ctx.ub_obj_gradout,
                        grad_output,
                        None,
                        ctx.tp_group,
                    )
            return grad_output, None

        # FP8 with all-gather: unfused bgrad, fused cast + transpose
        # Also supports debug quantization, which is handled inside gather_along_first_dim.
        if gather_grad_output:
            grad_bias = None
            if ctx.use_bias:
                grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
            if ctx.ub_overlap_ag:
                # Quantize the gradient if needed
                if not isinstance(
                    grad_output,
                    (
                        QuantizedTensor,
                        Float8TensorBase,
                        MXFP8TensorBase,
                        Float8BlockwiseQTensorBase,
                    ),
                ):
                    grad_output = quantizer(grad_output)

                # Copy into communication buffer, and replace original gradient with it
                grad_output, _ = fill_userbuffers_buffer_for_all_gather(
                    ctx.ub_obj_gradout,
                    grad_output,
                    quantizer,
                    ctx.tp_group,
                )
            else:
                grad_output, _ = gather_along_first_dim(
                    grad_output,
                    ctx.tp_group,
                    quantizer=quantizer,
                )
            return grad_output, grad_bias

        # Debug without all-gather: unfused cast and bgrad
        # bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None
        if ctx.debug:
            grad_output_ = quantizer(grad_output)
            if (
                isinstance(
                    grad_output_.get_tensor(True),
                    (
                        QuantizedTensor,
                        Float8TensorBase,
                        MXFP8TensorBase,
                        Float8BlockwiseQTensorBase,
                    ),
                )
                and ctx.use_bias
            ):
                grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
            else:
                grad_bias = None
            grad_output = grad_output_
            return grad_output, grad_bias

        # FP8 without all-gather: fused bgrad + cast + transpose
        grad_bias = None
        if ctx.use_bias:
            if isinstance(
                grad_output,
                (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
            ):
                grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
            else:
                if isinstance(quantizer, Float8BlockQuantizer):
                    # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer.
                    grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
                else:
                    grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
        if not isinstance(
            grad_output,
            (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
        ):
            grad_output = quantizer(grad_output)
        return grad_output, grad_bias

权重量化缓存,在mbs大于1时,只需要在第一次计算时量化一次,后续直接使用缓存的量化权重,第一次时进行量化或更新缓存。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    def get_weight_workspace(
        self,
        *,
        tensor: Optional[torch.Tensor] = None,
        quantizer: Optional[Quantizer] = None,
        cache_name: Optional[str] = None,
        update_workspace: bool = True,
        skip_update_flag: Optional[torch.Tensor] = None,
        fsdp_group: Optional[dist_group_type] = None,
        workspace_dtype: Optional[torch.dtype] = None,
    ) -> QuantizedTensor:
        """Get workspace buffer for weights and maybe update its values

        The workspace buffer may be cached for future function calls.

        Parameters
        ----------
        tensor : torch.Tensor, optional
            Values to copy into workspace. Required if the workspace
            is being constructed or updated.
        quantizer: Quantizer, optional
            Quantizer used to cast the weights. Required if the
            workspace is being constructed or updated.
        cache_name: str, optional
            Key for caching.
        update_workspace: bool, default = `True`
            Update workspace with values from `tensor`.
        skip_update_flag: torch.Tensor, optional
            GPU flag to skip updating the workspace. Take precedence
            over `update_workspace` if provided.
        fsdp_group: bool, default = None
            FSDP process group that the weights are distributed over.
        workspace_dtype: torch.dtype, default = None
            If weight workspace contains high-precision tensor - for example
            for debug quantization, this is dtype of the tensor.
        """

        # Handle case where weights are already quantized
        # Note: Make sure weights have required usages, but do not
        # destroy unnecessary usages since they may be used later.
        if isinstance(tensor, QuantizedTensor):
            update_rowwise_usage = True if quantizer.rowwise_usage else None
            update_columnwise_usage = True if quantizer.columnwise_usage else None
            tensor.update_usage(
                rowwise_usage=update_rowwise_usage,
                columnwise_usage=update_columnwise_usage,
            )
            return tensor

        # Try getting workspace from cache
        out = None

        if cache_name is not None:
            out = self._fp8_workspaces.get(cache_name, None)
            if quantizer is not None and isinstance(out, MXFP8TensorBase):
                if quantizer.rowwise_usage and out._rowwise_data is None:
                    out = None
                    del self._fp8_workspaces[cache_name]
                elif quantizer.columnwise_usage and out._columnwise_data is None:
                    out = None
                    del self._fp8_workspaces[cache_name]

            is_debug = isinstance(quantizer, DebugQuantizer)
            is_out_debug_tensor = out is not None and isinstance(out, DebugQuantizedTensor)
            if is_debug != is_out_debug_tensor:
                out = None

        # Gather cached Fp8 workspace if it's distributed
        # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
        #       for models initialized with Fp8 primary weights.
        if (
            out is not None
            and tensor is not None
            and fsdp_group is not None
            and out.data.shape != tensor.data.shape
        ):
            _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out)

        # Construct workspace if needed
        if out is None:
            if tensor is None or quantizer is None:
                raise ValueError(
                    "tensor and quantizer kwargs must be provided to construct FP8 workspace"
                )

            if cache_name is not None:
                # Ensure the tensor in the cache is an instance of torch.Tensor,
                # as it persists beyond a single forward pass.
                # Setting internal=True would cause the data to be removed in prepare_for_saving(...).
                quantizer_internal = quantizer.internal
                quantizer.internal = False
            out = quantizer.quantize(tensor, dtype=workspace_dtype)
            if cache_name is not None:
                quantizer.internal = quantizer_internal

            # Update cache
            if cache_name is not None:
                self._fp8_workspaces[cache_name] = out
            return out

        # Update workspace if needed
        if skip_update_flag is not None:
            update_workspace = True
        if update_workspace:
            if tensor is None:
                raise ValueError("tensor kwarg must be provided to update FP8 workspace")
            if hasattr(out, "quantize_"):
                out.quantize_(tensor, noop_flag=skip_update_flag)
            else:
                tex.quantize(tensor, quantizer, out, skip_update_flag)
        return out
This post is licensed under CC BY 4.0 by the author.