retuve.classes.seg

Segmentaition Classes

These are the classes you need to store your model results in.

They are used so Retuve can understand the results of your model.

Below is an example of how to write a custom AI model and config for Retuve.

https://github.com/radoss-org/retuve/tree/main/examples/ai_plugins/custom_ai_and_config.py

# Copyright 2024 Adam McArthur
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

import my_model
import pydicom

from retuve.classes.seg import SegFrameObjects, SegObject
from retuve.defaults.hip_configs import default_US
from retuve.funcs import analyse_hip_3DUS
from retuve.keyphrases.config import Config
from retuve.keyphrases.enums import HipMode
from retuve.testdata import Cases, download_case


# NOTE: Can accept any number of arguments.
def custom_ai_model(dcm, keyphrase, kwarg1, kwarg2) -> List[SegFrameObjects]:
    # Ensures keyphrase is converted to config if its not already
    config = Config.get_config(keyphrase)

    results = my_model.predict(dcm, ...)
    seg_results = []

    # Each result represents results for a single frame.
    for result in results:

        seg_frame_objects = SegFrameObjects(img=result.img)

        # there can be multiple objects in a single frame,
        # even of the same class.
        for box, class_, points, conf, mask in result:

            seg_obj = SegObject(points, class_, mask, box=box, conf=conf)
            seg_frame_objects.append(seg_obj)

        seg_results.append(seg_frame_objects)

    return seg_results


# NOTE: Needs to be called "setup" to be picked up by the Retuve API.
def setup():
    chop = default_US.get_copy()
    chop.device = 0
    chop.batch.mode_func = custom_ai_model
    chop.batch.mode_func_args = {
        "kwarg1": "value1",
        "kwarg2": "value2",
    }
    chop.batch.hip_mode = HipMode.US3D

    chop.register(name="chop")

    return chop


# Example usage
dcm_file, seg_file = download_case(Cases.ULTRASOUND_DICOM)

dcm = pydicom.dcmread(dcm_file)

CHOP = setup()

dcm = pydicom.dcmread(dcm_file)

hip_datas, video_clip, visual_3d, dev_metrics = analyse_hip_3DUS(
    dcm,
    keyphrase=CHOP,  # can also be "chop"
    modes_func=custom_ai_model,
    modes_func_kwargs_dict={"seg": seg_file},
)

video_clip.write_videofile("3dus.mp4")
visual_3d.write_html("3dus.html")

metrics = hip_datas.json_dump(default_US)
print(metrics)

This file can then be used with the UI and Trak using.

