大规模数值计算问题——算法系列教程(c++版)

梦想不会自己发光,真正闪耀的是那个为梦狂奔的你。献给知行的孩子们!(Eric.He著)


  分治算法(Divide and Conquer)不仅在查找问题中展现出高效优势,在数值计算领域同样是核心算法思想。数值计算的核心需求是对大规模数据进行精准、高效的运算(如矩阵乘法、信号变换等),传统算法往往面临时间复杂度高、运算效率低的问题。分治思想通过将大规模数值计算问题拆解为若干小规模子问题,并行或递归求解后合并结果,能显著降低时间复杂度。本文将聚焦两个经典实例——矩阵乘法的Strassen算法、傅里叶变换的快速傅里叶变换(FFT)算法,详细讲解分治思想在数值计算中的应用。


分治算法解决数值计算问题的通用步骤:

  1. 分解(Divide):将大规模数值问题拆分为若干个规模更小、结构一致的子问题(如将n×n矩阵拆分为4个(n/2)×(n/2)子矩阵);
  2. 治理(Conquer):递归求解每个子问题(小规模子问题可直接计算,避免递归过深);
  3. 合并(Combine):通过特定的数学规则,将子问题的运算结果合并,得到原问题的最终解(合并步骤是数值计算分治算法的核心,直接决定算法效率)。

教程目录导航

一、矩阵乘法(Strassen算法)——高效矩阵运算

  矩阵乘法是线性代数、机器学习、图像处理等领域的基础运算(如神经网络中的权重更新、图像的卷积运算)。传统矩阵乘法算法(三重循环)的时间复杂度为O(n³),对于大规模矩阵(如1000×1000及以上),运算效率极低。Strassen算法基于分治思想,通过优化子矩阵的合并策略,将时间复杂度降至O(n^log7)≈O(n².81),大幅提升了大规模矩阵乘法的运算效率。

问题描述:

算法解析:

  1. 分 (Divide)

    将矩阵A、B、C分别拆分为4个(n/2)×(n/2)的子矩阵,拆分规则如下:

    A = [[A₁₁, A₁₂], [A₂₁, A₂₂]],B = [[B₁₁, B₁₂], [B₂₁, B₂₂]],C = [[C₁₁, C₁₂], [C₂₁, C₂₂]]

      根据矩阵乘法规则,传统分治算法的子矩阵关系为:
    • C₁₁ = A₁₁×B₁₁ + A₁₂×B₂₁
    • C₁₂ = A₁₁×B₁₂ + A₁₂×B₂₂
    • C₂₁ = A₂₁×B₁₁ + A₂₂×B₂₁
    • C₂₂ = A₂₁×B₁₂ + A₂₂×B₂₂
  2. 治 (Conquer)

    Strassen提出通过构造7个中间矩阵(仅需7次(n/2)×(n/2)矩阵乘法),替代传统的8次乘法,中间矩阵定义如下:

    • M₁ = (A₁₁ + A₂₂) × (B₁₁ + B₂₂)
    • M₂ = (A₂₁ + A₂₂) × B₁₁
    • M₃ = A₁₁ × (B₁₂ - B₂₂)
    • M₄ = A₂₂ × (B₂₁ - B₁₁)
    • M₅ = (A₁₁ + A₁₂) × B₂₂
    • M₆ = (A₂₁ - A₁₁) × (B₁₁ + B₁₂)
    • M₇ = (A₁₂ - A₂₂) × (B₂₁ + B₂₂)

    通过7次乘法计算出M₁~M₇后,再通过8次矩阵加法得到子矩阵C₁₁~C₂₂:

    • C₁₁ = M₁ + M₄ - M₅ + M₇
    • C₁₂ = M₃ + M₅
    • C₂₁ = M₂ + M₄
    • C₂₂ = M₁ - M₂ + M₃ + M₆
  3. 合 (Combine)

    将计算得到的子矩阵C₁₁、C₁₂、C₂₁、C₂₂按原拆分规则拼接,得到最终的n×n矩阵C=A×B。

算法步骤:

给定4×4矩阵A和B,计算C=A×B,步骤如下:

  1. 给定4×4矩阵A和B,计算C=A×B,步骤如下:
  2. 计算中间矩阵M₁~M₇:每个M均为2×2矩阵,通过对应子矩阵的乘法和加法计算(如M₁=(A₁₁+A₂₂)×(B₁₁+B₂₂));
  3. 计算子矩阵C₁₁~C₂₂:根据中间矩阵的加减规则,得到4个2×2子矩阵;
  4. 合并子矩阵:将C₁₁、C₁₂、C₂₁、C₂₂拼接为4×4矩阵C,完成运算。

