Published in Analytics Vidhya
https://medium.com/analytics-vidhya/tensorflow-lite-tflite-with-golang-37a326c089ff


Tensorflow Lite commonly known as TFLite is used to generate and infer machine learning models on mobile and IoT(Edge) devices. TFLite made the on-device(offline) inference easier for multiple device architectures, such as Android, iOS, Raspberry pi and even backend servers. With TFLite you can build a lightweight server based inference application using any programming language with lightweight models, rather than using heavy Tensorflow models.

As developers, we can simply use existing optimized research models or convert existing Tensorflow models to TFLite. There are multiple ways of using TFLite in your mobile, IoT or server applications.

  • Implement the inference for different architecture (Android, iOS etc…) using the standard libraries, SDKs provided by TFLite.
  • Use the TFLite C API for inference along with platform independent programming language like Golang. And cross-compile for platforms like Android, iOS etc…

In this post I’m going to show case the implementation of TFLite inference application using platform independent language Golang and cross-compiling to a shared library. Which then can be consumed by Android, iOS etc…

First thanks to mattn who created the TFLite Go bindings and you can find the repo here. We will start the implementation of a simple Golang application for TFLite inference(You can find the example here). Here I’m using a simple text classifier which will classify to ‘Positive’ or ‘Negative’.

Here is the classifier.go, which has Go functions and are exported for use by C code.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
package main

//#include <stdlib.h>
import "C"
import (
	"bufio"
	"log"
	"os"
	"regexp"
	"sort"
	"strconv"
	"strings"
	"unsafe"
	"encoding/json"

	"github.com/mattn/go-tflite"
	gopointer "github.com/mattn/go-pointer"
)

// Classifier ...
type Classifier struct {
	dictionary  map[string]int
	labels      []string
	interpreter *tflite.Interpreter
}

const (
	START   = "<START>"
	PAD     = "<PAD>"
	UNKNOWN = "<UNKNOWN>"
)

const (
	SENTENCE_LEN = 256
)

//Build ... Builder method of Application and return the index
//export Build
func Build() unsafe.Pointer {
	dic, err := loadDictionary("vocab.txt")
	if err != nil {
		log.Fatal(err)
	}

	labels, err := loadLabels("labels.txt")
	if err != nil {
		log.Fatal(err)
	}

	model := tflite.NewModelFromFile("text_classification.tflite")
	if model == nil {
		log.Fatal("cannot load model")
	}

	interpreter := tflite.NewInterpreter(model, nil)

	classifier := Classifier{dictionary: dic, labels: labels, interpreter: interpreter}

	p := gopointer.Save(classifier)

	return p
}

//Classify ... Classify function
//export Classify
func Classify(appPointer unsafe.Pointer, word *C.char) *C.char {
	goWord := C.GoString(word)

	classifier := gopointer.Restore(appPointer)
	if classifier != nil {
		c := classifier.(*Classifier)
		result := c.classify(goWord)
		return C.CString(result)
	}
	return C.CString("Error Occurred")
}

//Close ... Close function
//export Close
func Close(appPointer unsafe.Pointer) {
	c := gopointer.Restore(appPointer).(*Classifier)
	if c != nil {
		defer c.interpreter.Delete()
	}
	defer gopointer.Unref(appPointer)
}

func loadDictionary(fname string) (map[string]int, error) {
	f, err := os.Open("vocab.txt")
	if err != nil {
		return nil, err
	}
	defer f.Close()

	dic := make(map[string]int)
	scanner := bufio.NewScanner(f)
	for scanner.Scan() {
		line := strings.Split(scanner.Text(), " ")
		if len(line) < 2 {
			continue
		}
		n, err := strconv.Atoi(line[1])
		if err != nil {
			continue
		}
		dic[line[0]] = n
	}
	return dic, nil
}

func loadLabels(fname string) ([]string, error) {
	f, err := os.Open(fname)
	if err != nil {
		return nil, err
	}
	defer f.Close()

	var labels []string
	scanner := bufio.NewScanner(f)
	for scanner.Scan() {
		labels = append(labels, scanner.Text())
	}
	return labels, nil
}

func(c *Classifier) classify(word string) string{
	re := regexp.MustCompile(" |\\,|\\.|\\!|\\?|\n")
	tokens := re.Split(strings.TrimSpace(word), -1)
		index := 0
		tmp := make([]float32, SENTENCE_LEN)
		if n, ok := c.dictionary[START]; ok {
			tmp[index] = float32(n)
			index++
		}
		for _, word := range tokens {
			if index >= SENTENCE_LEN {
				break
			}

			if v, ok := c.dictionary[word]; ok {
				tmp[index] = float32(v)
			} else {
				tmp[index] = float32(c.dictionary[UNKNOWN])
			}
			index++
		}

		for i := index; i < SENTENCE_LEN; i++ {
			tmp[i] = float32(c.dictionary[PAD])
		}

		c.interpreter.AllocateTensors()

		copy(c.interpreter.GetInputTensor(0).Float32s(), tmp)

		c.interpreter.Invoke()

		type rank struct {
			label string
			poll  float32
		}
		ranks := []rank{}
		for i, v := range c.interpreter.GetOutputTensor(0).Float32s() {
			ranks = append(ranks, rank{
				label: c.labels[i],
				poll:  v,
			})
		}
		sort.Slice(ranks, func(i, j int) bool {
			return ranks[i].poll < ranks[j].poll
		})

		strResponse, _ := json.Marshal(ranks)
		return string(strResponse)
}

func main() {
}