retuve --task trak --keyphrase_file config.py
  1# Copyright 2024 Adam McArthur
  2#
  3# Licensed under the Apache License, Version 2.0 (the "License");
  4# you may not use this file except in compliance with the License.
  5# You may obtain a copy of the License at
  6#
  7#     http://www.apache.org/licenses/LICENSE-2.0
  8#
  9# Unless required by applicable law or agreed to in writing, software
 10# distributed under the License is distributed on an "AS IS" BASIS,
 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12# See the License for the specific language governing permissions and
 13# limitations under the License.
 14
 15"""
 16Segmentaition Classes
 17
 18These are the classes you need to store your model results in.
 19
 20They are used so Retuve can understand the results of your model.
 21
 22Below is an example of how to write a custom AI model and config for Retuve.
 23
 24https://github.com/radoss-org/retuve/tree/main/examples/ai_plugins/custom_ai_and_config.py
 25
 26```python
 27.. include:: ../../examples/ai_plugins/custom_ai_and_config.py
 28```
 29
 30This file can then be used with the UI and Trak using.
 31
 32```bash
 33retuve --task trak --keyphrase_file config.py
 34```
 35"""
 36
 37from typing import Annotated, Iterable, List, Literal, Tuple, TypeVar
 38
 39import numpy as np
 40import numpy.typing as npt
 41from PIL import Image
 42from radstract.data.colors import get_unique_colours
 43
 44from retuve.hip_us.classes.enums import HipLabelsUS
 45from retuve.hip_xray.classes import HipLabelsXray
 46
 47DType = TypeVar("DType", bound=np.generic)
 48
 49NDArrayImg_NxNx3 = Annotated[npt.NDArray[DType], Literal["N", "N", 3]]
 50NDArrayImg_NxNx3_AllWhite = Annotated[npt.NDArray[DType], Literal["N", "N", 3]]
 51MidLine = List[Tuple[int, int]]
 52
 53
 54class SegObject:
 55    """
 56    Class for holding a single segmentation object.
 57    """
 58
 59    def __init__(
 60        self,
 61        points: List[Tuple[int, int]] = None,
 62        clss: HipLabelsUS = None,
 63        mask: NDArrayImg_NxNx3_AllWhite = None,
 64        conf: float = None,
 65        box: Tuple[int, int, int, int] = None,
 66        empty: bool = False,
 67    ):
 68        """
 69        :param points: List of points that make up the object.
 70        :param clss: Class of the object.
 71        :param mask: Mask of the object.
 72        :param conf: Confidence of the object.
 73        :param box: Bounding box of the object.
 74        :param empty: Is the object empty.
 75
 76        :attr midline: Midline of the object (Only for Hip Ultrasound).
 77        :attr midline_moved: Midline of the object after it
 78              has been moved (Only for Hip Ultrasound).
 79
 80        :raises ValueError: If any of the parameters are not as expected.
 81        """
 82        self.points = points
 83        self.cls = clss
 84        self.mask = mask
 85        self.box = box
 86        self.conf = conf
 87        self.empty = empty
 88
 89        # Only for hip_us
 90        self.midline: MidLine = None
 91        self.midline_moved: MidLine = None
 92
 93        if empty:
 94            return
 95
 96        if points is not None:
 97            for point in points:
 98                if len(point) != 2:
 99                    raise ValueError("Point is not a tuple of length 2")
