Zeyuan (Faradawn) Yang

Segment Tree: Max and Sum

Max Segment Tree (Accepted)

class MaxSegTree{
    vector<int> t;
public:
    MaxSegTree(vector<int> &nums){
        int n = nums.size();
        t.resize(4 * n, 0);
        build(nums, 0, 0, n-1);
    }

    void build(vector<int> &nums, int cur, int tl, int tr){
        if(tl == tr){
            t[cur] = nums[tl];
            return;
        }
        int mid = tl + (tr - tl) / 2;
        build(nums, 2*cur+1, tl, mid);
        build(nums, 2*cur+2, mid+1, tr);
        t[cur] = max(t[2*cur+1], t[2*cur+2]);
    }

    // update index to val
    void update(int idx, int val, int cur, int tl, int tr){
        if(tl == tr){
            t[cur] = val;
            return;
        }        
        int mid = tl + (tr - tl) / 2;
        if(idx <= mid){
            update(idx, val, 2*cur+1, tl, mid);
        }else{
            update(idx, val, 2*cur+2, mid+1, tr);
        }
        t[cur] = max(t[2*cur+1], t[2*cur+2]);
    }

    // search from l to r
    int query(int l, int r, int cur, int tl, int tr){
        if(l <= tl and r >= tr){
            return t[cur];
        }else if(l > tr or r < tl){
            return -1;
        }else {
            int mid = tl + (tr - tl) / 2;
            return max(query(l, r, 2*cur+1, tl, mid), query(l, r, 2*cur+2, mid+1, tr));
        }

    }
};

Sum Segment Tree (Solution)

class SegmentTree {
    vector<int> seg; // Segment Tree to be stored in a vector.
public:
    SegmentTree(vector<int>& nums) {
        int n = nums.size();
        seg.resize(4 * n, 0);  // Maximum size of a segment tree for an array of size n is 4n
        buildTree(nums, 0, 0, n - 1); // Build the segment tree
    }
    void buildTree(vector<int>& nums, int pos, int left, int right) {
        if (left == right) {
            seg[pos] = nums[left];
            return;
        }
        int mid = left + (right - left) / 2;
        buildTree(nums, 2 * pos + 1, left, mid);
        buildTree(nums, 2 * pos + 2, mid + 1, right);
        seg[pos] = seg[2 * pos + 1] + seg[2 * pos + 2];
    }
    void updateTree(int pos, int left, int right, int idx, int val) {
        // no overlap
        if (idx < left || idx > right) return;
        
        // total overlap
        if (left == right) {
            if (left == idx) seg[pos] = val;
            return;
        }
        // partial overlap
        int mid = left + (right - left) / 2;
        updateTree(2 * pos + 1, left, mid, idx, val);
        updateTree(2 * pos + 2, mid + 1, right, idx, val);
        seg[pos] = seg[2 * pos + 1] + seg[2 * pos + 2];
    }
    
    int queryTree(int qleft, int qright, int left, int right, int pos) {
        if (qleft <= left && qright >= right) { // total overlap
            return seg[pos];
        }
        if (qleft > right || qright < left)  {  // no overlap
            return 0;
        }
        // partial overlap
        int mid = left + (right - left) / 2;
        return queryTree(qleft, qright, left, mid, 2 * pos + 1) + queryTree(qleft, qright, mid + 1, right, 2 * pos + 2);
    }
};

Segment Tree vector version (TLE)

class SegTree{
    public:
    vector<int> t;

    // input arr, length of input array
    SegTree(vector<int> arr){ 
        t.resize(4 * arr.size(), 0);
        build(arr, 0, 0,  arr.size()-1);
    }

    // input array, cur position, insert elements from tl to tr
    void build(vector<int> arr, int cur, int tl, int tr){
        if(tl == tr){
            t[cur] = arr[tl];
            return;
        }
        int mid = tl + (tr - tl) / 2;
        build(arr, 2 * cur + 1, tl, mid);
        build(arr, 2 * cur + 2, mid + 1, tr);
        t[cur] = t[2 * cur + 1] + t[2 * cur + 2];
    }

    // get the sum from l to r
    int getSum(int l, int r, int cur, int tl, int tr){
        if(l <= tl and r >= tr){
            return t[cur];
        }
        else if(r < tl || l > tr){
            return 0;
        }
        else{
            int mid = tl + (tr - tl) / 2;
            return getSum(l, r, 2*cur+1, tl, mid) + getSum(l, r, 2*cur+2, mid+1, tr);
        }

    }

    // set the idx value to val
    void update(int idx, int val, int cur, int tl, int tr){
        if(tl == tr){
            // cur is index in big tree array, should not equal to tl and tr
            // cout << "Updating cur " << t[cur] << " to " << val << endl;
            t[cur] = val;
            return;
        }

        int mid = tl + (tr - tl) / 2;
        if(idx <= mid){
            update(idx, val, 2*cur+1, tl, mid);
        } 
        else {
            update(idx, val, 2*cur+2, mid+1, tr);
        }

        t[cur] = t[2*cur+1] + t[2*cur+2];
    }
};

Segment Tree Array version (Didn’t test)

class SegTree{
    public:
    int *t;
    int t_len;

    // input arr, length of input array
    SegTree(int *arr, int len){ 
        t_len = 4 * len; // if power of 2, t_len = 2 * len - 1
        t = new int[t_len];
        build(arr, 0, 0, len-1);
    }

    // input array, cur position, insert elements from tl to tr
    void build(int *arr, int cur, int tl, int tr){
        if(tl == tr){
            t[cur] = arr[tl];
            return;
        }
        int mid = tl + (tr - tl) / 2;
        build(arr, 2 * cur + 1, tl, mid);
        build(arr, 2 * cur + 2, mid + 1, tr);
        t[cur] = t[2 * cur + 1] + t[2 * cur + 2];
    }

    // get the sum from l to r
    int getSum(int l, int r, int cur, int tl, int tr){
        if(l <= tl and r >= tr){
            return t[cur];
        }
        else if(r < tl || l > tr){
            return 0;
        }
        else{
            int mid = tl + (tr - tl) / 2;
            return getSum(l, r, 2*cur+1, tl, mid) + getSum(l, r, 2*cur+2, mid+1, tr);
        }

    }

    // set the idx value to val
    void update(int idx, int val, int cur, int tl, int tr){
        if(tl == tr){
            // cur is index in big tree array, should not equal to tl and tr
            // cout << "Updating cur " << t[cur] << " to " << val << endl;
            t[cur] = val;
            return;
        }

        int mid = tl + (tr - tl) / 2;
        if(idx <= mid){
            update(idx, val, 2*cur+1, tl, mid);
        } 
        else {
            update(idx, val, 2*cur+2, mid+1, tr);
        }

        t[cur] = t[2*cur+1] + t[2*cur+2];
    }

    void show(){
        for(int i = 0; i < t_len; i ++){
            cout << t[i] << " ";
        }
        cout << endl;
    }
};