一直没有发文感觉有点咕太多了所以随便想了个话题来灌一篇(问题发言)。因为本人对C++模板有着特殊的执念(?)所以这次来讲一讲表达式模板。作为一个开头废我仍然不知道文章开头是要怎么写,所以这次我来讲一个故事。
为防止被打我先说明一下:以下故事中的“你”只是一个虚拟人物,不是指正在读这篇文章的您。好的故事开始。
第一章 事倍功半
想象一下你 是一个数学人 数学人才不会自己造这种轮子 需要写一个线性代数的C++库 (这不是问题更大吗)。当然一开始肯定要先实现一些简单的功能,比如说向量的基本运算。你非常快速地写出了一个向量类的框架:(P.S. 下面的示例很烂,而且严重缺功能,但是这不是重点所以看看就完了(草))
#include <vector>
#include <iostream>
template <typename T>
class vector
{
protected:
std::vector<T> data_;
public:
explicit vector(const size_t size) :data_(size) {}
vector(const std::initializer_list<T> ilist) :data_(ilist) {}
size_t size() const { return data_.size(); }
T& operator[](const size_t i) { return data_[i]; }
T operator[](const size_t i) const { return data_[i]; }
};
template <typename T>
std::ostream& operator<<(std::ostream& str, const vector<T>& vec)
{
for (size_t i = 0; i < vec.size(); i++)
{
if (i != 0) str << ", ";
str << vec[i];
}
return str;
}
int main()
{
const vector<float> v{ 3, 4, 2 };
std::cout << "v = " << v << '\n';
return 0;
}
然后程序输出v = 3, 4, 2
,与预期的一致,接下来你试着实现另外的两个函数:标量与向量的乘积以及向量之间的加法。你想了想,“这很简单啊”,于是写下了下面的代码。(在此为方便起见只实现了标量在左侧的情况,并且略去了向量加法中向量维度不同的错误处理)
template <typename T>
vector<T> operator*(const T lhs, vector<T> rhs)
{
for (size_t i = 0; i < rhs.size(); i++) rhs[i] *= lhs;
return rhs;
}
template <typename T>
vector<T> operator+(vector<T> lhs, const vector<T>& rhs)
{
for (size_t i = 0; i < lhs.size(); i++) lhs[i] += rhs[i];
return lhs;
}
你实现好了这两个函数,准备看一看它的效率。你生成了几个维度很大的向量测试运算需要消耗的时间。
#include <random>
#include <chrono>
static std::mt19937 gen{ std::random_device{}() };
const std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
vector<float> rand_vec(const size_t dimension)
{
vector<float> result(dimension);
for (size_t i = 0; i < dimension; i++) result[i] = dist(gen);
return result;
}
int main()
{
const size_t size = 67108864;
const auto v1 = rand_vec(size), v2 = rand_vec(size), v3 = rand_vec(size);
const float a1 = dist(gen), a2 = dist(gen), a3 = dist(gen);
using namespace std::literals;
using clock = std::chrono::high_resolution_clock;
{
const auto begin = clock::now();
const vector<float> result = a1 * v1 + a2 * v2 + a3 * v3;
const auto end = clock::now();
std::cout << "Class abstraction elapsed " << (end - begin) / 1.0ms << "ms\n";
}
return 0;
}
运行了一遍,好像耗时还不错?俗话说得好,没有对比就没有伤害。你决定把自己的这个vector
的实现跟手写的循环进行比较:
// Hand rolled loop
{
const auto begin = clock::now();
vector<float> result(size);
for (size_t i = 0; i < size; i++)
result[i] = a1 * v1[i] + a2 * v2[i] + a3 * v3[i];
const auto end = clock::now();
std::cout << "Hand rolled loop elapsed " << (end - begin) / 1.0ms << "ms\n";
}
// Abstraction
{
const auto begin = clock::now();
const vector<float> result = a1 * v1 + a2 * v2 + a3 * v3;
const auto end = clock::now();
std::cout << "Class abstraction elapsed " << (end - begin) / 1.0ms << "ms\n";
}
于是你运行了一下,运行的结果震撼了(数据删除)一整年:
Hand rolled loop elapsed 303.733ms
Class abstraction elapsed 1712.79ms
“怎么可能啊!” 你对着屏幕上冰冷的数字发出了无能狂怒,不信邪重新运行了好几遍,可是你写出来的函数效率是真的菜,没有拯救的办法。冷静下来以后你仔细想了想自己的实现为什么就这么菜,于是你发现了一个问题。这个最简单的实现的运行时间主要都花在了循环变量的自增和比较上,真正有效的运算只占了很少一部分。这种实现在每一个运算符的地方都循环了一遍,而手动写出来的那个循环只跑了一遍循环就完成了任务,省去了很多的无效的i++
以及i < size
的执行。并且,这种最简单的实现产生了过多不必要的临时值,也就出现了很大的不必要内存开销。
第二章 模板与抽象语法树
话是这样说没错,你发觉到了这种普通的计算方式的缺点。但是显然,手写这些循环也并不是非常好的做法,万一哪次复制粘贴的时候忘了改一个运算符怎么办呢?C++语言以提供“零额外成本”抽象 (zero-overhead abstractions) 而备受各界的青睐,那我们能否用C++给我们的这些工具同时实现直观和高效两个目标呢?
答案是肯定的。我们可以用表达式模板让编译器帮我们做这个“合并”循环的工作。那么具体要怎么实现呢?
我们显然不能像之前那样,在每一个运算符的位置都急切地进行循环对表达式求值(eager-evaluation),而是要做个懒人,在获得要计算的表达式的全部信息之后再惰性地在赋值等号处求值 (lazy-evaluation)。这已经不是惰性求值的概念第一次在这个blog里面登场了。
我们需要一种方式来捕获一个表达式的所有信息,计算机课上讲过的基础知识告诉我们表达式可以用语法树来表达。比如说u + a * v
这个表达式就可以看成是一棵树,树的根节点是最后执行的+
运算符 (operator),而它的两个孩子就分别是加运算符的两个操作数 (operand):左孩子是叶子节点u
,右孩子是*
运算符,而它的两个孩子就是两个操作数叶子节点a
和v
。
既然我们的表达式都是编译时可知的,我们就可以用编译时计算的一些方法来进行表达式的表示。很自然地,就可以想到用类模板来表示各种运算符节点。比如说,你现在要实现的标量与向量的乘积运算就可以用如下的类模板表示。因为树是一种递归的结构,所以这里标量向量积的右式(向量部分)也可能是一个子表达式,这个子表达式又可能是其他的类型,因此我们必须要使用模板解决这个问题。
template <typename T, typename E>
struct scalar_mul
{
T scalar;
const E& expr;
};
类似地,我们用类模板表示向量加法:
template <typename L, typename R>
struct vector_plus
{
const L& lhs;
const R& rhs;
};
举个例子,类似u + a * v
这样的表达式对应的类型就是
vector_plus<
vector<float>, // u
scalar_mul<
float, // a
vector<float> // v
>
>
第三章 优雅而高效
现在有了表达式的信息,那么需要怎么高效地对表达式求值呢?因为目前的运算符都是对向量各个元素依次求值,很自然地我们就想给每个运算符实现一个operator[]
计算某个维度上元素的值。另外,当然还需要实现一个size
方法获取表达式的计算结果的维度。为后续处理方便,这里我们给所有的运算符都加上一个成员类型定义et_value_type
代表其结果向量各分量值的类型。
template <typename T, typename E>
struct scalar_mul
{
using et_value_type = T;
T scalar;
const E& expr;
size_t size() const { return expr.size(); }
auto operator[](const size_t i) const { return scalar * expr[i]; }
};
template <typename L, typename R>
struct vector_plus
{
using et_value_type = typename L::et_value_type;
const L& lhs;
const R& rhs;
size_t size() const { return lhs.size(); }
auto operator[](const size_t i) const { return lhs[i] + rhs[i]; }
};
下面就要实现对应的运算符重载函数了。因为这里左右的两个子表达式的类型都是不固定的,所以这里要实现的两个运算符左右操作数的类型都需要是模板参数。我们显然不能不对两边操作数的类型进行限制,像下面这样:
template <typename L, typename R>
auto operator+(const L& lhs, const R& rhs);
因为这样的话就等于是给所有的类型都重载了operator+
,这显然不是我们的目的。我们需要给这个模板添加一些限制。比如这里的标量向量乘法运算符,我们的限制就是右式的分量值类型(即typename R::et_value_type
)需要与左式类型L
一致。
效仿标准库<type_traits>
里面类型特征的实现,我们可以写出这样的限制条件:
template <typename E, typename T, typename U = void>
struct is_expr_of
{
static constexpr bool value = false;
};
template <typename E, typename T>
struct is_expr_of<E, T, std::void_t<typename E::et_value_type>>
{
static constexpr bool value = std::is_same_v<T, typename E::et_value_type>;
};
template <typename E, typename T>
constexpr bool is_expr_of_v = is_expr_of<E, T>::value;
这里对于一般的任意类型E
和T
,我们有is_expr_of
的主模板,其中定义了一个静态编译时量value
值为false
,也就是表明如果只是任意给的E
和T
,那么E
不是分量类型为T
的表达式。如果类型E
定义了成员类型定义et_value_type
,那么后面给出的模板偏特化就会被选中,其中根据typename E::et_value_type
与T
是否相同给出了value
的定义。最后为了简化检测所需的代码,定义变量模板is_expr_of_v
获取is_expr_of
中静态成员value
的值。
这样,我们就可以用这个is_expr_of_v
作为运算符重载的限制了:
template <typename T, typename E,
typename = std::enable_if_t<is_expr_of_v<E, T>>>
auto operator*(const T lhs, const E& rhs)
{
return scalar_mul<T, E>{ lhs, rhs };
}
template <typename L, typename R,
typename = std::enable_if_t<is_expr_of_v<L, typename R::et_value_type>>>
auto operator+(const L& lhs, const R& rhs)
{
return vector_plus<L, R>{ lhs, rhs };
}
我们最后还需要定义vector
的隐式转型构造函数和赋值运算符,因为我们想通过到向量的赋值运算符来调用求值函数。就像这样:vector<float> result = u + v;
这个例子中就需要调用从vector_plus<vector<float>, vector<float>>
到vector<float>
的转型构造函数,在调用这个构造函数的时候执行运算。
template <typename T>
class vector
{
...
public:
using et_value_type = T;
template <typename E> vector(const E& expr);
template <typename E> vector& operator=(const E& expr);
...
};
当然,与前面原因相同,我们必须对这个转型构造函数模板的参数进行限制。首先与前面一样,E
必须是分量类型为T
的表达式;同时,在这里还有另外的一个限制:这个转型运算符不能将复制和移动构造/拷贝函数隐藏掉,所以E
也不能是诸如vector<T>&
,const vector<T>&&
之类的类型。这里使用一个变量模板实现后者所述的限制。
template <typename T, typename U>
constexpr bool is_cvref_of_v = std::is_same_v<std::decay_t<T>, U>;
有了这些铺垫,这两个函数模板的实现也就不困难了:
template <typename E,
typename = std::enable_if_t<!is_cvref_of_v<E, T>&& is_expr_of_v<E, T>>>
vector(const E& expr) :data_(expr.size())
{
const size_t size = data_.size();
for (size_t i = 0; i < size; i++) data_[i] = expr[i];
}
template <typename E,
typename = std::enable_if_t<!is_cvref_of_v<E, T>&& is_expr_of_v<E, T>>>
vector& operator=(const E& expr)
{
const size_t size = expr.size();
vector result(size);
for (size_t i = 0; i < size; i++) result.data_[i] = expr[i];
return *this;
}
你花了不少时间将上面说的这些东西都实现了一遍,终于,测试的时间到了。你重新运行了测试,稍等片刻之后,屏幕上写出了测试的结果。
Hand rolled loop elapsed 229.565ms
Class abstraction elapsed 230.397ms
“呼——” 你长舒了一口气,这心结总算是解开了。经过这次的程序优化,你变强了 ,也变秃了。
第四章 不仅仅是向量运算
就线性代数而言,表达式模板能够起到的优化效果很大,不仅是像这次“合并”循环这么简单。
想一想多个矩阵连续求乘积的问题。如果我们有这样的几个矩阵:\(A\in\mathbb{R}^{10000\times 2}\),\(B\in\mathbb{R}^{2\times 5000}\),\(C\in\mathbb{R}^{5000\times 10}\),需要计算 \(ABC\) 的值,一共需要多少次乘法运算呢?
用一般的方法计算矩阵乘法的话,大小分别是 \(m\times n\) 和 \(n\times p\) 的两个矩阵相乘需要用 \(m\times n\times p\) 次乘法运算。上面的问题如果按顺序从左到右运算的话,光是计算 \(AB\) 就需要\(10^8\)次乘法运算,再将这个结果与矩阵 \(C\) 相乘,又需要进行\(5\times 10^8\)次乘法运算。总共进行了\(6\times 10^8\)次乘法运算不说,计算 \(AB\) 的时候还生成了一个维度为\(10000\times 5000\)的巨大的临时矩阵。
我们都知道矩阵乘法满足结合律,也就是说可以先计算 \(BC\) 的值,这一部分只需要进行\(10^5\)次乘法运算,这里只使用了一个\(2\times 10\)的很小的临时矩阵;而之后计算 \(A\cdot BC\) 的时候也只需要再用\(2\times 10^5\)次乘法运算,一共\(3\times 10^5\)次乘法运算就能完成任务。比较起来,对于这个问题,前面从左到右计算的方法要比这种经过考虑地选择计算顺序的方法慢上2000倍,额外内存开销高出多达250万倍,高下立判。虽然这个问题中矩阵的大小是故意选成这样维度严重不一致的,但是现实问题中也确实有不少这样的数据,只是没这么夸张而已。
利用表达式模板的方法,我们可以在获取整个表达式的信息之后,用动态规划算法求出乘法运算次数最少的计算序列,从而获得最佳的计算效率。
感兴趣的读者们,要不要把这个当作课后练习呢?
(完)