convert recursion into iteration

code
数据结构和算法(data structure and algorithms)
Author

0warning0error

Published

June 19, 2024

Recursive function calls are a simple and understandable way to code. For certain problems, recursion can provide more concise code and more intuitive readability, such as the familiar traversal of tree structures.

However, when the number of function calls becomes too many, it can lead to a common issue: stack overflow. This is because the stack space generally has strict size limits. While we can adjust the stack space size through configuration parameters, this is merely a stopgap solution and does not address the root of the problem.

From another perspective, when the operating system calls a function, it saves stack information (such as the position of the next instruction after the function returns, local variables, etc.). This implicitly contains the state of each function call. The function checks this state to decide whether the stopping condition is met. Naturally, we think of using a stack data structure to simulate the process of function recursion.

Combining the knowledge of state machines, it’s not difficult to analyze the idea of converting a recursive function to an iterative one. Suppose we have the following recursive function:

void func(...args){
    // List of local variables (function parameters are also part of local variables)
    doSomething();
    
    // Recursive call appears somewhere
    func(...args1);
}

Then we can set up a state structure to save the function information when the recursive version is called, with one part saving the current function variable information of the call stack and the other part saving the position during execution. If the function has a return value, a variable can be used to store it:

enum Loc{
    START;
    FIRST;
    SECOND;
    THIRD;
    // ...other position information
};
struct State{
    // List of local parameters
    struct Info context;
    // Position during execution
    enum Loc location;
};
void func(...args){
    stack<State> s;
    State init_state = {}; // Set the initial state
    s.push(std::move(init_state));
    while(!s.empty()){
        State &cnt_state = s.top();
        switch(cnt_state.location){
            case START:
                doSomething();
            case FIRST:
                cnt_state.location = SECOND;
                cnt_state.context = ...;
                State new_state = {...};
                s.push(new_state);
                continue;
            //....
        }
    }
}

Taking the in-order traversal of LeetCode as an example, it’s not hard to write the following recursive version:

/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
 * };
 */
class Solution {
public:
    vector<int> inorderTraversal(TreeNode* root) {
        vector<int> res;
        if(root == nullptr){
            return res;
        }
        vector<int> && left = inorderTraversal(root -> left);
        res.insert(res.end(),left.begin(),left.end());
        res.push_back(root -> val);

        vector<int> && right = inorderTraversal(root -> right);
        res.insert(res.end(),right.begin(),right.end());
        return res;
    }
};

In-order traversal has two recursive call points: one when visiting the left subtree and another after reading the middle node to visit the right subtree. Since the important local variable in the whole function is the parameter root, the structure only needs to save one parameter. Finally, the iterative version of in-order traversal is as follows:

enum Loc{
    START,           // Initial state, indicating the node has just been visited
    VISITED_LEFT,    // Left child node has been visited
    VISITED_MIDDLE,  // Middle node has been visited
};

struct State {
    TreeNode *node;  // Current node
    Loc location;    // Visit state of the current node
};

class Solution {
public:
    vector<int> inorderTraversal(TreeNode* root) {
        vector<int> result;
        stack<State> s;

        // Helper function to push the node and its state into the stack
        auto pushState = [&](TreeNode *node) {
            s.push({node, START});
        };

        // Initial condition: push the root node if it's not null
        if (root != nullptr) {
            pushState(root);
        }

        // Loop to simulate recursive calls
        while (!s.empty()) {
            State &cnt_state = s.top(); // Get the top element of the stack, simulating the parameters and local variables of the current function call

            switch (cnt_state.location) {
                case START:
                    // First time visiting the node, visit the left subtree first
                    if (cnt_state.node->left != nullptr) {
                        cnt_state.location = VISITED_LEFT;  // Change state to left child visited
                        pushState(cnt_state.node->left);   // Push the left child node into the stack
                        continue;                          // End the current loop early
                    }
                case VISITED_LEFT:
                    // Left subtree visited, visit the current node
                    result.push_back(cnt_state.node->val); // Add the node value to the result
                    cnt_state.location = VISITED_MIDDLE;   // Change state to middle node visited
                    // Visit the right subtree
                    if (cnt_state.node->right != nullptr) {
                        pushState(cnt_state.node->right);  // Push the right child node into the stack
                        continue;
                    }
                case VISITED_MIDDLE:
                    // Current node and its left and right subtrees have been visited, return and pop from the stack
                    s.pop();
                    continue;
            }
        }

        return result;
    }
};