TE FP8管理结构
TransformerEngine的FP8控制通过类对象FP8GlobalStateMananger
进行管理,关联FP8的Recipe和RecipeState两个模块。
FP8 Recipe
Recipe定义在transformer_engine/pytorch/fp8/recipe.py
中,用以描述FP8的量化策略及其配置参数。不同的量化策略根据对应的配置参数对实现的细节进行管理,Recipe负责规定不同的量化策略需要的配置参数。
FP8 RecipeState
RecipeState负责将Recipe的配置参数实例化,其中包含一个抽象方法make_quantizers
,用来生成对应量化策略下的量化器。 以Float8BlockScalingRecipeState
为例,它继承自RecipeState
,并实现了make_quantizers
方法。该方法根据当前的模式(前向或后向)和量化器数量,生成相应的量化器列表,可以根据具体应用场景,如前向、反向过程,配置不同的量化策略。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def make_quantizers(self) -> list:
# TODO(ksivamani); Find better design for this, adding here to avoid circular import.
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
if self.mode == "forward":
# The index convention (coming from base.py set_meta_tensor)
# is somewhat awkward, and doesn't play nicely with QuantizeOp,
# which is not associated with a GEMM.
assert self.num_quantizers % 3 == 0 # x, w, output per gemm
return list(
itertools.chain.from_iterable(
[
[
Float8BlockQuantizer(
fp8_dtype=self.qx_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale,
block_scaling_dim=self.recipe.x_block_scaling_dim,
),
Float8BlockQuantizer(
fp8_dtype=self.qw_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale,
block_scaling_dim=self.recipe.w_block_scaling_dim,
),
Float8BlockQuantizer(
fp8_dtype=self.qx_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale,
block_scaling_dim=self.recipe.x_block_scaling_dim,
),
]
for _ in range(self.num_quantizers // 3)
]
)
)
assert self.mode == "backward", f"Unexpected mode {self.mode}"
assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm
return list(
itertools.chain.from_iterable(
[
[
Float8BlockQuantizer(
fp8_dtype=self.qgrad_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale,
block_scaling_dim=self.recipe.grad_block_scaling_dim,
),
Float8BlockQuantizer(
fp8_dtype=self.qgrad_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale,
block_scaling_dim=self.recipe.grad_block_scaling_dim,
),
]
for _ in range(self.num_quantizers // 2)
]
)
)
FP8GlobalStateMananger
在pytorch框架里让模型实现fp8计算的方式有autocast和model_init两种上下文管理器方式:
- model_init方式使用案例如下,在模型初始化时,通过
with fp8_model_init(enabled=True)
开启fp8计算,可以通过preserve_high_precision_init_val
参数保存模型权重的高精度初始值。1 2 3 4 5 6 7 8
with fp8_model_init(enabled=True): model = transformer_engine.pytorch.Linear(768, 768) # Preserving high precision initial value to initialize master weight with fp8_model_init(enabled=True, preserve_high_precision_init_val=True): model = transformer_engine.pytorch.Linear(768, 768) master_weight = model.weight.get_high_precision_init_val() model.weight.clear_high_precision_init_val()
- autocast方式使用案例如下,通过
with autocast(enabled=True)
开启fp8计算。1 2
with fp8_autocast(enabled=True): out = model(inp)
上述两种方式的实现都依赖于
FP8GlobalStateMananger
类来管理FP8的状态和配置。
以autocast方式为例,在autocast
上下文管理器中,首先保存进入上下文前的FP8状态,然后通过FP8GlobalStateMananger.fp8_autocast_enter
方法将recipe传入,来设置FP8量化策略。1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
@contextmanager def fp8_autocast( enabled: bool = True, calibrating: bool = False, fp8_recipe: Optional[Recipe] = None, fp8_group: Optional[dist_group_type] = None, _graph: bool = False, ) -> None: """ Context manager for FP8 usage. .. code-block:: python with fp8_autocast(enabled=True): out = model(inp) .. note:: Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors with shapes where both dimensions are divisible by 16. In terms of the input to the full Transformer network, this typically requires padding sequence length to be multiple of 16. .. note:: When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once inside a single `fp8_autocast` region. This is unsupported behavior because the amax reduction is handled during the exit of the `fp8_autocast` context. Calling the same module more than once inside an `fp8_autocast` region overrides the amax tensors before reduction can occur. Parameters ---------- enabled: bool, default = `True` whether or not to enable fp8 calibrating: bool, default = `False` calibration mode allows collecting statistics such as amax and scale data of fp8 tensors even when executing without fp8 enabled. This is useful for saving an inference ready fp8 checkpoint while training using a higher precision. fp8_recipe: recipe.Recipe, default = `None` recipe used for FP8 training. fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` distributed group over which amaxes for the fp8 tensors are reduced at the end of each training step. """ fp8_state = FP8GlobalStateManager.get_fp8_autocast_state() FP8GlobalStateManager.fp8_autocast_enter( enabled=enabled, calibrating=calibrating, fp8_recipe=fp8_recipe, fp8_group=fp8_group, _graph=_graph, ) try: yield finally: FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph)
FP8GlobalStateMananger.fp8_autocast_enter
方法主要设置FP8的状态,包括是否开启FP8、量化策略、分布式组等。此处的状态会在封装的模型运行时使用,包括权重量化及量化器生成等。1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
@classmethod def fp8_autocast_enter( cls, enabled: bool = False, calibrating: bool = False, fp8_recipe: Optional[Recipe] = None, fp8_group: Optional[dist_group_type] = None, _graph: bool = False, ) -> None: """Set state and tracking variables for entry into FP8 region.""" fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) cls.FP8_ENABLED = enabled cls.FP8_CALIBRATION = calibrating cls.FP8_RECIPE = fp8_recipe cls.FP8_DISTRIBUTED_GROUP = fp8_group cls.FP8_GRAPH_CAPTURING = _graph if cls.FP8_AUTOCAST_DEPTH == 0: cls.IS_FIRST_FP8_MODULE = True cls.FP8_AUTOCAST_DEPTH += 1 if enabled: fp8_available, reason_for_no_fp8 = cls.is_fp8_available() assert fp8_available, reason_for_no_fp8 if isinstance(fp8_recipe, MXFP8BlockScaling): mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available() assert mxfp8_available, reason_for_no_mxfp8