代码包含矩阵拆分、矩阵加减、Strassen算法核心函数,以及矩阵扩展(处理非2的幂次方规模矩阵)功能,使用vector容器存储矩阵,保证代码的灵活性和可读性。


#include <iostream>
#include <vector>
#include <cmath>
using namespace std;

// 定义矩阵类型:vector<vector<int>>(此处使用int,若需处理浮点数可改为double)
typedef vector<vector<int>> Matrix;

// 矩阵加法:两个矩阵A和B相加,返回新矩阵
Matrix matrixAdd(const Matrix& A, const Matrix& B) {
    int n = A.size();
    Matrix res(n, vector<int>(n, 0));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            res[i][j] = A[i][j] + B[i][j];
        }
    }
    return res;
}

// 矩阵减法:两个矩阵A和B相减,返回新矩阵
Matrix matrixSub(const Matrix& A, const Matrix& B) {
    int n = A.size();
    Matrix res(n, vector<int>(n, 0));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            res[i][j] = A[i][j] - B[i][j];
        }
    }
    return res;
}

// 矩阵拆分:将n×n矩阵拆分为4个(n/2)×(n/2)子矩阵
void matrixSplit(const Matrix& A, Matrix& A11, Matrix& A12, Matrix& A21, Matrix& A22) {
    int n = A.size();
    int mid = n / 2;
    A11.resize(mid, vector<int>(mid));
    A12.resize(mid, vector<int>(mid));
    A21.resize(mid, vector<int>(mid));
    A22.resize(mid, vector<int>(mid));
    for (int i = 0; i < mid; ++i) {
        for (int j = 0; j < mid; ++j) {
            A11[i][j] = A[i][j];
            A12[i][j] = A[i][j + mid];
            A21[i][j] = A[i + mid][j];
            A22[i][j] = A[i + mid][j + mid];
        }
    }
}

// 矩阵合并:将4个(n/2)×(n/2)子矩阵合并为n×n矩阵
Matrix matrixMerge(const Matrix& A11, const Matrix& A12, const Matrix& A21, const Matrix& A22) {
    int mid = A11.size();
    int n = mid * 2;
    Matrix res(n, vector<int>(n, 0));
    for (int i = 0; i < mid; ++i) {
        for (int j = 0; j < mid; ++j) {
            res[i][j] = A11[i][j];
            res[i][j + mid] = A12[i][j];
            res[i + mid][j] = A21[i][j];
            res[i + mid][j + mid] = A22[i][j];
        }
    }
    return res;
}

// 传统矩阵乘法:用于小规模子矩阵计算,避免递归过深
Matrix traditionalMult(const Matrix& A, const Matrix& B) {
    int n = A.size();
    int m = B[0].size();
    int p = B.size();
    Matrix res(n, vector<int>(m, 0));
    for (int i = 0; i < n; ++i) {
        for (int k = 0; k < p; ++k) {
            if (A[i][k] == 0) continue; // 剪枝:跳过0元素,提升效率
            for (int j = 0; j < m; ++j) {
                res[i][j] += A[i][k] * B[k][j];
            }
        }
    }
    return res;
}

// Strassen算法核心函数
Matrix strassenMult(const Matrix& A, const Matrix& B) {
    int n = A.size();
    // 递归终止条件:当矩阵规模小于等于2时,使用传统乘法
    if (n <= 2) {
        return traditionalMult(A, B);
    }
    // 1. 分解矩阵
    Matrix A11, A12, A21, A22;
    Matrix B11, B12, B21, B22;
    matrixSplit(A, A11, A12, A21, A22);
    matrixSplit(B, B11, B12, B21, B22);
    // 2. 计算7个中间矩阵M1-M7
    Matrix M1 = strassenMult(matrixAdd(A11, A22), matrixAdd(B11, B22));
    Matrix M2 = strassenMult(matrixAdd(A21, A22), B11);
    Matrix M3 = strassenMult(A11, matrixSub(B12, B22));
    Matrix M4 = strassenMult(A22, matrixSub(B21, B11));
    Matrix M5 = strassenMult(matrixAdd(A11, A12), B22);
    Matrix M6 = strassenMult(matrixSub(A21, A11), matrixAdd(B11, B12));
    Matrix M7 = strassenMult(matrixSub(A12, A22), matrixAdd(B21, B22));
    // 3. 计算子矩阵C11-C22
    Matrix C11 = matrixAdd(matrixSub(matrixAdd(M1, M4), M5), M7);
    Matrix C12 = matrixAdd(M3, M5);
    Matrix C21 = matrixAdd(M2, M4);
    Matrix C22 = matrixAdd(matrixSub(matrixAdd(M1, M3), M2), M6);
    // 4. 合并子矩阵,返回结果
    return matrixMerge(C11, C12, C21, C22);
}

