Prog blog

How to classify flowers with Tensorflow.js

How to classify flowers with Tensorflow.js

The post is about how I tried to make example from udacity Tensorflow about the classification of flowers in javascript using Tensorflow.js

The whole example is available on my flower-photos github.

The hardest part of all this was the augmentation image generator, which I didn't find on the npm, so I had to write the functions for the jimp by myself, which would work in a similar way as the original generator from the example in python.

augment images

In addition, the image generation had to be spread over all available processors using the workerpool library.

import { random, sampleSize } from "lodash";

import jimp from "jimp";
import { rotate } from "./rotate";
import { scale } from "./scale";
import { shiftHeight } from "./shift-height";
import { shiftWidth } from "./shift-width";
import workerpool from "workerpool";

const operations = [
  (image: jimp) => image.flip(true, false),
  (image: jimp) => shiftWidth(image, 0.15),
  (image: jimp) => shiftHeight(image, 0.15),
  (image: jimp) => scale(image, 0.5),
  (image: jimp) => rotate(image, 45)
];

const generateAugmentImage = async (path: string, shape: number) => {
  const image = (await jimp.read(path)).resize(
    shape,
    shape
  );
  const operationsImageResult = sampleSize(
    operations,
    random(0, operations.length)
  ).reduce((image: jimp, operation: (image: jimp) => jimp) => operation(image), image);
  
  // await operationsImageResult.writeAsync(`test_photo/${path}`);
  // console.log(`test_photo/${path}`)
  const buffer = await operationsImageResult.getBufferAsync(jimp.MIME_PNG);
  return buffer;
};

workerpool.worker({
  generateAugmentImage: generateAugmentImage,
});

Another problem I encountered was the lack of loss function of SparseCategoryCrossentropy in tensorflow.js. I replaced it with the categoricalCrossentropy function, which, as we will read on stackexchange sparse categorical crossentropy vs categorical crossentropy, is no different except the way labels are presented as oneHot.

const labels = (records: Record[], labels: string[]) =>
  function* () {
    for (let index = 0; index < records.length; index++) {
      const record = records[index];
      const indexOfLabel = labels.indexOf(record.label);
      if (indexOfLabel === -1) {
        throw new Error(
          `Something wrong. Missing label: ${
            record.label
          } in labels: ${labels.toString()}`
        );
      }
      yield tf.oneHot(indexOfLabel, labels.length);
    }
  };

One of the most important elements was also the prefetch, thanks to which the Tensorflow always had the images from the generator ready in time to learn.

  const trainX = tf.data
    .generator(features(trainRecords))
    .mapAsync(async (path: string) => {
      const image = await augmentImagePool.exec("generateAugmentImage", [
        path,
        IMAGE_SHAPE,
      ]);
      return imageBufferToInputTensor(image);
    })
    .prefetch(BATCH_SIZE * 3);

After long work on the generator, I managed to get an amazing 75% accuracy result during neural network validation.

npm run train

> flower-photos@1.0.0 train /flower-photos
> node ./dist/train.js

Overriding the gradient for 'Max'
Overriding the gradient for 'OneHot'
Overriding the gradient for 'PadV2'
Overriding the gradient for 'SpaceToBatchND'
Overriding the gradient for 'SplitV'
2020-08-11 14:51:06.023690: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2400075000 Hz
2020-08-11 14:51:06.024722: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x522b820 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-08-11 14:51:06.024759: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
Epoch 1 / 80
eta=0.0 ======================================================================> 
133405ms 4600184us/step - acc=0.257 loss=2.31 val_acc=0.384 val_loss=1.47 
Epoch 2 / 80
eta=0.0 ======================================================================> 
127012ms 4379732us/step - acc=0.418 loss=1.35 val_acc=0.506 val_loss=1.20 
...
eta=0.0 ======================================================================> 
121768ms 4198905us/step - acc=0.918 loss=0.236 val_acc=0.771 val_loss=0.859 
Epoch 80 / 80
eta=0.0 ======================================================================> 
118522ms 4086980us/step - acc=0.918 loss=0.228 val_acc=0.761 val_loss=0.872 

Before making the model publicly available, it is best to additionally quantify the weights using tensorflowjs_convert to get half the size of the model.

tensorflowjs_converter --quantize_float16 --input_format tfjs_layers_model --output_format tfjs_layers_model model/model.json quantized_model/

Additionally, it is recommended to transform the layer model into a graph to speed up prediction on weaker devices. We can read about the results of such optimisation in the publication: TensorFlow Graph Optimizations

tensorflowjs_converter --quantize_float16 --input_format tfjs_layers_model --output_format tfjs_graph_model model/model.json quantized_graph_model/

If you want to see how the flower detection works, I invite you to test live flower detector.

flower photos webcam app screenshot

Flower photos webcam app