详解Megatron中的数据混合算法(BlendableDataset)

news2024/9/21 15:45:06

🧑‍💻 本文主要讲解Megatron早期版本中的数据混合算法。

目录

  • 1. 数据混合
  • 2. 源码解析
  • 3. 证明部分&讨论
  • 4. 进一步优化

1. 数据混合

在谈源码之前,我们有必要先了解一下Megatron中的数据混合思想。

给定 n n n 个数据集 D 1 , D 2 , ⋯   , D n \mathcal{D}_1,\mathcal{D}_2,\cdots,\mathcal{D}_n D1,D2,,Dn 和对应的 n n n 个权重 w 1 , w 2 , ⋯   , w n w_1,w_2,\cdots,w_n w1,w2,,wn,我们要按照这些权重去混合 n n n 个数据集,设混合后的数据集为 D \mathcal{D} D

Megatron假定:

  • ∣ D ∣ = ∑ i = 1 n ∣ D i ∣ |\mathcal{D}|=\sum_{i=1}^n|\mathcal{D}_i| D=i=1nDi。即混合后的数据集大小等于混合前的各数据集大小之和。
  • D \mathcal{D} D 中有 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i Dwi 个样本来自 D i \mathcal{D}_i Di

那如何确定 D \mathcal{D} D 中到底有多少个样本是来自 D i \mathcal{D}_i Di 的呢?一种最直观的做法是,计算 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i Dwi,然后进行取整,但这种操作无法保证所有取整后的 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i Dwi 相加起来恰好是 ∣ D ∣ |\mathcal{D}| D 如果总和大于 ∣ D ∣ |\mathcal{D}| D,说明某些数据集被过采样了,应当减少相应数据集的采样数;如果总和小于 ∣ D ∣ |\mathcal{D}| D,说明某些数据集被欠采样了,应当增加相应数据集的采样数。可问题是,如何确定这些被过采样/欠采样的数据集呢?显然我们需要一个更加公平的算法。

我们可以把获取数据集 D \mathcal{D} D 看作是一个采样过程:一开始有 n n n 个数据源 { D i } i = 1 n \{\mathcal{D}_i\}_{i=1}^n {Di}i=1n,每一轮迭代,我们需要先从这 n n n 个数据源中选出一个数据源 D i \mathcal{D}_i Di,然后再从这个数据源中选出一个样本 S \mathcal{S} S 由于每一轮迭代只会选出一个样本,因此 ∣ D ∣ |\mathcal{D}| D 轮迭代结束后,我们便得到了 ∣ D ∣ |\mathcal{D}| D 个样本,这些样本构成了混合后的数据集 D \mathcal{D} D

每一轮迭代都会产生两个信息:要选取的数据源 D i \mathcal{D}_i Di,要从 D i \mathcal{D}_i Di 中选取的样本。我们可以考虑构造两个整数序列 P , S \mathcal{P},\mathcal{S} P,S,它们的长度均为 ∣ D ∣ |\mathcal{D}| D,含义如下:

  • P j \mathcal{P}_j Pj 代表的是第 j j j 轮迭代时,选取的数据源的下标。例如 P 10 = 3 \mathcal{P}_{10}=3 P10=3 意味着第 10 10 10 轮迭代选取的数据源是 D 3 \mathcal{D}_3 D3
  • S j \mathcal{S}_j Sj 代表的是第 j j j 轮迭代时,从数据源 D P j \mathcal{D}_{\mathcal{P}_j} DPj 选取的样本的下标。

由以上定义知, ∀ j \forall j j,都有 1 ≤ P j ≤ n 1\leq \mathcal{P}_j\leq n 1Pjn 1 ≤ S j ≤ ∣ D P j  ⁣ ∣ 1\leq \mathcal{S}_j\leq|\mathcal{D}_{\mathcal{P}_j}\!| 1SjDPj(下标均从 1 1 1 开始)。

接下来的问题是,如何确定每一轮的 P j \mathcal{P}_j Pj S j \mathcal{S}_j Sj 呢?

先谈 P j \mathcal{P}_j Pj。因为是一个从 1 1 1 ∣ D ∣ |\mathcal{D}| D 的一个逐步采样过程,在第 j j j 轮迭代时,我们已经抽取了 j − 1 j-1 j1 个样本,接下来要确定第 j j j 个样本。根据Megatron的假定,在确定下来第 j j j 个样本后,这 j j j 个样本中应当有约 j ⋅ w i j\cdot w_i jwi 个样本是来自 D i \mathcal{D}_i Di 的。

