Skip to content

Commit

Permalink
Fix TorchWrapper of training args
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Jul 22, 2024
1 parent 8ef4d64 commit 603525c
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 11 deletions.
9 changes: 2 additions & 7 deletions keras/src/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,9 +1013,6 @@ def _symbolic_build(self, iterator=None, data_batch=None):
self.optimizer is not None and not self.optimizer.built
)
if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt:
if backend.backend() == "torch":
original_training = self.training
self.eval()
# Create symbolic tensors matching an input batch.

def to_symbolic_input(v):
Expand All @@ -1038,7 +1035,7 @@ def to_symbolic_input(v):

# Build all model state with `backend.compute_output_spec`.
try:
y_pred = backend.compute_output_spec(self, x)
y_pred = backend.compute_output_spec(self, x, training=False)
except Exception as e:
raise RuntimeError(
"Unable to automatically build the model. "
Expand Down Expand Up @@ -1068,10 +1065,8 @@ def to_symbolic_input(v):
y,
y_pred,
sample_weight=sample_weight,
training=False,
)
if backend.backend() == "torch":
if original_training:
self.train()
if optimizer_unbuilt:
# Build optimizer
self.optimizer.build(self.trainable_variables)
Expand Down
6 changes: 5 additions & 1 deletion keras/src/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@ def _track_module_parameters(self):
self._track_variable(variable)
self.built = True

def call(self, *args, **kwargs):
def call(self, *args, training=None, **kwargs):
if training is False:
self.eval()
else:
self.train()
return self.module(*args, **kwargs)

def save_own_variables(self, store):
Expand Down
50 changes: 47 additions & 3 deletions keras/src/utils/torch_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def __init__(
self.torch_wrappers.append(TorchModuleWrapper(torch_model))
self.fc = layers.Dense(1)

def call(self, x):
def call(self, x, training=None):
for wrapper in self.torch_wrappers:
x = wrapper(x)
x = wrapper(x, training=training)
return self.fc(x)

def get_config(self):
Expand All @@ -49,7 +49,7 @@ def __init__(self, *args, **kwargs):
self.fc2 = torch.nn.Linear(4, 4)
self.fc3 = layers.Dense(2)

def call(self, x):
def call(self, x, training=None):
return self.fc3(self.fc2(self.bn1(self.fc1(x))))


Expand Down Expand Up @@ -82,6 +82,50 @@ def test_basic_usage(self, use_batch_norm, num_torch_layers):
model.compile(optimizer="sgd", loss="mse")
model.fit(np.random.random((3, 2)), np.random.random((3, 1)))

@parameterized.named_parameters(
(
"explicit_torch_wrapper",
Classifier,
{"use_batch_norm": True, "num_torch_layers": 1},
),
("implicit_torch_wrapper", ClassifierWithNoSpecialCasing, {}),
)
def test_training_args(self, cls, kwargs):
model = cls(**kwargs)
model(np.random.random((3, 2)), training=False) # Eager call to build
ref_weights = model.get_weights()
ref_running_mean = backend.convert_to_numpy(
model.torch_wrappers[0].module[-1].running_mean
if cls is Classifier
else model.bn1.module.running_mean
)

# Test training=False doesn't affect model weights
model(np.random.random((3, 2)), training=False)
weights = model.get_weights()
for w, ref_w in zip(weights, ref_weights):
self.assertAllClose(w, ref_w)

# Test training=None affects BN's stats
model.set_weights(ref_weights) # Restore previous weights
model(np.random.random((3, 2)))
running_mean = backend.convert_to_numpy(
model.torch_wrappers[0].module[-1].running_mean
if cls is Classifier
else model.bn1.module.running_mean
)
self.assertNotAllClose(running_mean, ref_running_mean)

# Test training=True affects BN's stats
model.set_weights(ref_weights) # Restore previous weights
model(np.random.random((3, 2)), training=True)
running_mean = backend.convert_to_numpy(
model.torch_wrappers[0].module[-1].running_mean
if cls is Classifier
else model.bn1.module.running_mean
)
self.assertNotAllClose(running_mean, ref_running_mean)

def test_module_autowrapping(self):
model = ClassifierWithNoSpecialCasing()
self.assertIsInstance(model.fc1, TorchModuleWrapper)
Expand Down

0 comments on commit 603525c

Please sign in to comment.