Segment Tree: Max and Sum
Max Segment Tree (Accepted)
class MaxSegTree{
vector<int> t;
public:
MaxSegTree(vector<int> &nums){
int n = nums.size();
t.resize(4 * n, 0);
build(nums, 0, 0, n-1);
}
void build(vector<int> &nums, int cur, int tl, int tr){
if(tl == tr){
t[cur] = nums[tl];
return;
}
int mid = tl + (tr - tl) / 2;
build(nums, 2*cur+1, tl, mid);
build(nums, 2*cur+2, mid+1, tr);
t[cur] = max(t[2*cur+1], t[2*cur+2]);
}
// update index to val
void update(int idx, int val, int cur, int tl, int tr){
if(tl == tr){
t[cur] = val;
return;
}
int mid = tl + (tr - tl) / 2;
if(idx <= mid){
update(idx, val, 2*cur+1, tl, mid);
}else{
update(idx, val, 2*cur+2, mid+1, tr);
}
t[cur] = max(t[2*cur+1], t[2*cur+2]);
}
// search from l to r
int query(int l, int r, int cur, int tl, int tr){
if(l <= tl and r >= tr){
return t[cur];
}else if(l > tr or r < tl){
return -1;
}else {
int mid = tl + (tr - tl) / 2;
return max(query(l, r, 2*cur+1, tl, mid), query(l, r, 2*cur+2, mid+1, tr));
}
}
};
Sum Segment Tree (Solution)
class SegmentTree {
vector<int> seg; // Segment Tree to be stored in a vector.
public:
SegmentTree(vector<int>& nums) {
int n = nums.size();
seg.resize(4 * n, 0); // Maximum size of a segment tree for an array of size n is 4n
buildTree(nums, 0, 0, n - 1); // Build the segment tree
}
void buildTree(vector<int>& nums, int pos, int left, int right) {
if (left == right) {
seg[pos] = nums[left];
return;
}
int mid = left + (right - left) / 2;
buildTree(nums, 2 * pos + 1, left, mid);
buildTree(nums, 2 * pos + 2, mid + 1, right);
seg[pos] = seg[2 * pos + 1] + seg[2 * pos + 2];
}
void updateTree(int pos, int left, int right, int idx, int val) {
// no overlap
if (idx < left || idx > right) return;
// total overlap
if (left == right) {
if (left == idx) seg[pos] = val;
return;
}
// partial overlap
int mid = left + (right - left) / 2;
updateTree(2 * pos + 1, left, mid, idx, val);
updateTree(2 * pos + 2, mid + 1, right, idx, val);
seg[pos] = seg[2 * pos + 1] + seg[2 * pos + 2];
}
int queryTree(int qleft, int qright, int left, int right, int pos) {
if (qleft <= left && qright >= right) { // total overlap
return seg[pos];
}
if (qleft > right || qright < left) { // no overlap
return 0;
}
// partial overlap
int mid = left + (right - left) / 2;
return queryTree(qleft, qright, left, mid, 2 * pos + 1) + queryTree(qleft, qright, mid + 1, right, 2 * pos + 2);
}
};
Segment Tree vector version (TLE)
class SegTree{
public:
vector<int> t;
// input arr, length of input array
SegTree(vector<int> arr){
t.resize(4 * arr.size(), 0);
build(arr, 0, 0, arr.size()-1);
}
// input array, cur position, insert elements from tl to tr
void build(vector<int> arr, int cur, int tl, int tr){
if(tl == tr){
t[cur] = arr[tl];
return;
}
int mid = tl + (tr - tl) / 2;
build(arr, 2 * cur + 1, tl, mid);
build(arr, 2 * cur + 2, mid + 1, tr);
t[cur] = t[2 * cur + 1] + t[2 * cur + 2];
}
// get the sum from l to r
int getSum(int l, int r, int cur, int tl, int tr){
if(l <= tl and r >= tr){
return t[cur];
}
else if(r < tl || l > tr){
return 0;
}
else{
int mid = tl + (tr - tl) / 2;
return getSum(l, r, 2*cur+1, tl, mid) + getSum(l, r, 2*cur+2, mid+1, tr);
}
}
// set the idx value to val
void update(int idx, int val, int cur, int tl, int tr){
if(tl == tr){
// cur is index in big tree array, should not equal to tl and tr
// cout << "Updating cur " << t[cur] << " to " << val << endl;
t[cur] = val;
return;
}
int mid = tl + (tr - tl) / 2;
if(idx <= mid){
update(idx, val, 2*cur+1, tl, mid);
}
else {
update(idx, val, 2*cur+2, mid+1, tr);
}
t[cur] = t[2*cur+1] + t[2*cur+2];
}
};
Segment Tree Array version (Didn’t test)
class SegTree{
public:
int *t;
int t_len;
// input arr, length of input array
SegTree(int *arr, int len){
t_len = 4 * len; // if power of 2, t_len = 2 * len - 1
t = new int[t_len];
build(arr, 0, 0, len-1);
}
// input array, cur position, insert elements from tl to tr
void build(int *arr, int cur, int tl, int tr){
if(tl == tr){
t[cur] = arr[tl];
return;
}
int mid = tl + (tr - tl) / 2;
build(arr, 2 * cur + 1, tl, mid);
build(arr, 2 * cur + 2, mid + 1, tr);
t[cur] = t[2 * cur + 1] + t[2 * cur + 2];
}
// get the sum from l to r
int getSum(int l, int r, int cur, int tl, int tr){
if(l <= tl and r >= tr){
return t[cur];
}
else if(r < tl || l > tr){
return 0;
}
else{
int mid = tl + (tr - tl) / 2;
return getSum(l, r, 2*cur+1, tl, mid) + getSum(l, r, 2*cur+2, mid+1, tr);
}
}
// set the idx value to val
void update(int idx, int val, int cur, int tl, int tr){
if(tl == tr){
// cur is index in big tree array, should not equal to tl and tr
// cout << "Updating cur " << t[cur] << " to " << val << endl;
t[cur] = val;
return;
}
int mid = tl + (tr - tl) / 2;
if(idx <= mid){
update(idx, val, 2*cur+1, tl, mid);
}
else {
update(idx, val, 2*cur+2, mid+1, tr);
}
t[cur] = t[2*cur+1] + t[2*cur+2];
}
void show(){
for(int i = 0; i < t_len; i ++){
cout << t[i] << " ";
}
cout << endl;
}
};