From e8249f7c4a34a5b34dd83e4ba8feeec1168a00d2 Mon Sep 17 00:00:00 2001 From: Harald Scheidl Date: Sat, 25 May 2019 13:51:30 +0200 Subject: [PATCH] added dump option for NN output --- .gitignore | 3 ++- README.md | 1 + src/Model.py | 34 +++++++++++++++++++++++++++++++--- src/main.py | 4 +++- 4 files changed, 37 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 7db9a633d0..632b7256c5 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ model/snapshot-* notes/ *.so *.pyc -.idea/ \ No newline at end of file +.idea/ +dump/ \ No newline at end of file diff --git a/README.md b/README.md index 86f538bdf6..e5682bad49 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ Tested with: * `--validate`: validate the NN, details see below. * `--beamsearch`: use vanilla beam search decoding (better, but slower) instead of best path decoding. * `--wordbeamsearch`: use word beam search decoding (only outputs words contained in a dictionary) instead of best path decoding. This is a custom TF operation and must be compiled from source, more information see corresponding section below. It should **not** be used when training the NN. +* `--dump`: dumps the output of the NN to CSV file(s) saved in the `dump/` folder. Can be used as input for the [CTCDecoder](https://github.com/githubharald/CTCDecoder). If neither `--train` nor `--validate` is specified, the NN infers the text from the test image (`data/test.png`). Two examples: if you want to infer using beam search, execute `python main.py --beamsearch`, while you have to execute `python main.py --train --beamsearch` if you want to train the NN and do the validation using beam search. diff --git a/src/Model.py b/src/Model.py index c56e65f46d..09cfb735e6 100644 --- a/src/Model.py +++ b/src/Model.py @@ -4,6 +4,7 @@ import sys import numpy as np import tensorflow as tf +import os class DecoderType: @@ -20,8 +21,9 @@ class Model: imgSize = (128, 32) maxTextLen = 32 - def __init__(self, charList, decoderType=DecoderType.BestPath, mustRestore=False): + def __init__(self, charList, decoderType=DecoderType.BestPath, mustRestore=False, dump=False): "init model: add CNN, RNN and CTC and initialize TF" + self.dump = dump self.charList = charList self.decoderType = decoderType self.mustRestore = mustRestore @@ -217,14 +219,35 @@ def trainBatch(self, batch): return lossVal + def dumpNNOutput(self, rnnOutput): + "dump the output of the NN to CSV file(s)" + dumpDir = '../dump/' + if not os.path.isdir(dumpDir): + os.mkdir(dumpDir) + + # iterate over all batch elements and create a CSV file for each one + maxT, maxB, maxC = rnnOutput.shape + for b in range(maxB): + csv = '' + for t in range(maxT): + for c in range(maxC): + csv += str(rnnOutput[t, b, c]) + ';' + csv += '\n' + fn = dumpDir + 'rnnOutput_'+str(b)+'.csv' + print('Write dump of NN to file: ' + fn) + with open(fn, 'w') as f: + f.write(csv) + + def inferBatch(self, batch, calcProbability=False, probabilityOfGT=False): "feed a batch into the NN to recognize the texts" # decode, optionally save RNN output numBatchElements = len(batch.imgs) - evalList = [self.decoder] + ([self.ctcIn3dTBC] if calcProbability else []) + evalRnnOutput = self.dump or calcProbability + evalList = [self.decoder] + ([self.ctcIn3dTBC] if evalRnnOutput else []) feedDict = {self.inputImgs : batch.imgs, self.seqLen : [Model.maxTextLen] * numBatchElements, self.is_train: False} - evalRes = self.sess.run([self.decoder, self.ctcIn3dTBC], feedDict) + evalRes = self.sess.run(evalList, feedDict) decoded = evalRes[0] texts = self.decoderOutputToText(decoded, numBatchElements) @@ -237,6 +260,11 @@ def inferBatch(self, batch, calcProbability=False, probabilityOfGT=False): feedDict = {self.savedCtcInput : ctcInput, self.gtTexts : sparse, self.seqLen : [Model.maxTextLen] * numBatchElements, self.is_train: False} lossVals = self.sess.run(evalList, feedDict) probs = np.exp(-lossVals) + + # dump the output of the NN to CSV file(s) + if self.dump: + self.dumpNNOutput(evalRes[1]) + return (texts, probs) diff --git a/src/main.py b/src/main.py index a33e255bfc..7645320680 100644 --- a/src/main.py +++ b/src/main.py @@ -105,6 +105,8 @@ def main(): parser.add_argument('--validate', help='validate the NN', action='store_true') parser.add_argument('--beamsearch', help='use beam search instead of best path decoding', action='store_true') parser.add_argument('--wordbeamsearch', help='use word beam search instead of best path decoding', action='store_true') + parser.add_argument('--dump', help='dump output of NN to CSV file(s)', action='store_true') + args = parser.parse_args() decoderType = DecoderType.BestPath @@ -135,7 +137,7 @@ def main(): # infer text on test image else: print(open(FilePaths.fnAccuracy).read()) - model = Model(open(FilePaths.fnCharList).read(), decoderType, mustRestore=True) + model = Model(open(FilePaths.fnCharList).read(), decoderType, mustRestore=True, dump=args.dump) infer(model, FilePaths.fnInfer)