Skip to content

Commit

Permalink
Address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Jul 22, 2024
1 parent 97915e3 commit 8ef4d64
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
6 changes: 5 additions & 1 deletion keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from keras.src.backend import KerasTensor
from keras.src.backend.common import global_state
from keras.src.backend.common.name_scope import current_path
from keras.src.backend.common.symbolic_scope import in_symbolic_scope
from keras.src.distribution import distribution_lib
from keras.src.dtype_policies import DTypePolicyMap
from keras.src.layers import input_spec
Expand Down Expand Up @@ -1139,7 +1140,10 @@ def _get_regularization_losses(self):
for variable in self.trainable_weights:
if variable.regularizer is None:
continue
if backend.in_stateless_scope():
if backend.in_stateless_scope() and not in_symbolic_scope():
# If in symbolic scope, we might get `None` from
# `get_current_value` in `backend.compute_output_spec`. So we
# assign `variable` instead.
v = backend.get_stateless_scope().get_current_value(variable)
else:
v = variable
Expand Down
9 changes: 2 additions & 7 deletions keras/src/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from keras.src import ops
from keras.src import optimizers
from keras.src import tree
from keras.src.backend.common.symbolic_scope import in_symbolic_scope
from keras.src.optimizers.loss_scale_optimizer import LossScaleOptimizer
from keras.src.saving import serialization_lib
from keras.src.trainers.compile_utils import CompileLoss
Expand Down Expand Up @@ -328,12 +327,8 @@ def metrics(self):
loss = self._compile_loss(y, y_pred, sample_weight)
if loss is not None:
losses.append(loss)

# If in symbolic scope, skip `self.losses` to ensure we don't access
# any variables. Otherwise, it might break.
if not in_symbolic_scope():
for loss in self.losses:
losses.append(ops.sum(ops.cast(loss, dtype=backend.floatx())))
for loss in self.losses:
losses.append(ops.sum(ops.cast(loss, dtype=backend.floatx())))
if backend.backend() != "jax" and len(losses) == 0:
raise ValueError(
"No loss to compute. Provide a `loss` argument in `compile()`."
Expand Down

0 comments on commit 8ef4d64

Please sign in to comment.