Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert the handwriting recognition example to Keras 3. #1831

Draft
wants to merge 14 commits into
base: master
Choose a base branch
from
187 changes: 100 additions & 87 deletions examples/vision/handwriting_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: Handwriting recognition
Authors: [A_K_Nain](https://twitter.com/A_K_Nain), [Sayak Paul](https://twitter.com/RisingSayak)
Date created: 2021/08/16
Last modified: 2023/07/06
Last modified: 2024/04/12
Description: Training a handwriting recognition model with variable-length sequences.
Accelerator: GPU
"""
Expand All @@ -17,6 +17,8 @@
handwritten text, and its corresponding target is the string present in the image.
The IAM Dataset is widely used across many OCR benchmarks, so we hope this example can serve as a
good starting point for building OCR systems.

Note: This example is currently only compatible with TensorFlow.
"""

"""
Expand Down Expand Up @@ -45,16 +47,23 @@
## Imports
"""

from tensorflow.keras.layers import StringLookup
from tensorflow import keras
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

from keras.layers import StringLookup
from keras import layers
from keras import ops
from keras_nlp.metrics import EditDistance
import keras

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import os

np.random.seed(42)
tf.random.set_seed(42)
k_seed = keras.random.SeedGenerator(42)

k_nlp_edit_distance = EditDistance(normalize=False)

"""
## Dataset splitting
Expand All @@ -72,7 +81,7 @@

len(words_list)

np.random.shuffle(words_list)
keras.random.shuffle(words_list, seed=k_seed)

"""
We will split the dataset into three subsets with a 90:5:5 ratio (train:validation:test).
Expand Down Expand Up @@ -210,11 +219,11 @@ def clean_labels(labels):

def distortion_free_resize(image, img_size):
w, h = img_size
image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True)
image = ops.image.resize(image, size=(h, w))

# Check tha amount of padding needed to be done.
pad_height = h - tf.shape(image)[0]
pad_width = w - tf.shape(image)[1]
# Check the amount of padding needed to be done.
pad_height = h - ops.shape(image)[0]
pad_width = w - ops.shape(image)[1]

# Only necessary if you want to do same amount of padding on both sides.
if pad_height % 2 != 0:
Expand All @@ -231,17 +240,16 @@ def distortion_free_resize(image, img_size):
else:
pad_width_left = pad_width_right = pad_width // 2