考虑构造一个长度为 n n n 的序列 C \mathcal{C} C,该序列随着迭代不断更新。 C i \mathcal{C}_i Ci 代表当前已经从 D i \mathcal{D}_i Di 抽取了多少个样本。显然可知,第一轮迭代开始时,有 C i = 0 ,   i = 1 , 2 , ⋯   , n \mathcal{C}_i=0,\,i=1,2,\cdots,n Ci=0,i=1,2,,n。最后一轮迭代结束后,有 ∑ i = 1 n C i = ∣ D ∣ \sum_{i=1}^n\mathcal{C}_i=|\mathcal{D}| i=1nCi=D,并且

C i = { ∑ t = 1 j − 1 I ( P t = i ) , P j 确定前 ∑ t = 1 j I ( P t = i ) , P j 确定后 , ∀ i \mathcal{C}_i=\begin{cases} \sum_{t=1}^{j-1} I(\mathcal{P}_t=i),&\text{$\mathcal{P}_j$确定前} \\ \sum_{t=1}^{j} I(\mathcal{P}_t=i),&\text{$\mathcal{P}_j$确定后} \\ \end{cases},\quad \forall i Ci={t=1j1I(Pt=i),t=1jI(Pt=i),Pj确定前Pj确定后,i

回到对 P j \mathcal{P}_j Pj 的讨论中。假设在确定第 j j j 个样本前已经从 D i \mathcal{D}_i Di 中抽取了 C i \mathcal{C}_i Ci 个样本,在确定第 j j j 个样本后,诸 C i \mathcal{C}_i Ci有且仅有一个的值会增加 1 1 1,不妨记为 C k \mathcal{C}_k Ck,这个过程可以形容为

[ C 1 , ⋯   , C k , ⋯   , C n ] ⏟ 第 j 轮迭代开始时 → [ C 1 , ⋯   , C k + 1 , ⋯   , C n ] ⏟ 第 j 轮迭代结束时 [ j ⋅ w 1 , j ⋅ w 2 , ⋯   , j ⋅ w n ] ⏟ 理论值 \underbrace{[\mathcal{C}_1,\cdots,\mathcal{C}_k,\cdots,\mathcal{C}_n]}_{第j轮迭代开始时}\to\underbrace{[\mathcal{C}_1,\cdots,\mathcal{C}_{k}+1,\cdots,\mathcal{C}_n]}_{第j轮迭代结束时}\qquad \underbrace{[j\cdot w_1,j\cdot w_2,\cdots,j\cdot w_n]}_{理论值} j轮迭代开始时 [C1,,Ck,,Cn]j轮迭代结束时 [C1,,Ck+1,,Cn]理论值 [jw1,jw2,,jwn]

我们期望第 j j j 轮迭代结束时,诸 C i \mathcal{C}_i Ci 应当尽可能地接近理论值(在MSE下)。由于只能让其中一个 C k \mathcal{C}_k Ck 自增 1 1 1,显然有 k = arg max ⁡ i ( j ⋅ w i − C i ) k=\argmax_i(j\cdot w_i-\mathcal{C}_i) k=argmaxi(jwiCi)

再谈 S j \mathcal{S}_j Sj。在确定了数据源是 D k \mathcal{D}_k Dk 后,为了避免重复,我们应当做到不放回、随机地从中采样。如何做到这两点呢?我们可以在一开始就对 n n n 个数据源进行打乱,然后在采样的时候只需要从前往后进行,就可以做到以上两点。注意到 C i \mathcal{C}_i Ci 的值是从 0 0 0 开始,以步长为 1 1 1 依次递增,所以我们可以用每次更新完的 C i \mathcal{C}_i Ci 赋值给相应的 S j \mathcal{S}_j Sj,即 S j = 第 j 轮迭代结束时的 C i \mathcal{S}_j=第j轮迭代结束时的\mathcal{C}_i Sj=j轮迭代结束时的Ci

由此我们可以得到整个算法的伪代码:

2. 源码解析

Python部分:

class BlendableDataset(torch.utils.data.Dataset):
    def __init__(self, datasets, weights):
        self.datasets = datasets
        num_datasets = len(datasets)
        assert num_datasets == len(weights), "The number of datasets and weights must match."

        self.size = sum(len(dataset) for dataset in self.datasets)

        # Normalize weights.
        weights = np.array(weights, dtype=np.float64)
        sum_weights = np.sum(weights)
        assert sum_weights > 0.0, "Sum of weights must be positive."
        weights /= sum_weights

        # Build indices.
        start_time = time.time()
        assert num_datasets < 255, "Number of datasets must be less than 255."
        self.dataset_index = np.zeros(self.size, dtype=np.uint8)
        self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)

        helpers.build_blending_indices(
            self.dataset_index,
            self.dataset_sample_index,
            weights,
            num_datasets,
            self.size,
            torch.distributed.get_rank() == 0,
        )
        print_rank_0(f'> elapsed time for building blendable dataset indices: '
                     f'{time.time() - start_time:.2f} sec')

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        dataset_idx = self.dataset_index[idx]
        sample_idx = self.dataset_sample_index[idx]
        return {
            "dataset_idx": dataset_idx,
            **self.datasets[dataset_idx][sample_idx],
        }

C++部分:

void build_blending_indices(
    py::array_t<uint8_t> &dataset_index,
    py::array_t<int64_t> &dataset_sample_index,
    const py::array_t<double> &weights,
    const int32_t num_datasets,
    const int64_t size,
    const bool verbose
) {
    /* Given multiple datasets and a weighting array, build samples
       such that it follows those weights. */

    if (verbose) {
        std::cout << "> building indices for blendable datasets ..." << std::endl;
    }

    // Get the pointer access without the checks.
    auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
    auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
    auto weights_ptr = weights.unchecked<1>();

    // Initialize buffer for number of samples used for each dataset.
    int64_t current_samples[num_datasets];
    for (int64_t i = 0; i < num_datasets; ++i) {
        current_samples[i] = 0;
    }

    // For each sample:
    for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {

        // Determine where the max error in sampling is happening.
        auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
        int64_t max_error_index = 0;
        double max_error = weights_ptr[0] * sample_idx_double - static_cast<double>(current_samples[0]);
        for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
            double error = weights_ptr[dataset_idx] * sample_idx_double - static_cast<double>(current_samples[dataset_idx]);
            if (error > max_error) {
                max_error = error;
                max_error_index = dataset_idx;
            }
        }

        // Populate the indices.
        dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
        dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];

        // Update the total samples.
        current_samples[max_error_index] += 1;
    }

    // Print info
    if (verbose) {
        std::cout << " > sample ratios:" << std::endl;
        for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
            auto ratio = static_cast<double>(current_samples[dataset_idx]) / static_cast<double>(size);
            std::cout << "   dataset " << dataset_idx << ", input: " << weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
        }
    }
}

