Skip to content

Commit dba5d7d

Browse files
Merge pull request #371 from GFNOrg/constant_pb
minor change to allow for pb to be None when the gflownet's DAG is a tree
2 parents 80f0941 + 90a4060 commit dba5d7d

File tree

6 files changed

+147
-47
lines changed

6 files changed

+147
-47
lines changed

src/gfn/gflownet/base.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import warnings
23
from abc import ABC, abstractmethod
34
from typing import Any, Generic, Tuple, TypeVar
45

@@ -151,19 +152,52 @@ class PFBasedGFlowNet(GFlowNet[TrainingSampleType], ABC):
151152
152153
Attributes:
153154
pf: The forward policy estimator.
154-
pb: The backward policy estimator.
155+
pb: The backward policy estimator, or None if it can be ignored (e.g., the
156+
gflownet DAG is a tree, and pb is therefore always 1).
157+
constant_pb: Whether to ignore the backward policy estimator.
155158
"""
156159

157-
def __init__(self, pf: Estimator, pb: Estimator) -> None:
160+
def __init__(
161+
self, pf: Estimator, pb: Estimator | None, constant_pb: bool = False
162+
) -> None:
158163
"""Initializes a PFBasedGFlowNet instance.
159164
160165
Args:
161166
pf: The forward policy estimator.
162-
pb: The backward policy estimator.
167+
pb: The backward policy estimator, or None if the gflownet DAG is a tree,
168+
and pb is therefore always 1.
169+
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
170+
gflownet DAG is a tree, and pb is therefore always 1. Must be set
171+
explicitly by user to ensure that pb is an Estimator except under this
172+
special case.
163173
"""
164174
super().__init__()
175+
# Technical note: pb may be constant for a variety of edge cases, for example,
176+
# if all terminal states can be reached with exactly the same number of
177+
# trajectories, and we assume a uniform backward policy, then we can omit the pb
178+
# term (see section 6 of Discrete Probabilistic Inference as Control in
179+
# Multi-path Environments by Tristan Deleu, Padideh Nouri, Nikolay Malkin,
180+
# Doina Precup, Yoshua Bengio for more details). We do not intend to document
181+
# all of these cases for now.
182+
if pb is None and not constant_pb:
183+
raise ValueError(
184+
"pb must be an Estimator unless constant_pb is True. "
185+
"If you intend to ignore pb, e.g., the gflownet DAG is a tree, "
186+
"set constant_pb to True."
187+
)
188+
if isinstance(pb, Estimator) and constant_pb:
189+
warnings.warn(
190+
"The user specified that pb should be ignored, and specified a "
191+
"backward policy estimator. Under normal circumstances, pb should be "
192+
"None if pb is constant, (e.g., the GFlowNet DAG is a tree and "
193+
"the backward policy probability is always 1), because learning a "
194+
"backward policy estimator is not necessary and will slow down "
195+
"training. Please ensure this is the intended experimental setup."
196+
)
197+
165198
self.pf = pf
166199
self.pb = pb
200+
self.constant_pb = constant_pb
167201

168202
def sample_trajectories(
169203
self,
@@ -221,9 +255,28 @@ class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]):
221255
222256
Attributes:
223257
pf: The forward policy module.
224-
pb: The backward policy module.
258+
pb: The backward policy module, or None if the gflownet DAG is a tree, and
259+
pb is therefore always 1.
260+
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
261+
gflownet DAG is a tree, and pb is therefore always 1.
225262
"""
226263

