Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] CoreML model inference issue #273

Open
jessearmandse opened this issue Jan 24, 2025 · 0 comments
Open

[Question] CoreML model inference issue #273

jessearmandse opened this issue Jan 24, 2025 · 0 comments

Comments

@jessearmandse
Copy link

I have converted the pytorch version of object detection model RTDETRv2 to CoreML. Then, after loading the mlpackage into RectLabel Pro, ran one image inference with auto-labeling, but there seem to be an issue with how the bounding box and object is identified.

Is there any way for us to know how to debug or understand how should we export a CoreML model from a pytorch model to be loaded correctly on RectLabel?

For example, this is the result of inference:

Image

object28 is not listed as one of the objects to be detected in the metadata, and the bounding box has zero height.

The ground truth annotation:

Image

The following is the model metadata:

Model Information:
Author: 
License: 
Short description: rtdetrv2

Inputs:
  Name: image
  Type: Image
  Width: 960
  Height: 960
  Color Space: RGB

Outputs:
  Name: labels
  Type: MultiArray
  Shape: [1, 300]
  Data Type: INT32
  Name: boxes
  Type: MultiArray
  Shape: [1, 300, 4]
  Data Type: FLOAT32
  Name: scores_0
  Type: MultiArray
  Shape: [1, 300]
  Data Type: FLOAT32

Output Shapes Summary:
{
  "labels": "shape(1, 300)",
  "boxes": "shape(1, 300, 4)",
  "scores_0": "shape(1, 300)"
}

Metadata:
  User Defined:
    names: {0: 'object1', 1: 'object2', 2: 'object3', 3: 'object4', 4: 'object5'}
    imgsz: [960, 960]
    com.github.apple.coremltools.source: torch==2.2.2
    com.github.apple.coremltools.source_dialect: TorchScript
    com.github.apple.coremltools.version: 8.2
    task: detect

The above output can be printed with:

import argparse
import coremltools.models as models
import json
from coremltools.proto.FeatureTypes_pb2 import ImageFeatureType, ArrayFeatureType

def get_array_shape(type_info):
    if hasattr(type_info, 'multiArrayType'):
        return [dim for dim in type_info.multiArrayType.shape]
    return None

def get_data_type_name(data_type):
    data_type_map = {
        65568: "FLOAT32",  # ArrayFeatureType.FLOAT32
        131104: "INT32",   # ArrayFeatureType.INT32
        196640: "DOUBLE",  # ArrayFeatureType.DOUBLE
        262176: "INT64"    # ArrayFeatureType.INT64
    }
    return data_type_map.get(data_type, str(data_type))

def get_color_space_name(color_space):
    return {
        ImageFeatureType.RGB: "RGB",
        ImageFeatureType.BGR: "BGR",
        ImageFeatureType.GRAYSCALE: "GRAYSCALE"
    }.get(color_space, str(color_space))

def get_image_info(type_info):
    if hasattr(type_info, 'imageType'):
        return {
            'width': type_info.imageType.width,
            'height': type_info.imageType.height,
            'colorSpace': get_color_space_name(type_info.imageType.colorSpace)
        }
    return None

def format_shape(shape):
    return f"shape({', '.join(str(dim) for dim in shape)})"

def print_model_info(model_path):
    try:
        # Load the MLModel
        model = models.MLModel(model_path)
        
        # Get model spec
        spec = model.get_spec()
        
        # Print model information
        print("Model Information:")
        print(f"Author: {model.author if hasattr(model, 'author') else 'Not specified'}")
        print(f"License: {model.license if hasattr(model, 'license') else 'Not specified'}")
        print(f"Short description: {spec.description.metadata.shortDescription if spec.description.metadata.shortDescription else 'Not specified'}")
        
        # Print input information
        print("\nInputs:")
        for input_feature in spec.description.input:
            print(f"  Name: {input_feature.name}")
            
            # Handle image input
            image_info = get_image_info(input_feature.type)
            if image_info:
                print(f"  Type: Image")
                print(f"  Width: {image_info['width']}")
                print(f"  Height: {image_info['height']}")
                print(f"  Color Space: {image_info['colorSpace']}")
            
            # Handle array input
            shape = get_array_shape(input_feature.type)
            if shape:
                print(f"  Type: MultiArray")
                print(f"  Shape: {shape}")
                print(f"  Data Type: {input_feature.type.multiArrayType.dataType}")
        
        # Print output information
        print("\nOutputs:")
        shapes_dict = {}
        for output_feature in spec.description.output:
            print(f"  Name: {output_feature.name}")
            
            # Handle array output
            shape = get_array_shape(output_feature.type)
            if shape:
                print(f"  Type: MultiArray")
                print(f"  Shape: {shape}")
                print(f"  Data Type: {get_data_type_name(output_feature.type.multiArrayType.dataType)}")
                shapes_dict[output_feature.name] = format_shape(shape)
        
        if shapes_dict:
            print("\nOutput Shapes Summary:")
            print(json.dumps(shapes_dict, indent=2))
        
        # Print metadata
        print("\nMetadata:")
        if spec.description.metadata.userDefined:
            print("  User Defined:")
            for key, value in spec.description.metadata.userDefined.items():
                print(f"    {key}: {value}")
    except Exception as e:
        print(f"Error loading or processing model: {str(e)}")

# Usage
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Print information about a Core ML model')
    parser.add_argument('model_path', type=str, help='Path to the .mlpackage file')
    args = parser.parse_args()
    
    if not args.model_path:
        parser.print_help()
        print("\nError: Model path is required")
        exit(1)
        
    print_model_info(args.model_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant