相关头文件:
#include <iostream> #include <algorithm> #include <numeric> #include <climits> #include <cassert>C++ 11 中的 Lambda 表达式用于定义并创建匿名的函数对象,以简化编程工作。Lambda 的语法形式如下:
[函数对象参数] (操作符重载函数参数) mutable 或 exception 声明 -> 返回值类型 {函数体}[函数对象参数] 必须 标识一个 Lambda 表达式的开始,这部分必须存在,不能省略。函数对象参数是传递给编译器自动生成的函数对象类的构造函数的。函数对象参数只能使用 Lambda 所在作用范围内可见的局部变量(包括 Lambda 所在类的 this)。函数对象参数有以下形式:
[] 空,没有任何函数对象参数。[=] 函数体内可以使用 Lambda 所在范围内所有可见的局部变量(包括 Lambda 所在类的 this),并且是值传递方式(相当于编译器自动为我们按值传递了所有局部变量)。[&] 函数体内可以使用 Lambda 所在范围内所有可见的局部变量(包括 Lambda 所在类的 this),并且是引用传递方式(相当于是编译器自动为我们按引用传递了所有局部变量)。[this] 函数体内可以使用 Lambda 所在类中的成员变量。[a] 将 a 按值进行传递。按值进行传递时,函数体内不能修改传递进来的 a 的拷贝,因为默认情况下函数是 const 的,要修改传递进来的拷贝,可以添加 mutable 修饰符。[&a] 将 a 按引用进行传递。[a, &b] 将 a 按值传递,b 按引用进行传递。[=, &a, &b] 除 a 和 b 按引用进行传递外,其他参数都按值进行传递。[&, a, b] 除 a 和 b 按值进行传递外,其他参数都按引用进行传递。(操作符重载函数参数) 可选 没有参数时,这部分可以省略。参数可以通过按值(如: (a, b))和按引用 (如: (&a, &b)) 两种 方式进行传递。
mutable 或 exception 声明 可选 按值传递函数对象参数时,加上 mutable 修饰符后,可以修改传递进来的拷贝(注意是能修改拷贝,而不是值本身)。exception 声明用于指定函数抛出的异常,如抛出整数类型的异常,可以使用 throw(int)。
-> 返回值类型 可选 当返回值为 void,或者函数体中只有一处 return 的地方(此时编译器可以自动推断出返回值类型) 时,这部分可以省略。
{函数体} 必须 标识函数的实现,这部分不能省略,但函数体可以为空。
参考 numpy 的C++版本 NumCPP 进行实现:https://github.com/dpilger26/NumCpp/blob/master/include/NumCpp/NdArray/NdArrayCore.hpp
template<class T> void argsort(T *array, int num, int *index) { const auto function = [array](int a, int b) noexcept -> bool { return array[a] < array[b]; }; assert(num < INT_MAX); int *temp = new int[num]; std::iota(temp, temp + num, 0); std::sort(temp, temp + num, function); memcpy(index, temp, num * sizeof(int)); delete[] temp; }直接写出:
template<class T> int argmin(T *array, int num) { const auto function = [](T &a, T &b) noexcept -> bool { return a < b; }; return std::min_element(array, array + num, function) - array; }输出:
8, 7, 6, 5, 4, 3, 2, 1, 0, 0 8
比如计算 MobileNet 输出的前五项类别
template<class T> void top_k(T *array, int num, int *index, int k) { const auto function = [array](int a, int b) noexcept -> bool { return array[a] > array[b]; }; int *temp = new int[num]; std::iota(temp, temp + num, 0); std::sort(temp, temp + num, function); memcpy(index, temp, k * sizeof(int)); delete[] temp; } int top_5[5]; top_k(output, 1001, top_5, 5);