// 矩阵扩展:将矩阵扩展为2的幂次方规模(补0)
Matrix matrixExpand(const Matrix& A, int targetSize) {
    int n = A.size();
    Matrix res(targetSize, vector<int>(targetSize, 0));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            res[i][j] = A[i][j];
        }
    }
    return res;
}

// 对外接口:处理任意规模的方阵乘法
Matrix strassenMatrixMult(const Matrix& A, const Matrix& B) {
    // 检查输入矩阵是否为方阵且维度一致
    if (A.size() != A[0].size() || B.size() != B[0].size() || A.size() != B.size()) {
        throw invalid_argument("输入必须是维度一致的方阵");
    }
    int n = A.size();
    // 计算最小的2的幂次方,使其大于等于n
    int targetSize = 1;
    while (targetSize < n) {
        targetSize <<= 1; // 等价于targetSize *= 2
    }
    // 扩展矩阵(若需)
    Matrix AExpanded = matrixExpand(A, targetSize);
    Matrix BExpanded = matrixExpand(B, targetSize);
    // 执行Strassen乘法
    Matrix CExpanded = strassenMult(AExpanded, BExpanded);
    // 截取结果矩阵的前n×n部分(去除扩展的0元素)
    Matrix C(n, vector<int>(n, 0));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            C[i][j] = CExpanded[i][j];
        }
    }
    return C;
}

// 打印矩阵
void printMatrix(const Matrix& mat) {
    for (const auto& row : mat) {
        for (int val : row) {
            cout << val << " ";
        }
        cout << endl;
    }
}

// 测试案例
int main() {
    // 4×4矩阵A和B
    Matrix A = {
        {1, 2, 3, 4},
        {5, 6, 7, 8},
        {9, 10, 11, 12},
        {13, 14, 15, 16}
    };
    Matrix B = {
        {17, 18, 19, 20},
        {21, 22, 23, 24},
        {25, 26, 27, 28},
        {29, 30, 31, 32}
    };
    try {
        // 使用Strassen算法计算
        Matrix CStrassen = strassenMatrixMult(A, B);
        // 输出结果
        cout << "Strassen算法结果:" << endl;
        printMatrix(CStrassen);
    } catch (const exception& e) {
        cout << "错误:" << e.what() << endl;
        return 1;
    }
    return 0;
}
        

输出结果



        

二、快速傅里叶变换(FFT算法)——高效信号变换

傅里叶变换(Fourier Transform)是数字信号处理、图像处理、通信等领域的核心技术,用于将时域离散信号转换为频域离散信号;,从而分析信号的频率成分(如声音的音调、图像的纹理)。传统离散傅里叶变换(DFT)的时间复杂度为O(n²),对于大规模信号(如高清图像、长时音频)处理效率极低。快速傅里叶变换(FFT)基于分治思想,通过利用复数单位根的对称性和周期性,将时间复杂度降至O(n log n),是傅里叶变换工程应用的基础。

核心概念

离散傅里叶变换(DFT)的定义:

算法解析:

  1. 分 (Divide)

    将长度为n的信号x[j]按索引的奇偶性拆分为两个长度为n/2的子信号:

    • 偶序列:x₀[j] = x[2j](j=0,1,...,n/2-1)
    • 奇序列:x₁[j] = x[2j+1](j=0,1,...,n/2-1)

    根据DFT定义,原信号的DFT结果可拆分为偶序列和奇序列的DFT结果的组合:

    • X[k] = X₀[k] + Wₙ^k × X₁[k](k=0,1,...,n/2-1)
    • X[k+n/2] = X₀[k] - Wₙ^k × X₁[k](k=0,1,...,n/2-1)

    其中X₀[k]为偶序列x₀[j]的DFT结果,X₁[k]为奇序列x₁[j]的DFT结果。

  2. 治 (Conquer)

    递归求解偶序列x₀[j]和奇序列x₁[j]的DFT结果(即X₀[k]和X₁[k])。当子信号长度为1时,DFT结果等于信号本身(递归终止条件)。

  3. 合 (Combine)

    利用单位根的周期性(Wₙ^k = Wₙ/2^(k mod (n/2)))和对称性(Wₙ^(k+n/2) = -Wₙ^k),将X₀[k]和X₁[k]组合为原信号的DFT结果X[k]和X[k+n/2],完成合并。

算法步骤:

给定长度为8的离散信号x = [x0, x1, x2, x3, x4, x5, x6, x7],计算其FFT结果X,步骤如下:

  1. 分解信号:按奇偶索引拆分为偶序列x0 = [x0, x2, x4, x6]和奇序列x1 = [x1, x3, x5, x7];
  2. 递归分解:将x0和x1分别拆分为长度为2的子序列(如x0拆分为[x0,x4]和[x2,x6]),继续拆分至长度为1;
  3. 递归求解:计算每个长度为1的子序列的DFT(即自身),再合并为长度为2的子序列的DFT结果;
  4. 合并结果:利用单位根组合规则,将长度为2的DFT结果合并为长度为4的结果(X0和X1),再合并为长度为8的最终FFT结果X。

代码实现按时间抽取的FFT算法,使用复数容器存储信号和变换结果,包含信号扩展、FFT核心计算、逆FFT(用于信号还原验证)及信号打印功能。


#include <iostream>
#include <vector>
#include <complex>
#include <cmath>
#include <iomanip>
using namespace std;

typedef complex<double> Complex;
const double PI = acos(-1.0);

// 信号扩展:将信号长度扩展为2的幂次方(零填充)
vector<Complex> signalExpand(const vector<Complex>& x, int targetSize) {
    vector res(targetSize, 0);
    for (int i = 0; i < x.size(); ++i) {
        res[i] = x[i];
    }
    return res;
}

// FFT核心函数(按时间抽取,递归实现)
void fftRecursive(vector<Complex>& x) {
    int n = x.size();
    // 递归终止条件:信号长度为1时,无需变换
    if (n == 1) return;
    // 1. 分解:按奇偶索引拆分信号
    vector<Complex> even(n / 2), odd(n / 2);
    for (int i = 0; 2 * i < n; ++i) {
        even[i] = x[2 * i];
        odd[i] = x[2 * i + 1];
    }
    // 递归处理子信号
    fftRecursive(even);
    fftRecursive(odd);
    // 2. 合并:计算FFT结果
    for (int k = 0; 2 * k < n; ++k) {
        // 计算n次单位根 W_n^k = e^(-2πik/n)
        Complex W = exp(Complex(0, -2 * PI * k / n));
        x[k] = even[k] + W * odd[k];
        x[k + n / 2] = even[k] - W * odd[k];
    }
}

// 逆FFT(用于信号还原,验证FFT正确性)
void ifftRecursive(vector<Complex>& X) {
    int n = X.size();
    if (n == 1) return;
    // 逆FFT的单位根为 W_n^k = e^(2πik/n)(正号)
    vector<Complex> even(n / 2), odd(n / 2);
    for (int i = 0; 2 * i < n; ++i) {
        even[i] = X[2 * i];
        odd[i] = X[2 * i + 1];
    }
    ifftRecursive(even);
    ifftRecursive(odd);
    for (int k = 0; 2 * k < n; ++k) {
        Complex W = exp(Complex(0, 2 * PI * k / n));
        X[k] = (even[k] + W * odd[k]) / (double)n; // 逆变换需除以n
        X[k + n / 2] = (even[k] - W * odd[k]) / (double)n;
    }
}

// 对外接口:处理任意长度的信号,返回FFT结果
vector<Complex> fft(vector<Complex> x) {
    int n = x.size();
    // 计算最小的2的幂次方,使其大于等于n
    int targetSize = 1;
    while (targetSize < n) {
        targetSize <<= 1;
    }
    // 扩展信号
    x = signalExpand(x, targetSize);
    // 执行FFT
    fftRecursive(x);
    return x;
}

// 打印信号(实部,保留4位小数)
void printSignal(const vector<Complex>& x, int nSamples) {
    for (int i = 0; i < nSamples; ++i) {
        cout << fixed << setprecision(4) << x[i].real() << " ";
        if ((i + 1) % 10 == 0) cout << endl; // 每10个数据换行
    }
    cout << endl;
}