具体的算法实现是在C++的函数中,我们先来看Python部分。

self.size 实际上就是 ∣ D ∣ |\mathcal{D}| D,即混合后的数据集大小(从后面的 __len__ 也能看出)。在构造函数中,首先会对 weights 进行归一化,然后声明 P , S \mathcal{P},\mathcal{S} P,S 两个数组。注意 self.dataset_index 实际上就是 P \mathcal{P} Pself.dataset_sample_index 实际上就是 S \mathcal{S} S。由于 P \mathcal{P} P 的数据类型是 uint8,这表明其中元素的范围是 [ 0 , 2 8 − 1 = 255 ] [0,2^8-1=255] [0,281=255],故 P \mathcal{P} P 最多能表示 256 256 256 个数据集,而源码中规定了参与混合的数据集个数必须严格少于 255 255 255博主不是很懂这一点,看懂的小伙伴可以在评论区留言)。

再来看C++部分。前五个形参分别是 P , S , { w i } i , n , ∣ D ∣ \mathcal{P},\mathcal{S},\{w_i\}_i,n,|\mathcal{D}| P,S,{wi}i,n,D

C \mathcal{C} C 数组会在该函数中进行声明并初始化。随后的两个嵌套 for 循环则是整个算法的核心流程,注意到这里的实现中,sample_idx(即 j j j)是从 0 0 0 开始的,而算法伪代码中的 j j j 是从 1 1 1 开始的,所以一开始要执行 j = max ⁡ ( j , 1 ) j=\max(j,1) j=max(j,1) 以确保 j j j 至少是 1 1 1但这样做有一个弊端就是前两轮的循环里, j j j 的值是相同的,和我们期望的每一轮里 j j j 值不同相违背,这是源码中的一个缺陷,实际上应该计算 ( j + 1 ) ⋅ w i − C i (j+1)\cdot w_i-\mathcal{C}_i (j+1)wiCi)。内层循环中的 error 实际上就是 j ⋅ w i − C i j\cdot w_i-\mathcal{C}_i jwiCi。此外,由于 j j j 是从 0 0 0 开始的,所以 C P j \mathcal{C}_{\mathcal{P}_j} CPj 的更新要放到最后执行。

一言以蔽之, j j j 1 1 1 开始,更新顺序为 P → C → S \mathcal{P}\to\mathcal{C}\to\mathcal{S} PCS j j j 0 0 0 开始,更新顺序为 P → S → C \mathcal{P}\to\mathcal{S}\to\mathcal{C} PSC

得到了 P , S \mathcal{P},\mathcal{S} P,S 数组后,我们便可得到混合后的数据集 D \mathcal{D} D

D j = D P j [ S j ] , j = 1 , 2 , ⋯   , ∣ D ∣ \mathcal{D}_j=\mathcal{D}_{\mathcal{P}_j}[\mathcal{S}_j],\quad j=1,2,\cdots,|\mathcal{D}| Dj=DPj[Sj],j=1,2,,D

其中 D i [ j ] \mathcal{D}_i[j] Di[j] 代表数据集 D i \mathcal{D}_i Di 中的第 j j j 个样本。

