PyTorch gather 方法详解:作用、应用场景与示例解析(中英双语)

news/2025/2/22 16:45:41

PyTorch gather 方法详解:作用、应用场景与示例解析

在深度学习和自然语言处理(NLP)任务中,我们经常需要从高维张量中提取特定索引的数据
PyTorch 提供的 torch.gather 方法可以高效地从张量的指定维度收集数据,广泛应用于语言模型(Transformer)、分类任务、强化学习等场景

在本文中,我们将详细介绍:

  • gather 方法的作用
  • 使用 gather 进行索引操作
  • gather 在 NLP 模型中的应用
  • gather 的计算效率与优化

1. torch.gather 的作用

1.1 gather 的基本用法

gather 允许我们在张量的指定维度上,按照给定的索引提取数据
其基本语法如下:

python">torch.gather(input, dim, index)
  • input:输入张量,形状为 (B, L, V)(可以是任意维度)。
  • dim:指定在哪个维度上收集数据(例如 dim=-1 代表在最后一个维度索引)。
  • index:索引张量,形状必须与 inputdim 之外的维度相同

1.2 gather 的核心逻辑

给定 inputindexgather 沿 dim 维度逐元素地获取 input 中指定索引位置的值


2. gather 的基础示例

2.1 从二维张量中提取元素

python">import torch

# 定义一个 3x4 的张量
input_tensor = torch.tensor([[10, 20, 30, 40], 
                             [50, 60, 70, 80], 
                             [90, 100, 110, 120]])

# 定义索引张量
index_tensor = torch.tensor([[0, 1], 
                             [2, 3], 
                             [1, 2]])

# 在 `dim=1` 维度上使用 gather
output = torch.gather(input_tensor, dim=1, index=index_tensor)
print(output)

输出:

tensor([[ 10,  20],
        [ 70,  80],
        [100, 110]])

解释:

  • input_tensor 形状为 (3,4),即 3 行 4 列
  • index_tensor 形状为 (3,2),其中的值指示要从 input_tensordim=1(列) 选取的数据:
    • 第 1 行:取 input_tensor[0,0]input_tensor[0,1],即 [10, 20]
    • 第 2 行:取 input_tensor[1,2]input_tensor[1,3],即 [70, 80]
    • 第 3 行:取 input_tensor[2,1]input_tensor[2,2],即 [100, 110]

3. gather 在 NLP 中的应用

3.1 计算 Token 的对数概率

在语言模型(如 Transformer)中,我们通常需要计算目标 token 的概率,即:
P ( y t ) = e logit y t ∑ e logit v P(y_t) = \frac{e^{\text{logit}_{y_t}}}{\sum e^{\text{logit}_{v}}} P(yt)=elogitvelogityt

其中:

  • logits 形状为 (B, L, V),表示 batch 里每个 token 对整个词表(vocabulary)中所有词的 logit 分数。
  • input_ids 形状为 (B, L),表示实际的 token 索引(即每个 token 在词表中的 ID)。

我们使用 gather 取出每个 input_idlogits 中对应的 logit 分值:

python">import torch

# 假设 batch_size=2, sequence_length=3, vocab_size=5
logits = torch.tensor([[[2.0, 1.0, 0.5, -1.0, 0.2], 
                         [0.1, -0.5, 2.2, 1.5, 0.0], 
                         [1.1, 3.5, 0.8, -0.2, -1.5]],

                        [[0.0, 2.3, -0.5, 1.0, 0.8], 
                         [-1.2, 1.7, 2.0, 0.3, -0.8], 
                         [2.5, -0.1, -1.2, 0.5, 3.0]]])

input_ids = torch.tensor([[0, 2, 1],  # 对应每个 token 在词表中的索引
                          [1, 3, 4]])

# 取出 input_ids 在 logits 中的 logit 值
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
print(token_logits)

输出:

tensor([[ 2.0000,  2.2000,  3.5000],
        [ 2.3000,  0.3000,  3.0000]])

解释:

  • logits.gather(dim=-1, index=input_ids.unsqueeze(-1))
    • dim=-1 代表从 Vocab 维度(最后一维)索引数据。
    • input_ids.unsqueeze(-1) 扩展维度,让 input_ids 形状变为 (B, L, 1),符合 gather 要求。
    • squeeze(-1) 还原到 (B, L) 形状,使结果是 每个 token 的 logit 值

4. gatherscatter 的对比

除了 gather(用于提取数据),PyTorch 还提供 scatter(用于写入数据)。

4.1 scatter 基本用法

python">import torch

# 初始化 3x3 零张量
x = torch.zeros(3, 3)

# 指定索引
index = torch.tensor([[0, 2], 
                      [1, 1], 
                      [2, 0]])

