input_graph = [[0,0,0,0,0,0],
[0,0,1,3,999,999],
[0,1,0,3,6,999],
[0,3,3,0,4,2],
[0,999,6,4,0,5],
[0,999,999,2,5,0]
]
def prim_algorithm(input_graph : list()):
num_of_edge = len(input_graph) - 1
set_of_used = [2]
pair_info = [2] * (num_of_edge + 1)
distance = input_graph[2]
total_weight = 0
while len(set_of_used) != num_of_edge:
min_weight = 999
min_idx = 0
flag = 0
for idx, weight in enumerate(distance):
if weight > 0 and min_weight > weight:
min_weight = weight
min_idx = idx
flag = 1
if flag == 1:
total_weight += min_weight
set_of_used.append(min_idx)
distance[min_idx] = 0
set_mem = set_of_used[-1]
for i in range(1, len(input_graph)):
if distance[i] != 0:
if distance[i] > input_graph[set_mem][i]:
distance[i] = input_graph[set_mem][i]
pair_info[i] = set_mem
return total_weight
print(prim_algorithm(input_graph))
# 1 2 3 4 5
# 1 0 1 3 -1 -1
# 2 1 0 3 6 -1
# 3 3 3 0 4 2
# 4 -1 6 4 0 5
# 5 -1 -1 2 5 0