使用 Rust 实现 RMS Normalization 算法——InifiniTensor夏季训练营专业项目三作业二

使用 Rust 实现 RMS Normalization 算法——InifiniTensor夏季训练营专业项目三作业二

1. 函数签名与参数

pub fn rms_norm(y: &mut Tensor<f32>, x: &Tensor<f32>, w: &Tensor<f32>, epsilon: f32) {
    ...
}
  • y: 输出张量(可变),用于存储归一化后的结果。
  • x: 输入张量,待处理的数据。
  • w: 权重张量,应用于每个元素的缩放。
  • epsilon: 小的正数,用于防止 RMS 值为零时的数值不稳定问题。

2. 参数校验

assert_eq!(x.shape(), y.shape(), "Input and output tensors must have the same shape.");
assert_eq!(x.shape().last(), w.shape().last(), "Weight tensor must have the same size as the last dimension of input tensor.");
  • 相同形状校验:确保输入和输出张量 xy 的形状相同,确保每个输入元素有对应的输出位置。
  • 权重校验:确保权重张量 w 的大小与输入张量 x 的最后一个维度的大小相同。这是因为 RMS Normalization 是在最后一个维度上进行的,而权重应用于每个元素。

3. 准备批次处理

let num_elements = x.length();
let last_dim_size = *x.shape().last().clone().unwrap(); // 得到最后一个维度的大小,即长度 n
let num_batch = num_elements / last_dim_size; // 计算批次数量
  • 元素数量num_elements 表示张量 x 中的总元素数量。
  • 最后维度大小last_dim_sizen,表示最后一个维度的大小,即每个子向量的长度。
  • 批次数量num_batch 计算出一共有多少个批次,每个批次为一个长度为 n 的向量。

4. 遍历批次进行计算

for batch_idx in 0..num_batch {
    let start_idx = batch_idx * last_dim_size;
    let end_idx = start_idx + last_dim_size;

    // 提取当前批次的数据并计算平方和
    let batch_data = &x.data()[start_idx..end_idx];
    let sum_of_sequence: f32 = batch_data.iter().map(|&value| value * value).sum();
  • 批次索引:通过 batch_idx 遍历每一个批次。
  • 批次边界start_idxend_idx 确定了当前批次在数据中的起始和结束索引。
  • 提取批次数据batch_data 是当前批次的数据切片。
  • 平方和计算sum_of_sequence 计算当前批次的所有元素的平方和。

5. 计算 RMS 值

    let rms = (sum_of_sequence / last_dim_size as f32 + epsilon).sqrt();

RMS 计算:RMS 值通过计算平方和的平均值,然后取平方根,再加上一个小的正数 epsilon 来稳定计算。

6. 执行归一化并应用权重

    for i in 0..last_dim_size {
        let weight = &w.data()[i]; // 权重应用于每个元素
        unsafe { y.data_mut()[start_idx + i] = batch_data[i] * weight / rms; }
    }
}
  • 遍历批次元素:在当前批次内,遍历每个元素进行归一化。
  • 权重应用weight 作为缩放因子应用于每个元素。
  • 归一化计算:更新输出张量 y 中每个元素为:
  • 使用 unsafe:直接访问和修改输出张量的内部数据。这种操作要求开发者保证访问的安全性。

7. 函数调用示例

下面是如何使用 rms_norm 函数的一个简单示例:

fn main() {
    // 初始化示例张量
    let input_data = vec![0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
    let weight_data = vec![1.0, 1.0, 1.0];
    let shape = vec![2, 3]; // 示例张量形状为 2x3

    let x = Tensor::new(input_data.clone(), shape.clone());
    let w = Tensor::new(weight_data, vec![3]);
    let mut y = Tensor::new(vec![0.0; input_data.len()], shape);

    // 调用 rms_norm 函数
    rms_norm(&mut y, &x, &w, 1e-8);

    // 输出归一化结果
    println!("Normalized Tensor: {:?}", y.data());
}

解释总结

  1. 输入与输出准备:确保输入、输出和权重张量的形状匹配。
  2. 批次计算:针对张量的最后一维度,每个子向量作为一个批次进行处理。
  3. RMS 计算:使用均方根公式计算每个批次的 RMS 值。
  4. 归一化与权重应用:调整每个元素的值以实现归一化,并应用给定的权重。
  5. 结果存储:将计算结果存储在输出张量中。

完整代码

pub fn rms_norm(y: &mut Tensor<f32>, x: &Tensor<f32>, w: &Tensor<f32>, epsilon: f32) {
    assert_eq!(x.shape(), y.shape(), "Input and output tensors must have the same shape.");
    assert_eq!(x.shape().last(), w.shape().last(), "Weight tensor must have the same size as the last dimension of input tensor.");

    let num_elements = x.length();
    let last_dim_size  = *x.shape().last().clone().unwrap();//也就是提示当中提到的n
    let num_batch = num_elements/last_dim_size;//得到究竟有多少批次
    //“即张量 X(...,n)和 Y(...,n)都是由若干个长度为n的向量xi,yi组成的”
    for batch_idx in 0..num_batch{
        let start_idx = batch_idx*last_dim_size;
        let end_idx = start_idx+last_dim_size;

        //提前当前批次的所有数据,并计算出平方和
        let batch_data = &x.data()[start_idx..end_idx];
        let sum_of_sequence:f32 = batch_data.iter().map(|&value| value*value).sum();

        //计算本批次的RMS
        let rms = (sum_of_sequence/ last_dim_size as f32+epsilon).sqrt();

        //得出本批次的RMS Normalization
        for i in 0..last_dim_size{
            let weight =&w.data()[i];//说明了w就是一个一维向量
            unsafe { y.data_mut()[start_idx + i] = batch_data[i] * weight / rms; }
        }
    }
}

这种实现方式在处理大规模数据时特别有效,因为它能够利用批处理模式优化计算效率,并支持复杂张量的多维度操作。Rust 的类型系统和内存安全特性提供了安全而高效的数值计算能力。

Comments

No comments yet. Why don’t you start the discussion?

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注