机器翻译之Bahdanau注意力机制在Seq2Seq中的应用

目录

1.创建 添加了Bahdanau的decoder 

2. 训练

 3.定义评估函数BLEU

 4.预测

 5.知识点个人理解


1.创建 添加了Bahdanau的decoder 

import torch
from torch import nn
import dltools


#定义注意力解码器基类
class AttentionDecoder(dltools.Decoder):  #继承dltools.Decoder写注意力编码器的基类
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    @property    #装饰器, 定义的函数方法可以像类的属性一样被调用
    def attention_weights(self):
        #raise用于引发(或抛出)异常
        raise NotImplementedError  #通常用于抽象基类中,作为占位符,提醒子类必须实现这个方法。 


#创建 添加了Bahdanau的decoder
#继承AttentionDecoder这个基类创建Seq2SeqAttentionDecoder子类, 子类必须实现父类中NotImplementedError占位的方法
class Seq2SeqAttentionDecoder(AttentionDecoder):  
    #初始化属性和方法
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        """
        vocab_size:此表大小,  相当于输入数据的特征数features,  也是输出数据的特征数
        embed_size:嵌入层的大小:将输入数据处理成小批量的数据
        num_hiddens:隐藏层神经元的数量
        num_layers:循环网络的层数
        dropout=0:不释放模型的参数(比如:神经元)
        """
        super().__init__(**kwargs)
        
        #初始化注意力机制的评分函数方法
        self.attention = dltools.AdditiveAttention(key_size=num_hiddens,
                                                   query_size=num_hiddens, 
                                                   num_hiddens=num_hiddens,
                                                   dropout=dropout)
        #初始化嵌入层:将输入的数据处理成小批量的tensor数据   (文本--->数值的映射转化)
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size)
        #初始化循环网络
        self.rnn = nn.GRU(embed_size+num_hiddens, num_hiddens, num_layers, dropout=dropout)
        #初始化线性层  (输出层)
        self.dense = nn.Linear(num_hiddens, vocab_size)
        
    #初始化隐藏层的状态state   (计算state,需要编码器的输出结果、序列的有效长度)
    def init_state(self, enc_outputs, enc_valid_lens, *args):
        #enc_outputs是一个元组(输出结果,隐藏状态)
        #outputs的shape=(batch_size, num_steps, num_hiddens)
        #hidden_state的shape=(num_layers, batch_size, num_hiddens)
        outputs, hidden_state = enc_outputs
        #返回一个元组(,),可以用一个变量接收
        #outputs.permute(1, 0, 2)转换数据的维度是因为rnn循环神经网络的输入要求是先num_steps,再batch_size,
        return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
    
    #定义前向传播   (输入数据X,state)
    def forward(self, X, state):
        #变量赋值:接收编码器encoder的输出结果、隐藏状态、序列有效长度
        #enc_outputs的shape=(batch_size, num_steps, num_hiddens)
        #hidden_state的shape=(num_layers, batch_size, num_hiddens)
        enc_outputs, hidden_state, enc_valid_lens = state
        #X的shape=(batch_size, num_steps, vocab_size)
        X = self.embedding(X)   #将X输入embedding嵌入层后, X的shape=(batch_size, num_steps, embed_size)
        #调换X的0维度和1维度数据
        X = X.permute(1, 0, 2)   #X的shape=(num_steps, batch_size, embed_size)
        outputs, self._attention_weights = [], []  #创建空列表,用于存储数据
        
        for x in X:  #遍历每一批数据
            #获取query
            #hidden_state[-1]表示最后一层循环网络的隐藏层状态  (有两层循环网络)
            #hidden_state[-1]的shape=(batch_size, num_hiddens)    #dim=1表示在原索引1的维度增加一个维度
            query = torch.unsqueeze(hidden_state[-1], dim=1)  
#             print('query的shape:', query.shape)   #query的shape=(batch_size, 1, num_hiddens)
            #通过注意力机制获取上下文序列
            context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
#             print('context的shape:', context.shape)  #context的shape=(batch_size, 1, num_hiddens)
            #用最后一个维度 拼接context, x 数据
            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
#             print('x的shape:', x.shape)   #x的shape=(batch_size, 1, num_hiddens+embed_size)
            #将x和hidden_state输入循环神经网络中,获取输出结果和新的hidden_state
            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
#             print('out的shape:', out.shape)   #out的shape=(1, batch_size, num_hiddens)
#             print('hidden_state的shape:', hidden_state.shape) #两层循环层:hidden_state的shape=(2, batch_size, num_hiddens)
            #将输出结果添加到列表中
            outputs.append(out)
            self._attention_weights.append(self.attention_weights)
        
        outputs = self.dense(torch.cat(outputs, dim=0))
