-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsketch.py
30 lines (21 loc) · 912 Bytes
/
sketch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import gradio as gr
import tensorflow as tf
import numpy as np
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], 28*28) / 255.0
x_test = x_test.reshape(x_test.shape[0], 28*28) / 255.0
def predict(img):
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(x_train, y_train)
img_array = np.array(img)
img_array = img_array.reshape(1,-1)
img_array = img_array/255
knn_pred = knn.predict(img_array)
print(knn_pred[0])
return knn_pred[0]
iface = gr.Interface(predict, inputs = 'sketchpad',
outputs = 'text',
allow_flagging = 'never',
description = 'Draw a Digit Below (Draw in the centre for best results)')
iface.launch(share = True, width = 300, height = 500)