-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
101 lines (78 loc) · 2.93 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_io as tfio
import numpy as np
import cv2
import math
import sys
import time
# Import matplotlib libraries
from matplotlib import pyplot as plt
from matplotlib.collections import LineCollection
import matplotlib.patches as patches
# debug
from icecream import ic
# import local modules
from modules.visualisation import *
from modules.formcheck import *
from modules.inference import *
from modules.graphing import *
# Initialize the TFLite interpreter
interpreter = tf.lite.Interpreter(model_path="models/model.tflite")
interpreter.allocate_tensors()
input_size = 192
# Load the input image.
video_path = 'videos/squat_dillon.mp4'
CURRENT_MOVEMENT = 'squat' # 'squat' or 'bench'
# Read the video from specified path
cam = cv2.VideoCapture(video_path)
frames = []
n_frames = 0
while True:
ret, frame = cam.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
n_frames += 1
else:
break
print(n_frames, 'frames processed')
# convert frames array to tensor
image = tf.convert_to_tensor(np.array(frames))
ic(image.shape)
num_frames, image_height, image_width, _ = image.shape
crop_region = init_crop_region(image_height, image_width)
# Run model inference.
current_depth_flag = False
depth_flag = False
output_images = []
depthplot = DepthPlot(movement=CURRENT_MOVEMENT)
for frame_idx in range(num_frames):
# status bar
print("{:.1f}%".format((100/(num_frames-1)*frame_idx)), end='\r')
keypoints_with_scores = run_inference(
movenet, image[frame_idx, :, :, :], crop_region,
crop_size=[input_size, input_size], interpreter=interpreter)
# ic(keypoints_with_scores)
output_images.append(draw_prediction_on_image(
image[frame_idx, :, :, :].numpy().astype(np.int32),
keypoints_with_scores, crop_region=None,
close_figure=True, output_image_height=300, current_depth_flag=current_depth_flag, depth_flag=depth_flag, movement=CURRENT_MOVEMENT))
crop_region = determine_crop_region(
keypoints_with_scores, image_height, image_width)
depthplot.add_keypoints(keypoints_with_scores=keypoints_with_scores)
if CURRENT_MOVEMENT == 'squat':
current_depth_flag = check_squat_depth(keypoints_with_scores)
depth_flag = True if check_squat_depth(keypoints_with_scores) else depth_flag
elif CURRENT_MOVEMENT == 'bench':
current_depth_flag = check_bench_depth(keypoints_with_scores)
depth_flag = True if check_bench_depth(keypoints_with_scores) else depth_flag
print('\n')
if depth_flag == True:
print("Depth good!")
else:
print("Insufficient Depth")
# depthplot.plot_depth('inference/' + video_path.split('/')[1].replace('mp4', 'png'))
depthplot.plot_animation('inference/plot_' + video_path.split('/')[1].replace('mp4', 'mp4'))
output = np.stack(output_images, axis=0)
to_video(output, fps=30, name='inference/'+video_path.split('/')[-1])