傻乎乎地分不清楚树状数组与线段树?

tech2023-02-07  102

“树状数组和线段树都是用于维护数列信息的数据结构,支持单点/区间修改,单点/区间询问信息。以增加权值与询问区间权值和为例,其余的信息需要维护也都类似。时间复杂度均为 O ( l o g n ) O(logn) O(logn)。 ”

详细的数学证明

练习题目 计算右侧小于当前元素的个数 最大子序和

I. 树状数组

Fenwick Tree

地中海的程序猿们研究数组,时候遇到这样一个问题: 有一个数组 S S S 0 − n − 1 0 - n-1 0n1,现在要在 O ( l o g n ) O(logn) O(logn) 的时间复杂度内,搜索一个确定的值(或修改) w w w并且对区间 [ a , b ] [a,b] [a,b] 求和。空间复杂度必须严格限制在 O ( n ) O(n) O(n).

他们想到了二叉搜索树(BST),对于平衡二叉树其插入和删除的时间复杂度都是 O ( l o g n ) O(logn) O(logn),因为树是类似于嵌套列表的思想,进而可以想到二叉堆,这是一种非嵌套列表,也可以实现 O ( l o g n ) O(logn) O(logn)。于是有了下面这张图:

解释一下,编号为 x x x的节点上统计着 [ x − l o w b i t ( x ) + 1 , x ] [x-lowbit(x)+1,x] [xlowbit(x)+1,x]这一段区间的信息, x x x的父亲就是 x + l o w b i t ( x ) x+lowbit(x) x+lowbit(x),我们要维护数组 C C C上的信息,存储在数组 A A A中。

按照Peter M. Fenwick的说法,正如所有的整数都可以表示成2的幂和,我们也可以把一串序列表示成一系列子序列的和。采用这个想法,我们可将一个前缀和划分成多个子序列的和,而划分的方法与数的2的幂和具有极其相似的方式。一方面,子序列的个数是其二进制表示中1的个数,另一方面,子序列代表的 f [ i ] f[i] f[i]的个数也是2的幂。

1. Lowbit函数

返回参数转换为二进制后,最后一个1的位置所代表的数值。

比如34转换为二进制就是0010 0010, Lowbit(34)返回2. 程序上 我们可以用((Not I)+1) AND I, 比如NOT(0010 0010) = 1101 1101, 加1之后为 1101 1110,再与上I,为0000 0010(2)。

int lowbit(int x) { return x&(-x); }

2. 新建数组

我们定义一个数组BIT,用以维护A的前缀和,

B I T i = ∑ j = i − l o w b i t ( i ) + 1 i A j BIT_i = \sum\limits_{j=i-lowbit(i)+1}^{i} A_{j} BITi=j=ilowbit(i)+1iAj

void build() { for (int i = 1; i <= MAX_N; i++) { BIT[i] = A[i - 1]; for (int j = i - 2; j >= i - lowbit(i); j--) BIT[i] += A[j]; } }

3. 修改

假设现在要在 A [ i ] A[i] A[i]的值增加 δ \delta δ, 那么需要将 B I T BIT BIT在所有含 A [ i ] A[i] A[i]的区间都加上一个数,