image = tf.pad(
image = ops.image.pad_images(
image,
paddings=[
[pad_height_top, pad_height_bottom],
[pad_width_left, pad_width_right],
[0, 0],
],
top_padding=pad_height_top,
bottom_padding=pad_height_bottom,
right_padding=pad_width_right,
left_padding=pad_width_left
)

image = tf.transpose(image, perm=[1, 0, 2])
image = tf.image.flip_left_right(image)
image = ops.transpose(image, axes=[1, 0, 2])
image = ops.flip(image, axis=1)
return image


Expand All @@ -267,15 +275,18 @@ def preprocess_image(image_path, img_size=(image_width, image_height)):
image = tf.io.read_file(image_path)
image = tf.image.decode_png(image, 1)
image = distortion_free_resize(image, img_size)
image = tf.cast(image, tf.float32) / 255.0
image = ops.cast(image, dtype="float32") / 255.0
return image


def vectorize_label(label):
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
length = tf.shape(label)[0]
length = ops.shape(label)[0]
pad_amount = max_len - length
label = tf.pad(label, paddings=[[0, pad_amount]], constant_values=padding_token)
label = ops.pad(
label, pad_width=[[0, pad_amount]],
mode="constant", constant_values=padding_token
)
return label


Expand All @@ -295,7 +306,6 @@ def prepare_dataset(image_paths, labels):
"""
## Prepare `tf.data.Dataset` objects
"""

train_ds = prepare_dataset(train_img_paths, train_labels_cleaned)
validation_ds = prepare_dataset(validation_img_paths, validation_labels_cleaned)
test_ds = prepare_dataset(test_img_paths, test_labels_cleaned)
Expand All @@ -311,17 +321,18 @@ def prepare_dataset(image_paths, labels):

for i in range(16):
img = images[i]
img = tf.image.flip_left_right(img)
img = tf.transpose(img, perm=[1, 0, 2])
img = ops.flip(img, axis=1)
img = ops.transpose(img, axes=[1, 0, 2])
img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
img = img[:, :, 0]

# Gather indices where label!= padding_token.
label = labels[i]
indices = tf.gather(label, tf.where(tf.math.not_equal(label, padding_token)))
indices = ops.take(
label, ops.where(ops.not_equal(label, padding_token))
)
# Convert to string.
label = tf.strings.reduce_join(num_to_char(indices))
label = label.numpy().decode("utf-8")
label = tf.strings.reduce_join(num_to_char(indices)).numpy().decode("utf-8")

ax[i // 4, i % 4].imshow(img, cmap="gray")
ax[i // 4, i % 4].set_title(label)
Expand All @@ -338,91 +349,95 @@ def prepare_dataset(image_paths, labels):
"""
## Model

Our model will use the CTC loss as an endpoint layer. For a detailed understanding of the
CTC loss, refer to [this post](https://distill.pub/2017/ctc/).
Our model will use the CTC loss. For a detailed understanding of the CTC loss,
refer to [this post](https://distill.pub/2017/ctc/).
"""


class CTCLayer(keras.layers.Layer):
def __init__(self, name=None):
super().__init__(name=name)
self.loss_fn = keras.backend.ctc_batch_cost

def call(self, y_true, y_pred):
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")

input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
self.add_loss(loss)

# At test time, just return the computed predictions.
return y_pred
# Ported from keras 2 backend (with keras 3 ops updates).
def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
fkouteib marked this conversation as resolved.
Show resolved Hide resolved
input_shape = tf.shape(y_pred)
num_samples, num_steps = input_shape[0], input_shape[1]
y_pred = ops.log(ops.transpose(y_pred, axes=[1, 0, 2]) + ops.epsilon())
input_length = ops.cast(input_length, dtype="int32")

if greedy:
(decoded, log_prob) = tf.nn.ctc_greedy_decoder(
inputs=y_pred, sequence_length=input_length
)
else:
(decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder(
inputs=y_pred,
sequence_length=input_length,
beam_width=beam_width,
top_paths=top_paths,
)
decoded_dense = []
for st in decoded:
st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps))
decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))
return (decoded_dense, log_prob)

def build_model():
# Inputs to the model
input_img = keras.Input(shape=(image_width, image_height, 1), name="image")
labels = keras.layers.Input(name="label", shape=(None,))
labels = layers.Input(name="label", shape=(None,))

