随着对话式AI和聊天机器人越来越先进,理解对话的流程对于创建引人入胜且自然的互动至关重要。其中一个挑战在于检测用户的发言结束,这项任务主要是识别用户何时结束发言,以便助手可以接着回应,从而确保交流的顺畅。
基于语音活动检测(VAD)的简单解决方案在构建语音机器人时可能会造成困扰,因为它只能检测是否有说话声或沉默,而无法理解对话的上下文。虽然VAD可以识别何时有人在说话,但它无法处理诸如思考中的停顿、重叠说话或对话中的轮流说话等细微差别。这可能导致在用户自然停顿时过早插话,或在用户仍在处理信息时错过他们的发言结束。
与更复杂的模型相比,后者能够分析对话流程和上下文,基于VAD的解决方案无法解读停顿或不完整句子背后的意图,而这些情况在自然的人类对话中非常普遍。因此,与语音机器人互动时可能会感觉机械且缺乏流畅性,从而影响用户的使用体验。
在这篇博客文章中,我们将展示如何使用变压器模型来检测对话中的回合结束(EoT)。提供的示例代码利用了Hugging Face的强大变压器模型,以及ONNX Runtime工具,来预测给定对话回合是否结束。
问题简述对话回合结束检测可能是一项具有挑战性的任务,尤其是在动态对话中。准确预测用户何时结束发言对于有效的对话管理至关重要。如果模型错误地假设用户已经结束发言,而实际上他们还没有结束发言,这可能导致尴尬或延迟的回复。相反,如果等待时间过长,它会减慢对话的节奏。
为了做到这一点,我们使用了一个基于变压器的模型来识别对话结束的模式和标志。在提供的实现中,该模型会根据输入的聊天记录来预测对话是否可能已经结束。
咱们把解决方法分成几个关键步骤吧
- 模型的初始化
- 文本标准化和格式化
- 预测结束的概率
- 主要功能和模型测试
让我们来详细地检查代码的每一部分。
1. 首先,我们来初始化模型我们将从Huggingface下载livekit开源的模型(model)。我们首先需要初始化分词器和ONNX模型。分词器将输入文本转换成模型可以处理的格式。这个模型是一个量化过的ONNX模型,特别适合在生产环境中进行快速推理。
访问 https://huggingface.co/livekit/turn-detector,访问此链接以查看详细的转述检测器模型页面。
def 初始化模型函数():
try:
开始时间 = time.time()
模型文件路径 = 'models/model_quantized.onnx'
会话对象 = ort.InferenceSession(str(模型文件路径))
分词器对象 = AutoTokenizer.from_pretrained('models')
# eou_index 表示编码为空字符串的索引
eou_index = 分词器对象.encode('')[0]
print(f"模型初始化花费时间: {time.time() - 开始时间:.2f} 秒")
return 分词器对象, 会话对象, eou_index
except Exception as e:
print(f"异常信息: {e}")
raise
- ONNX Runtime (
onnxruntime as ort
) 用于高效的模型推理。 - AutoTokenizer 用于将文本转换为转换器模型可以理解的分词格式。
eou_index
是表示结束发言 (eou
) 的标记 (eou_index
),这个标记对于检测对话的结束至关重要。
在输入文本到模型之前,我们需要对其进行处理以确保一致和易于阅读。处理过程会去除标点并统一空格。
def normalize(text):
PUNCS = '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~' # 不包含单引号的标点符号
stripped = ''.join(char for char in text if char not in PUNCS) # 去除文本中的标点符号
return ' '.join(stripped.lower().split()) # 将处理后的字符串转换为小写并用空格分隔
这个功能会去掉所有标点,并把所有文字变成小写,确保模型只关注内容,而不受格式变化的影响。
我们将对话内容整理成以便模型能够使用的格式:
def 格式化聊天上下文(chat上下文, 分词器):
标准化上下文 = [
消息 for 消息 in [
{**消息, '内容': 格式化(消息['内容'])}
for 消息 in chat上下文
]
if 消息['内容']
]
对话文本 = 分词器应用对话模板(标准化上下文, 添加生成提示=True, 添加特殊标记=False, 分词=False)
结束标记 = "
对话被处理并准备好,以便模型能够解读。其中,关键的结束标记符(`)帮助模型识别一个回合的结束和另一个回合的开始。
# 3\. 预测结束对话的概率
模型的核心在于预测当前对话轮次是否已经结束。这是通过将对话上下文输入模型,然后提取结束标记的概率来完成的。
# 预测对话结束的函数
def predict_end_of_turn(chat_context, model_data):
# 解包模型数据
tokenizer, session, eou_index = model_data
# 格式化对话上下文
formatted_text = format_chat_context(chat_context, tokenizer)
# 将格式化后的文本转换为输入张量
inputs = tokenizer(formatted_text, add_special_tokens=False, return_tensors="np", max_length=MAX_HISTORY_TOKENS, truncation=True)
# 创建输入字典
input_dict = {"input_ids": np.array(inputs["input_ids"], dtype=np.int64)}
# 运行会话获取输出
output = session.run(["logits"], input_dict)
# 提取logits
logits = output[0]
# 获取最后一个token的logits
last_token_logits = logits[0, -1]
# 计算softmax概率
probs = softmax(last_token_logits)
# 返回结束标志的概率
return float(probs[eou_index])
`predict_end_of_turn`函数返回当前对话可能结束的概率。如果这个概率很高,表示对话到了结尾。如果这个概率很低,模型认为用户还有更多话要说。
# 4\. 主要功能测试和模型测试
最后,`main()` 函数初始化模型并通过三个示例对话展示了应用。
def 主函数():
模型数据 = 初始化模型()
for i, 示例 in enumerate([聊天示例1, 聊天示例2, 聊天示例3], 1):
概率 = 预测回合结束(示例, 模型数据)
print(f"回合结束概率 {i}: {概率}")
该模型分别在三个不同的聊天对话(包括英语和印地语)中进行了测试,并为每个对话回合打印了结束的概率。
端到端的编码
<https://gist.github.com/monuminu/d526ac9ba4b67fed08dd20911162b172>(这是一个链接,指向包含源代码的GitHub gist页面。)
# 最后
这种使用此架构实现的回合结束检测展示了对话模型如何有效预测对话的流程。通过利用回合检测模型的力量,系统可以高效处理输入并预测用户是否已完成说话的可能性。这只是构建更自然、更吸引人的聊天机器人,能够理解和管理动态对话的一个步骤。
随意尝试提供的代码,以适应您的具体需求。您可以试试不同的模型,调整结束回合的阈值,或者优化预处理步骤来获得更好的效果。