# 指定填充值
updates = torch.tensor([[5, 8], 
                         [3, 7], 
                         [6, 2]])

# 在 dim=1 维度上 scatter
x.scatter_(dim=1, index=index, src=updates)
print(x)

输出:

tensor([[5., 0., 8.],
        [0., 7., 0.],
        [2., 0., 6.]])

scatter_() 用于替换 index 位置的值,而 gather() 用于提取 index 位置的值


5. 总结

  • torch.gather(input, dim, index) 提取 input dim 维度上指定 index 位置的值
  • 常用于 NLP 任务中计算 token 对数概率、分类任务中提取预测分数
  • 通过 gather 提取 logits 对应 input_ids,可高效计算 对数概率损失函数
  • gather从索引获取数据,而 scatter根据索引写入数据

🚀 掌握 gather 让你在深度学习项目中更高效地处理索引操作! 🚀

深入理解 gather(dim=-1):作用、计算过程与 dim=-2 的对比

在 PyTorch 中,torch.gather 是一个强大的索引操作函数,它可以根据提供的 index 张量,从 input 张量的指定维度(dim)中提取相应的数据。
在 NLP(自然语言处理)任务中,我们常用 gather(dim=-1) 来从 logits 中获取 输入 token(input_ids)对应的 logits 值,用于计算损失或评估模型表现。


1. gather(dim=-1) 的作用

