Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support async ort inference (re. progress bars) #35

Open
github-actions bot opened this issue Mar 17, 2024 · 0 comments
Open

support async ort inference (re. progress bars) #35

github-actions bot opened this issue Mar 17, 2024 · 0 comments
Assignees
Labels

Comments

@github-actions
Copy link

// TODO: support async ort inference (re. progress bars)

}


fn generate_rotated_frames(
    mut commands: Commands,
    descriptors: Res<StreamDescriptors>,
    raw_frames: Query<
        (
            Entity,
            &PipelineConfig,
            &RawFrames,
            &Session,
        ),
        Without<RotatedFrames>,
    >,
) {
    // TODO: create a caching/loading system wrapper over run_node interior
    for (
        entity,
        config,
        raw_frames,
        session,
    ) in raw_frames.iter() {
        // TODO: get stream descriptor rotation

        if config.rotate_raw_frames {
            let run_node = !RotatedFrames::exists(session);
            let mut rotated_frames = RotatedFrames::load_from_session(session);

            if run_node {
                let rotations: HashMap<StreamId, f32> = descriptors.0.iter()
                    .enumerate()
                    .map(|(id, descriptor)| (StreamId(id), descriptor.rotation.unwrap_or_default()))
                    .collect();

                info!("generating rotated frames for session {}", session.id);

                raw_frames.frames.iter()
                    .for_each(|(stream_id, frames)| {
                        let output_directory = format!("{}/{}", rotated_frames.directory, stream_id.0);
                        std::fs::create_dir_all(&output_directory).unwrap();

                        let frames = frames.par_iter()
                            .map(|frame| {
                                let frame_idx = std::path::Path::new(frame).file_stem().unwrap().to_str().unwrap();
                                let output_path = format!("{}/{}.png", output_directory, frame_idx);

                                rotate_image(
                                    std::path::Path::new(frame),
                                    std::path::Path::new(&output_path),
                                    rotations[stream_id],
                                ).unwrap();

                                output_path
                            })
                            .collect::<Vec<_>>();

                            rotated_frames.frames.insert(*stream_id, frames);
                    });
            } else {
                info!("rotated frames already exist for session {}", session.id);
            }

            commands.entity(entity).insert(rotated_frames);
        }
    }
}


fn generate_mask_frames(
    mut commands: Commands,
    frames: Query<
        (
            Entity,
            &PipelineConfig,
            &RotatedFrames,
            &Session,
        ),
        Without<MaskFrames>,
    >,
    modnet: Res<Modnet>,
    onnx_assets: Res<Assets<Onnx>>,
) {
    for (
        entity,
        config,
        frames,
        session,
    ) in frames.iter() {
        if config.mask_frames {
            if onnx_assets.get(&modnet.onnx).is_none() {
                return;
            }

            let onnx = onnx_assets.get(&modnet.onnx).unwrap();
            let onnx_session_arc = onnx.session.clone();
            let onnx_session_lock = onnx_session_arc.lock().map_err(|e| e.to_string()).unwrap();
            let onnx_session = onnx_session_lock.as_ref().ok_or("failed to get session from ONNX asset").unwrap();

            let run_node = !MaskFrames::exists(session);
            let mut mask_frames = MaskFrames::load_from_session(session);

            if run_node {
                info!("generating mask frames for session {}", session.id);

                frames.frames.keys()
                    .for_each(|stream_id| {
                        let output_directory = format!("{}/{}", mask_frames.directory, stream_id.0);
                        std::fs::create_dir_all(output_directory).unwrap();
                    });

                let mask_images = frames.frames.iter()
                    .map(|(stream_id, frames)| {
                        let frames = frames.iter()
                            .map(|frame| {
                                let mut decoder = png::Decoder::new(std::fs::File::open(frame).unwrap());
                                decoder.set_transformations(Transformations::EXPAND | Transformations::ALPHA);
                                let mut reader = decoder.read_info().unwrap();
                                let mut img_data = vec![0; reader.output_buffer_size()];
                                let _ = reader.next_frame(&mut img_data).unwrap();

                                assert_eq!(reader.info().bytes_per_pixel(), 3);

                                let width = reader.info().width;
                                let height = reader.info().height;

                                // TODO: separate image loading and onnx inference (so the image loading result can be viewed in the pipeline grid view)
                                let image = Image::new(
                                    Extent3d {
                                        width,
                                        height,
                                        depth_or_array_layers: 1,
                                    },
                                    bevy::render::render_resource::TextureDimension::D2,
                                    img_data,
                                    bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb,
                                    RenderAssetUsages::all(),
                                );

                                let frame_idx = std::path::Path::new(frame).file_stem().unwrap().to_str().unwrap();

                                (
                                    frame_idx,
                                    modnet_inference(
                                        onnx_session,
                                        &[&image],
                                        Some((512, 512)),
                                    ).pop().unwrap(),
                                )
                            })
                            .collect::<Vec<_>>();

                        (stream_id, frames)
                    })
                    .collect::<Vec<_>>();

                mask_images.iter()
                    .for_each(|(stream_id, frames)| {
                        let output_directory = format!("{}/{}", mask_frames.directory, stream_id.0);
                        let mask_paths = frames.iter()
                            .map(|(frame_idx, frame)| {
                                let path = format!("{}/{}.png", output_directory, frame_idx);

                                let buffer = ImageBuffer::<Luma<u8>, Vec<u8>>::from_raw(
                                    frame.width(),
                                    frame.height(),
                                    frame.data.clone(),
                                ).unwrap();

                                let _ = buffer.save(&path);

                                path
                            })
                            .collect::<Vec<_>>();

                        mask_frames.frames.insert(**stream_id, mask_paths);
                    });
            } else {
                info!("mask frames already exist for session {}", session.id);
            }

            commands.entity(entity).insert(mask_frames);
        }
    }
}


