C++基础:lambda函数及用实现argsort, argmax, argmin

tech2026-03-04  2

文章目录

lambda函数语法示例argsortargmaxargmin 测试程序计算TOP-K 索引


相关头文件:

#include <iostream> #include <algorithm> #include <numeric> #include <climits> #include <cassert>

lambda函数语法

C++ 11 中的 Lambda 表达式用于定义并创建匿名的函数对象,以简化编程工作。Lambda 的语法形式如下:

[函数对象参数] (操作符重载函数参数) mutable 或 exception 声明 -> 返回值类型 {函数体}

[函数对象参数] 必须 标识一个 Lambda 表达式的开始,这部分必须存在,不能省略。函数对象参数是传递给编译器自动生成的函数对象类的构造函数的。函数对象参数只能使用 Lambda 所在作用范围内可见的局部变量(包括 Lambda 所在类的 this)。函数对象参数有以下形式:

[] 空,没有任何函数对象参数。[=] 函数体内可以使用 Lambda 所在范围内所有可见的局部变量(包括 Lambda 所在类的 this),并且是值传递方式(相当于编译器自动为我们按值传递了所有局部变量)。[&] 函数体内可以使用 Lambda 所在范围内所有可见的局部变量(包括 Lambda 所在类的 this),并且是引用传递方式(相当于是编译器自动为我们按引用传递了所有局部变量)。[this] 函数体内可以使用 Lambda 所在类中的成员变量。[a] 将 a 按值进行传递。按值进行传递时,函数体内不能修改传递进来的 a 的拷贝,因为默认情况下函数是 const 的,要修改传递进来的拷贝,可以添加 mutable 修饰符。[&a] 将 a 按引用进行传递。[a, &b] 将 a 按值传递,b 按引用进行传递。[=, &a, &b] 除 a 和 b 按引用进行传递外,其他参数都按值进行传递。[&, a, b] 除 a 和 b 按值进行传递外,其他参数都按引用进行传递。

(操作符重载函数参数) 可选 没有参数时,这部分可以省略。参数可以通过按值(如: (a, b))和按引用 (如: (&a, &b)) 两种 方式进行传递。

mutable 或 exception 声明 可选 按值传递函数对象参数时,加上 mutable 修饰符后,可以修改传递进来的拷贝(注意是能修改拷贝,而不是值本身)。exception 声明用于指定函数抛出的异常,如抛出整数类型的异常,可以使用 throw(int)。

-> 返回值类型 可选 当返回值为 void,或者函数体中只有一处 return 的地方(此时编译器可以自动推断出返回值类型) 时,这部分可以省略。

{函数体} 必须 标识函数的实现,这部分不能省略,但函数体可以为空。


示例

argsort

参考 numpy 的C++版本 NumCPP 进行实现:https://github.com/dpilger26/NumCpp/blob/master/include/NumCpp/NdArray/NdArrayCore.hpp

template<class T> void argsort(T *array, int num, int *index) { const auto function = [array](int a, int b) noexcept -> bool { return array[a] < array[b]; }; assert(num < INT_MAX); int *temp = new int[num]; std::iota(temp, temp + num, 0); std::sort(temp, temp + num, function); memcpy(index, temp, num * sizeof(int)); delete[] temp; }

argmax

template<class T> int argmax(T *array, int num) { const auto function = [](T &a, T &b) noexcept -> bool { return a < b; }; return std::max_element(array, array + num, function) - array; }

argmin

直接写出:

template<class T> int argmin(T *array, int num) { const auto function = [](T &a, T &b) noexcept -> bool { return a < b; }; return std::min_element(array, array + num, function) - array; }

测试程序

void test() { double a[] = {9, 8, 7, 6, 5, 4, 3, 2, 1}; int index[9]; argsort(a, 9, index); for (int i : index) { cout << i << ", "; } cout << endl; cout << argmax(a, 9) << endl; cout << argmin(a, 9) << endl; }

输出:

8, 7, 6, 5, 4, 3, 2, 1, 0, 0 8

计算TOP-K 索引

比如计算 MobileNet 输出的前五项类别

template<class T> void top_k(T *array, int num, int *index, int k) { const auto function = [array](int a, int b) noexcept -> bool { return array[a] > array[b]; }; int *temp = new int[num]; std::iota(temp, temp + num, 0); std::sort(temp, temp + num, function); memcpy(index, temp, k * sizeof(int)); delete[] temp; } int top_5[5]; top_k(output, 1001, top_5, 5);
最新回复(0)