Skip to content

Commit

Permalink
data aug: random background, write summary of CERs while training
Browse files Browse the repository at this point in the history
  • Loading branch information
Harald Scheidl committed Feb 2, 2021
1 parent 2f0499f commit cdfb518
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 17 deletions.
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
data/words*
data/words.txt
src/__pycache__/
model/checkpoint
model/snapshot-*
notes/
*.so
*.pyc
Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ The input image and the expected output is shown below.

```
> python main.py
Validation character error rate of saved model: 11.118344571029994%
Init with stored values from ../model/snapshot-76
Recognized: "Hello"
Probability: 0.8462573289871216
Expand Down
Binary file removed data/pixelRelevance.npy
Binary file not shown.
Binary file removed data/translationInvariance.npy
Binary file not shown.
Binary file removed data/translationInvarianceTexts.pickle
Binary file not shown.
4 changes: 4 additions & 0 deletions model/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore
1 change: 0 additions & 1 deletion model/accuracy.txt

This file was deleted.

18 changes: 9 additions & 9 deletions src/SamplePreprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,30 @@ def preprocess(img, imgSize, dataAugmentation=False):
if random.random() < 0.25:
img = cv2.erode(img, np.ones((3, 3)))
if random.random() < 0.5:
img = img * (0.5 + random.random() * 0.5)
img = img * (0.1 + random.random() * 0.9)
if random.random() < 0.25:
img = np.clip(img + (np.random.random(img.shape) - 0.5) * random.randint(1, 50), 0, 255)
img = np.clip(img + (np.random.random(img.shape) - 0.5) * random.randint(1, 25), 0, 255)
if random.random() < 0.1:
img = 255 - img

# geometric data augmentation
wt, ht = imgSize
h, w = img.shape
f = min(wt / w, ht / h)
fx = f * np.random.uniform(0.75, 1.25)
fy = f * np.random.uniform(0.75, 1.25)
fx = f * np.random.uniform(0.75, 1.5)
fy = f * np.random.uniform(0.75, 1.5)

# random position around center
txc = (wt - w * fx) / 2
tyc = (ht - h * fy) / 2
freedom_x = wt // 10
freedom_y = ht // 10
tx = txc + np.random.randint(-freedom_x, freedom_x)
ty = tyc + np.random.randint(-freedom_y, freedom_y)
freedom_x = wt / 5
freedom_y = ht / 5
tx = txc + np.random.uniform(-freedom_x, freedom_x)
ty = tyc + np.random.uniform(-freedom_y, freedom_y)

# map image into target image
M = np.float32([[fx, 0, tx], [0, fy, ty]])
target = np.ones(imgSize[::-1]) * 255 / 2
target = np.ones(imgSize[::-1]) * np.random.uniform(0, 255)
img = cv2.warpAffine(img, M, dsize=imgSize, dst=target, borderMode=cv2.BORDER_TRANSPARENT)

# no data augmentation
Expand Down
16 changes: 12 additions & 4 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import json

import cv2
import editdistance
Expand All @@ -12,14 +13,20 @@
class FilePaths:
"filenames and paths to data"
fnCharList = '../model/charList.txt'
fnAccuracy = '../model/accuracy.txt'
fnSummary = '../model/summary.json'
fnInfer = '../data/test.png'
fnCorpus = '../data/corpus.txt'


def write_summary(charErrorRates):
with open(FilePaths.fnSummary, 'w') as f:
json.dump(charErrorRates, f)


def train(model, loader):
"train NN"
epoch = 0 # number of training epochs since start
summaryCharErrorRates = []
bestCharErrorRate = float('inf') # best valdiation character error rate
noImprovementSince = 0 # number of epochs no improvement of character error rate occured
earlyStopping = 25 # stop training after this number of epochs without improvement
Expand All @@ -39,14 +46,16 @@ def train(model, loader):
# validate
charErrorRate = validate(model, loader)

# write summary
summaryCharErrorRates.append(charErrorRate)
write_summary(summaryCharErrorRates)

# if best validation accuracy so far, save model parameters
if charErrorRate < bestCharErrorRate:
print('Character error rate improved, save model')
bestCharErrorRate = charErrorRate
noImprovementSince = 0
model.save()
open(FilePaths.fnAccuracy, 'w').write(
f'Validation character error rate of saved model: {charErrorRate * 100.0}%')
else:
print(f'Character error rate not improved, best so far: {charErrorRate * 100.0}%')
noImprovementSince += 1
Expand Down Expand Up @@ -140,7 +149,6 @@ def main():

# infer text on test image
else:
print(open(FilePaths.fnAccuracy).read())
model = Model(open(FilePaths.fnCharList).read(), decoderType, mustRestore=True, dump=args.dump)
infer(model, FilePaths.fnInfer)

Expand Down

0 comments on commit cdfb518

Please sign in to comment.