回到Python部分,__getitem__ 中传入的 idx 实际上就是 j j jself.datasets[dataset_idx][sample_idx] 实际上就是上述的 D P j [ S j ] \mathcal{D}_{\mathcal{P}_j}[\mathcal{S}_j] DPj[Sj]

3. 证明部分&讨论


Prop    1.   \text{Prop} \;1.\, Prop1. 每一轮循环开始时所有误差加和为 1 1 1,即 ∑ i = 1 n e i = 1 \sum_{i=1}^n e_i=1 i=1nei=1,其中 e i ≜ j ⋅ w i − C i e_i\triangleq j\cdot w_i-\mathcal{C}_i eijwiCi

P r o o f .    Proof.\; Proof. 注意到第 j j j 轮循环开始时,此时一共只采样了 j − 1 j-1 j1 个样本,所以 ∑ i = 1 n C i = j − 1 \sum_{i=1}^n\mathcal{C}_i=j-1 i=1nCi=j1,从而

∑ i = 1 n e i = ∑ i = 1 n ( j ⋅ w i − C i ) = j ⋅ ∑ i = 1 n w i − ∑ i = 1 n C i = j − ∑ i = 1 n C i = j − ( j − 1 ) = 1 \sum_{i=1}^n e_i=\sum_{i=1}^n (j\cdot w_i-\mathcal{C}_i)=j\cdot\sum_{i=1}^n w_i-\sum_{i=1}^n\mathcal{C}_i=j-\sum_{i=1}^n\mathcal{C}_i=j-(j-1)=1 i=1nei=i=1n(jwiCi)=ji=1nwii=1nCi=ji=1nCi=j(j1)=1

进一步可知,每一轮循环结束时所有误差加和为 0 0 0


Prop    2.   \text{Prop} \;2.\, Prop2. 假定下标从 1 1 1 开始,且 n = 2 n=2 n=2(即只有两个数据源)。若 e 1 ≥ 0.5 e_1\geq 0.5 e10.5,则 P j = 1 \mathcal{P}_j=1 Pj=1,否则 P j = 2 \mathcal{P}_j=2 Pj=2

P r o o f .    Proof.\; Proof. e 1 > 0.5 e_1>0.5 e1>0.5 的情况显然。当 e 1 = e 2 = 0.5 e_1=e_2=0.5 e1=e2=0.5 时, arg max ⁡ \argmax argmax 会优先挑选下标最小的,故此时 P j \mathcal{P}_j Pj 仍是 1 1 1


Prop    3.   \text{Prop} \;3.\, Prop3. 假定下标从 1 1 1 开始。可能存在一组 { D i } i \{\mathcal{D}_i\}_i {Di}i { w i } i \{w_i\}_i {wi}i,使得经由上述算法得到的 P , S \mathcal{P},\mathcal{S} P,S 数组, ∃   j ,   s.t.     S j > ∣ D P j ∣ \exists \,j,\,\text{s.t.}\;\,\mathcal{S}_j>|\mathcal{D}_{\mathcal{P}_j}| j,s.t.Sj>DPj,意味着 __getitem__ 会出现下标越界的错误。

P r o o f .    Proof.\; Proof. 构造特殊情形即可。令 n = 2 n=2 n=2 ∣ D 1 ∣ = ∣ D 2 ∣ = 2 |\mathcal{D}_1|=|\mathcal{D}_2|=2 D1=D2=2 w 1 = 0.1 ,   w 2 = 0.9 w_1=0.1,\,w_2=0.9 w1=0.1,w2=0.9

∣ D 1 ∣ + ∣ D 2 ∣ = 4 |\mathcal{D}_1|+|\mathcal{D}_2|=4 D1+D2=4 可知,总共会有 4 4 4 轮循环。且理应有 1 ≤ P j , S j ≤ 2 ,   j = 1 , 2 , 3 , 4 1\leq \mathcal{P}_j,\mathcal{S}_j\leq 2,\,j=1,2,3,4 1Pj,Sj2,j=1,2,3,4

