Files
happy-llm/docs/chapter5/code/tokenizer.py
2024-09-22 16:02:14 +08:00

68 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import struct
from sentencepiece import SentencePieceProcessor
from typing import List
TOKENIZER_MODEL = "./data/tok4096.model"
class Tokenizer:
def __init__(self, tokenizer_model=None):
"""
初始化分词器。加载预训练的SentencePiece模型并设置一些特殊的token ID。
参数:
tokenizer_model: str, 可选,分词器模型的路径,如果不指定则使用默认路径 TOKENIZER_MODEL。
"""
# 如果提供了分词器模型路径,使用该路径;否则使用默认模型路径
model_path = tokenizer_model if tokenizer_model else TOKENIZER_MODEL
# 确保模型文件存在
assert os.path.isfile(model_path), model_path
# 加载 SentencePiece 模型
self.sp_model = SentencePieceProcessor(model_file=model_path)
self.model_path = model_path
# 获取分词器的特殊token和词汇表大小
self.n_words: int = self.sp_model.vocab_size() # 词汇表大小
self.bos_id: int = self.sp_model.bos_id() # 句子开头 (BOS) 的ID
self.eos_id: int = self.sp_model.eos_id() # 句子结尾 (EOS) 的ID
self.pad_id: int = self.sp_model.pad_id() # 填充 (PAD) 的ID
# 验证分词器词汇表大小是否正确
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
"""
将字符串编码为词元ID列表。可以选择是否添加句子开头 (BOS) 和句子结尾 (EOS) 标记。
参数:
s: str, 要编码的字符串。
bos: bool, 是否在编码的词元列表前添加 BOS 标记。
eos: bool, 是否在编码的词元列表末尾添加 EOS 标记。
返回:
List[int]: 编码后的词元ID列表。
"""
# 确保输入是字符串类型
assert type(s) is str
# 使用SentencePiece将字符串编码为词元ID
t = self.sp_model.encode(s)
# 如果需要BOS标记将其添加到词元列表开头
if bos:
t = [self.bos_id] + t
# 如果需要EOS标记将其添加到词元列表末尾
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int]) -> str:
"""
将词元ID列表解码为字符串。
参数:
t: List[int], 词元ID列表。
返回:
str: 解码后的字符串。
"""
return self.sp_model.decode(t)