void add(int k, int w) {// 在下标k、加上w for(int j = k; j< tr.size();j+=low_bit(j)) tr[j]+=w; }

4. 区间求和

假设我们需要计算 ∑ i = 1 k A i \sum\limits_{i=1}^kA_i i=1kAi的值。

首先,将 a n s ans ans初始化为 k k k a n s ans ans的值加上 B I T [ i ] BIT[i] BIT[i] i i i的值减去 L o w b i t ( i ) Lowbit(i) Lowbit(i)重复2 . 3 步骤直到 i i i的值变为0. int sum (int k) { int ans = 0; for (int i = k; i > 0; i -= lowbit(i)) ans += BIT[i]; return ans; }

应用:求逆序数

练习 LC315. 计算右侧小于当前元素的个数

II. 线段树

Segment Tree

使用线段树可以快速查找某一个节点在若干线段中出现的次数,时间复杂度为 O ( l o g N ) O(logN) O(logN),而未优化的空间复杂度为 2 N 2N 2N,一般要开 4 N 4N 4N的数组防止越界。

除了叶子节点外,对于 [ a , b ] [a,b] [a,b]线段节点,其有两个子节点, 左子节点 [ a , ( a + b ) / 2 ] [a,(a+b)/2] [a,(a+b)/2]和右子节点 [ ( a + b ) / 2 + 1 , b ] [(a+b)/2+1,b] [(a+b)/2+1,b]。由于线段树在程序竞赛中被广泛应用,这种结构被 A C M e r ACMer ACMer O I e r OIer OIer戏谑为必须掌握的数据结构。一般地,我们先定义一个线段树节点结构体:

struct SegmentNode { int start;//线段左节点 int end;//线段右节点 int sum;//线段对应的和 int lazytag;//懒标记 SegmentNode *left; SegmentNode *right; SegmentNode():start(0),end(0),sum(0){} };

请务必熟悉理解上述结构!

1. 建立树

我们对区间 [ l , r ] [l,r] [l,r]建立线段树,是一个自上而下过程。

inline void build(SegmentNode *self, int l, int r) { if(l>r) return; self->start = l;self->end = r; if(l==r) { return; } int mid = (l+r)>>1; self->left = new SegmentNode(); build(self->left,l,mid); self->right = new SegmentNode(); build(self->right,mid+1,r); }

2. 单点修改

从根节点开始,以递归的方式不断更新sum值,直到叶子节点即区间长度为1,每个区间的sum值等于左子区间的sum值,加上右子区间的sum值。

inline void add(SegmentNode *self, int pos, int k) { if(pos<self->start||pos>self->end) return; if(self->start == self->end) { self->sum += k; return; } if(self->right->start>pos) add(self->left,pos,k); else add(self->right,pos,k); self->sum = self->left->sum + self->right->sum; }

3. 区间查询

第一种情况是当前的区间范围完全在 [ l , r ] [l,r] [l,r]内,这个时候把当前区间的 s u m sum sum值返回即可,第二张情况是当前节点的左子节点的右端点和 [ l , r ] [l,r] [l,r]有交集。这个时候就搜索左子节点。第三张情况是当前节点的右子节点的左端点和 [ l , r ] [l,r] [l,r]有交集。这个时候就搜索右子节点。 inline int search(SegmentNode *self, int i,int j) {//这里的i,j分别代表要搜索的区间 if(i>j) return 0; if(i<=self->start && self->end<=j) { return self->sum; } int s = 0; if(self->left->end>=i) s+=search(self->left,i,j); if(self->right->start<=j) s+=search(self->right,i,j); return s; }

4. 延迟标记

对于区间修改,这里会遇到一个问题:为了使所有sum值都保持正确,每一次插入操作可能要更新 O ( N ) O(N) O(N)个sum值,从而使时间复杂度退化为 O ( N ) O(N) O(N)。所以就有了Lazytag,如果一个节点有延迟标记,那么表明这个节点已经被修改过了。

void add_tag(SegmentNode *self,int l,int r,int v) { self->sum += (r-l+1)*v;self->lazytag+=v;//标记只对儿子有影响,自己在打标记的同时一起把统计信息更改了。 } void push_down(SegmentNode *self,int l,int r) { int mid=(l+r)>>1; add_tag(self->left,l,mid,self->lazytag); add_tag(self->right,mid+1,r,self->lazytag); self->lazytag = 0;//把当前标记分别传给两个儿子然后清空 } inline int search(SegmentNode *self, int l, int r,int v) {//[l,r]为当前区间,[L,R]为要修改的区间 if(l<=self->start && self->end<=r) { add_tag(self,l,r,v);//打标记 return; } int s = 0; push_down(self,l,r);//下传标记 if(self->left->end>=i) s+=search(self->left,i,j,v); if(self->right->start<=j) s+=search(self->right,i,j,v); return s; }

III. 树状数组和线段树比较

数据结构时间复杂度空间复杂度适用特点线段树 O ( l o g N ) O(logN) O(logN)O(N)-树状数组 O ( l o g N ) O(logN) O(logN)O(N)空间复杂度略低,容易扩展到多维,适用范围较线段树小

下面看一些经典题目吧

53. 最大子序和

其实这题除了用动态规划,还可以用线段树做。

这个分治方法类似于「线段树求解 LCIS 问题」的 pushUp 操作。 当然,如果读者有兴趣的话,推荐看一看线段树区间合并法解决 多次询问 的「区间最长连续上升序列问题」和「区间最大子段和问题」,还是非常有趣的。

我们定义一个操作get(a,l,r)表示查询a序列 [ l , r ] 区 [l,r]区 [l,r]间内的最大字段和。对于一个区间,我们取 m = [ l + r 2 ] m = [\frac{l+r}{2}] m=[2l+r],然后逐层递归。最关键的问题是:

我们要维护区间什么信息?

我们如何合并这些信息?

对于一个区间 [ l , r ] [l,r] [l,r],lSum表示 [ l , r ] [l,r] [l,r] l l l为左端点的最大子段和;rSum表示 [ l , r ] [l,r] [l,r] r r r为右端点的最大子段和,mSum表示 [ l , r ] [l,r] [l,r]

内的最大子段和。iSum表示 [ l , r ] [l,r] [l,r]的区间和。

iSum是左右区间的子段和的和。对于 [ l , r ] [l, r] [l,r] 的 lSum,存在两种可能,它要么等于「左子区间」的 lSum,要么等于「左子区间」的 iSum 加上「右子区间」的 lSum,二者取大。对于 [ l , r ] [l, r] [l,r] 的 rSum,存在两种可能,它要么等于「右子区间」的 rSum,要么等于「右子区间」的 iSum 加上「左子区间」的 rSum,二者取大。对于mSum,存在三种可能,要么完全在左区间,要么完全在中间,要么两边都有,我想你已经猜到了,就是左区间的rSum加上右区间的lSum。

好的已经可以开始写代码了


struct Status { int lSum, rSum, mSum, iSum; // 分别表示,以l为左端点的最大子序和,以r为右端点的最大子序和, // mSum表示区间[l,r]最大子序和 //iSum表示区间和 }; Status get(vector<int> nums,int l,int r) { if(l==r) return (Status){nums[l],nums[l],nums[l],nums[l]}; int m = (l+r)>>1; Status lpus = get(nums,l,m); Status rpus = get(nums,m+1,r); int lSum = max(lpus.lSum, lpus.iSum + rpus.lSum); int rSum = max(rpus.rSum, rpus.iSum + lpus.rSum); int iSum = lpus.iSum + rpus.iSum; int mSum = max(lpus.rSum+ rpus.lSum,max(lpus.mSum, rpus.mSum)); return (Status){lSum,rSum,mSum,iSum}; } int maxSubArray(vector<int>& nums) { if(!nums.size()) return 0; return get(nums, 0 , nums.size()-1).mSum; }

然后我们分析一下时间和空间复杂度。

时间复杂度: O ( n ) O(n) O(n),我们把递归过程看成二叉树的先序遍历,那么这颗二叉树时间复杂度:假设我们把递归的过程看作是一颗二叉树的先序遍历,那么这颗二叉树的深度的渐进上界为 O ( log ⁡ n ) O(\log n) O(logn),这里的总时间相当于遍历这颗二叉树的所有节点,故总时间的渐进上界是 O ( ∑ i = 1 log ⁡ n 2 i − 1 ) = O ( n ) O(\sum_{i = 1}^{\log n} 2^{i - 1}) = O(n) O(i=1logn2i1)=O(n),故渐进时间复杂度为 O ( n ) O(n) O(n)。 空间复杂度:递归会使用 O(\log n)O(logn) 的栈空间,故渐进空间复杂度为 O ( l o g n ) O(logn) O(logn)


315. 计算右侧小于当前元素的个数

给定一个整数数组 nums,按要求返回一个新数组 counts。数组 counts 有该性质: counts[i] 的值是 nums[i] 右侧小于 nums[i] 的元素的数量。

示例:

输入:[5,2,6,1] 输出:[2,1,1,0] 解释: 5 的右侧有 2 个更小的元素 (2 和 1) 2 的右侧仅有 1 个更小的元素 (1) 6 的右侧有 1 个更小的元素 (1) 1 的右侧有 0 个更小的元素
#include <iostream> #include <vector> #include <algorithm> #include <unordered_map> using namespace std; struct SegmentNode { int start; int end; int sum; SegmentNode *left; SegmentNode *right; SegmentNode():start(0),end(0),sum(0){} }; class Solution{ public: //---------------------------Segment tree solution---------------------------- inline void build(SegmentNode *self, int l, int r) { if(l>r) return; self->start = l;self->end = r; if(l==r) { // self->sum = l; return; } int mid = (l+r)>>1; self->left = new SegmentNode(); build(self->left,l,mid); self->right = new SegmentNode(); build(self->right,mid+1,r); // self->sum = self->left->sum + self->right->sum; } inline void add(SegmentNode *self, int pos, int k) { if(pos<self->start||pos>self->end) return; if(self->start == self->end) { self->sum += k; return; } if(self->right->start>pos) add(self->left,pos,k); else add(self->right,pos,k); self->sum = self->left->sum + self->right->sum; } inline int search(SegmentNode *self, int i,int j) { if(i>j) return 0; if(i<=self->start && self->end<=j) { return self->sum; } int s = 0; if(self->left->end>=i) s+=search(self->left,i,j); if(self->right->start<=j) s+=search(self->right,i,j); return s; } vector<int> countSmaller_SegmentTree(vector<int>&nums) { if(!nums.size()) return nums; SegmentNode *root = new SegmentNode(); //find the min and max val in nums int min_val = INT_MAX, max_val = INT_MIN; for(auto &c:nums){min_val=min(min_val,c);max_val = max(max_val,c);} build(root,min_val,max_val); vector<int> res(nums.size()); // for(auto &c:nums) // add(root,c,1);//All sub interval adds 1 for(int i = nums.size()-1; i>=0;i--) { add(root,nums[i],1); res[i] = search(root,min_val,nums[i]-1); } return res; } //------------------------------Fenwick Tree Solution---------------------------------- //Due to the uncertainty of scale of data, we discretize the array int n; vector<int> tr; int low_bit(int x) {//pow(2,x) return (x&(-x)); } int sum(int k) { int res = 0; for(int j = k; j>0; j-=low_bit(j)) res+=tr[j]; return res; } void add(int k, int w) {// add k to node w for(int j = k; j< tr.size();j+=low_bit(j)) tr[j]+=w; } vector<int> countSmaller_Fenwick(vector<int>&nums) { if(!nums.size()) return {}; int n = nums.size(); vector<int> res(n); //First, we discretize the vector and delete the repeated nums vector<int> tmp = nums; sort(tmp.begin(),tmp.end()); auto c = unique(tmp.begin(),tmp.end()); tmp.erase(c,tmp.end()); int new_len = c - tmp.begin(); // we define a unordered-map to count the number of tmp unordered_map<int,int> ump; tr = vector<int>(new_len + 1);//redefine the tr to (new_len+1) default value int count = 1; for(int i = 0;i<new_len;i++) ump[tmp[i]] = count++;//redefine the discretized values into serialized values using hashmap //we build the Fenwick tree and do summation and addition for(int k = nums.size()-1;k>=0;k--) { count = ump[nums[k]];// count of number res[k] = sum(count-1); add(count,1); } return res; } };

感谢阅读!有任何问题请在评论区提出,笔者看到会及时回答!

最新回复(0)