Skip to content

Commit 80f0941

Browse files
Merge pull request #381 from GFNOrg/fix-graph
Don't allow removal of nodes with edges
2 parents b2e2f57 + b32fb45 commit 80f0941

File tree

4 files changed

+28
-10
lines changed

4 files changed

+28
-10
lines changed

src/gfn/gym/graph_building.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStat
208208
# Update node features
209209
graph.x = graph.x[mask]
210210

211+
# Update edge indices
212+
assert torch.all(graph.edge_index != node_idx)
213+
graph.edge_index[graph.edge_index > node_idx] -= 1
214+
211215
# Handle ADD_EDGE action
212216
if torch.any(add_edge_mask):
213217
add_edge_index = torch.where(add_edge_mask)[0]
@@ -264,8 +268,13 @@ def is_action_valid(
264268
)
265269

266270
if backward:
267-
# For backward actions, we need at least one matching node
268-
if not torch.any(equal_nodes):
271+
# For backward actions, we need one matching node
272+
num_equal_nodes = torch.sum(equal_nodes)
273+
if num_equal_nodes != 1:
274+
return False
275+
# And no edges from/to the node
276+
equal_node_idx = torch.where(equal_nodes)[0][0]
277+
if torch.any(graph.edge_index == equal_node_idx):
269278
return False
270279
else:
271280
# For forward actions, we should not have any matching nodes

src/gfn/states.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -988,9 +988,11 @@ def backward_masks(self) -> TensorDict:
988988
)
989989

990990
for i, graph in enumerate(self.data.flat):
991-
node_class_masks[i, graph.x.flatten()] = True
992-
if graph.x is None:
993-
continue
991+
node_idxs = torch.arange(len(graph.x.flatten()))
992+
has_edge = torch.any(
993+
node_idxs[:, None] == graph.edge_index.flatten()[None], dim=1
994+
)
995+
node_class_masks[i, graph.x.flatten()] = ~has_edge
994996
ei0, ei1 = get_edge_indices(graph.x.size(0), self.is_directed, self.device)
995997

996998
if graph.edge_index is not None and graph.edge_index.size(1) > 0:

testing/test_states.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -525,9 +525,9 @@ def test_backward_masks(datas):
525525

526526
# Check action type mask
527527
assert masks[GraphActions.ACTION_TYPE_KEY].shape == (1, 3)
528-
assert masks[GraphActions.ACTION_TYPE_KEY][
528+
assert not masks[GraphActions.ACTION_TYPE_KEY][
529529
0, GraphActionType.ADD_NODE
530-
].item() # Can remove node
530+
].item() # Can't remove node as it has an edge
531531
assert masks[GraphActions.ACTION_TYPE_KEY][
532532
0, GraphActionType.ADD_EDGE
533533
].item() # Can remove edge
@@ -540,7 +540,14 @@ def test_backward_masks(datas):
540540
available_nodes = (
541541
torch.bincount(states.tensor.x.flatten(), minlength=states.num_node_classes) > 0
542542
)
543-
assert torch.all(available_nodes == masks[GraphActions.NODE_CLASS_KEY])
543+
nodes_with_edges = (
544+
torch.bincount(
545+
states.tensor.edge_index.flatten(), minlength=states.num_node_classes
546+
)
547+
> 0
548+
)
549+
removable_node = available_nodes & ~nodes_with_edges
550+
assert torch.all(removable_node == masks[GraphActions.NODE_CLASS_KEY])
544551

545552
# Check edge_class mask
546553
assert masks[GraphActions.EDGE_CLASS_KEY].shape == (1, states.num_edge_classes)

tutorials/examples/train_graph_triangle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def main(args: argparse.Namespace) -> None:
224224
)
225225
parser.add_argument("--seed", type=int, default=1234, help="Random seed")
226226
parser.add_argument(
227-
"--embedding_dim", type=int, default=64, help="Embedding dim for policy heads"
227+
"--embedding_dim", type=int, default=128, help="Embedding dim for policy heads"
228228
)
229229
parser.add_argument(
230230
"--num_conv_layers", type=int, default=1, help="Number of GNN layers"
@@ -235,7 +235,7 @@ def main(args: argparse.Namespace) -> None:
235235
)
236236
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
237237
parser.add_argument(
238-
"--lr_Z", type=float, default=1e-1, help="Learning rate for logZ"
238+
"--lr_Z", type=float, default=5e-2, help="Learning rate for logZ"
239239
)
240240
parser.add_argument(
241241
"--use_buffer", action="store_true", default=True, help="Use replay buffer"

0 commit comments

Comments
 (0)