100
101        if type(mask) != np.ndarray:
102            raise ValueError("mask is not a numpy array")
103
104        # check mask shape
105        if len(mask.shape) != 3 or mask.shape[2] != 3:
106            raise ValueError("mask is not RGB as required")
107
108        colors = get_unique_colours(array=mask)
109        # check each pixel in mask is either rgb white or black
110        if (
111            any(color not in [(0, 0, 0), (255, 255, 255)] for color in colors)
112            or len(colors) > 2
113        ):
114            raise ValueError("mask is not all white")
115
116        if box is not None:
117            if len(box) != 4:
118                raise ValueError("box is not a tuple of length 4")
119
120        if clss is None or not (
121            isinstance(clss, int)
122            or isinstance(clss, HipLabelsUS)
123            or isinstance(clss, HipLabelsXray)
124        ):
125            raise ValueError("clss is None or wrong type")
126
127        if conf is not None and not 0 <= conf <= 1:
128            raise ValueError("conf is not None and not between 0 and 1")
129
130    def __str__(self):
131        return f"SegObject({self.cls}, {self.conf}, {self.points})"
132
133    def area(self):
134        """
135        Returns the area of the object.
136        """
137        # Use the mask to calculate the area
138        if self.mask is None:
139            return 0
140        return np.sum(self.mask[:, :, 0] == 255)
141
142    def flip_horizontally(self, img_width: int):
143        """
144        Flips the object horizontally.
145
146        :param img_width: Width of the image.
147        """
148        if self.empty:
149            return
150
151        if self.box is not None:
152            self.box = (
153                img_width - self.box[2],
154                img_width - self.box[0],
155                self.box[1],
156                self.box[3],
157            )
158
159        if self.points is not None:
160            self.points = [(img_width - x, y) for x, y in self.points]
161
162        if self.midline is not None:
163            self.midline = np.array(
164                [(y, img_width - x) for y, x in self.midline]
165            )
166
167        if self.midline_moved is not None:
168            self.midline_moved = np.array(
169                [(y, img_width - x) for y, x in self.midline_moved]
170            )
171
172        if self.mask is not None:
173            self.mask = np.flip(self.mask, axis=1)
174
175
176class SegFrameObjects:
177    """
178    Class for holding a frame of segmentation objects.
179    """
180
181    def __init__(
182        self, img: NDArrayImg_NxNx3, seg_objects: list[SegObject] = None
183    ):
184        """
185        :param img: Image of the frame.
186        :param seg_objects: List of segmentation objects in the frame.
187        """
188        if type(img) != np.ndarray:
189            raise ValueError("img is not a numpy array")
190
191        if seg_objects is None:
192            seg_objects = []
193
194        self.seg_objects = seg_objects
195        self.img = img
196
197    def __iter__(self) -> Iterable[SegObject]:
198        return iter(self.seg_objects)
199
200    def append(self, seg_result: SegObject):
201        self.seg_objects.append(seg_result)
202
203    def __getitem__(self, index):
204        return self.seg_objects[index]
205
206    def __setitem__(self, index, value):
207        self.seg_objects[index] = value
208
209    def __len__(self):
210        return len(self.seg_objects)
211
212    @classmethod
213    def empty(cls, img: Image.Image):
214        """
215        Returns an empty SegFrameObjects object.
216
217        :param img: Image of the frame.
218        """
219
220        return cls(img=img, seg_objects=[SegObject(empty=True)])
221
222    def __str__(self):
223        return f"SegFrameObjects({self.seg_objects})"
NDArrayImg_NxNx3 = typing.Annotated[numpy.ndarray[typing.Any, numpy.dtype[~DType]], typing.Literal['N', 3]]
NDArrayImg_NxNx3_AllWhite = typing.Annotated[numpy.ndarray[typing.Any, numpy.dtype[~DType]], typing.Literal['N', 3]]
MidLine = typing.List[typing.Tuple[int, int]]
class SegObject:
 55class SegObject:
 56    """
 57    Class for holding a single segmentation object.
 58    """
 59
 60    def __init__(
 61        self,
 62        points: List[Tuple[int, int]] = None,
 63        clss: HipLabelsUS = None,
 64        mask: NDArrayImg_NxNx3_AllWhite = None,
 65        conf: float = None,
 66        box: Tuple[int, int, int, int] = None,
 67        empty: bool = False,
 68    ):
 69        """
 70        :param points: List of points that make up the object.
 71        :param clss: Class of the object.
 72        :param mask: Mask of the object.
 73        :param conf: Confidence of the object.
 74        :param box: Bounding box of the object.
 75        :param empty: Is the object empty.
 76
 77        :attr midline: Midline of the object (Only for Hip Ultrasound).
 78        :attr midline_moved: Midline of the object after it
 79              has been moved (Only for Hip Ultrasound).
 80
 81        :raises ValueError: If any of the parameters are not as expected.
 82        """
 83        self.points = points
 84        self.cls = clss
 85        self.mask = mask
 86        self.box = box
 87        self.conf = conf
 88        self.empty = empty
 89
 90        # Only for hip_us
 91        self.midline: MidLine = None
 92        self.midline_moved: MidLine = None
 93
 94        if empty:
 95            return
 96
 97        if points is not None:
 98            for point in points:
 99                if len(point) != 2:
100                    raise ValueError("Point is not a tuple of length 2")
101
102        if type(mask) != np.ndarray:
103            raise ValueError("mask is not a numpy array")
104
105        # check mask shape
106        if len(mask.shape) != 3 or mask.shape[2] != 3:
107            raise ValueError("mask is not RGB as required")
108
109        colors = get_unique_colours(array=mask)
110        # check each pixel in mask is either rgb white or black
111        if (
112            any(color not in [(0, 0, 0), (255, 255, 255)] for color in colors)
113            or len(colors) > 2
114        ):
115            raise ValueError("mask is not all white")
116
117        if box is not None:
118            if len(box) != 4:
119                raise ValueError("box is not a tuple of length 4")
120
121        if clss is None or not (
122            isinstance(clss, int)
123            or isinstance(clss, HipLabelsUS)
124            or isinstance(clss, HipLabelsXray)
125        ):
126            raise ValueError("clss is None or wrong type")
127
128        if conf is not None and not 0 <= conf <= 1:
129            raise ValueError("conf is not None and not between 0 and 1")
130
131    def __str__(self):
132        return f"SegObject({self.cls}, {self.conf}, {self.points})"
133
134    def area(self):
135        """
136        Returns the area of the object.
137        """
138        # Use the mask to calculate the area
139        if self.mask is None:
140            return 0
141        return np.sum(self.mask[:, :, 0] == 255)
142
143    def flip_horizontally(self, img_width: int):
144        """
145        Flips the object horizontally.
146
147        :param img_width: Width of the image.
148        """
149        if self.empty:
150            return
151
152        if self.box is not None:
153            self.box = (
154                img_width - self.box[2],
155                img_width - self.box[0],
156                self.box[1],
157                self.box[3],
158            )
159
160        if self.points is not None:
161            self.points = [(img_width - x, y) for x, y in self.points]
162
163        if self.midline is not None:
164            self.midline = np.array(
165                [(y, img_width - x) for y, x in self.midline]
166            )
167
168        if self.midline_moved is not None:
169            self.midline_moved = np.array(
170                [(y, img_width - x) for y, x in self.midline_moved]
171            )
172
173        if self.mask is not None:
174            self.mask = np.flip(self.mask, axis=1)

