TE Module模块结构
Module模块主要包括一个TEModule基类和transformer中的基本组件:
1
2
3
4
5
6
7
8
9
10
11
├── __init__.py
├── _common.py
├── base.py
├── fp8_padding.py
├── fp8_unpadding.py
├── grouped_linear.py
├── layernorm.py
├── layernorm_linear.py
├── layernorm_mlp.py
├── linear.py
└── rmsnorm.py
TEModule基类
核心基类TransformerEngineBaseModule,初始化包括大量管理FP8相关状态的成员变量,用来控制组件fp8量化或矩乘的行为。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module."""
def __init__(self) -> None:
super().__init__()
assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
self.name = None
# module的fp8元数据和状态
self.fp8_initialized = False
self.fp8 = False
self.fp8_calibration = False
self.fp8_meta = {}
self.fp8_meta["fp8_checkpoint"] = False
self.fp8_meta["fp8_group"] = None
self.fp8_meta_tensors_initialized = False
# fp8量化器组
self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}}
self.tp_group = None
self.tp_size = 1
self.sequence_parallel = False
self.param_init_meta = {}
# 模型权重的fp8量化初始化
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val()
self.fsdp_wrapped = False
self.fsdp_group = None
self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
self.activation_dtype: Optional[torch.dtype] = None
fp8 recipeState初始化及量化器组初始化,区分前向和反向过程的量化器组。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd."""
fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
# Return early if recipe state matches recipe
if self.fp8_meta_tensors_initialized:
recipe_state = self.fp8_meta[fp8_meta_tensor_key]
if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState):
self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd)
return
if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
return
if recipe.float8_current_scaling() and isinstance(
recipe_state, Float8CurrentScalingRecipeState
):
return
if recipe.float8_block_scaling() and isinstance(
recipe_state, Float8BlockScalingRecipeState
):
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2
# Initialize recipe state and quantizers
recipe_state = RecipeState.create(
recipe,
mode=("forward" if fwd else "backward"),
num_quantizers=num_fp8_tensors,
)
self.fp8_meta[fp8_meta_tensor_key] = recipe_state
self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers()
fp8元数据初始化,包括model_init的权重fp8量化和计算过程的fp8量化。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# This routine is shared across FP8 and FP8_calibration paths so should not actually
# assume FP8 execution.
def init_fp8_metadata(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop."""
_original_recipe = self.fp8_meta.get("recipe", None)
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
fp8_enabled = self.fp8 or self.fp8_calibration
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
if self.fp8_parameters or fp8_enabled:
if (
self.fp8_initialized
and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"]
):
# FP8 init has already been run and recipe is the same, don't do anything.
return
self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
else:
# If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False
return
if self.fp8_parameters and not self.fp8_initialized:
self.fp8_meta["num_gemms"] = num_gemms
self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
if fp8_enabled:
# Set FP8 and other FP8 metadata
self.fp8_meta["num_gemms"] = num_gemms
self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
# Set FP8_MAX per tensor according to recipe
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd
# Allocate scales and amaxes
self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
self.fp8_initialized = True
self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
_current_recipe = self.fp8_meta["recipe"]
if _original_recipe is not None and not (
issubclass(_current_recipe.__class__, _original_recipe.__class__)
or issubclass(_original_recipe.__class__, _current_recipe.__class__)
):
warnings.warn(
f"Recipe type changed from {_original_recipe.__class__.__name__} "
f"to {_current_recipe.__class__.__name__}. "
"This may affect model behavior."
)
# Clear cached workspaces as they were created with the old recipe/quantizer type
self._fp8_workspaces.clear()
前向装饰器,触发fp8元数据初始化,为fp8量化做准备。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
@contextmanager
def prepare_forward(
self,
inp: torch.Tensor,
num_gemms: int = 1,
allow_non_contiguous: bool = False,
) -> Generator[torch.Tensor, None, None]:
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
else:
assert inp.is_cuda, "TransformerEngine needs CUDA."
if self.tp_size > 1:
assert self.tp_group_initialized, "TP group not initialized."
self.set_activation_dtype(inp)
self.init_fp8_metadata(num_gemms=num_gemms)
self._check_weight_tensor_recipe_correspondence()
if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed():
assert self.fp8_meta["recipe"].reduce_amax, (
"Amax reduction across tensor parallel group is "
"necessary when using sequence parallelism with FP8."
)
if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)
# Activation recomputation is used and this is the first forward phase.
if self.fp8 and self.training and is_fp8_activation_recompute_enabled():
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous()
yield inp
if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
反向过程梯度的fp8量化预处理,量化后的梯度在计算dgrad和wgrad时使用
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
@staticmethod
def grad_output_preprocess(
ctx,
grad_output: torch.Tensor,
row_parallel_mode: bool,
quantizer: Optional[Quantizer],
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Utility function for backward.
Returns tuple in order (all optional/None based on training precion/recipe):
R1: gathered `grad_output`.
R2: bias gradient on R1.
"""
grad_output = grad_output.reshape((-1, grad_output.shape[-1]))
grad_output = grad_output.contiguous()
gather_grad_output = row_parallel_mode and ctx.sequence_parallel
# Non-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8 and not ctx.debug:
if gather_grad_output:
if not ctx.ub_overlap_ag: # Perform NCCL all-gather
grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
else: # Initialize Userbuffers all-gather
grad_output, _ = fill_userbuffers_buffer_for_all_gather(
ctx.ub_obj_gradout,
grad_output,
None,
ctx.tp_group,
)
return grad_output, None
# FP8 with all-gather: unfused bgrad, fused cast + transpose
# Also supports debug quantization, which is handled inside gather_along_first_dim.
if gather_grad_output:
grad_bias = None
if ctx.use_bias:
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
if ctx.ub_overlap_ag:
# Quantize the gradient if needed
if not isinstance(
grad_output,
(
QuantizedTensor,
Float8TensorBase,
MXFP8TensorBase,
Float8BlockwiseQTensorBase,
),
):
grad_output = quantizer(grad_output)
# Copy into communication buffer, and replace original gradient with it
grad_output, _ = fill_userbuffers_buffer_for_all_gather(
ctx.ub_obj_gradout,
grad_output,
quantizer,
ctx.tp_group,
)
else:
grad_output, _ = gather_along_first_dim(
grad_output,
ctx.tp_group,
quantizer=quantizer,
)
return grad_output, grad_bias
# Debug without all-gather: unfused cast and bgrad
# bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None
if ctx.debug:
grad_output_ = quantizer(grad_output)
if (
isinstance(
grad_output_.get_tensor(True),
(
QuantizedTensor,
Float8TensorBase,
MXFP8TensorBase,
Float8BlockwiseQTensorBase,
),
)
and ctx.use_bias
):
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else:
grad_bias = None
grad_output = grad_output_
return grad_output, grad_bias
# FP8 without all-gather: fused bgrad + cast + transpose
grad_bias = None
if ctx.use_bias:
if isinstance(
grad_output,
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
):
grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
else:
if isinstance(quantizer, Float8BlockQuantizer):
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer.
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else:
grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
if not isinstance(
grad_output,
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
):
grad_output = quantizer(grad_output)
return grad_output, grad_bias
权重量化缓存,在mbs大于1时,只需要在第一次计算时量化一次,后续直接使用缓存的量化权重,第一次时进行量化或更新缓存。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
def get_weight_workspace(
self,
*,
tensor: Optional[torch.Tensor] = None,
quantizer: Optional[Quantizer] = None,
cache_name: Optional[str] = None,
update_workspace: bool = True,
skip_update_flag: Optional[torch.Tensor] = None,
fsdp_group: Optional[dist_group_type] = None,
workspace_dtype: Optional[torch.dtype] = None,
) -> QuantizedTensor:
"""Get workspace buffer for weights and maybe update its values
The workspace buffer may be cached for future function calls.
Parameters
----------
tensor : torch.Tensor, optional
Values to copy into workspace. Required if the workspace
is being constructed or updated.
quantizer: Quantizer, optional
Quantizer used to cast the weights. Required if the
workspace is being constructed or updated.
cache_name: str, optional
Key for caching.
update_workspace: bool, default = `True`
Update workspace with values from `tensor`.
skip_update_flag: torch.Tensor, optional
GPU flag to skip updating the workspace. Take precedence
over `update_workspace` if provided.
fsdp_group: bool, default = None
FSDP process group that the weights are distributed over.
workspace_dtype: torch.dtype, default = None
If weight workspace contains high-precision tensor - for example
for debug quantization, this is dtype of the tensor.
"""
# Handle case where weights are already quantized
# Note: Make sure weights have required usages, but do not
# destroy unnecessary usages since they may be used later.
if isinstance(tensor, QuantizedTensor):
update_rowwise_usage = True if quantizer.rowwise_usage else None
update_columnwise_usage = True if quantizer.columnwise_usage else None
tensor.update_usage(
rowwise_usage=update_rowwise_usage,
columnwise_usage=update_columnwise_usage,
)
return tensor
# Try getting workspace from cache
out = None
if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None)
if quantizer is not None and isinstance(out, MXFP8TensorBase):
if quantizer.rowwise_usage and out._rowwise_data is None:
out = None
del self._fp8_workspaces[cache_name]
elif quantizer.columnwise_usage and out._columnwise_data is None:
out = None
del self._fp8_workspaces[cache_name]
is_debug = isinstance(quantizer, DebugQuantizer)
is_out_debug_tensor = out is not None and isinstance(out, DebugQuantizedTensor)
if is_debug != is_out_debug_tensor:
out = None
# Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
# for models initialized with Fp8 primary weights.
if (
out is not None
and tensor is not None
and fsdp_group is not None
and out.data.shape != tensor.data.shape
):
_fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out)
# Construct workspace if needed
if out is None:
if tensor is None or quantizer is None:
raise ValueError(
"tensor and quantizer kwargs must be provided to construct FP8 workspace"
)
if cache_name is not None:
# Ensure the tensor in the cache is an instance of torch.Tensor,
# as it persists beyond a single forward pass.
# Setting internal=True would cause the data to be removed in prepare_for_saving(...).
quantizer_internal = quantizer.internal
quantizer.internal = False
out = quantizer.quantize(tensor, dtype=workspace_dtype)
if cache_name is not None:
quantizer.internal = quantizer_internal
# Update cache
if cache_name is not None:
self._fp8_workspaces[cache_name] = out
return out
# Update workspace if needed
if skip_update_flag is not None:
update_workspace = True
if update_workspace:
if tensor is None:
raise ValueError("tensor kwarg must be provided to update FP8 workspace")
if hasattr(out, "quantize_"):
out.quantize_(tensor, noop_flag=skip_update_flag)
else:
tex.quantize(tensor, quantizer, out, skip_update_flag)
return out