@@ -63,30 +63,38 @@ class DBGFlowNet(PFBasedGFlowNet[Transitions]):
63
63
log_reward_clip_min: If finite, clips log rewards to this value.
64
64
safe_log_prob_min: If True, uses -1e10 as the minimum log probability value
65
65
to avoid numerical instability, otherwise uses -1e38.
66
+ constant_pb: Whether to ignore the backward policy estimator, e.g., if the
67
+ gflownet DAG is a tree, and pb is therefore always 1.
66
68
"""
67
69
68
70
def __init__ (
69
71
self ,
70
72
pf : Estimator ,
71
- pb : Estimator ,
73
+ pb : Estimator | None ,
72
74
logF : ScalarEstimator | ConditionalScalarEstimator ,
73
75
forward_looking : bool = False ,
74
76
log_reward_clip_min : float = - float ("inf" ),
75
77
safe_log_prob_min : bool = True ,
78
+ constant_pb : bool = False ,
76
79
) -> None :
77
80
"""Initializes a DBGFlowNet instance.
78
81
79
82
Args:
80
83
pf: The forward policy estimator.
81
- pb: The backward policy estimator.
84
+ pb: The backward policy estimator, or None if the gflownet DAG is a tree, and
85
+ pb is therefore always 1.
82
86
logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log
83
87
flow of the states.
84
88
forward_looking: Whether to use the forward-looking GFN loss.
85
89
log_reward_clip_min: If finite, clips log rewards to this value.
86
90
safe_log_prob_min: If True, uses -1e10 as the minimum log probability value
87
91
to avoid numerical instability, otherwise uses -1e38.
92
+ constant_pb: Whether to ignore the backward policy estimator, e.g., if the
93
+ gflownet DAG is a tree, and pb is therefore always 1. Must be set
94
+ explicitly by user to ensure that pb is an Estimator except under this
95
+ special case.
88
96
"""
89
- super ().__init__ (pf , pb )
97
+ super ().__init__ (pf , pb , constant_pb = constant_pb )
90
98
assert any (
91
99
isinstance (logF , cls )
92
100
for cls in [ScalarEstimator , ConditionalScalarEstimator ]
@@ -285,15 +293,19 @@ class ModifiedDBGFlowNet(PFBasedGFlowNet[Transitions]):
285
293
286
294
Attributes:
287
295
pf: The forward policy estimator.
288
- pb: The backward policy estimator.
289
- logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log
290
- flow of the states.
291
- forward_looking: Whether to use the forward-looking GFN loss.
292
- log_reward_clip_min: If finite, clips log rewards to this value.
293
- safe_log_prob_min: If True, uses -1e10 as the minimum log probability value
294
- to avoid numerical instability, otherwise uses -1e38.
296
+ pb: The backward policy estimator, or None if the gflownet DAG is a tree, and
297
+ pb is therefore always 1.
298
+ constant_pb: Whether to ignore the backward policy estimator, e.g., if the
299
+ gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly
300
+ by user to ensure that pb is an Estimator except under this special case.
295
301
"""
296
302
303
+ def __init__ (
304
+ self , pf : Estimator , pb : Estimator | None , constant_pb : bool = False
305
+ ) -> None :
306
+ """Initializes a ModifiedDBGFlowNet instance."""
307
+ super ().__init__ (pf , pb , constant_pb = constant_pb )
308
+
297
309
def get_scores (
298
310
self , transitions : Transitions , recalculate_all_logprobs : bool = True
299
311
) -> torch .Tensor :
@@ -371,18 +383,23 @@ def get_scores(
371
383
372
384
non_exit_actions = actions [~ actions .is_exit ]
373
385
374
- if transitions .conditioning is not None :
375
- with has_conditioning_exception_handler ("pb" , self .pb ):
376
- module_output = self .pb (
377
- valid_next_states , transitions .conditioning [mask ]
378
- )
386
+ if self .pb is not None :
387
+ if transitions .conditioning is not None :
388
+ with has_conditioning_exception_handler ("pb" , self .pb ):
389
+ module_output = self .pb (
390
+ valid_next_states , transitions .conditioning [mask ]
391
+ )
392
+ else :
393
+ with no_conditioning_exception_handler ("pb" , self .pb ):
394
+ module_output = self .pb (valid_next_states )
395
+
396
+ valid_log_pb_actions = self .pb .to_probability_distribution (
397
+ valid_next_states , module_output
398
+ ).log_prob (non_exit_actions .tensor )
379
399
else :
380
- with no_conditioning_exception_handler ("pb" , self .pb ):
381
- module_output = self .pb (valid_next_states )
382
-
383
- valid_log_pb_actions = self .pb .to_probability_distribution (
384
- valid_next_states , module_output
385
- ).log_prob (non_exit_actions .tensor )
400
+ # If pb is None, we assume that the gflownet DAG is a tree, and therefore
401
+ # the backward policy probability is always 1 (log probs are 0).
402
+ valid_log_pb_actions = torch .zeros_like (valid_log_pf_s_exit )
386
403
387
404
preds = all_log_rewards [:, 0 ] + valid_log_pf_actions + valid_log_pf_s_prime_exit
388
405
targets = all_log_rewards [:, 1 ] + valid_log_pb_actions + valid_log_pf_s_exit
0 commit comments