#         print('outputs的shape:', outputs.shape)  #outputs的shape=(num_steps, batch_size, vocab_size)
        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]
    
    @property
    def attention_weights(self):
        return self._attention_weights



#测试代码
#创建编码器对象
encoder = dltools.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
#需要预测, 要加encoder.eval()
encoder.eval()
#创建解码器对象
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()

#假设数据
batch_size, num_steps = 4, 7
X = torch.zeros((4, 7), dtype = torch.long)
#初始化状态state
state = decoder.init_state(encoder(X), None)
outputs, state = decoder(X, state)
#state包含三个东西(enc_outputs, hidden_state, enc_valid_lens)
#state[0]是 enc_outputs
#state[1]是 hidden_state, 两层循环层,就会有两个hidden_state, state[1][0]是第一层的hidden_state
outputs.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
outputs的shape: torch.Size([7, 4, 10])

Out[11]:

(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))

2. 训练

#声明变量
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()

#加载数据
train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)

#创建编辑器对象
encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
#创建编辑器对象
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)

#创建网络模型
net = dltools.EncoderDecoder(encoder, decoder)

#模型训练
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

 

 3.定义评估函数BLEU

def bleu(pred_seq, label_seq, k):
    print('pred_seq:', pred_seq)
    print('label_seq:', label_seq)
    #将pred_seq, label_seq分别进行空格分隔
    pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')
    #获取pred_seq, label_seq的长度
    len_pred, len_label = len(pred_seq), len(label_seq)
    
    score = math.exp(min(0, 1 - (len_label / len_pred)))
    for n in range(1, k+1): #n的取值范围,  range()左闭右开
        num_matches, label_subs = 0, collections.defaultdict(int)
        for i in range(len_label - n + 1):
            label_subs[' '.join(label_tokens[i: i+n])] += 1
            
        for i in range(len_pred - n + 1):
            if label_subs[' '.join(pred_tokens[i: i+n])] > 0:
                num_matches += 1
                label_subs[' '.join(pred_tokens[i: i+n])] -=1
        score *= math.pow(num_matches / (len_pred -n + 1), math.pow(0.5, n))
    return score

 4.预测

import math
import collections


engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)
    print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')

go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('il est bon .', []), bleu 0.658
i'm home . => ('je suis chez moi .', []), bleu 1.000

 5.知识点个人理解

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/881799.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

大数据实验2.Hadoop 集群搭建(单机/伪分布式/分布式)

实验二: Hadoop安装和使用 一、实验目的 实现hadoop的环境搭建和安装Hadoop的简单使用; 二、实验平台 操作系统:Linux(建议Ubuntu16.04或者18.04);Hadoop版本:3.1.3;JDK版本&…

安全热点问题

安全热点问题 1.DDOS2.补丁管理3.堡垒机管理4.加密机管理 1.DDOS 分布式拒绝服务攻击,是指黑客通过控制由多个肉鸡或服务器组成的僵尸网络,向目标发送大量看似合法的请求,从而占用大量网络资源使网络瘫痪,阻止用户对网络资源的正…

【靶点Talk】免疫检查点争夺战:TIGIT能否超越PD-1?

曾经的TIGIT靶点顶着“下一个PD-1”的名号横空出世,三年的“征程”中TIGIT走过一次又一次的失败,然而面对质疑和压力仍有一批公司选择前行。今天给大家分享TIGIT靶点的相关内容,更多靶点科普视频请关注义翘神州B站和知乎官方账号。 TIGIT的“…

2024年最新 Python 大数据网络爬虫技术基础案例详细教程(更新中)

