-
Notifications
You must be signed in to change notification settings - Fork 1
/
hungary.py
101 lines (87 loc) · 3.23 KB
/
hungary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import networkx as nx
import matplotlib.pyplot as plt
from collections import defaultdict
class Hungary:
def __init__(self, nodes_one, nodes_two, edges) -> None:
self.nodes_one = nodes_one # 二部图的第一部分节点
self.nodes_two = nodes_two # 二部图的第二部分节点
self.edges = edges # 二部图的所有边
def max_match(self):
'''求最大匹配边的数量,标记匹配的节点
'''
match = defaultdict(lambda: None) #记录节点匹配
num_match = 0 #最大匹配数量
node_neighbors = defaultdict(list) # 节点邻居映射
for (u, v) in self.edges:
node_neighbors[u].append(v)
node_neighbors[v].append(u)
def dfs(u):
for v in node_neighbors[u]:
if v not in visited:
visited.add(v)
if match[v] == None or dfs(match[v]):
match[v] = u
return True
return False
for node in self.nodes_one:
visited = set()
if (dfs(node)):
num_match += 1
return num_match, match
def draw(G, nodes_one, nodes_two, color_edges):
nodes = list(G.nodes)
edges = list(G.edges)
num_node = len(nodes)
num_edge = len(edges)
node_color = ['b'] * num_node
edge_color = ['b'] * num_edge
for i in range(0, num_node):
if isinstance(nodes[i], type(nodes_one[0])):
node_color[i] = 'r'
for i in range(num_edge):
u, v = edges[i][0], edges[i][1]
# 无向图
if (u, v) in color_edges or (v, u) in color_edges:
edge_color[i] = 'r'
'''
自定义pos
'''
# 对matplotlib不太熟悉,布局有待改进
pos = dict()
size = max(len(nodes_one), len(nodes_two)) + 2
one_x, two_x = size//3, 2*size // 3
one_y, two_y = size-1, size-1
for node_one in nodes_one:
pos[node_one] = [one_x, one_y]
one_y -= 1
for node_two in nodes_two:
pos[node_two] = [two_x, two_y]
two_y -= 1
# print(pos)
plt.title('Hungary Algorithm: Maximum Matching')
nx.draw(G, pos, with_labels=True, node_color=node_color, edge_color=edge_color)
# plt.savefig('hungary.png', format='PNG')
plt.show()
def main():
''' 测试用例0
'''
# nodes_one = [0, 1, 2, 3]
# nodes_two = ['0', '1', '2', '3']
# edges = [(0, '0'), (0, '1'), (1, '1'), (1, '2'), (2, '0'), (2, '1'), (3, '2')]
''' 测试用例1
'''
# nodes_one = [0, 1, 2, 3]
# nodes_two = ['a', 'b', 'c', 'd']
# edges = [(0, 'a'), (0, 'b'), (1, 'b'), (1, 'c'), (2, 'a'), (2, 'b'), (3, 'c'), (3, 'd')]
nodes_one = [0, 1, 2, 3, 4]
nodes_two = ['0', '1', '2', '3', '4']
edges = [(0, '1'), (0, '2'), (1, '0'), (1, '1'), (1, '3'), (1, '4'), (2, '1'), (2, '2'), (3, '1'), (3, '2'), (4, '3'), (4, '4')]
num_match, match = Hungary(nodes_one, nodes_two, edges).max_match()
match_edges = [(u, v) for u, v in match.items()]
print('{} | {}'.format(match_edges, num_match))
G = nx.Graph()
G.add_nodes_from(nodes_one+nodes_two)
G.add_edges_from(edges)
draw(G, nodes_one, nodes_two, match_edges)
if __name__ == '__main__':
main()