Skip to content

Commit

Permalink
Merge branch 'vito'
Browse files Browse the repository at this point in the history
  • Loading branch information
vitostamatti committed Aug 23, 2023
2 parents 0e183c6 + 3b191b1 commit b4891bc
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 25 deletions.
2 changes: 1 addition & 1 deletion scripts/dataset_from_videos.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def video_into_images(input, output, frames_freq=1, max_frames=10):

if n_frames % frames_freq == 0:
output_name = output + f"_frame_{n_frames}.jpg"
save_image(output_name, frame)
save_image(frame, output_name)
n_frames += 1


Expand Down
7 changes: 3 additions & 4 deletions src/fall_detection/fall/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,12 @@ def load_pose_samples_from_dir(
with open(os.path.join(landmarks_dir, file_name)) as csv_file:
csv_reader = csv.reader(csv_file, delimiter=file_separator)
for row in csv_reader:
if len(row) == 0: continue
if len(row) == 0:
continue
assert (
len(row) == n_landmarks * 3 + 1
), "Wrong number of values: {}".format(len(row))
landmarks = np.array(row[1:], np.float32).reshape(
[n_landmarks, 3]
)
landmarks = np.array(row[1:], np.float32).reshape([n_landmarks, 3])
pose_samples.append(
PoseSample(
name=row[0],
Expand Down
38 changes: 20 additions & 18 deletions src/fall_detection/pose/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(
images_in_folder,
images_out_folder,
csvs_out_folder,
pose_augmentators: List[PoseAugmentation],
pose_augmentators: List[PoseAugmentation] = [],
per_pose_class_limit=None,
):
self._images_in_folder = images_in_folder
Expand Down Expand Up @@ -81,42 +81,43 @@ def __call__(self, pose_model: PoseModel):
# Bootstrap every image.
for image_name in tqdm.tqdm(image_names):
# Load image.
initial_frame = cv2.imread(os.path.join(images_in_folder, image_name))
initial_frame = cv2.imread(
os.path.join(images_in_folder, image_name)
)

base_image_name = ".".join(image_name.split(".")[:-1])
image_extension = image_name.split('.')[-1]
image_extension = image_name.split(".")[-1]

# Input frames
input_frames = [(image_name, initial_frame)]

# Check if any pose augmentations required
if len(self._pose_augmentators):
if len(self._pose_augmentators) > 0:
for pose_augmentator in self._pose_augmentators:
input_frames.append((
f"{base_image_name}_{pose_augmentator.get_pose_augmentaion_name()}.{image_extension}",
pose_augmentator(initial_frame)
))

input_frames.append(
(
f"{base_image_name}_{pose_augmentator.get_pose_augmentaion_name()}.{image_extension}",
pose_augmentator(initial_frame),
)
)

for input_frame_name, input_frame in input_frames:
# Initialize fresh pose tracker and run it.
results = pose_model.predict(input_frame)

# Save image with pose prediction (if pose was detected).
output_frame = input_frame.copy()

<<<<<<< HEAD
cv2.imwrite(
os.path.join(images_out_folder, image_name), output_frame
)
=======
if results is not None:
output_frame = pose_model.draw_landmarks(
image=output_frame,
results=results,
)
>>>>>>> b1c4cb19cf767b2b66c29dad46b92007d58821c4

cv2.imwrite(os.path.join(images_out_folder, input_frame_name), output_frame)
cv2.imwrite(
os.path.join(images_out_folder, input_frame_name),
output_frame,
)

# Save landmarks if pose was detected.
if results is not None:
Expand All @@ -129,7 +130,8 @@ def __call__(self, pose_model: PoseModel):
results, frame_height, frame_width
)
csv_out_writer.writerow(
[input_frame_name] + pose_landmarks.flatten().astype(str).tolist()
[input_frame_name]
+ pose_landmarks.flatten().astype(str).tolist()
)
self.align_images_and_csvs()

Expand Down
3 changes: 1 addition & 2 deletions tests/test_fall.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_fall_pipeline():

pose_sample_generator(pose_model=pose_model)

embedder = PoseEmbedder(landmark_names=COCO_POSE_KEYPOINTS)
embedder = PoseEmbedder(landmark_names=COCO_POSE_KEYPOINTS, dims=2)

classifier = EstimatorClassifier(
estimator=make_pipeline(StandardScaler(), LogisticRegression(random_state=42)),
Expand All @@ -208,7 +208,6 @@ def test_fall_pipeline():

pose_samples = load_pose_samples_from_dir(
pose_embedder=embedder,
n_dimensions=3,
n_landmarks=17,
landmarks_dir="./tests/test_data/test_dataset_csv",
file_extension="csv",
Expand Down

0 comments on commit b4891bc

Please sign in to comment.