Hugging Face BertConfig 的常用方法和属性

BertConfigtransformers 库中的 BERT 模型配置类,用于 定义 BERT 模型的结构、超参数,如 隐藏层维度、注意力头数、最大序列长度、dropout 比例 等。


1. BertConfig 的常见属性

from transformers import BertConfig

config = BertConfig()
属性 作用 默认值
config.hidden_size 隐藏层维度 768
config.num_hidden_layers Transformer 层数 12
config.num_attention_heads 自注意力头数 12
config.intermediate_size 前馈网络层维度 3072
config.hidden_dropout_prob 隐藏层 Dropout 0.1
config.attention_probs_dropout_prob 注意力 Dropout 0.1
config.max_position_embeddings 最大序列长度 512
config.vocab_size 词汇表大小 30522
config.type_vocab_size Segment Embeddings 词汇表大小 2
config.is_decoder 是否是解码器模型 False
config.pad_token_id [PAD] 的 ID 0
config.bos_token_id [CLS] 的 ID 101
config.eos_token_id [SEP] 的 ID 102

2. BertConfig 的常用方法

方法 作用
BertConfig.from_pretrained(model_name_or_path) 加载预训练 BERT 配置
config.to_dict() 转换为 Python 字典
config.to_json_string() 转换为 JSON 格式
config.save_pretrained(path) 保存配置到本地
config.update({"hidden_size": 512}) 更新配置参数

3. BertConfig 详细用法

3.1. 从预训练模型加载 BertConfig

from transformers import BertConfig

config = BertConfig.from_pretrained("bert-base-uncased")
print(config.hidden_size)  # 768
print(config.num_attention_heads)  # 12
  • 这里加载的是 bert-base-uncased 预训练模型的默认配置。

3.2. 自定义 BertConfig

如果你要 创建一个自定义 BERT 模型

config = BertConfig(
    hidden_size=512,          # 修改隐藏层维度
    num_hidden_layers=6,      # 6 层 Transformer
    num_attention_heads=8,    # 8 个注意力头
    intermediate_size=2048,   # FFN 维度
    max_position_embeddings=256  # 最大序列长度
)

print(config.hidden_size)  # 512

3.3. 以 JSON 格式查看 BertConfig

json_config = config.to_json_string()
print(json_config)

示例输出

{
  "hidden_size": 512,
  "num_hidden_layers": 6,
  "num_attention_heads": 8,
  "intermediate_size": 2048,
  "max_position_embeddings": 256
}

3.4. 保存 BertConfig

config.save_pretrained("./my_bert_config")

3.5. 重新加载 BertConfig

from transformers import BertConfig

config = BertConfig.from_pretrained("./my_bert_config")

3.6. 通过 config 创建 BERT 模型

from transformers import BertModel

config = BertConfig(hidden_size=512, num_hidden_layers=6)
model = BertModel(config)
print(model.config.hidden_size)  # 512

4. BertConfig 在不同任务中的应用

4.1. 文本分类

from transformers import BertConfig, BertForSequenceClassification

config = BertConfig.from_pretrained("bert-base-uncased", num_labels=2)
model = BertForSequenceClassification(config)

4.2. 问答任务

from transformers import BertConfig, BertForQuestionAnswering

config = BertConfig.from_pretrained("bert-base-uncased")
model = BertForQuestionAnswering(config)

4.3. 命名实体识别(NER)

from transformers import BertConfig, BertForTokenClassification

config = BertConfig.from_pretrained("bert-base-uncased", num_labels=9)
model = BertForTokenClassification(config)

5. 总结

BertConfig 是 BERT 模型的配置对象,包含 模型架构、隐藏层维度、注意力头数、dropout 等信息

常见属性:

  • config.hidden_size 隐藏层维度
  • config.num_attention_heads 注意力头数
  • config.num_hidden_layers Transformer 层数
  • config.vocab_size 词汇表大小
  • config.max_position_embeddings 最大序列长度
  • config.hidden_dropout_prob Dropout 概率

常见方法:

  • BertConfig.from_pretrained("bert-base-uncased") 加载预训练配置
  • config.to_dict() 转换为字典
  • config.to_json_string() 转换为 JSON
  • config.save_pretrained(path) 保存配置
  • config.update({...}) 更新参数
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