1.1 dim=-1 的含义

  • dim=-1 代表最后一个维度(即词表维度)。

  • logits.shape = (batch_size, sequence_length, vocab_size) 这样一个张量中:

    • dim=0:表示 batch 维度(不同样本)。
    • dim=1:表示 sequence 维度(句子中的不同 token)。
    • dim=2(即 dim=-1:表示词汇表(vocab),即每个 token 对所有单词的 logits 评分。
  • gather(dim=-1, index=input_ids.unsqueeze(-1)) 的作用:

    • dim=-1(词表维度)上提取 input_ids 对应的 logits 值
    • 这样,每个 token 只保留它对应的 logits,而不是整个词表的所有 logits。

2. 代码示例与计算过程

2.1 示例:计算 Token Logits

python">import torch

# 假设 batch_size=2, sequence_length=3, vocab_size=5
logits = torch.tensor([
    [[2.0, 1.0, 0.5, -1.0, 0.2], 
     [0.1, -0.5, 2.2, 1.5, 0.0], 
     [1.1, 3.5, 0.8, -0.2, -1.5]],

    [[0.0, 2.3, -0.5, 1.0, 0.8], 
     [-1.2, 1.7, 2.0, 0.3, -0.8], 
     [2.5, -0.1, -1.2, 0.5, 3.0]]
])

input_ids = torch.tensor([
    [0, 2, 1],  # 第一个样本的 token 索引
    [1, 3, 4]   # 第二个样本的 token 索引
])

# 扩展维度,使 input_ids 形状变为 (batch_size, sequence_length, 1)
expanded_index = input_ids.unsqueeze(-1)

# 使用 gather 从 logits 中提取相应的 token logits
token_logits = logits.gather(dim=-1, index=expanded_index).squeeze(-1)
print(token_logits)

输出:

tensor([[2.0000, 2.2000, 3.5000],
        [2.3000, 0.3000, 3.0000]])

2.2 gather(dim=-1) 计算过程解析

对于 logits.shape = (2, 3, 5)

  • dim=-1 代表最后一维,即 vocab_size=5 维度。
  • input_ids.shape = (2, 3),表示每个 batch 的 token 在词表中的索引。

让我们手动解析 gather(dim=-1) 的计算步骤:

BatchToken 索引 (dim=1)Index (dim=-1)Extracted Logit (gather 结果)
第 1 个样本第 1 个 tokeninput_ids[0,0] = 0logits[0,0,0] = 2.0
第 2 个 tokeninput_ids[0,1] = 2logits[0,1,2] = 2.2
第 3 个 tokeninput_ids[0,2] = 1logits[0,2,1] = 3.5
第 2 个样本第 1 个 tokeninput_ids[1,0] = 1logits[1,0,1] = 2.3
第 2 个 tokeninput_ids[1,1] = 3logits[1,1,3] = 0.3
第 3 个 tokeninput_ids[1,2] = 4logits[1,2,4] = 3.0

这正是 gather(dim=-1) 提取的值。


3. dim=-1 vs. dim=-2 的区别

3.1 什么是 dim=-2

如果改成 gather(dim=-2),它会尝试在 sequence 维度(dim=1 上进行索引,这会导致错误的行为。
因为 input_ids 只包含 token 在词表中的索引,而不是 token 在句子中的索引。

3.2 如果错误地使用 dim=-2

python">wrong_gather = logits.gather(dim=-2, index=input_ids.unsqueeze(-1))
print(wrong_gather.shape)
  • dim=-2 代表第二个维度(sequence 维度)
  • 这意味着 PyTorch 会尝试从 logits 中选取整个 token 级别的数据,而不是单独的 token logits

❌ 结果错误,因为 input_ids 里的索引根本不适用于 dim=-2

3.3 为什么 dim=-1 才是正确的?

  • input_ids 里的索引指向 词表索引(vocab index),所以应该沿着词表维度(dim=-1)索引数据。
  • dim=-1 选择的是 单个 token 对应的 logits,不会影响整个句子结构。

4. 结论

dim含义是否正确
dim=-1 (最后一维)提取每个 token 在词表中的 logits正确
dim=-2 (倒数第二维)尝试索引整个句子级别的数据错误

核心要点

dim=-1(最后一维)用于 获取输入 token 对应的 logits 值,常用于 NLP 任务。
gather(dim=-1, index=input_ids.unsqueeze(-1))input_ids 选择 logits 里的正确位置。
dim=-2 会错误地索引整个 token 级别的数据,而不是单个 token logits。

🚀 正确理解 gather(dim=-1),能够帮助你高效地提取模型输出,用于计算损失、评估模型! 🚀

Understanding torch.gather: Purpose, Use Cases, and Implementation in NLP

In deep learning, particularly in Natural Language Processing (NLP) and reinforcement learning, we often need to extract specific values from high-dimensional tensors using given indices. torch.gather is a powerful PyTorch function that efficiently retrieves data along a specified dimension based on an index tensor.

This article will cover:

  • The purpose of torch.gather
  • How gather works with examples
  • Practical applications in NLP and deep learning
  • Performance considerations and comparisons with scatter

1. What is torch.gather?

1.1 Basic Syntax

python">torch.gather(input, dim, index)
  • input: The source tensor from which values are gathered.
  • dim: The dimension along which to index values.
  • index: A tensor containing the indices of elements to extract.

1.2 How gather Works

  • It retrieves values from input at positions specified by index along the dim dimension.
  • The index tensor must have the same shape as input, except for the dim dimension.

2. Basic Examples of torch.gather

2.1 Extracting Elements from a 2D Tensor

python">import torch

# Define a 3x4 tensor
input_tensor = torch.tensor([[10, 20, 30, 40], 
                             [50, 60, 70, 80], 
                             [90, 100, 110, 120]])

# Define the index tensor
index_tensor = torch.tensor([[0, 1], 
                             [2, 3], 
                             [1, 2]])

# Gather values along dimension 1 (columns)
output = torch.gather(input_tensor, dim=1, index=index_tensor)
print(output)

Output:

tensor([[ 10,  20],
        [ 70,  80],
        [100, 110]])

Explanation:

  • dim=1 means we are indexing columns.
  • index_tensor[i, j] determines which element to select from input_tensor[i].

3. gather in NLP: Extracting Token Logits

3.1 Why Use gather in NLP?

In Transformer-based language models (GPT, BERT, etc.), we often need to compute the log probability of specific tokens. Given:

  • logits: The model’s output scores for each token.
  • input_ids: The actual token indices.

We use gather to efficiently retrieve the logits corresponding to each token.

3.2 Extracting Token Logits for Loss Calculation

python">import torch

# Simulated logits for batch_size=2, sequence_length=3, vocab_size=5
logits = torch.tensor([[[2.0, 1.0, 0.5, -1.0, 0.2], 
                         [0.1, -0.5, 2.2, 1.5, 0.0], 
                         [1.1, 3.5, 0.8, -0.2, -1.5]],

                        [[0.0, 2.3, -0.5, 1.0, 0.8], 
                         [-1.2, 1.7, 2.0, 0.3, -0.8], 
                         [2.5, -0.1, -1.2, 0.5, 3.0]]])

input_ids = torch.tensor([[0, 2, 1],  # Token indices
                          [1, 3, 4]])

# Extracting logits corresponding to input_ids
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
print(token_logits)

Output:

tensor([[ 2.0000,  2.2000,  3.5000],
        [ 2.3000,  0.3000,  3.0000]])

3.3 Explanation

  1. input_ids.unsqueeze(-1) converts shape (B, L)(B, L, 1), making it compatible with gather.
  2. gather(dim=-1, index=input_ids.unsqueeze(-1)) retrieves the logits corresponding to input_ids.
  3. squeeze(-1) removes the unnecessary last dimension.

This operation is efficient and memory-friendly compared to iterating through tokens manually.


4. gather vs. scatter

While gather retrieves values from an input tensor using an index, scatter does the opposite: it writes values to an output tensor at specific indices.

4.1 Using scatter to Modify a Tensor

python">import torch

# Initialize a 3x3 zero tensor
x = torch.zeros(3, 3)

# Define index positions
index = torch.tensor([[0, 2], 
                      [1, 1], 
                      [2, 0]])

# Define values to write
updates = torch.tensor([[5, 8], 
                         [3, 7], 
                         [6, 2]])

# Use scatter_ to update x
x.scatter_(dim=1, index=index, src=updates)
print(x)

Output:

tensor([[5., 0., 8.],
        [0., 7., 0.],
        [2., 0., 6.]])

4.2 Key Difference

  • gather: Extracts values from specific indices.
  • scatter: Writes values to specific indices.

5. Performance and Memory Efficiency

5.1 Why gather is Efficient?

  • Vectorized indexing: Instead of looping through individual indices, gather efficiently extracts multiple values in parallel.
  • Lower memory footprint: Since gather does not require additional tensor allocations, it is more memory-efficient than manually indexing with loops.
  • Optimized for GPU: PyTorch internally optimizes gather to run efficiently on CUDA devices.

5.2 Performance Benchmark

python">import time

x = torch.randn(1000, 1000)
index = torch.randint(0, 1000, (1000, 500))

start = time.time()
_ = x.gather(dim=1, index=index)
end = time.time()
print(f"gather execution time: {end - start:.6f} s")

Results (Example):

gather execution time: 0.002341 s

This is much faster than manually iterating over indices.


6. Summary

  • torch.gather(input, dim, index) efficiently extracts values from a tensor using an index tensor.
  • Common use cases:
    • Extracting token logits for NLP tasks (e.g., loss computation in Transformer models).
    • Indexing probability distributions in reinforcement learning.
    • Selecting specific elements from multi-dimensional tensors.
  • gather is memory-efficient, parallelized, and optimized for GPU acceleration.
  • Comparison with scatter:
    • gather extracts values from an input tensor.
    • scatter writes values into an output tensor.

🚀 Mastering torch.gather will help you write more efficient deep learning models! 🚀

后记

2025年2月21日19点14分于上海,在GPT4o大模型辅助下完成。


http://www.niftyadmin.cn/n/5862536.html

相关文章

软件集成测试的技术要求

文章目录 一、软件集成测试的概念二、测试对象三、测试目的四、进入条件五、测试内容六、测试环境七、测试实施方一、软件集成测试的概念 软件集成测试(Software Integration Testing),也称部件测试,一种旨在暴露接口以及集成组件间交互时存在的缺陷的测试。集成测试是灰盒…

【gitlab】认识 持续集成与部署

持续集成(CI)与持续部署(CD) 1. 什么是持续集成(CI)? 持续集成(Continuous Integration,CI)是一种软件开发实践,强调开发人员频繁地将代码提交到…

【Python爬虫(35)】解锁Python多进程爬虫:高效数据抓取秘籍

【Python爬虫】专栏简介:本专栏是 Python 爬虫领域的集大成之作,共 100 章节。从 Python 基础语法、爬虫入门知识讲起,深入探讨反爬虫、多线程、分布式等进阶技术。以大量实例为支撑,覆盖网页、图片、音频等各类数据爬取&#xff…

Windows 系统下,使用 PyTorch 的 DataLoader 时,如果 num_workers 参数设置为大于 0 的值,报错

在 Windows 系统下,使用 PyTorch 的 DataLoader 时,如果 num_workers 参数设置为大于 0 的值,可能会遇到以下错误: RuntimeError: An attempt has been made to start a new process before thecurrent process has finished its…

现场可以通过手机或者pad实时拍照上传到大屏幕的照片墙现场大屏电子照片墙功能

现场可以通过手机或者pad实时拍照上传到大屏幕的照片墙现场大屏电子照片墙功能,每个人都可以通过手机实时拍照上传到大屏幕上,同时还可以发布留言内容,屏幕上会同步滚动播放展示所有人的照片和留言。相比校传统的照片直播功能更加灵活方便,而…

什么是事务?并发事务引发的问题?什么是MVCC?

文章目录 什么是事务?并发事务引发的问题?什么是MVCC?1.事务的四大特性2.并发事务下产生的问题:脏读、不可重复读、幻读3.如何应对并发事务引发的问题?4.什么是MVCC?5.可见性规则?参考资料 什么…

两相四线步进电机的步距角为什么是1.8度

机缘 在CSDN查了好多文章,发现都是用公式来解释1.8的步距角(Q=360/MZ),因为转子是50齿,4拍一个循环,所以θ360度/(50x4)1.8度。估计第一次接触步进电机的什么…

SPRING10_SPRING的生命周期流程图

经过前面使用三大后置处理器BeanPostProcessor、BeanFactoryPostProcessor、InitializingBean对创建Bean流程中的干扰,梳理出SPRING的生命周期流程图如下