Published in Analytics Vidhyahttps://medium.com/analytics-vidhya/tensorflow-lite-tflite-with-golang-37a326c089ff
Tensorflow Lite commonly known as TFLite is used to generate and infer machine learning models on mobile and IoT(Edge) devices. TFLite made the on-device(offline) inference easier for multiple device architectures, such as Android, iOS, Raspberry pi and even backend servers. With TFLite you can build a lightweight server based inference application using any programming language with lightweight models, rather than using heavy Tensorflow models.
As developers, we can simply use existing optimized research models or convert existing Tensorflow models to TFLite. There are multiple ways of using TFLite in your mobile, IoT or server applications.
- Implement the inference for different architecture (Android, iOS etc…) using the standard libraries, SDKs provided by TFLite.
- Use the TFLite C API for inference along with platform independent programming language like Golang. And cross-compile for platforms like Android, iOS etc…
In this post I’m going to show case the implementation of TFLite inference application using platform independent language Golang and cross-compiling to a shared library. Which then can be consumed by Android, iOS etc…
First thanks to mattn who created the TFLite Go bindings and you can find the repo here. We will start the implementation of a simple Golang application for TFLite inference(You can find the example here). Here I’m using a simple text classifier which will classify to ‘Positive’ or ‘Negative’.
Here is the classifier.go
, which has Go functions and are exported for use by C code.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
| package main
//#include <stdlib.h>
import "C"
import (
"bufio"
"log"
"os"
"regexp"
"sort"
"strconv"
"strings"
"unsafe"
"encoding/json"
"github.com/mattn/go-tflite"
gopointer "github.com/mattn/go-pointer"
)
// Classifier ...
type Classifier struct {
dictionary map[string]int
labels []string
interpreter *tflite.Interpreter
}
const (
START = "<START>"
PAD = "<PAD>"
UNKNOWN = "<UNKNOWN>"
)
const (
SENTENCE_LEN = 256
)
//Build ... Builder method of Application and return the index
//export Build
func Build() unsafe.Pointer {
dic, err := loadDictionary("vocab.txt")
if err != nil {
log.Fatal(err)
}
labels, err := loadLabels("labels.txt")
if err != nil {
log.Fatal(err)
}
model := tflite.NewModelFromFile("text_classification.tflite")
if model == nil {
log.Fatal("cannot load model")
}
interpreter := tflite.NewInterpreter(model, nil)
classifier := Classifier{dictionary: dic, labels: labels, interpreter: interpreter}
p := gopointer.Save(classifier)
return p
}
//Classify ... Classify function
//export Classify
func Classify(appPointer unsafe.Pointer, word *C.char) *C.char {
goWord := C.GoString(word)
classifier := gopointer.Restore(appPointer)
if classifier != nil {
c := classifier.(*Classifier)
result := c.classify(goWord)
return C.CString(result)
}
return C.CString("Error Occurred")
}
//Close ... Close function
//export Close
func Close(appPointer unsafe.Pointer) {
c := gopointer.Restore(appPointer).(*Classifier)
if c != nil {
defer c.interpreter.Delete()
}
defer gopointer.Unref(appPointer)
}
func loadDictionary(fname string) (map[string]int, error) {
f, err := os.Open("vocab.txt")
if err != nil {
return nil, err
}
defer f.Close()
dic := make(map[string]int)
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.Split(scanner.Text(), " ")
if len(line) < 2 {
continue
}
n, err := strconv.Atoi(line[1])
if err != nil {
continue
}
dic[line[0]] = n
}
return dic, nil
}
func loadLabels(fname string) ([]string, error) {
f, err := os.Open(fname)
if err != nil {
return nil, err
}
defer f.Close()
var labels []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
labels = append(labels, scanner.Text())
}
return labels, nil
}
func(c *Classifier) classify(word string) string{
re := regexp.MustCompile(" |\\,|\\.|\\!|\\?|\n")
tokens := re.Split(strings.TrimSpace(word), -1)
index := 0
tmp := make([]float32, SENTENCE_LEN)
if n, ok := c.dictionary[START]; ok {
tmp[index] = float32(n)
index++
}
for _, word := range tokens {
if index >= SENTENCE_LEN {
break
}
if v, ok := c.dictionary[word]; ok {
tmp[index] = float32(v)
} else {
tmp[index] = float32(c.dictionary[UNKNOWN])
}
index++
}
for i := index; i < SENTENCE_LEN; i++ {
tmp[i] = float32(c.dictionary[PAD])
}
c.interpreter.AllocateTensors()
copy(c.interpreter.GetInputTensor(0).Float32s(), tmp)
c.interpreter.Invoke()
type rank struct {
label string
poll float32
}
ranks := []rank{}
for i, v := range c.interpreter.GetOutputTensor(0).Float32s() {
ranks = append(ranks, rank{
label: c.labels[i],
poll: v,
})
}
sort.Slice(ranks, func(i, j int) bool {
return ranks[i].poll < ranks[j].poll
})
strResponse, _ := json.Marshal(ranks)
return string(strResponse)
}
func main() {
}
|
Build() : function will initialize all the object references in go runtime environment and return the memory pointer for future use.
Classify(appPointer unsafe.Pointer, word C.char) : function will accept the input of appPointer returned in Build() function and word a C string.
//export : is used to export Go functions to be used in C code.
More info about cgo you can find here.
Now we already have the Golang implementation of text classification in classifier.go. Next big challenge is the cross-compilation of above text classifier to a shared library. In Golang you can build your application targeting to a specific operating system and an architecture. Important Go commands used in our example.
1
2
3
4
| GOARCH - target architecture
GOOS - target operating system
buildmode - which kind of object file is to be built. More info here
CGO_ENABLED - Whether the cgo command is supported. In this case it's 1
|
You can find the supported GOOS and GOARCH info here or you can get it by running the below command.
1
2
3
4
5
6
7
8
9
10
11
12
13
| go tool dist list | column -c 75 | column -t
Output :
aix/ppc64 freebsd/amd64 linux/mipsle openbsd/386
android/386 freebsd/arm linux/ppc64 openbsd/amd64
android/amd64 illumos/amd64 linux/ppc64le openbsd/arm
android/arm js/wasm linux/s390x openbsd/arm64
android/arm64 linux/386 nacl/386 plan9/386
darwin/386 linux/amd64 nacl/amd64p32 plan9/amd64
darwin/amd64 linux/arm nacl/arm plan9/arm
darwin/arm linux/arm64 netbsd/386 solaris/amd64
darwin/arm64 linux/mips netbsd/amd64 windows/386
dragonfly/amd64 linux/mips64 netbsd/arm windows/amd64
freebsd/386 linux/mips64le netbsd/arm64 windows/arm
|
Here is the Makefile
I created to generate shared library for Linux, Macos, Android and iOS.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
|
.PHONY: build_library_darwin, build_library_linux
.PHONY: build_library_android_arm, build_library_android_arm64
.PHONY: build_library_ios_64, build_library_ios_simulator
IOS_ARCH=arm64
IPHONEOS_PLATFORM := $(shell xcrun --sdk iphoneos --show-sdk-platform-path)
IPHONEOS_SYSROOT := $(shell xcrun --sdk iphoneos --show-sdk-path)
MIN_SDK_VERSION=9.0
CC_FLAGS=$(shell xcrun --sdk iphoneos --find clang)
#IOS Simulator config
IPHONEOS_SYSROOT_SIMULATOR := $(shell xcrun --sdk iphonesimulator --show-sdk-path)
MIN_SDK_VERSION_SIMULATOR=12.0
IOS_ARCH_SIMULATOR=x86_64
CC_FLAGS_SIMULATOR=$(shell xcrun --sdk iphonesimulator --find clang)
build_library_darwin:
CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -o $(CURDIR)/gen/darwin_amd64/libtextclassification.so -buildmode=c-shared .
build_library_linux:
CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -o $(CURDIR)/gen/linux_amd64/libtextclassification.so -buildmode=c-shared .
build_library_android_arm:
CGO_ENABLED=1 GOOS=android GOARCH=arm GOARM=7 CC=arm-linux-androideabi-gcc CXX=false \
go build -o $(CURDIR)/gen/armeabi-v7a/libtextclassification.so -buildmode=c-shared .
build_library_android_arm64:
CGO_LDFLAGS+="-fuse-ld=gold" \
CGO_ENABLED=1 GOOS=android GOARCH=arm64 CC=aarch64-linux-android21-clang CXX=aarch64-linux-android21-clang++ \
go build -o $(CURDIR)/gen/arm64-v8a/libtextclassification.so -buildmode=c-shared .
build_library_ios_64:
CGO_CFLAGS+="-miphoneos-version-min=$(MIN_SDK_VERSION) \
-isysroot ${IPHONEOS_SYSROOT} \
-arch $(IOS_ARCH) \
-O3" \
CGO_LDFLAGS+="-miphoneos-version-min=$(MIN_SDK_VERSION) \
-isysroot ${IPHONEOS_SYSROOT} \
-arch $(IOS_ARCH) \
-lc++" \
CGO_ENABLED=1 \
GOOS=darwin \
GOARCH=$(IOS_ARCH) \
CC=$(CC_FLAGS) \
CXX=clang++ \
go build -ldflags -w -v -installsuffix goi -o $(CURDIR)/gen/ios_$(IOS_ARCH)/libtextclassification.a -buildmode=c-archive .
build_library_ios_simulator:
CGO_CFLAGS+="-isysroot ${IPHONEOS_SYSROOT_SIMULATOR} \
-arch $(IOS_ARCH_SIMULATOR)" \
CGO_LDFLAGS+="-mios-simulator-version-min=$(MIN_SDK_VERSION_SIMULATOR) \
-isysroot ${IPHONEOS_SYSROOT_SIMULATOR} \
-arch $(IOS_ARCH_SIMULATOR) \
-lc++" \
CGO_ENABLED=1 \
GOOS=darwin \
CC=cc \
CXX=c++
go build -ldflags -w -v -installsuffix opsim -o $(CURDIR)/gen/ios_$(IOS_ARCH_SIMULATOR)/libtextclassification.a -buildmode=c-archive .
|
Finally, you can use above generated shared libraries(.so or .a) in Android and iOS. In Android, you can use shared library using JNI (Java Native Interface) and in iOS you can use it as a framework module. Also with flutter you can use dart:ffi (foreign function interface).
You can find the implementation of text classifier https://github.com/duladissa/go-tflite/blob/cross_compilation_support/_example/cross_complie