吃透Dijkstra算法

algorithm
Author

0warning0error

Published

June 29, 2024

PAT甲级有一道关于真题 All Roads Lead to Rome

Indeed there are many different tourist routes from our city to Rome. You are supposed to find your clients the route with the least cost while gaining the most happiness.

输入描述:

Each input file contains one test case. For each case, the first line contains 2 positive integers N (2<=N<=200), the number of cities, and K, the total number of routes between pairs of cities; followed by the name of the starting city. The next N-1 lines each gives the name of a city and an integer that represents the happiness one can gain from that city, except the starting city. Then K lines follow, each describes a route between two cities in the format “City1 City2 Cost”. Here the name of a city is a string of 3 capital English letters, and the destination is always ROM which represents Rome.

输出描述:

For each test case, we are supposed to find the route with the least cost. If such a route is not unique, the one with the maximum happiness will be recommended. If such a route is still not unique, then we output the one with the maximum average happiness – it is guaranteed by the judge that such a solution exists and is unique. Hence in the first line of output, you must print 4 numbers: the number of different routes with the least cost, the cost, the happiness, and the average happiness (take the integer part only) of the recommended route. Then in the next line, you are supposed to print the route in the format “City1->City2->…->ROM”.

这虽然是一道典型的Dijkstra算法的题,但需要做的判断都非常复杂。这主要源于不仅仅需要判断哪条路的成本最少,还要在成本相同的情况下判断其中哪条路能获得更多的幸福值,以及平均幸福值最大的路。但这些条件仍然能够通过局部最优来达到全局最优,故而可以使用Dijkstra算法处理。(包括求解最短路径数量)

from collections import defaultdict
import heapq

general_func = lambda x: int(x) if x.isdigit() else x

N, K, starting_city = map(
    general_func, input().strip().split()
)


happiness_cities = {starting_city : 0}

adj = defaultdict(set)
for _ in range(N - 1):
    city, happiness = map(
        general_func, input().strip().split()
    )
    happiness_cities[city] = happiness
for _ in range(K):
    src, dst, consume_happiness = map(
        general_func, input().strip().split()
    )
    adj[src].add((dst, consume_happiness))
    adj[dst].add((src,consume_happiness))
INF = float('inf')
def find_best_route(n, k, start, happiness, graph):
    # 初始化图和幸福值
    # 初始化 Dijkstra 相关变量
    dist = defaultdict(lambda: INF)  # 最短距离
    dist[start] = 0
    num_paths = defaultdict(int)  # 从起点到每个节点的路径数
    num_paths[start] = 1
    max_happiness = defaultdict(int)  # 从起点到每个节点的最大幸福值
    prev = defaultdict(list)  # 用于记录路径


    visited = set()

    # 优先队列 (距离, - 总幸福值, 已访问的节点数, 当前城市)
    pq = [(0, 0, 1, start)]
    
    while pq:
        cur_dist, cur_happiness, city_count, cur_city = heapq.heappop(pq)
        if cur_city in visited:
            continue # 跳过已经遍历过的节点,加快速度。
        cur_happiness = -cur_happiness
        
        # 遍历邻居节点
        for neighbor, cost in (graph[cur_city] - visited):
            new_dist = cur_dist + cost
            new_happiness = cur_happiness + happiness[neighbor]
            
            # 如果找到更短的路径
            if new_dist < dist[neighbor]:
                dist[neighbor] = new_dist
                max_happiness[neighbor] = new_happiness
                prev[neighbor] = cur_city
                num_paths[neighbor] = num_paths[cur_city]
                heapq.heappush(pq, (new_dist, -new_happiness, city_count + 1, neighbor))
            
            # 如果路径长度相同,则更新幸福值和路径计数
            elif new_dist == dist[neighbor]:
                num_paths[neighbor] += num_paths[cur_city]
                if new_happiness > max_happiness[neighbor]:
                    max_happiness[neighbor] = new_happiness
                    prev[neighbor] = cur_city
                elif new_happiness == max_happiness[neighbor]:
                    # 更新平均幸福值最大路径
                    if (max_happiness[neighbor] / city_count) < (new_happiness / (city_count + 1)):
                        prev[neighbor] = cur_city
        visited.add(cur_city)
    
    # 构建结果路径
    path = []
    city = "ROM"
    while city:
        path.append(city)
        city = prev.get(city, None)
    path.reverse()
    
    # 输出结果
    total_cost = dist["ROM"]
    total_happiness = max_happiness["ROM"]
    avg_happiness = total_happiness // (len(path) - 1)
    num_routes = num_paths["ROM"]
    
    return f"{num_routes} {total_cost} {total_happiness} {avg_happiness}\n{'->'.join(path)}"

r = find_best_route(N,K,starting_city,happiness_cities,adj)
print(r)