利用 Prop    2 \text{Prop} \;2 Prop2 快速计算:

  • 第一轮循环,计算误差 e 1 = 1 ⋅ w 1 − 0 = 0.1 < 0.5 e_1=1\cdot w_1-0=0.1<0.5 e1=1w10=0.1<0.5,故 P 1 = 2 \mathcal{P}_1=2 P1=2 C = { 0 , 1 } \mathcal{C}=\{0,1\} C={0,1} S 1 = C 2 = 1 \mathcal{S}_1=\mathcal{C}_2=1 S1=C2=1

  • 第二轮循环,计算误差 e 1 = 2 ⋅ w 1 − 0 = 0.2 < 0.5 e_1=2\cdot w_1-0=0.2<0.5 e1=2w10=0.2<0.5,故 P 2 = 2 \mathcal{P}_2=2 P2=2 C = { 0 , 2 } \mathcal{C}=\{0,2\} C={0,2} S 2 = C 2 = 2 \mathcal{S}_2=\mathcal{C}_2=2 S2=C2=2

  • 第三轮循环,计算误差 e 1 = 3 ⋅ w 1 − 0 = 0.3 < 0.5 e_1=3\cdot w_1-0=0.3<0.5 e1=3w10=0.3<0.5,故 P 3 = 2 \mathcal{P}_3=2 P3=2 C = { 0 , 3 } \mathcal{C}=\{0,3\} C={0,3} S 3 = C 2 = 3 \mathcal{S}_3=\mathcal{C}_2=3 S3=C2=3

  • 第四轮循环,计算误差 e 1 = 4 ⋅ w 1 − 0 = 0.4 < 0.5 e_1=4\cdot w_1-0=0.4<0.5 e1=4w10=0.4<0.5,故 P 4 = 2 \mathcal{P}_4=2 P4=2 C = { 0 , 4 } \mathcal{C}=\{0,4\} C={0,4} S 4 = C 2 = 4 \mathcal{S}_4=\mathcal{C}_2=4 S4=C2=4

由上可知 j = 3 , 4 j=3,4 j=3,4 满足要求。


Prop    4.   \text{Prop} \;4.\, Prop4. 在MSE下,要使 [ C 1 , ⋯   , C k + 1 , ⋯   , C n ] [\mathcal{C}_1,\cdots,\mathcal{C}_{k}+1,\cdots,\mathcal{C}_n] [C1,,Ck+1,,Cn] 尽可能接近 [ j ⋅ w 1 , j ⋅ w 2 , ⋯   , j ⋅ w n ] [j\cdot w_1,j\cdot w_2,\cdots,j\cdot w_n] [jw1,jw2,,jwn],应当有 k = arg max ⁡ i ( j ⋅ w i − C i ) k=\argmax_i(j\cdot w_i-\mathcal{C}_i) k=argmaxi(jwiCi)

P r o o f .    Proof.\; Proof. 注意到

Δ MSE = MSE a f t e r − MSE b e f o r e = 1 n [ ( j ⋅ w k − C k − 1 ) 2 − ( j ⋅ w k − C k ) 2 ] = 1 n [ 1 − 2 ( j ⋅ w k − C k ) ] \begin{aligned} \Delta \text{MSE}=\text{MSE}_{after}-\text{MSE}_{before}&=\frac1n[(j\cdot w_k-\mathcal{C}_k-1)^2-(j\cdot w_k-\mathcal{C}_k)^2] \\ &=\frac1n[1-2(j\cdot w_k-\mathcal{C}_k)] \end{aligned} ΔMSE=MSEafterMSEbefore=n1[(jwkCk1)2(jwkCk)2]=n1[12(jwkCk)]

由上式可知,要使 Δ MSE \Delta \text{MSE} ΔMSE 越小,应使 j ⋅ w k − C k j\cdot w_k-\mathcal{C}_k jwkCk 越大,故 k = arg max ⁡ i ( j ⋅ w i − C i ) k=\argmax_i(j\cdot w_i-\mathcal{C}_i) k=argmaxi(jwiCi)


Prop    5.   \text{Prop} \;5.\, Prop5. 假定下标从 1 1 1 开始。若 w 1 = w 2 = ⋯ = w n = 1 / n w_1=w_2=\cdots=w_n=1/n w1=w2==wn=1/n,令 ∣ D ∣ = q ⋅ n + r |\mathcal{D}|=q\cdot n+r D=qn+r,其中 q q q 是商, r r r 是余数,则有

P = [ 1 , 2 , ⋯   , n ] ∗ q + [ 1 , 2 , ⋯   , r ] S = [ 1 , 1 , ⋯   , 1 ] + [ 2 , 2 , ⋯   , 2 ] + ⋯ + [ q , q , ⋯   , q ] ⏟ 每个列表的长度均为 n + [ q + 1 , q + 1 , ⋯   , q + 1 ] ⏟ 长度为 r C = [ q + 1 , q + 1 , ⋯   , q + 1 ⏟ r 个 , q , q , ⋯   , q ⏟ n − r 个 ] \begin{aligned} \mathcal{P}&=[1,2,\cdots,n] * q + [1,2,\cdots,r] \\ \mathcal{S}&=\underbrace{[1,1,\cdots,1] + [2,2,\cdots,2] + \cdots+[q,q,\cdots,q]}_{每个列表的长度均为 n}+\underbrace{[q+1,q+1,\cdots,q+1]}_{长度为r} \\ \mathcal{C}&=[\underbrace{q+1,q+1,\cdots,q+1}_{r个},\underbrace{q,q,\cdots,q}_{n-r个}] \end{aligned} PSC=[1,2,,n]q+[1,2,,r]=每个列表的长度均为n [1,1,,1]+[2,2,,2]++[q,q,,q]+长度为r [q+1,q+1,,q+1]=[r q+1,q+1,,q+1,nr q,q,,q]