fn generate_yolo_frames(
    mut commands: Commands,
    raw_frames: Query<
        (
            Entity,
            &PipelineConfig,
            &RawFrames,
            &Session,
        ),
        Without<YoloFrames>,
    >,
    yolo: Res<Yolo>,
    onnx_assets: Res<Assets<Onnx>>,
) {
    for (
        entity,
        config,
        raw_frames,
        session,
    ) in raw_frames.iter() {
        if config.yolo {
            if onnx_assets.get(&yolo.onnx).is_none() {
                return;
            }

            let onnx = onnx_assets.get(&yolo.onnx).unwrap();
            let onnx_session_arc = onnx.session.clone();
            let onnx_session_lock = onnx_session_arc.lock().map_err(|e| e.to_string()).unwrap();
            let onnx_session = onnx_session_lock.as_ref().ok_or("failed to get session from ONNX asset").unwrap();

            let run_node = !YoloFrames::exists(session);
            let mut yolo_frames = YoloFrames::load_from_session(session);

            if run_node {
                info!("generating yolo frames for session {}", session.id);

                raw_frames.frames.keys()
                    .for_each(|stream_id| {
                        let output_directory = format!("{}/{}", yolo_frames.directory, stream_id.0);
                        std::fs::create_dir_all(output_directory).unwrap();
                    });

                // TODO: support async ort inference (re. progress bars)
                let bounding_box_streams = raw_frames.frames.iter()
                    .map(|(stream_id, frames)| {
                        let frames = frames.iter()
                            .map(|frame| {
                                let mut decoder = png::Decoder::new(std::fs::File::open(frame).unwrap());
                                decoder.set_transformations(Transformations::EXPAND | Transformations::ALPHA);
                                let mut reader = decoder.read_info().unwrap();
                                let mut img_data = vec![0; reader.output_buffer_size()];
                                let _ = reader.next_frame(&mut img_data).unwrap();

                                assert_eq!(reader.info().bytes_per_pixel(), 3);

                                let width = reader.info().width;
                                let height = reader.info().height;

                                // TODO: separate image loading and onnx inference (so the image loading result can be viewed in the pipeline grid view)
                                let image = Image::new(
                                    Extent3d {
                                        width,
                                        height,
                                        depth_or_array_layers: 1,
                                    },
                                    bevy::render::render_resource::TextureDimension::D2,
                                    img_data,
                                    bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb,
                                    RenderAssetUsages::all(),
                                );

                                let frame_idx = std::path::Path::new(frame).file_stem().unwrap().to_str().unwrap();

                                (
                                    frame_idx,
                                    yolo_inference(
                                        onnx_session,
                                        &image,
                                        0.5,
                                    ),
                                )
                            })
                            .collect::<Vec<_>>();

                        (stream_id, frames)
                    })
                    .collect::<Vec<_>>();

                bounding_box_streams.iter()
                    .for_each(|(stream_id, frames)| {
                        let output_directory = format!("{}/{}", yolo_frames.directory, stream_id.0);
                        let bounding_boxes = frames.iter()
                            .map(|(frame_idx, bounding_boxes)| {
                                let path = format!("{}/{}.json", output_directory, frame_idx);

                                let _ = serde_json::to_writer(std::fs::File::create(path).unwrap(), bounding_boxes);

                                bounding_boxes.clone()
                            })
                            .collect::<Vec<_>>();

                        yolo_frames.frames.insert(**stream_id, bounding_boxes);
                    });
            } else {
                info!("yolo frames already exist for session {}", session.id);
            }

            commands.entity(entity).insert(yolo_frames);
        }
    }
}


