力扣上周周赛第三题,比赛的时候一开始用 dfs,不停增加剪枝条件(加memo数组,低于res就返回等),从 n 在1k超时,到5k超时,到1w超时,没辙了,只能赛后尝试 bfs。
思路
bfs+大根堆(贪心):Java 用时73ms
bfs 为什么可以比可以比 dfs 优秀呢?不都需要走到 end 才更新结果吗?思考了一下,问题的关键在于「剪枝的顺序上」
- dfs 的剪枝,是一条路走到黑,再走其他路,那么假如说这条路走到后面概率就很小了,而有其中些位置在别的路径上走几步就到了,概率还很高,那么这了路在 bfs 的情况下就不会走下去了,从而节省了很多时间。
- 而在 bfs 正常的思路里,用的是队列,因为我们需要考虑层级(树的深度,步长等)关系,层级越小越好。而这道题里是概率,只有概率大的位置才有可能后面走到的位置概率也大,才能走到 end 的概率也大,那么先把每个位置的大概率确定下来了,剪枝的时候能剪掉的情况也就多很多了。这里有贪心的思想在其中。
那么问题的关键就在于选什么类来对这个队列进行排序,由于我们每次只要取最大的,很容易想到大根堆。大根堆内存位置,并重写比较器,用每个位置的概率作为比较条件。
但大根堆有个问题,某个位置已经被存储在堆里了,但现在所处的位置的下一个位置也是它,而且去到的概率还比它大,直接更新位置概率数组并不会让大根堆更新排序,这时候我们需要让它强制更新,概率数组添加一个位置为 n,概率为1.1的值,然后在大根堆中插入 n,然后删除堆顶(就是 n),从而实现堆的重排序,时间复杂度为 O(2log(h)),h 为堆的长度。(突然发现,还可以再优化一下,不超过堆顶就不需要强制更新,因为再下次取出堆顶的时候就会更新)
至此,bfs 的剪枝操作都完成了,具体实现就看代码吧~
代码
class Solution {
double res = 0;
public double maxProbability(int n, int[][] edges, double[] succProb, int start, int end) {
Map<Integer, List<Integer>> posCanGo = new HashMap<>();
int edgeNum = edges.length;
// 建立从某个位置能去到位置的edges数组索引,方便之后查找。
for (int i = 0; i < edgeNum; i++) {
List<Integer> goList = posCanGo.getOrDefault(edges[i][0], new ArrayList<Integer>());
goList.add(i);
posCanGo.put(edges[i][0], goList);
goList = posCanGo.getOrDefault(edges[i][1], new ArrayList<Integer>());
goList.add(i);
posCanGo.put(edges[i][1], goList);
}
// for (Integer key : posCanGo.keySet())
// System.out.println(posCanGo.get(key));
if (!posCanGo.containsKey(end))
return res;
bfs(posCanGo, edges, succProb, start, end, n);
return res;
}
public void bfs(Map<Integer,List<Integer>> posCanGo, int[][] edges, double[] succProb, int start, int end, int n) {
double[] maxProb = new double[n+1]; // 记录去到每个位置的最大概率
boolean[] inQue = new boolean[n+1]; // 记录有哪些还在队列里
maxProb[start] = 1.0;
maxProb[n] = 1.1; // 加一个用于大根堆更新的大根值,不会有概率比1.1大了
// 大根堆存位置,但根据最大概率进行排序
PriorityQueue<Integer> nextMaxWay = new PriorityQueue<Integer>((a, b) -> (maxProb[b] > maxProb[a] ? 1 : -1));
nextMaxWay.add(start);
inQue[start] = true;
while (!nextMaxWay.isEmpty()) {
int pos = nextMaxWay.poll();
inQue[pos] = false;
// 走到了end,不再向后走了
if (pos == end) {
if (maxProb[end] > res)
res = maxProb[end];
continue;
}
// 停止条件:堆中最大的概率都小于res时就结束
if (maxProb[pos] <= res)
break;
// 找该位置的能去到的下一个位置
List<Integer> goList = posCanGo.getOrDefault(pos, new ArrayList<Integer>());
int goNum = goList.size();
if (goNum == 0)
continue;
for (int i = 0; i < goNum; i++) {
int nxtPosIdx = goList.get(i);
int nxtPos = (edges[nxtPosIdx][0] == pos ? edges[nxtPosIdx][1] : edges[nxtPosIdx][0]);
double thisWaytoNxtProb = maxProb[pos] * succProb[nxtPosIdx];
// 下个去到的位置的概率小于等于res的概率,就不去下个位置了
if (thisWaytoNxtProb <= res)
continue;
// 下一个去到的位置的概率大于历史去这个位置的概率
if (thisWaytoNxtProb > maxProb[nxtPos]) {
maxProb[nxtPos] = thisWaytoNxtProb;
// 放入最大的数,让堆更新,然后再删除堆顶,时间复杂度O(2logh)
// 应该也可以通过插入最小概率并不删除来减小一点时间,但堆中就会有很多无用数值,导致h变大,logh变大
if (inQue[nxtPos] && thisWaytoNxtProb > maxProb[nextMaxWay.peek()]) {
nextMaxWay.add(n);
nextMaxWay.poll();
}
// 不在堆里就放进堆里
else {
nextMaxWay.add(nxtPos);
inQue[nxtPos] = true;
}
}
}
}
}
}