上述的 ∗ * + + + 均是列表运算符

P r o o f .    Proof.\; Proof. 证明留给读者。


讨论:

Prop    3.   \text{Prop} \;3.\, Prop3. 中提到了可能会出现下标越界的错误,为了避免这个错误,我们可以在得到 P , S \mathcal{P},\mathcal{S} P,S 数组后,对 S \mathcal{S} S 进行更新(假定下标从 1 1 1 开始):

S j = S j    mod    ( ∣ D P j ∣ + 1 ) , j = 1 , 2 , ⋯   , ∣ D ∣ \mathcal{S}_j=\mathcal{S}_j\;\text{mod}\; (|\mathcal{D}_{\mathcal{P}_j}|+1),\quad j=1,2,\cdots,|\mathcal{D}| Sj=Sjmod(DPj+1),j=1,2,,D

例如某个数据集是 [ 1 , 2 , 3 , 4 , 5 ] [1,2,3,4,5] [1,2,3,4,5],如果要从这个数据集采样 8 8 8 个样本,则原先的算法会在采样第 6 6 6 个样本时抛出下标越界错误,修正后的算法的采样结果为 [ 1 , 2 , 3 , 4 , 5 , 1 , 2 , 3 ] [1,2,3,4,5,1,2,3] [1,2,3,4,5,1,2,3]

为什么Megatron源码里没有规避这个错误但在使用的过程中却好像并没有遇到bug呢?注意到 self.datasets[dataset_idx] 实际上指向的是 megatron/data/gpt_dataset.py 中的 GPTDataset 类,在混合数据集场景下,Megatron会预先根据权重计算每个数据集所需要的样本数,然后根据这个样本数构建 GPTDataset,而非根据document数去构建。所以,即使对于两个完全相同的数据集,当赋予它们的权重不同时,所得到的 GPTDataset 的长度也不同,这一点可以通过向 BlendableDataset 源码中加入以下代码来验证:

for i, dataset in enumerate(self.datasets):
    print(f"dataset {i}: {len(dataset)}")

因为 GPTDataset 的长度已经根据权重做出了相应的调整,所以绝大部分时候是不会出现bug的,但我们依然可以构造极端情形来触发bug。

考虑在训练脚本中提供两个完全相同的路径,但却赋予它们不同的权重,如下:

--train-data-path 0.001 /path/to/your/data_text_document 0.999 /path/to/your/data_text_document

然后在 BlendableDataset 源码中的 __getitem__ 方法中固定索引,即:

def __getitem__(self, idx):
	idx = self.size - 1  # 意味着我们总是取BlendableDataset的最后一个样本
    dataset_idx = self.dataset_index[idx]
    sample_idx = self.dataset_sample_index[idx]
    return {
        "dataset_idx": dataset_idx,
        **self.datasets[dataset_idx][sample_idx],
    }

这样就可以稳定的触发下标越界的bug。

📝 注意到从 GPTDataset 中取出来的是sample,所以Megatron的混合算法实际上是以sample为单位的,而非以document为单位。

4. 进一步优化

根据 Prop    3.   \text{Prop} \;3.\, Prop3. Prop    5.   \text{Prop} \;5.\, Prop5. 以及其他细节,我们有以下几个优化方向:

  • 修复可能会出现的下标越界错误(可通过取余来实现)。
  • 在等权重情形下加速混合(利用 numpy)。
  • 支持更多数据集进行混合(修改 uint8 为其他类型)。

假设相应的接口名为 make_blendable_dataset,它接收两个形参:datasetsweights。前者是一个二维列表,包含了要进行混合的数据集(每个数据集是一个一维列表),后者是一个一维列表,包含了每个数据集的权重。

使用Python进行实现:

from typing import List, Any, Union
import numpy as np
import random
from tqdm import tqdm