Class for holding a single segmentation object.

SegObject( points: List[Tuple[int, int]] = None, clss: retuve.hip_us.classes.enums.HipLabelsUS = None, mask: typing.Annotated[numpy.ndarray[typing.Any, numpy.dtype[~DType]], typing.Literal['N', 3]] = None, conf: float = None, box: Tuple[int, int, int, int] = None, empty: bool = False)
 60    def __init__(
 61        self,
 62        points: List[Tuple[int, int]] = None,
 63        clss: HipLabelsUS = None,
 64        mask: NDArrayImg_NxNx3_AllWhite = None,
 65        conf: float = None,
 66        box: Tuple[int, int, int, int] = None,
 67        empty: bool = False,
 68    ):
 69        """
 70        :param points: List of points that make up the object.
 71        :param clss: Class of the object.
 72        :param mask: Mask of the object.
 73        :param conf: Confidence of the object.
 74        :param box: Bounding box of the object.
 75        :param empty: Is the object empty.
 76
 77        :attr midline: Midline of the object (Only for Hip Ultrasound).
 78        :attr midline_moved: Midline of the object after it
 79              has been moved (Only for Hip Ultrasound).
 80
 81        :raises ValueError: If any of the parameters are not as expected.
 82        """
 83        self.points = points
 84        self.cls = clss
 85        self.mask = mask
 86        self.box = box
 87        self.conf = conf
 88        self.empty = empty
 89
 90        # Only for hip_us
 91        self.midline: MidLine = None
 92        self.midline_moved: MidLine = None
 93
 94        if empty:
 95            return
 96
 97        if points is not None:
 98            for point in points:
 99                if len(point) != 2:
100                    raise ValueError("Point is not a tuple of length 2")
101
102        if type(mask) != np.ndarray:
103            raise ValueError("mask is not a numpy array")
104
105        # check mask shape
106        if len(mask.shape) != 3 or mask.shape[2] != 3:
107            raise ValueError("mask is not RGB as required")
108
109        colors = get_unique_colours(array=mask)
110        # check each pixel in mask is either rgb white or black
111        if (
112            any(color not in [(0, 0, 0), (255, 255, 255)] for color in colors)
113            or len(colors) > 2
114        ):
115            raise ValueError("mask is not all white")
116
117        if box is not None:
118            if len(box) != 4:
119                raise ValueError("box is not a tuple of length 4")
120
121        if clss is None or not (
122            isinstance(clss, int)
123            or isinstance(clss, HipLabelsUS)
124            or isinstance(clss, HipLabelsXray)
125        ):
126            raise ValueError("clss is None or wrong type")
127
128        if conf is not None and not 0 <= conf <= 1:
129            raise ValueError("conf is not None and not between 0 and 1")
Parameters
  • points: List of points that make up the object.
  • clss: Class of the object.
  • mask: Mask of the object.
  • conf: Confidence of the object.
  • box: Bounding box of the object.
  • empty: Is the object empty.

:attr midline: Midline of the object (Only for Hip Ultrasound). :attr midline_moved: Midline of the object after it has been moved (Only for Hip Ultrasound).

