-
Notifications
You must be signed in to change notification settings - Fork 43
/
Copy pathinference_tflite_files.py
101 lines (76 loc) · 3.37 KB
/
inference_tflite_files.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import tensorflow as tf
import numpy as np
tflite_model_path_1 = '~/Downloads/whisper-encoder.tflite'
# Load the TFLite model and allocate tensors
interpreter_1 = tf.lite.Interpreter(model_path=tflite_model_path_1)
interpreter_1.allocate_tensors()
print("== Input details ==")
print("name:", interpreter_1.get_input_details()[0]['name'])
print("shape:", interpreter_1.get_input_details()[0]['shape'])
print("type:", interpreter_1.get_input_details()[0]['dtype'])
print("\nDUMP INPUT")
print(interpreter_1.get_input_details()[0])
print("\n== Output details ==")
print("name:", interpreter_1.get_output_details()[0]['name'])
print("shape:", interpreter_1.get_output_details()[0]['shape'])
print("type:", interpreter_1.get_output_details()[0]['dtype'])
print("\nDUMP OUTPUT")
print(interpreter_1.get_output_details()[0])
# Get input and output tensors
input_details = interpreter_1.get_input_details()
output_details = interpreter_1.get_output_details()
# Test the model with random data
input_shape = input_details[0]['shape']
print("before whisper-encoder.tflite inference")
interpreter_1.invoke()
print("after whisper-encoder.tflite inference")
print("\n\n\n\n")
tflite_model_path_2 = '~/Downloads/whisper-decoder_main.tflite'
# Load the TFLite model and allocate tensors
interpreter_2 = tf.lite.Interpreter(model_path=tflite_model_path_2)
interpreter_2.allocate_tensors()
print("== Input details ==")
print("name:", interpreter_2.get_input_details()[0]['name'])
print("shape:", interpreter_2.get_input_details()[0]['shape'])
print("type:", interpreter_2.get_input_details()[0]['dtype'])
print("\nDUMP INPUT")
print(interpreter_2.get_input_details()[0])
print("\n== Output details ==")
print("name:", interpreter_2.get_output_details()[0]['name'])
print("shape:", interpreter_2.get_output_details()[0]['shape'])
print("type:", interpreter_2.get_output_details()[0]['dtype'])
print("\nDUMP OUTPUT")
print(interpreter_2.get_output_details()[0])
# Get input and output tensors
input_details = interpreter_2.get_input_details()
output_details = interpreter_2.get_output_details()
# Test the model with random data
input_shape = input_details[0]['shape']
print("before whisper-decoder_main.tflite inference")
interpreter_2.invoke()
print("after whisper-decoder_main.tflite inference")
print("\n\n\n\n")
tflite_model_path_3 = '~/Downloads/whisper-decoder_language.tflite'
# Load the TFLite model and allocate tensors
interpreter_3 = tf.lite.Interpreter(model_path=tflite_model_path_3)
interpreter_3.allocate_tensors()
print("== Input details ==")
print("name:", interpreter_3.get_input_details()[0]['name'])
print("shape:", interpreter_3.get_input_details()[0]['shape'])
print("type:", interpreter_3.get_input_details()[0]['dtype'])
print("\nDUMP INPUT")
print(interpreter_3.get_input_details()[0])
print("\n== Output details ==")
print("name:", interpreter_3.get_output_details()[0]['name'])
print("shape:", interpreter_3.get_output_details()[0]['shape'])
print("type:", interpreter_3.get_output_details()[0]['dtype'])
print("\nDUMP OUTPUT")
print(interpreter_3.get_output_details()[0])
# Get input and output tensors
input_details = interpreter_3.get_input_details()
output_details = interpreter_3.get_output_details()
# Test the model with random data
input_shape = input_details[0]['shape']
print("before whisper-decoder_language.tflite inference")
interpreter_3.invoke()
print("after whisper-decoder_language.tflite inference")