def make_blendable_dataset(datasets: List[List[Any]], weights: List[Union[float, int]]) -> List[Any]:
    num_datasets = len(datasets)
    assert num_datasets == len(weights), "The number of datasets must match the number of weights."

    # Shuffle
    size = 0
    for dataset in datasets:
        size += len(dataset)
        random.shuffle(dataset)

    # Normalize weights
    weights = np.array(weights, dtype=np.float64)
    assert np.all(weights > 0), "All weights must be positive."
    weights /= weights.sum()

    # Determine if all weights are equal
    if np.ptp(weights) < 1e-5:
        q, r = divmod(size, num_datasets)
        dataset_index = np.concatenate([
            np.tile(np.arange(num_datasets, dtype=np.int16), q),
            np.arange(r, dtype=np.int16)
        ])
        dataset_sample_index = np.concatenate([
            np.repeat(np.arange(q, dtype=np.int64), num_datasets),
            np.full(r, q, dtype=np.int64)
        ])
        current_samples = np.full(num_datasets, q, dtype=np.int64)
        current_samples[:r] += 1
    else:
        dataset_index = np.zeros(size, dtype=np.int16)
        dataset_sample_index = np.zeros(size, dtype=np.int64)
        current_samples = np.zeros(num_datasets, dtype=np.int64)

        for sample_idx in tqdm(range(size), desc="Calculating error"):
            errors = weights * (sample_idx + 1) - current_samples
            max_error_index = np.argmax(errors)
            dataset_index[sample_idx] = max_error_index
            dataset_sample_index[sample_idx] = current_samples[max_error_index]
            current_samples[max_error_index] += 1

    print(f"Ratios:")
    for i in range(num_datasets):
        print(f"input: {weights[i]}, achieved: {current_samples[i] / size}")

    # Blend
    res = []
    for i in tqdm(range(size), desc="Blending"):
        dataset_idx = dataset_index[i]
        sample_idx = dataset_sample_index[i] % len(datasets[dataset_idx])
        res.append(datasets[dataset_idx][sample_idx])

    return res

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1468596.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【加密周报】中美非“出手”压制比特币?以太坊飙涨震醒沉睡8年巨鲸!“AI热潮”刺激相关代币集体拉涨!

回顾本周&#xff0c;中美非三国出现压制加密货币行动&#xff0c;比特币空头暂获胜利&#xff0c;币价最低触及50521美元。以太币表现跑赢比特币&#xff0c;牛市回归下震醒沉睡8年的ICO巨鲸。美国人工智能(AI)热潮下&#xff0c;刺激世界币(Worldcoin)突破历史新高&#xff0…

美团外卖流程解析:便捷、高效、安全的美食配送

美团外卖作为中国最大的外卖平台之一&#xff0c;提供了丰富多样的美食选择&#xff0c;并通过高效的配送服务将美食送到用户手中。本文将深入探讨美团外卖的流程&#xff0c;从下单到送达&#xff0c;揭秘背后的便捷、高效、安全的运营体系。 1. 下单与支付 美团外卖的下单过…

成功解决No module named ‘sklearn‘(ModuleNotFoundError)

成功解决No module named ‘sklearn’(ModuleNotFoundError) &#x1f308; 个人主页&#xff1a;高斯小哥 &#x1f525; 高质量专栏&#xff1a;Matplotlib之旅&#xff1a;零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程 &#x1f448; 希望得到您…

【前端素材】推荐优质后台管理系统Qovex平台模板(附源码)

一、需求分析 1、定义 后台管理系统是一种用于管理和监控网站、应用程序或系统的在线工具。它通常是通过网页界面进行访问和操作&#xff0c;用于管理网站内容、用户权限、数据分析等。后台管理系统是网站或应用程序的控制中心&#xff0c;管理员可以通过后台系统进行各种管理…

MATLAB环境下使用滤波自适应算法进行主动噪声消除

滤波作为自适应滤波系统中信号处理等研究领域的重要组成模块&#xff0c;主要被应用于信道均衡、系统识别、声学回波抵消、生物医学、雷达、波束形成等模块。在自适应滤波系统中&#xff0c;当信息数据统计方面的相关先验知识是已知的情况下&#xff0c;滤波器才能处理相关的输…

vulnhub靶场之driftingblues-1

一.环境搭建 1.靶场描述 get flags difficulty: easy about vm: tested and exported from virtualbox. dhcp and nested vtx/amdv enabled. you can contact me by email (it should be on my profile) for troubleshooting or questions. 2.靶场下载 https://www.vulnhub.…

新版Java面试专题视频教程——虚拟机篇①

新版Java面试专题视频教程——虚拟机篇① 1 JVM组成1.1 JVM由那些部分组成&#xff0c;运行流程是什么&#xff1f;1.2 什么是程序计数器&#xff1f;1.3 你能给我详细的介绍Java堆吗?1.3.1 1.7和1.8 堆的区别1.3.2 元空间(MetaSpace)介绍 1.4 什么是虚拟机栈1.4.1 堆和栈的区…

二进制部署k8s集群之cni网络插件

目录 k8s的三种网络模式 pod内容器之间的通信 同一个node节点中pod之间通信 不同的node节点的pod之间通信 flannel网络插件 flannel的三种工作方式 VxLAN host-GW UDP Flannel udp 模式 Flannel VXLAN 模式 flannel插件的三大模式的总结 calico网络插件 k8s 组网…

Python urllib、requests、HTMLParser

HTTP协议 HTTP 协议&#xff1a;一般指HTTP(超文本传输)协议。 HTTP是为Web浏览器和Web服务器之间的通信而设计的&#xff0c;基于TCP/IP通信协议嘞传递数据。 HTTP消息结构 客户端请求消息 客户端发送一个HTTP请求到服务器的请求消息包括以下格式 请求行(request line)请求…