# First conv block.
x = keras.layers.Conv2D(
x = layers.Conv2D(
32,
(3, 3),
activation="relu",
kernel_initializer="he_normal",
padding="same",
name="Conv1",
)(input_img)
x = keras.layers.MaxPooling2D((2, 2), name="pool1")(x)
x = layers.MaxPooling2D((2, 2), name="pool1")(x)

# Second conv block.
x = keras.layers.Conv2D(
x = layers.Conv2D(
64,
(3, 3),
activation="relu",
kernel_initializer="he_normal",
padding="same",
name="Conv2",
)(x)
x = keras.layers.MaxPooling2D((2, 2), name="pool2")(x)
x = layers.MaxPooling2D((2, 2), name="pool2")(x)

# We have used two max pool with pool size and strides 2.
# Hence, downsampled feature maps are 4x smaller. The number of
# filters in the last layer is 64. Reshape accordingly before
# passing the output to the RNN part of the model.
new_shape = ((image_width // 4), (image_height // 4) * 64)
x = keras.layers.Reshape(target_shape=new_shape, name="reshape")(x)
x = keras.layers.Dense(64, activation="relu", name="dense1")(x)
x = keras.layers.Dropout(0.2)(x)
x = layers.Reshape(target_shape=new_shape, name="reshape")(x)
x = layers.Dense(64, activation="relu", name="dense1")(x)
x = layers.Dropout(0.2)(x)

# RNNs.
x = keras.layers.Bidirectional(
keras.layers.LSTM(128, return_sequences=True, dropout=0.25)
x = layers.Bidirectional(
layers.LSTM(128, return_sequences=True, dropout=0.25)
)(x)
x = keras.layers.Bidirectional(
keras.layers.LSTM(64, return_sequences=True, dropout=0.25)
x = layers.Bidirectional(
layers.LSTM(64, return_sequences=True, dropout=0.25)
)(x)

# +2 is to account for the two special tokens introduced by the CTC loss.
# The recommendation comes here: https://git.io/J0eXP.
x = keras.layers.Dense(
len(char_to_num.get_vocabulary()) + 2, activation="softmax", name="dense2"
output = layers.Dense(
len(char_to_num.get_vocabulary()) + 2, activation="softmax", name="softmax"
)(x)

# Add CTC layer for calculating CTC loss at each step.
output = CTCLayer(name="ctc_loss")(labels, x)

# Define the model.
model = keras.models.Model(
inputs=[input_img, labels], outputs=output, name="handwriting_recognizer"
)
# Optimizer.
opt = keras.optimizers.Adam()

# Compile the model and return.
model.compile(optimizer=opt)
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.CTC())

return model


Expand Down Expand Up @@ -455,22 +470,22 @@ def build_model():

def calculate_edit_distance(labels, predictions):
# Get a single batch and convert its labels to sparse tensors.
saprse_labels = tf.cast(tf.sparse.from_dense(labels), dtype=tf.int64)
sparse_labels = ops.cast(
tf.sparse.from_dense(labels), dtype="int64"
)

# Make predictions and convert them to sparse tensors.
input_len = np.ones(predictions.shape[0]) * predictions.shape[1]
predictions_decoded = keras.backend.ctc_decode(
predictions, input_length=input_len, greedy=True
)[0][0][:, :max_len]
sparse_predictions = tf.cast(
tf.sparse.from_dense(predictions_decoded), dtype=tf.int64
input_len = ops.ones(predictions.shape[0]) * predictions.shape[1]
predictions_decoded = ctc_decode(
predictions, input_length=input_len, greedy=True)[0][0][:, :max_len]
sparse_predictions = ops.cast(
tf.sparse.from_dense(predictions_decoded), dtype="int64"
)

# Compute individual edit distances and average them out.
edit_distances = tf.edit_distance(
sparse_predictions, saprse_labels, normalize=False
)
return tf.reduce_mean(edit_distances)
edit_distances = k_nlp_edit_distance(sparse_predictions, sparse_labels)

return ops.mean(edit_distances)


class EditDistanceCallback(keras.callbacks.Callback):
Expand All @@ -484,10 +499,10 @@ def on_epoch_end(self, epoch, logs=None):
for i in range(len(validation_images)):
labels = validation_labels[i]
predictions = self.prediction_model.predict(validation_images[i])
edit_distances.append(calculate_edit_distance(labels, predictions).numpy())
edit_distances.append(calculate_edit_distance(labels, predictions))

print(
f"Mean edit distance for epoch {epoch + 1}: {np.mean(edit_distances):.4f}"
f"Mean edit distance for epoch {epoch + 1}: {ops.mean(edit_distances):.4f}"
)


Expand Down Expand Up @@ -521,15 +536,13 @@ def on_epoch_end(self, epoch, logs=None):

# A utility function to decode the output of the network.
def decode_batch_predictions(pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
input_len = ops.ones(pred.shape[0]) * pred.shape[1]
# Use greedy search. For complex tasks, you can use beam search.
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
:, :max_len
]
results = ctc_decode(pred, input_length=input_len, greedy=True)[0][0][:, :max_len]
# Iterate over the results and get back the text.
output_text = []
for res in results:
res = tf.gather(res, tf.where(tf.math.not_equal(res, -1)))
res = ops.take(res, ops.where(ops.not_equal(res, -1)))
res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
output_text.append(res)
return output_text
Expand All @@ -545,8 +558,8 @@ def decode_batch_predictions(pred):

for i in range(16):
img = batch_images[i]
img = tf.image.flip_left_right(img)
img = tf.transpose(img, perm=[1, 0, 2])
img = ops.flip(img, axis=1)
img = ops.transpose(img, axes=[1, 0, 2])
img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
img = img[:, :, 0]

Expand Down
Loading