网络爬虫概述 网络爬虫(Web Crawler),又称为网页蜘蛛(Web Spider)或网络机器人(Web Robot),是一种自动化程序或脚本,用于浏览万维网(World Wide Web&#xf…

串口助手的qt实现思路

要求实现如下功能&#xff1a; 获取串口号&#xff1a; foreach (const QSerialPortInfo &serialPortInfo, QSerialPortInfo::availablePorts()) {qDebug() << "Port: " << serialPortInfo.portName(); // e.g. "COM1"qDebug() <<…

2024年中国研究生数学建模竞赛C题——解题思路

2024年中国研究生数学建模竞赛C题——解题思路 数据驱动下磁性元件的磁芯损耗建模——解决思路 二、问题描述 为解决磁性元件磁芯材料损耗精确计算问题&#xff0c;通过实测磁性元件在给定工况&#xff08;不同温度、频率、磁通密度&#xff09;下磁芯材料损耗的数据&#xf…

【HTTPS】中间人攻击和证书的验证

中间人攻击 服务器可以创建出一堆公钥和私钥&#xff0c;黑客也可以按照同样的方式&#xff0c;创建一对公钥和私钥&#xff0c;冒充自己是服务器&#xff08;搅屎棍&#xff09; 黑客自己也能生成一对公钥和私钥。生成公钥和私钥的算法是开放的&#xff0c;服务器能生产&…

「iOS」——单例模式

iOS学习 前言单例模式的概念单例模式的优缺点单例模式的两种模式懒汉模式饿汉模式单例模式的写法 总结 前言 在一开始学习OC的时候&#xff0c;我们初步接触过单例模式。在学习定时器与视图移动的控件中&#xff0c;我们初步意识到单例模式的重要性。对于我们需要保持的控件&a…

java基础知识20 Intern方法的作用

一 Intern方法作用 1.1 Intern方法 1.在jdk1.6中&#xff1a; intern()方法&#xff1a;在jdk1.6中&#xff0c;根据字符串对象&#xff0c;检查常量池中是否存在相同字符串对象 如果字符串常量池里面已经包含了等于字符串X的字符串&#xff0c;那么就返回常量池中这个字符…

windows安装docker 本地打包代码

参考文章1&#xff1a;https://gitcode.csdn.net/65ea814b1a836825ed792f4a.html 参考文章2&#xff1a; Windows 安装docker&#xff08;详细图解&#xff09;-CSDN博客 一 下载 Docker Desktop 在官网上下载 Docker Desktop&#xff0c;可以从以下链接下载最新版本&#x…

【大模型实战篇】关于Bert的一些实操回顾以及clip-as-service的介绍

最近在整理之前的一些实践工作&#xff0c;一方面是为了笔记记录&#xff0c;另一方面也是自己做一些温故知新&#xff0c;或许对于理解一些现在大模型工作也有助益。 1. 基于bert模型实现中文语句的embedding编码 首先是基于bert模型实现中文语句的embedding编码&#xff0c;…

鸿蒙【项目打包】- .hap 和 .app;(测试如何安装发的hap包)(应用上架流程)

#打包成.hap需要用到真机 原因是&#xff1a;只有用上了真机才能在项目中配置 自动签名 #步骤: ##第一步:选择真机->选择项目结构->点Sigining Configs(签名配置) ##第二步:勾选Automatically generate signature(自动签名)->点击ok ##第三步:点击构建->点击 …

接口幂等性和并发安全的区别?

目录标题 幂等性并发安全总结 接口幂等性和并发安全是两个不同的概念&#xff0c;虽然它们在设计API时都很重要&#xff0c;但侧重点不同。 幂等性 定义&#xff1a;幂等性指的是无论对接口进行多少次相同的操作&#xff0c;结果都是一致的。例如&#xff0c;HTTP的PUT和DELE…

QT快速安装使用指南

在Ubuntu 16.04上安装Qt可以通过多种方式进行。以下是使用Qt在线安装程序和apt包管理器的两种常见方法&#xff1a; 方法一&#xff1a;使用Qt在线安装程序 下载Qt在线安装程序 访问Qt官方网站&#xff1a;Try Qt | Develop Applications and Embedded Systems | Qt找到并下载…

Hadoop的安装

文章目录 一. 到Hadoop官网下载安装文件hadoop-3.4.0.tar.gz。二. 环境变量三. 配置 一. 到Hadoop官网下载安装文件hadoop-3.4.0.tar.gz。 随后点击下载即可 由于Hadoop不直接支持Windows系统&#xff0c;因此&#xff0c;需要修改一些配置才能运行 二. 环境变量 三. 配置 进…

arcgisPro地理配准

1、添加图像 2、在【影像】选项卡中&#xff0c;点击【地理配准】 3、 点击添加控制点 4、选择影像左上角格点&#xff0c;然后右击填入目标点的投影坐标 5、依次输入四个格角点的坐标 6、点击【变换】按钮&#xff0c;选择【一阶多项式&#xff08;仿射&#xff09;】变换 7…

VisualPromptGFSS

COCO-20 i ^i i太大&#xff0c;不建议复现

大学生必看!60万人在用的GPT4o大学数学智能体有多牛

❤️作者主页&#xff1a;小虚竹 ❤️作者简介&#xff1a;大家好,我是小虚竹。2022年度博客之星&#x1f3c6;&#xff0c;Java领域优质创作者&#x1f3c6;&#xff0c;CSDN博客专家&#x1f3c6;&#xff0c;华为云享专家&#x1f3c6;&#xff0c;掘金年度人气作者&#x1…

14年数据结构

第一题 解析&#xff1a; 求时间复杂度就是看程序执行了多少次。 假设最外层执行了k次&#xff0c;我们看终止条件是kn&#xff0c;则&#xff1a; 有, 内层是一个j1到jn的循环&#xff0c;显然执行了n次。 总的时间复杂度是内层外层 答案选C。 第二题 解析&#xff1a; 一步一…

基于协同过滤+python+django+vue的音乐推荐系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【2025最新】基于协同过滤pythondjangovue…