无限长序列、惰性计算与C++(1)

宝宝的函数式编程


我一直以来都喜欢尽力滥用编程语言的一些特性达到一些奇妙的结果,虽然最终得到的成品都很劣质,但是这在某种程度上也是对自己编程思维能力的一个提升。之前看到了C++20中范围库的各种骚操作,我就在想,自己能不能实现一个类似的东西。先给自己定个小目标吧,就比如说,试着用范围for循环来遍历从2开始的素数列。

for (const auto i : /* something here */)
    // Do something with i

要达成这个目标,并且如果想让算法有普适性,我们分步实现这个目标。首先来实现一个无限长自然数列iota

for (const auto i : iota)
    // Do something with i

// Equivalent to
for (size_t i = 0;; i++)
    // Do something with i

我们声明一个类作为iota的类型。这里以防万一以后要用到一些C++要求的类型别名,后面的类型名就都采用标准库的全小写格式 (虽然个人不是非常爽……)

struct iota_t final {} inline const iota;

下面就要开始填类的实现了。若想让iota支持范围for循环,那么iota需要有对应的begin以及end方法,返回首尾迭代器。我们写一个成员类作为迭代器的类型,因为我们想要得到迭代器对应的当前值,所以里面显然至少需要一个size_t用来存放当前值。这样,首迭代器就可以用存放零值的迭代器了。

struct iota_t final
{
    struct iterator final { size_t value = 0; };
    iterator begin() const { return {}; }
} inline const iota;

作为一个可以用范围for循环来遍历的迭代器,这个迭代器类至少需要实现三个成员函数:

  1. 前置的operator++,用来将迭代器指向下一个位置;
  2. 一元operator*,即解引用运算符,返回迭代器指向的值;
  3. operator!=,用来判断当前迭代器是否与尾迭代器不等。

有了这些基础,我们就可以很简单地实现出这三个函数:

struct iterator final
{
    size_t value = 0;
    iterator& operator++() { ++value; return *this; }
    size_t operator*() const { return value; }
    bool operator!=(const iterator& other) const { return value != other.value; }
};

但是,这里还有一个问题需要解决,就是尾迭代器的问题。我们希望序列是无限长的,意思也就是我们想让每一个不越界的迭代器都与尾迭代器判断不等。我们可以在迭代器里面使用一个bool作为尾迭代器的标签,也即:(下面代码中都只写出改变的主要部分)

struct iota_t final
{
    struct iterator final
    {
        bool is_end = false;
        size_t value = 0;
        bool operator!=(const iterator& other) const
        {
            return value != other.value
                || is_end != other.is_end;
        }
    };
    iterator end() const { return { true }; }
} inline const iota;

这样我们就已经能达到目的了,但是我们还有一个更好的解法:比起使用一个flag来标识尾迭代器,还不如直接用一个哨兵类作为尾迭代器的类型,并让两种迭代器类型间做不等比较始终返回true

struct iota_t final
{
    struct sentinel final {};
    struct iterator final
    {
        size_t value = 0;
        iterator& operator++() { ++value; return *this; }
        size_t operator*() const { return value; }
        bool operator!=(const iterator& other) const { return value != other.value; }
        bool operator!=(sentinel) const { return true; }
    };
    iterator begin() const { return {}; }
    sentinel end() const { return {}; }
} inline const iota;

这样我们就可以把一部分工作挪到编译时,获得更好的优化效果。不过C++17之前的范围for循环不支持首尾迭代器类型不同,所以就只能用前面添加一个标识的方法了。不过话说回来,我们其实只用提供iteratorsentinel比较不等的函数,而不用提供iterator自己比较的函数。因为范围for循环展开以后只需要比较当前迭代器和尾迭代器的值。

下一个目标就是实现一个筛选器,提供一个一元谓词从原来的序列中筛选出一个新的序列,实现了这个的话这样我们最初的目标就可以用下面的代码表示了,其中is_prime就是一个判断一个数是否是素数的函数:

for (const auto i : filter(iota, is_prime))
    // Do something with i

仿照前面的例子,我们可以类似地写出这个filter函数的骨架。因为返回值需要支持范围for所以之前的那套操作还是需要走一遍。这里给了返回值类型一个名字,方便后面迭代器的实现。同时,定义输入范围range头尾迭代器类型(也就是“正常”迭代器和“哨兵”迭代器类型)的别名以便后用。

template <typename Range, typename Predicate>
auto filter(Range&& range, Predicate&& predicate)
{
    using iter_type = decltype(range.begin());
    using sentinel_type = decltype(range.end());
    struct filter_t final
    {
        struct iterator final
        {
            iterator& operator++();
            auto operator*();
            bool operator!=(const iterator&) const;
        };
        auto begin();
        auto end() const;
    } result{};
    return result;
}

为了防止悬垂引用,我们把函数的这两个参数转发到filter_t类里,留存一份值。为了后面操作方便,我们同时也存下来range.end()的值以便后用。这三个变量的值在遍历的过程中是不会改变的,所以才选择存在filter_t中而不是iterator中。iterator中需要存储range中的迭代器,也要存一份filter_t&以访问filter_t里面存的这三个成员。并且,像iota的实现一样,我们使用一个哨兵类标识filter_t的尾迭代器。这样我们就可以实现一大部分了。

struct filter_t final
{
    struct sentinel final {};
    struct iterator final
    {
        filter_t& parent;
        iter_type iter;
        decltype(auto) operator*() const { return *iter; }
        bool operator!=(const iterator& other) const { return iter != other.iter; }
        bool operator!=(sentinel) const { return iter != parent.range_end; }
    };
    Range range;
    Predicate predicate;
    sentinel_type range_end = range.end();
    sentinel end() const { return {}; }
} result{ std::forward<Range>(range), std::forward<Predicate>(predicate) };

重点就在operator++的实现上了。我们先实现一个帮助函数to_next,用来将迭代器移动到下一个使谓词成立的位置。

void to_next()
{
    while (iter != parent.range_end && !parent.predicate(*iter))
        ++iter;
}

利用这个函数我们就可以轻松地实现iterator::operator++以及filter_t::begin了。

struct filter_t final
{
    struct iterator final
    {
        iterator& operator++()
        {
            ++iter;
            to_next();
            return *this;
        }
    };
    iterator begin()
    {
        iterator result{ *this, range.begin() };
        result.to_next();
        return result;
    }
};

这样我们最初的目标就达成了!用下面的代码测试一下?

bool is_prime(const size_t n)
{
    if (n < 2) return false;
    for (size_t i = 2; i * i <= n; i++)
        if (n % i == 0)
            return false;
    return true;
}

int main()
{
    for (const auto i : filter(iota, is_prime))
        std::cout << i << '\n';
}

本系列的下一期如果我不会咕还会出下一期的话中,我将会使用一些TMP的技巧求你不要再TMP了讲解实现起来更困难的zip以及cartesian_product的实现方法。敬请期待~

(8月15日更新:将代码中的类型名全部改为小写形式)