// TODO: alphablend frames
#[derive(Component, Default)]
pub struct AlphablendFrames {
    pub frames: HashMap<StreamId, Vec<String>>,
    pub directory: String,
}
impl AlphablendFrames {
    pub fn load_from_session(
        session: &Session,
    ) -> Self {
        let directory = format!("{}/alphablend", session.directory);
        std::fs::create_dir_all(&directory).unwrap();

        let mut alphablend_frames = Self {
            frames: HashMap::new(),
            directory,
        };
        alphablend_frames.reload();

        alphablend_frames
    }

    pub fn reload(&mut self) {
        std::fs::read_dir(&self.directory)
            .unwrap()
            .filter_map(|entry| entry.ok())
            .filter(|entry| entry.path().is_dir())
            .map(|stream_dir| {
                let stream_id = StreamId(stream_dir.path().file_name().unwrap().to_str().unwrap().parse::<usize>().unwrap());

                let frames = std::fs::read_dir(stream_dir.path()).unwrap()
                    .filter_map(|entry| entry.ok())
                    .filter(|entry| entry.path().is_file() && entry.path().extension().and_then(|s| s.to_str()) == Some("png"))
                    .map(|entry| entry.path().to_str().unwrap().to_string())
                    .collect::<Vec<_>>();

                (stream_id, frames)
            })
            .for_each(|(stream_id, frames)| {
                self.frames.insert(stream_id, frames);
            });
    }

    pub fn exists(
        session: &Session,
    ) -> bool {
        let output_directory = format!("{}/alphablend", session.directory);
        std::fs::metadata(output_directory).is_ok()
    }

    pub fn image(&self, _camera: usize, _frame: usize) -> Option<Image> {
        todo!()
    }
}



// TODO: support loading maskframes -> images into a pipeline mask viewer


#[derive(Component, Default)]
pub struct RawFrames {
    pub frames: HashMap<StreamId, Vec<String>>,
    pub directory: String,
}
impl RawFrames {
    pub fn load_from_session(
        session: &Session,
    ) -> Self {
        let directory = format!("{}/frames", session.directory);
        std::fs::create_dir_all(&directory).unwrap();

        let mut raw_frames = Self {
            frames: HashMap::new(),
            directory,
        };
        raw_frames.reload();

        raw_frames
    }

    pub fn reload(&mut self) {
        std::fs::read_dir(&self.directory)
            .unwrap()
            .filter_map(|entry| entry.ok())
            .filter(|entry| entry.path().is_dir())
            .map(|stream_dir| {
                let stream_id = StreamId(stream_dir.path().file_name().unwrap().to_str().unwrap().parse::<usize>().unwrap());

                let frames = std::fs::read_dir(stream_dir.path()).unwrap()
                    .filter_map(|entry| entry.ok())
                    .filter(|entry| entry.path().is_file() && entry.path().extension().and_then(|s| s.to_str()) == Some("png"))
                    .map(|entry| entry.path().to_str().unwrap().to_string())
                    .collect::<Vec<_>>();

                (stream_id, frames)
            })
            .for_each(|(stream_id, frames)| {
                self.frames.insert(stream_id, frames);
            });
    }

    pub fn exists(
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant