
Description
The problem is to find the number of paths in a binary tree where the sum of the node values along the path equals a given target sum. A path can start and end at any node in the tree, and must move from parent to child nodes.
Intuition
To solve this problem efficiently, we use a prefix sum approach combined with depth-first search (DFS). The idea is to keep track of the cumulative sum of node values as we traverse the tree. By maintaining a hash map of prefix sums, we can quickly determine how many paths end at the current node that sum up to the target value.
Approach
- Prefix Sum Storage:
- Use a hash map (or dictionary) to store the count of prefix sums encountered during the DFS traversal. Initialize the map with
{0: 1}
to handle cases where a path sum directly equals the target sum. - DFS Traversal:
- Traverse the tree using DFS. For each node:
- Update the current path sum by adding the value of the current node.
- Calculate the number of valid paths that end at the current node by checking how many times
current_sum - target_sum
has been encountered before. - Recursively apply DFS to the left and right children of the current node. - After exploring both subtrees, backtrack by removing the current path sum from the hash map to avoid affecting other paths. - Return Result: - The total count of valid paths is accumulated during the DFS traversal.
Complexity
Time Complexity:
O(N), where N is the number of nodes in the binary tree. Each node is processed exactly once.
Space Complexity:
O(H), where H is the height of the tree. This space is used for storing the prefix sums in the hash map and the recursion stack.
Code
C++
class Solution {
public:
int pathSum(TreeNode* root, int targetSum) {
std::unordered_map prefixSumCount;
prefixSumCount[0] = 1; // Initialize for cases where path sum directly equals targetSum
return dfs(root, 0, targetSum, prefixSumCount);
}
private:
int dfs(TreeNode* node, long currentSum, int targetSum, std::unordered_map& prefixSumCount) {
if (node == nullptr) return 0;
currentSum += node->val;
int count = prefixSumCount[currentSum - targetSum];
prefixSumCount[currentSum]++;
// Recursively check left and right subtrees
count += dfs(node->left, currentSum, targetSum, prefixSumCount);
count += dfs(node->right, currentSum, targetSum, prefixSumCount);
// Remove the current path sum count to ensure it doesn't affect other paths
prefixSumCount[currentSum]--;
return count;
}
};
Python
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
class Solution:
def pathSum(self, root: Optional[TreeNode], targetSum: int) -> int:
prefix_sum_count = defaultdict(int)
prefix_sum_count[0] = 1 # Initialize for cases where path sum directly equals targetSum
return self.dfs(root, 0, targetSum, prefix_sum_count)
def dfs(self, node: Optional[TreeNode], current_sum: int, target_sum: int, prefix_sum_count: Dict[int, int]) -> int:
if not node:
return 0
current_sum += node.val
count = prefix_sum_count.get(current_sum - target_sum, 0)
prefix_sum_count[current_sum] += 1
# Recursively check left and right subtrees
count += self.dfs(node.left, current_sum, target_sum, prefix_sum_count)
count += self.dfs(node.right, current_sum, target_sum, prefix_sum_count)
# Remove the current path sum count to ensure it doesn't affect other paths
prefix_sum_count[current_sum] -= 1
return count
Java
class TreeNode {
int val;
TreeNode left;
TreeNode right;
TreeNode() {}
TreeNode(int val) { this.val = val; }
TreeNode(int val, TreeNode left, TreeNode right) {
this.val = val;
this.left = left;
this.right = right;
}
}
class Solution {
public int pathSum(TreeNode root, int targetSum) {
HashMap prefixSumCount = new HashMap<>();
prefixSumCount.put(0L, 1); // Initialize for cases where path sum directly equals targetSum
return dfs(root, 0L, targetSum, prefixSumCount);
}
private int dfs(TreeNode node, long currentSum, int targetSum, HashMap prefixSumCount) {
if (node == null) return 0;
currentSum += node.val;
int count = prefixSumCount.getOrDefault(currentSum - targetSum, 0);
prefixSumCount.put(currentSum, prefixSumCount.getOrDefault(currentSum, 0) + 1);
// Recursively check left and right subtrees
count += dfs(node.left, currentSum, targetSum, prefixSumCount);
count += dfs(node.right, currentSum, targetSum, prefixSumCount);
// Remove the current path sum count to ensure it doesn't affect other paths
prefixSumCount.put(currentSum, prefixSumCount.get(currentSum) - 1);
return count;
}
}
JavaScript
var pathSum = function (root, targetSum) {
let nodeCount = 0;
const dfs = (node, path = []) => {
if (!node) return;
const newPath = path.map(value => value + node.val);
newPath.push(node.val);
newPath.forEach((value) => {
if (value === targetSum) {
nodeCount++;
}
})
dfs(node.left, newPath);
dfs(node.right, newPath);
}
dfs(root);
return nodeCount;
};