排列组合(附10道题目及解析)

一、认识C,P,A: A.排列 A(x,y)(x!)/[(x-y)!]x(x-1)...(x-y1) P.排列 P(x,y)A(x,y) C.组合 C(x,y)A(x,y)(y!)x(x-1)...(x-y1)/(y!)(x!){(y!)[(x-y)!]} 例&#xff1a;C(5,2)(54)(21)10 例&#xff1a;A(5,2)5420 例&#xff1a;P(5,2)5420 二、特殊情况 C…

Linux: yum查看、安装、删除软件包

Linux: yum安装删除软件包 yum查找软件包yum 安装软件yum 卸载软件 yum查找软件包 在Linux中提供一条yum list指令用于查看当前系统中已存在和可以安装的软件包&#xff0c;但由于软件包的数量过多&#xff0c;所以我们可以通过grep指令来过滤出我们需要查找的软件包。 yum l…

欧拉函数性质和快速幂算法及python实现

目录 欧拉函数 快速幂算法 快速模幂算法 欧拉函数 两个不同的正整数a,b&#xff0c;若gcd(a,b)1,则a和b互质&#xff0c;1与任何正整数都互质 欧拉函数的意义 φ(n) 表示小于或等于正整数n的所有正整数中与n互质的数的个数 如φ(32) 16&#xff0c;即小于32的数中有16个…

【大厂AI课学习笔记】【2.2机器学习开发任务实例】(8)模型训练

好吧&#xff0c;搞了半天&#xff0c;都是围绕数据在干活&#xff0c;这也就验证了&#xff0c;我们说的&#xff0c;数据准备等工作&#xff0c;要占到机器学习项目一半以上的工作量和时间。而且数据决定了模型的天花板&#xff0c;算法只是去达到上限。 我们今天来学习模型…

优秀实践| 运营商核心系统国产数据库迁移实践

作者介绍 陕西移动信息技术部 张云川 陕西移动信息技术部 王永强 新炬网络中北三部 张建 随着国家对自主可控战略的深入推进&#xff0c;笔者所在省份聚焦数据库国产化替换&#xff0c;全面加速数据库国产化替换进程。以核心系统带动周边系统&#xff0c;成功在能力运营中…

EasyRecovery易恢复16软件最新中文版本下载

一、主要功能与特点 EasyRecovery易恢复16是一款专业的数据恢复软件&#xff0c;其主要功能和特点包括&#xff1a; 全面的数据恢复&#xff1a;支持恢复因误删除、格式化、系统崩溃、病毒攻击等多种原因丢失的数据。深度扫描技术&#xff1a;对于严重损坏或格式化的存储设备…

【坑】Spring Boot整合MyBatis,一级缓存失效

一、Spring Boot整合MyBatis&#xff0c;一级缓存失效 1.1、概述 MyBatis一级缓存的作用域是同一个SqlSession&#xff0c;在同一个SqlSession中执行两次相同的查询&#xff0c;第一次执行完毕后&#xff0c;Mybatis会将查询到的数据缓存起来&#xff08;缓存到内存中&#xf…

Flutter Engine 编译

本地环境 Flutter 开发基本环境配置&#xff0c;SDK【】 MAC. M2芯片 git工具 python环境[MAC自带] xcode Chromium depot_tools depot_tools 是调试 Flutter 引擎的必备工具包&#xff0c;包含了 gclient、gn 和 ninja 等工具&#xff0c;这些在下面会用到&#xff01;…

洛谷P5738 歌唱比赛 题解

#题外话&#xff08;第37篇题解&#xff09;&#xff08;本题为普及-难度&#xff09; #先看题目 题目链接https://www.luogu.com.cn/problem/P5738 #思路&#xff08;好像和P5726-打分有点像&#xff0c;参考一下&#xff09; #代码 #include <bits/stdc.h> using na…

谈谈智能机器人智能电表运维管理体系在铁路牵引变电所建设构想

关键词&#xff1a;智能移动机器人状态检测分析决策智能化管理 0前言 牵引变电所是铁路机车供电的来源&#xff0c;是保障铁路运行的基本环节&#xff0c;经过多年的改造和发展&#xff0c;牵引变电所在我国大部分的铁路线路中已经得到了合理的覆盖&#xff0c;无人值班牵引变…

【ArcGIS微课1000例】0104:二位面状数据转三维多面体(建筑物按高度拉伸)

文章目录 一、加载数据二、添加高度字段三、三维拉伸显示四、生成三维体数据五、注意事项一、加载数据 打开ArcScene,加载配套实验数据(0104.rar中的二维建筑物矢量数据,订阅专栏,获取专栏所有文章阅读权限及配套数据),如下图所示: 二、添加高度字段 本实验将二维数据…