264+
def __init__(
265+
self, pf: Estimator, pb: Estimator | None, constant_pb: bool = False
266+
) -> None:
267+
"""Initializes a TrajectoryBasedGFlowNet instance.
268+
269+
Args:
270+
pf: The forward policy estimator.
271+
pb: The backward policy estimator, or None if the gflownet DAG is a tree, and
272+
pb is therefore always 1.
273+
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
274+
gflownet DAG is a tree, and pb is therefore always 1. Must be set
275+
explicitly by user to ensure that pb is an Estimator except under this
276+
special case.
277+
"""
278+
super().__init__(pf, pb, constant_pb=constant_pb)
279+
227280
def get_pfs_and_pbs(
228281
self,
229282
trajectories: Trajectories,

src/gfn/gflownet/detailed_balance.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,30 +63,38 @@ class DBGFlowNet(PFBasedGFlowNet[Transitions]):
6363
log_reward_clip_min: If finite, clips log rewards to this value.
6464
safe_log_prob_min: If True, uses -1e10 as the minimum log probability value
6565
to avoid numerical instability, otherwise uses -1e38.
66+
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
67+
gflownet DAG is a tree, and pb is therefore always 1.
6668
"""
6769

6870
def __init__(
6971
self,
7072
pf: Estimator,
71-
pb: Estimator,
73+
pb: Estimator | None,
7274
logF: ScalarEstimator | ConditionalScalarEstimator,
7375
forward_looking: bool = False,
7476
log_reward_clip_min: float = -float("inf"),
7577
safe_log_prob_min: bool = True,
78+
constant_pb: bool = False,
7679
) -> None:
7780
"""Initializes a DBGFlowNet instance.
7881
7982
Args:
8083
pf: The forward policy estimator.
81-
pb: The backward policy estimator.
84+
pb: The backward policy estimator, or None if the gflownet DAG is a tree, and
85+
pb is therefore always 1.
8286
logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log
8387
flow of the states.
8488
forward_looking: Whether to use the forward-looking GFN loss.
8589
log_reward_clip_min: If finite, clips log rewards to this value.
8690
safe_log_prob_min: If True, uses -1e10 as the minimum log probability value
8791
to avoid numerical instability, otherwise uses -1e38.
92+
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
93+
gflownet DAG is a tree, and pb is therefore always 1. Must be set
94+
explicitly by user to ensure that pb is an Estimator except under this
95+
special case.
8896
"""
89-
super().__init__(pf, pb)
97+
super().__init__(pf, pb, constant_pb=constant_pb)
9098
assert any(
9199
isinstance(logF, cls)
92100
for cls in [ScalarEstimator, ConditionalScalarEstimator]
@@ -285,15 +293,19 @@ class ModifiedDBGFlowNet(PFBasedGFlowNet[Transitions]):
285293
286294
Attributes:
287295
pf: The forward policy estimator.
288-
pb: The backward policy estimator.
289-
logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log
290-
flow of the states.
291-
forward_looking: Whether to use the forward-looking GFN loss.
292-
log_reward_clip_min: If finite, clips log rewards to this value.
293-
safe_log_prob_min: If True, uses -1e10 as the minimum log probability value
294-
to avoid numerical instability, otherwise uses -1e38.
296+
pb: The backward policy estimator, or None if the gflownet DAG is a tree, and
297+
pb is therefore always 1.
298+
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
299+
gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly
300+
by user to ensure that pb is an Estimator except under this special case.
295301
"""
296302

303+
def __init__(
304+
self, pf: Estimator, pb: Estimator | None, constant_pb: bool = False
305+
) -> None:
306+
"""Initializes a ModifiedDBGFlowNet instance."""
307+
super().__init__(pf, pb, constant_pb=constant_pb)
308+
297309
def get_scores(
298310
self, transitions: Transitions, recalculate_all_logprobs: bool = True
299311
) -> torch.Tensor:
@@ -371,18 +383,23 @@ def get_scores(
371383

372384
non_exit_actions = actions[~actions.is_exit]
373385

374-
if transitions.conditioning is not None:
375-
with has_conditioning_exception_handler("pb", self.pb):
376-
module_output = self.pb(
377-
valid_next_states, transitions.conditioning[mask]
378-
)
386+
if self.pb is not None:
387+
if transitions.conditioning is not None:
388+
with has_conditioning_exception_handler("pb", self.pb):
389+
module_output = self.pb(
390+
valid_next_states, transitions.conditioning[mask]
391+
)
392+
else:
393+
with no_conditioning_exception_handler("pb", self.pb):
394+
module_output = self.pb(valid_next_states)
395+
396+
valid_log_pb_actions = self.pb.to_probability_distribution(
397+
valid_next_states, module_output
398+
).log_prob(non_exit_actions.tensor)
379399
else:
380-
with no_conditioning_exception_handler("pb", self.pb):
381-
module_output = self.pb(valid_next_states)
382-
383-
valid_log_pb_actions = self.pb.to_probability_distribution(
384-
valid_next_states, module_output
385-
).log_prob(non_exit_actions.tensor)
400+
# If pb is None, we assume that the gflownet DAG is a tree, and therefore
401+
# the backward policy probability is always 1 (log probs are 0).
402+
valid_log_pb_actions = torch.zeros_like(valid_log_pf_s_exit)
386403

387404
preds = all_log_rewards[:, 0] + valid_log_pf_actions + valid_log_pf_s_prime_exit
388405
targets = all_log_rewards[:, 1] + valid_log_pb_actions + valid_log_pf_s_exit

src/gfn/gflownet/sub_trajectory_balance.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ class SubTBGFlowNet(TrajectoryBasedGFlowNet):
3333
3434
Attributes:
3535
pf: The forward policy estimator.
36-
pb: The backward policy estimator.
36+
pb: The backward policy estimator, or None if the gflownet DAG is a tree, and
37+
pb is therefore always 1.
3738
logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow
3839
of the states.
3940
weighting: The sub-trajectories weighting scheme.
@@ -60,12 +61,14 @@ class SubTBGFlowNet(TrajectoryBasedGFlowNet):
6061
lamda: Discount factor for longer trajectories (used in geometric weighting).
6162
log_reward_clip_min: If finite, clips log rewards to this value.
6263
forward_looking: Whether to use the forward-looking GFN loss.
64+
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
65+
gflownet DAG is a tree, and pb is therefore always 1.
6366
"""
6467

6568
def __init__(
6669
self,
6770
pf: Estimator,
68-
pb: Estimator,
71+
pb: Estimator | None,
6972
logF: ScalarEstimator | ConditionalScalarEstimator,
7073
weighting: Literal[
7174
"DB",
@@ -79,6 +82,7 @@ def __init__(
7982
lamda: float = 0.9,
8083
log_reward_clip_min: float = -float("inf"),
8184
forward_looking: bool = False,
85+
constant_pb: bool = False,
8286
):
8387
"""Initializes a SubTBGFlowNet instance.
8488
@@ -92,8 +96,12 @@ def __init__(
9296
lamda: Discount factor for longer trajectories (used in geometric weighting).
9397
log_reward_clip_min: If finite, clips log rewards to this value.
9498
forward_looking: Whether to use the forward-looking GFN loss.
99+
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
100+
gflownet DAG is a tree, and pb is therefore always 1. Must be set
101+
explicitly by user to ensure that pb is an Estimator except under this
102+
special case.
95103
"""
96-
super().__init__(pf, pb)
104+
super().__init__(pf, pb, constant_pb=constant_pb)
97105
assert any(
98106
isinstance(logF, cls)
99107
for cls in [ScalarEstimator, ConditionalScalarEstimator]

src/gfn/gflownet/trajectory_balance.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,30 +32,37 @@ class TBGFlowNet(TrajectoryBasedGFlowNet):
3232
3333
Attributes:
3434
pf: The forward policy estimator.
35-
pb: The backward policy estimator.
35+
pb: The backward policy estimator, or None if the gflownet DAG is a tree, and
36+
pb is therefore always 1.
3637
logZ: A learnable parameter or a ScalarEstimator instance (for conditional GFNs).
3738
log_reward_clip_min: If finite, clips log rewards to this value.
39+
constant_pb: Whether the gflownet DAG is a tree, and pb is therefore always 1.
3840
"""
3941

4042
def __init__(
4143
self,
4244
pf: Estimator,
43-
pb: Estimator,
45+
pb: Estimator | None,
4446
logZ: nn.Parameter | ScalarEstimator | None = None,
4547
init_logZ: float = 0.0,
4648
log_reward_clip_min: float = -float("inf"),
49+
constant_pb: bool = False,
4750
):
4851
"""Initializes a TBGFlowNet instance.
4952
5053
Args:
5154
pf: The forward policy estimator.
52-
pb: The backward policy estimator.
55+
pb: The backward policy estimator, or None if the gflownet DAG is a tree, and
56+
pb is therefore always 1.
5357
logZ: A learnable parameter or a ScalarEstimator instance (for
5458
conditional GFNs).
5559
init_logZ: The initial value for the logZ parameter (used if logZ is None).
5660
log_reward_clip_min: If finite, clips log rewards to this value.
61+
constant_pb: Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb
62+
is therefore always 1. Must be set explicitly by user to ensure that pb
63+
is an Estimator except under this special case.
5764
"""
58-
super().__init__(pf, pb)
65+
super().__init__(pf, pb, constant_pb=constant_pb)
5966

6067
self.logZ = logZ or nn.Parameter(torch.tensor(init_logZ))
6168
self.log_reward_clip_min = log_reward_clip_min

src/gfn/utils/prob_calculations.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def check_cond_forward(
4646

4747
def get_trajectory_pfs_and_pbs(
4848
pf: Estimator,
49-
pb: Estimator,
49+
pb: Estimator | None,
5050
trajectories: Trajectories,
5151
fill_value: float = 0.0,
5252
recalculate_all_logprobs: bool = True,
@@ -55,7 +55,8 @@ def get_trajectory_pfs_and_pbs(
5555
5656
Args:
5757
pf: The forward policy estimator.
58-
pb: The backward policy estimator.
58+
pb: The backward policy estimator, or None if the gflownet DAG is a tree, and
59+
pb is therefore always 1.
5960
trajectories: The trajectories to calculate probabilities for.
6061
fill_value: The value to fill for invalid states (e.g., sink states).
6162
recalculate_all_logprobs: Whether to recalculate log probabilities even if they
@@ -157,7 +158,7 @@ def get_trajectory_pfs(
157158

158159

159160
def get_trajectory_pbs(
160-
pb: Estimator,
161+
pb: Estimator | None,
161162
trajectories: Trajectories,
162163
fill_value: float = 0.0,
163164
) -> torch.Tensor:
@@ -210,11 +211,16 @@ def get_trajectory_pbs(
210211
# We need to index it with the state_mask to get the valid states
211212
masked_cond = trajectories.conditioning[state_mask]
212213

213-
estimator_outputs = check_cond_forward(pb, "pb", valid_states, masked_cond)
214+
if pb is not None:
215+
estimator_outputs = check_cond_forward(pb, "pb", valid_states, masked_cond)
216+
valid_log_pb_actions = pb.to_probability_distribution(
217+
valid_states, estimator_outputs
218+
).log_prob(valid_actions.tensor)
214219

215-
valid_log_pb_actions = pb.to_probability_distribution(
216-
valid_states, estimator_outputs
217-
).log_prob(valid_actions.tensor)
220+
else:
221+
# If pb is None, we assume that the gflownet DAG is a tree, and therefore
222+
# the backward policy probability is always 1 (log probs are 0).
223+
valid_log_pb_actions = torch.zeros_like(valid_actions.tensor)
218224

219225
log_pb_trajectories[action_mask] = valid_log_pb_actions
220226

@@ -233,15 +239,16 @@ def get_trajectory_pbs(
233239

234240
def get_transition_pfs_and_pbs(
235241
pf: Estimator,
236-
pb: Estimator,
242+
pb: Estimator | None,
237243
transitions: Transitions,
238244
recalculate_all_logprobs: bool = True,
239245
) -> Tuple[torch.Tensor, torch.Tensor]:
240246
"""Calculates the log probabilities of forward and backward transitions.
241247
242248
Args:
243249
pf: The forward policy estimator.
244-
pb: The backward policy estimator.
250+
pb: The backward policy estimator, or None if the gflownet DAG is a tree, and
251+
pb is therefore always 1.
245252
transitions: The transitions to calculate probabilities for.
246253
recalculate_all_logprobs: Whether to recalculate log probabilities even if they
247254
already exist in the transitions object.
@@ -301,11 +308,12 @@ def get_transition_pfs(
301308
return log_pf_actions
302309

303310

304-
def get_transition_pbs(pb: Estimator, transitions: Transitions) -> torch.Tensor:
311+
def get_transition_pbs(pb: Estimator | None, transitions: Transitions) -> torch.Tensor:
305312
"""Calculates the log probabilities of backward transitions.
306313
307314
Args:
308-
pb: The backward policy Estimator.
315+
pb: The backward policy Estimator, or None if the gflownet DAG is a tree, and
316+
pb is therefore always 1.
309317
transitions: The transitions to calculate probabilities for.
310318
"""
311319
# automatically removes invalid transitions (i.e. s_f -> s_f)
@@ -318,18 +326,24 @@ def get_transition_pbs(pb: Estimator, transitions: Transitions) -> torch.Tensor:
318326
if transitions.conditioning is not None
319327
else None
320328
)
321-
estimator_outputs = check_cond_forward(pb, "pb", valid_next_states, masked_cond)
322329

323-
# Evaluate the log PB of the actions.
330+
# TODO: We support a fill_value for trajectories, but not for transitions.
331+
# Should we add it here, or remove it for trajectories?
324332
log_pb_actions = torch.zeros(
325333
(transitions.n_transitions,), device=transitions.states.device
326334
)
327335

328-
if len(valid_next_states) != 0:
336+
# If pb is None, we assume that the gflownet DAG is a tree, and therefore
337+
# the backward policy probability is always 1 (log probs are 0).
338+
if pb is not None:
339+
estimator_outputs = check_cond_forward(pb, "pb", valid_next_states, masked_cond)
340+
341+
# Evaluate the log PB of the actions.
329342
valid_log_pb_actions = pb.to_probability_distribution(
330343
valid_next_states, estimator_outputs
331344
).log_prob(non_exit_actions.tensor)
332345

333-
log_pb_actions[~transitions.is_terminating] = valid_log_pb_actions
346+
if len(valid_next_states) != 0:
347+
log_pb_actions[~transitions.is_terminating] = valid_log_pb_actions
334348

335349
return log_pb_actions

0 commit comments

Comments
 (0)