Raises
  • ValueError: If any of the parameters are not as expected.
points
cls
mask
box
conf
empty
midline: List[Tuple[int, int]]
midline_moved: List[Tuple[int, int]]
def area(self):
134    def area(self):
135        """
136        Returns the area of the object.
137        """
138        # Use the mask to calculate the area
139        if self.mask is None:
140            return 0
141        return np.sum(self.mask[:, :, 0] == 255)

Returns the area of the object.

def flip_horizontally(self, img_width: int):
143    def flip_horizontally(self, img_width: int):
144        """
145        Flips the object horizontally.
146
147        :param img_width: Width of the image.
148        """
149        if self.empty:
150            return
151
152        if self.box is not None:
153            self.box = (
154                img_width - self.box[2],
155                img_width - self.box[0],
156                self.box[1],
157                self.box[3],
158            )
159
160        if self.points is not None:
161            self.points = [(img_width - x, y) for x, y in self.points]
162
163        if self.midline is not None:
164            self.midline = np.array(
165                [(y, img_width - x) for y, x in self.midline]
166            )
167
168        if self.midline_moved is not None:
169            self.midline_moved = np.array(
170                [(y, img_width - x) for y, x in self.midline_moved]
171            )
172
173        if self.mask is not None:
174            self.mask = np.flip(self.mask, axis=1)

Flips the object horizontally.

Parameters
  • img_width: Width of the image.
class SegFrameObjects:
177class SegFrameObjects:
178    """
179    Class for holding a frame of segmentation objects.
180    """
181
182    def __init__(
183        self, img: NDArrayImg_NxNx3, seg_objects: list[SegObject] = None
184    ):
185        """
186        :param img: Image of the frame.
187        :param seg_objects: List of segmentation objects in the frame.
188        """
189        if type(img) != np.ndarray:
190            raise ValueError("img is not a numpy array")
191
192        if seg_objects is None:
193            seg_objects = []
194
195        self.seg_objects = seg_objects
196        self.img = img
197
198    def __iter__(self) -> Iterable[SegObject]:
199        return iter(self.seg_objects)
200
201    def append(self, seg_result: SegObject):
202        self.seg_objects.append(seg_result)
203
204    def __getitem__(self, index):
205        return self.seg_objects[index]
206
207    def __setitem__(self, index, value):
208        self.seg_objects[index] = value
209
210    def __len__(self):
211        return len(self.seg_objects)
212
213    @classmethod
214    def empty(cls, img: Image.Image):
215        """
216        Returns an empty SegFrameObjects object.
217
218        :param img: Image of the frame.
219        """
220
221        return cls(img=img, seg_objects=[SegObject(empty=True)])
222
223    def __str__(self):
224        return f"SegFrameObjects({self.seg_objects})"

Class for holding a frame of segmentation objects.

SegFrameObjects( img: typing.Annotated[numpy.ndarray[typing.Any, numpy.dtype[~DType]], typing.Literal['N', 3]], seg_objects: list[SegObject] = None)
182    def __init__(
183        self, img: NDArrayImg_NxNx3, seg_objects: list[SegObject] = None
184    ):
185        """
186        :param img: Image of the frame.
187        :param seg_objects: List of segmentation objects in the frame.
188        """
189        if type(img) != np.ndarray:
190            raise ValueError("img is not a numpy array")
191
192        if seg_objects is None:
193            seg_objects = []
194
195        self.seg_objects = seg_objects
196        self.img = img
Parameters
  • img: Image of the frame.
  • seg_objects: List of segmentation objects in the frame.
seg_objects
img
def append(self, seg_result: SegObject):
201    def append(self, seg_result: SegObject):
202        self.seg_objects.append(seg_result)
@classmethod
def empty(cls, img: PIL.Image.Image):
213    @classmethod
214    def empty(cls, img: Image.Image):
215        """
216        Returns an empty SegFrameObjects object.
217
218        :param img: Image of the frame.
219        """
220
221        return cls(img=img, seg_objects=[SegObject(empty=True)])

Returns an empty SegFrameObjects object.

Parameters
  • img: Image of the frame.