// 测试案例:生成正弦信号,验证FFT与逆FFT
int main() {
    // 生成信号:2Hz + 5Hz的正弦波混合信号(采样频率100Hz,采样点数64)
    int fs = 100; // 采样频率
    int nSamples = 64; // 采样点数
    vector<Complex> x(nSamples);
    for (int i = 0; i < nSamples; ++i) {
        double t = (double)i / fs; // 时间轴
        // 混合正弦信号:sin(2π*2t) + sin(2π*5t)
        double val = sin(2 * PI * 2 * t) + sin(2 * PI * 5 * t);
        x[i] = Complex(val, 0);
    }
    cout << "原始信号(前64个点,实部):" << endl;
    printSignal(x, nSamples);

    // 1. 执行FFT
    vector<Complex> X = fft(x);
    int fftSize = X.size();
    // 计算频率轴(仅取前半部分,因为FFT结果对称)
    vector<Complex> freq(fftSize / 2);
    for (int k = 0; k < fftSize / 2; ++k) {
        freq[k] = (double)k * fs / fftSize;
    }
    // 计算幅度谱(FFT结果的模,归一化)
    vector<Complex> amplitude(fftSize / 2);
    for (int k = 0; k < fftSize / 2; ++k) {
        amplitude[k] = abs(X[k]) / fftSize * 2;
    }

    // 2. 执行逆FFT,还原信号
    vector<Complex> xRecovered = X; // 复制FFT结果用于逆变换
    ifftRecursive(xRecovered);

    // 3. 输出结果验证
    cout << "\nFFT频率轴(前20个频率,单位:Hz):" << endl;
    for (int k = 0; k < 20; ++k) {
        cout << fixed << setprecision(2) << freq[k] << " ";
    }
    cout << endl;

    cout << "\nFFT幅度谱(前20个频率的幅度):" << endl;
    for (int k = 0; k < 20; ++k) {
        cout << fixed << setprecision(4) << amplitude[k] << " ";
    }
    cout << endl;

    // 计算原始信号与还原信号的误差
    double maxError = 0.0;
    cout << "\n原始信号与还原信号的误差(前10个点):" << endl;
    for (int i = 0; i < 10; ++i) {
        double orig = x[i].real();
        double rec = xRecovered[i].real();
        double error = abs(orig - rec);
        if (error > maxError) maxError = error;
        cout << fixed << setprecision(4) << "原始:" << orig << ", 还原:" << rec << ", 误差:" << error << endl;
    }

    cout << "\n最大还原误差:" << fixed << setprecision(6) << maxError << endl;
    cout << "还原是否成功:" << (maxError < 1e-10 ? "是" : "否") << endl;

    return 0;
}
        

输出结果



            

三、算法对比

对比维度 Strassen算法(矩阵乘法) FFT算法(傅里叶变换)
分治核心 将大矩阵拆分为子矩阵,通过优化中间矩阵减少乘法运算次数 将信号按奇偶索引拆分,利用单位根的对称性和周期性减少运算次数
时间复杂度 O(n^log7)≈O(n².81) O(n log n)
分解方式 按矩阵维度均匀拆分(n×n→4个(n/2)×(n/2)) 按信号索引奇偶性拆分(n点→2个n/2点)
合并关键 通过中间矩阵的加减组合得到最终结果,合并逻辑复杂 利用单位根的对称性组合子信号FFT结果,合并逻辑为“蝴蝶操作”
应用领域 线性代数、机器学习、图像处理(矩阵运算相关) 信号处理、通信、图像处理(频谱分析相关)
核心挑战 中间加减运算带来的精度损失和额外开销 频谱泄漏、零填充选择、复数运算的解读

四、总结

  1. 分治思想在数值计算中的核心价值是“通过问题拆分与优化合并,突破传统算法的复杂度瓶颈”;
  2. Strassen算法通过7次中间矩阵乘法替代传统8次,将矩阵乘法的时间复杂度从O(n³)降至O(n².81),解决了大规模矩阵运算的效率问题;
  3. FFT算法利用单位根的对称性,将DFT的O(n²)复杂度降至O(n log n),成为信号处理领域的基石。
  4. 两者的共性是“拆分后的子问题独立可解,合并步骤通过数学规则优化,实现整体效率提升”。

学习建议:先理解两种算法的分治拆分逻辑(矩阵拆分、信号奇偶拆分),再深入研究合并步骤的数学原理(Strassen中间矩阵组合、FFT单位根特性);通过对比传统算法与分治算法的效率差异,体会分治思想的优化价值;结合实际应用场景(如矩阵乘法的神经网络应用、FFT的音频频谱分析),加深对算法的理解与应用能力。


返回顶部