Skip to content

Commit

Permalink
Minor updates and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Jul 22, 2024
1 parent 11edb68 commit 320b7d7
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 17 deletions.
2 changes: 2 additions & 0 deletions keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from keras.src.backend.common.stateless_scope import StatelessScope
from keras.src.backend.common.stateless_scope import get_stateless_scope
from keras.src.backend.common.stateless_scope import in_stateless_scope
from keras.src.backend.common.symbolic_scope import SymbolicScope
from keras.src.backend.common.symbolic_scope import in_symbolic_scope
from keras.src.backend.common.variables import AutocastScope
from keras.src.backend.common.variables import get_autocast_scope
from keras.src.backend.common.variables import is_float_dtype
Expand Down
2 changes: 2 additions & 0 deletions keras/src/backend/common/symbolic_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

@keras_export("keras.SymbolicScope")
class SymbolicScope:
"""Scope to indicate the symbolic stage."""

def __enter__(self):
self.original_scope = get_symbolic_scope()
global_state.set_global_attribute("symbolic_scope", self)
Expand Down
26 changes: 26 additions & 0 deletions keras/src/backend/common/symbolic_scope_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np

from keras.src import ops
from keras.src import testing
from keras.src.backend.common.symbolic_scope import SymbolicScope
from keras.src.backend.common.symbolic_scope import in_symbolic_scope


class TestSymbolicScope(testing.TestCase):
def test_basic_flow(self):

# Define a function that behaves differently according to
# `in_symbolic_scope`.
def compute_loss(y, y_pred):
if in_symbolic_scope():
return ops.zeros_like(y)
return ops.add(y, y_pred)

y = ops.ones(shape=(2,))
y_pred = ops.ones(shape=(2,))
with SymbolicScope():
loss = compute_loss(y, y_pred)
self.assertAllClose(loss, np.zeros((2,)))

loss = compute_loss(y, y_pred)
self.assertAllClose(loss, 2 * np.ones((2,)))
14 changes: 13 additions & 1 deletion keras/src/backend/numpy/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ def _symbolic_build(self, data_batch):
self._compile_metrics is not None
and not self._compile_metrics.built
)
if model_unbuilt or compile_metrics_unbuilt:
compile_loss_unbuilt = (
self._compile_loss is not None and not self._compile_loss.built
)
if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt:
# Create symbolic tensors matching an input batch.

def to_symbolic_input(v):
Expand Down Expand Up @@ -133,6 +136,15 @@ def to_symbolic_input(v):
y_pred,
sample_weight=sample_weight,
)
if compile_loss_unbuilt:
# Build `CompileLoss` state with `backend.compute_output_spec`.
backend.compute_output_spec(
self._compute_loss,
x,
y,
y_pred,
sample_weight=sample_weight,
)
self._post_build()

def fit(
Expand Down
15 changes: 3 additions & 12 deletions keras/src/trainers/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,16 +445,10 @@ def __init__(

@property
def metrics(self):
if not self.built:
return []
return self._metrics

@property
def variables(self):
# Avoiding relying on implicit tracking since
# CompileLoss may be instantiated or built in a no tracking scope.
if not self.built:
return []
vars = []
for m in self.metrics:
vars.extend(m.variables)
Expand Down Expand Up @@ -639,12 +633,9 @@ def call(self, y_true, y_pred, sample_weight=None):
sample_weight = [sample_weight[0] for _ in range(len(y_true))]
else:
sample_weight = [None for _ in y_true]
if len(self.metrics) == 0:
# This means that the model has a single output. We need to add a
# dummy `None` for the following `zip` to function correctly.
metrics = [None]
else:
metrics = self.metrics

# We need to add a dummy `None` if the model has only a single output.
metrics = [None] if len(self.metrics) == 0 else self.metrics

# Iterate all losses in flat form.
loss_values = []
Expand Down
8 changes: 4 additions & 4 deletions keras/src/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,10 @@ def metrics(self):
loss = self._compile_loss(y, y_pred, sample_weight)
if loss is not None:
losses.append(loss)

# If in symbolic scope, skip `self.losses` to ensure we don't access
# any variables. Otherwise, it might break.
if not in_symbolic_scope():
# If in symbolic scope, skip `self.losses` to ensure we don't access
# any variables.
for loss in self.losses:
losses.append(ops.sum(ops.cast(loss, dtype=backend.floatx())))
if backend.backend() != "jax" and len(losses) == 0:
Expand Down Expand Up @@ -1042,7 +1043,7 @@ def to_symbolic_input(v):

# Build all model state with `backend.compute_output_spec`.
try:
y_pred = backend.compute_output_spec(self, x, training=False)
y_pred = backend.compute_output_spec(self, x)
except Exception as e:
raise RuntimeError(
"Unable to automatically build the model. "
Expand Down Expand Up @@ -1072,7 +1073,6 @@ def to_symbolic_input(v):
y,
y_pred,
sample_weight=sample_weight,
training=False,
)
if backend.backend() == "torch":
if original_training:
Expand Down
59 changes: 59 additions & 0 deletions keras/src/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,6 +1617,65 @@ def test_loss_weights(self):
atol=1e-3,
)

def test_symbolic_build(self):
class ExampleModelWithTrainingArgs(Trainer, layers.Layer):
def __init__(self, units):
layers.Layer.__init__(self)
Trainer.__init__(self)
self.dense = layers.Dense(units)
self.bn = layers.BatchNormalization(axis=-1)

def build(self, input_shape):
self.dense.build(input_shape)
input_shape = self.dense.compute_output_shape(input_shape)
self.bn.build(input_shape)

def call(self, x, training=None):
outputs = self.bn(self.dense(x), training=training)
return [outputs, outputs]

model = ExampleModelWithTrainingArgs(units=3)
model.compile(
optimizer=optimizers.SGD(),
loss=[losses.MeanSquaredError(), losses.MeanSquaredError()],
metrics=[metrics.MeanSquaredError(), metrics.MeanSquaredError()],
)
x = np.ones((4, 4))
y = np.zeros((4, 3))
model(x) # Eager call to build model weights
ref_weights = model.get_weights()

# Before `_symbolic_build`
self.assertTrue(model.built)
self.assertTrue(model._compile_metrics.built)
self.assertFalse(model._compile_loss.built)
self.assertLen(model._compile_loss.metrics, 0)
self.assertLen(model.metrics, 2)

model._symbolic_build(data_batch=(x, (y, y)))
weights = model.get_weights()

# Ensure weights are intact
self.assertEqual(len(weights), len(ref_weights))
for w, ref_w in zip(weights, ref_weights):
self.assertAllClose(w, ref_w)

# Ensure `built`
self.assertTrue(model.built)
self.assertTrue(model._compile_metrics.built)
self.assertTrue(model._compile_loss.built)

# Ensure the len of metrics (original metrics + loss trackers)
self.assertLen(model._compile_metrics.metrics, 2)
self.assertLen(model._compile_loss.metrics, 2)
self.assertLen(model.metrics, 4)

# Ensure no values in metrics
for v in model._compile_metrics.variables:
self.assertAllClose(v, 0.0)
for v in model._compile_loss.variables:
self.assertAllClose(v, 0.0)


class TrainerDistributeTest(testing.TestCase):
@pytest.mark.skipif(
Expand Down

0 comments on commit 320b7d7

Please sign in to comment.