Build() : function will initialize all the object references in go runtime environment and return the memory pointer for future use.

Classify(appPointer unsafe.Pointer, word C.char) : function will accept the input of appPointer returned in Build() function and word a C string.

//export : is used to export Go functions to be used in C code.

More info about cgo you can find here.

Now we already have the Golang implementation of text classification in classifier.go. Next big challenge is the cross-compilation of above text classifier to a shared library. In Golang you can build your application targeting to a specific operating system and an architecture. Important Go commands used in our example.

1
2
3
4
GOARCH      - target architecture
GOOS        - target operating system
buildmode   - which kind of object file is to be built. More info here
CGO_ENABLED - Whether the cgo command is supported. In this case it's 1

You can find the supported GOOS and GOARCH info here or you can get it by running the below command.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
go tool dist list | column -c 75 | column -t
Output :
aix/ppc64        freebsd/amd64   linux/mipsle   openbsd/386
android/386      freebsd/arm     linux/ppc64    openbsd/amd64
android/amd64    illumos/amd64   linux/ppc64le  openbsd/arm
android/arm      js/wasm         linux/s390x    openbsd/arm64
android/arm64    linux/386       nacl/386       plan9/386
darwin/386       linux/amd64     nacl/amd64p32  plan9/amd64
darwin/amd64     linux/arm       nacl/arm       plan9/arm
darwin/arm       linux/arm64     netbsd/386     solaris/amd64
darwin/arm64     linux/mips      netbsd/amd64   windows/386
dragonfly/amd64  linux/mips64    netbsd/arm     windows/amd64
freebsd/386      linux/mips64le  netbsd/arm64   windows/arm

Here is the Makefile I created to generate shared library for Linux, Macos, Android and iOS.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

.PHONY: build_library_darwin, build_library_linux
.PHONY: build_library_android_arm, build_library_android_arm64
.PHONY: build_library_ios_64, build_library_ios_simulator

IOS_ARCH=arm64
IPHONEOS_PLATFORM := $(shell xcrun --sdk iphoneos --show-sdk-platform-path)
IPHONEOS_SYSROOT := $(shell xcrun --sdk iphoneos --show-sdk-path)
MIN_SDK_VERSION=9.0
CC_FLAGS=$(shell xcrun --sdk iphoneos --find clang)

#IOS Simulator config
IPHONEOS_SYSROOT_SIMULATOR := $(shell xcrun --sdk iphonesimulator --show-sdk-path)
MIN_SDK_VERSION_SIMULATOR=12.0
IOS_ARCH_SIMULATOR=x86_64
CC_FLAGS_SIMULATOR=$(shell xcrun --sdk iphonesimulator --find clang)


build_library_darwin:
	CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -o $(CURDIR)/gen/darwin_amd64/libtextclassification.so -buildmode=c-shared .

build_library_linux:
	CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -o $(CURDIR)/gen/linux_amd64/libtextclassification.so -buildmode=c-shared .

build_library_android_arm:
	CGO_ENABLED=1 GOOS=android GOARCH=arm GOARM=7 CC=arm-linux-androideabi-gcc CXX=false \
	go build -o $(CURDIR)/gen/armeabi-v7a/libtextclassification.so -buildmode=c-shared .

build_library_android_arm64:
	CGO_LDFLAGS+="-fuse-ld=gold" \
	CGO_ENABLED=1 GOOS=android GOARCH=arm64 CC=aarch64-linux-android21-clang CXX=aarch64-linux-android21-clang++ \
	go build -o $(CURDIR)/gen/arm64-v8a/libtextclassification.so -buildmode=c-shared .

build_library_ios_64:
	CGO_CFLAGS+="-miphoneos-version-min=$(MIN_SDK_VERSION) \
		-isysroot ${IPHONEOS_SYSROOT} \
		-arch $(IOS_ARCH) \
		-O3" \
	CGO_LDFLAGS+="-miphoneos-version-min=$(MIN_SDK_VERSION) \
		-isysroot ${IPHONEOS_SYSROOT} \
		-arch $(IOS_ARCH) \
		-lc++" \
	CGO_ENABLED=1 \
	GOOS=darwin \
	GOARCH=$(IOS_ARCH) \
	CC=$(CC_FLAGS) \
	CXX=clang++ \
	go build -ldflags -w -v -installsuffix goi -o $(CURDIR)/gen/ios_$(IOS_ARCH)/libtextclassification.a -buildmode=c-archive .

build_library_ios_simulator:
	CGO_CFLAGS+="-isysroot ${IPHONEOS_SYSROOT_SIMULATOR} \
		-arch $(IOS_ARCH_SIMULATOR)" \
	CGO_LDFLAGS+="-mios-simulator-version-min=$(MIN_SDK_VERSION_SIMULATOR) \
		-isysroot ${IPHONEOS_SYSROOT_SIMULATOR} \
		-arch $(IOS_ARCH_SIMULATOR) \
		-lc++" \
	CGO_ENABLED=1 \
	GOOS=darwin \
	CC=cc \
	CXX=c++ 
	go build -ldflags -w -v -installsuffix opsim -o $(CURDIR)/gen/ios_$(IOS_ARCH_SIMULATOR)/libtextclassification.a -buildmode=c-archive .

Finally, you can use above generated shared libraries(.so or .a) in Android and iOS. In Android, you can use shared library using JNI (Java Native Interface) and in iOS you can use it as a framework module. Also with flutter you can use dart:ffi (foreign function interface).

You can find the implementation of text classifier https://github.com/duladissa/go-tflite/blob/cross_compilation_support/_example/cross_complie