All Nodes Distance K in Binary Tree

给一个target node和一个范围k, 在一个树上找距离k远的所有node.

这题的node可以跨越root, 也可以是parent.

我是重构整个图..加了个父节点.

/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    class Node{
        Node(int x) {this.val = x;}
        int val;
        Node p;
        Node l;
        Node r;
    }
    public List<Integer> distanceK(TreeNode root, TreeNode target, int k) {
        Node nr = new Node(root.val);
        dfs(root, null, nr, null);
        List<Integer> res = new ArrayList<>();
        Node nt = find(target.val, nr);
        Queue<Node> q = new LinkedList<>();
        q.add(nt);
        Set<Node> set = new HashSet<>();
        set.add(nt);
        while(k > 0){
            int size = q.size();
            for(int i = 0; i < size; i++){
                Node c = q.poll();
                if(c.l != null && !set.contains(c.l)){
                    set.add(c.l);
                    q.add(c.l);
                }
                if(c.r != null && !set.contains(c.r)){
                    set.add(c.r);
                    q.add(c.r);
                }
                if(c.p != null && !set.contains(c.p)){
                    set.add(c.p);
                    q.add(c.p);
                }
            }
            k--;
        }
        for(Node cc : q)
            res.add(cc.val);
        return res;
    }
    public Node find(int val, Node nr) {
        if(nr == null)
            return null;
        if(nr.val == val)
            return nr;
        Node l = find(val, nr.l);
        Node r = find(val, nr.r);
        return l == null ? r : l;
    }
    public void dfs(TreeNode cur, TreeNode parent, Node nr, Node nrp){
        if(cur == null)
            return;
        nr.val = cur.val;
        nr.p = nrp;
        if(cur.left != null){
            nr.l = new Node(cur.left.val);
            dfs(cur.left, cur, nr.l, nr);
        }
        if(cur.right != null){
            nr.r = new Node(cur.right.val);
            dfs(cur.right, cur, nr.r, nr);
        }
    }
}