retuve.trak.main

  1# Copyright 2024 Adam McArthur
  2#
  3# Licensed under the Apache License, Version 2.0 (the "License");
  4# you may not use this file except in compliance with the License.
  5# You may obtain a copy of the License at
  6#
  7#     http://www.apache.org/licenses/LICENSE-2.0
  8#
  9# Unless required by applicable law or agreed to in writing, software
 10# distributed under the License is distributed on an "AS IS" BASIS,
 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12# See the License for the specific language governing permissions and
 13# limitations under the License.
 14
 15import multiprocessing
 16import os
 17import shutil
 18import time
 19
 20from retuve.app.classes import File, FileEnum
 21from retuve.app.helpers import API_RESULTS_URL_ACCESS
 22from retuve.keyphrases.config import Config
 23from retuve.keyphrases.enums import HipMode, Outputs
 24from retuve.logs import ulogger
 25from retuve.trak.data import extract_files, insert_files
 26
 27
 28def get_state(config: Config) -> bool:
 29    """
 30    Get the state of the files and update the database.
 31
 32    :param config: The configuration.
 33    """
 34
 35    # turn the above into a list comprehension
 36    files = [
 37        os.path.join(dataset, file)
 38        for dataset in config.trak.datasets
 39        for file in os.listdir(dataset)
 40        if any(file.endswith(ext) for ext in config.batch.input_types)
 41    ]
 42    save_dir = config.api.savedir
 43
 44    new_states = {}
 45
 46    for file in files:
 47        file_id = file.split("/")[-1].split(".")[0]
 48
 49        updated = File(
 50            file_id=file_id,
 51            state=FileEnum.PENDING,
 52            metrics_url="N/A",
 53            video_url="N/A",
 54            img_url="N/A",
 55            figure_url="N/A",
 56            attempts=0,
 57        )
 58
 59        # Check if any case files exist
 60        output_paths = [os.path.join(save_dir, file_id)]
 61        any_case_files_exist = any(
 62            os.path.isfile(os.path.join(path, "metrics.json"))
 63            for path in output_paths
 64        )
 65
 66        url = config.api.url
 67        base_url = os.path.join(
 68            url, API_RESULTS_URL_ACCESS, config.name, file_id
 69        )
 70        updated.img_url = os.path.join(base_url, Outputs.IMAGE)
 71
 72        # If files exist, update URLs and set state to COMPLETED
 73        if any_case_files_exist:
 74            updated.video_url = os.path.join(base_url, Outputs.VIDEO_CLIP)
 75            updated.figure_url = os.path.join(base_url, Outputs.VISUAL3D)
 76            updated.metrics_url = os.path.join(base_url, Outputs.METRICS)
 77            updated.state = FileEnum.COMPLETED
 78
 79        else:
 80            os.makedirs(os.path.join(save_dir, file_id), exist_ok=True)
 81
 82            # Insert Empty Images automatically, if mode is not 3D
 83            if config.batch.hip_mode not in [HipMode.US3D, HipMode.US2DSW]:
 84                shutil.copyfile(
 85                    file, os.path.join(save_dir, file_id, "img.jpg")
 86                )
 87
 88            any_case_videos_exist = any(
 89                os.path.isfile(os.path.join(path, "video.mp4"))
 90                for path in output_paths
 91            )
 92
 93            if any_case_videos_exist:
 94                # Add the video URL
 95                updated.video_url = os.path.join(base_url, Outputs.VIDEO_CLIP)
 96                updated.state = FileEnum.FAILED
 97
 98        new_states[file_id] = updated
 99
100    # get file_ids in cache, find the difference and insert the new states
101    cached_files = extract_files(config.api.db_path)
102    for cached_file in cached_files:
103        if cached_file.file_id not in new_states:
104            # and there is no video + metrics file
105
106            # Check if any case files exist
107            output_paths = [
108                os.path.join(save_dir, cached_file.file_id, output)
109                for output in config.batch.outputs
110            ]
111            any_case_files_exist = any(
112                os.path.exists(path) for path in output_paths
113            )
114
115            if any_case_files_exist:
116                # mark as dead
117                cached_file.state = FileEnum.DEAD_WITH_RESULTS
118                new_states[cached_file.file_id] = cached_file
119
120            else:
121                # mark as pending
122                cached_file.state = FileEnum.DEAD
123                new_states[cached_file.file_id] = cached_file
124
125    insert_files(new_states.values(), config.api.db_path)
126
127
128def run_state_machine(config: Config):
129    """
130    Continuously run the state machine.
131    """
132    ulogger.info(f"\nRunning state machine {config.name}!\n")
133
134    while True:
135        get_state(config)
136        time.sleep(5)
137
138
139def run_all_state_machines():
140    """
141    Run all state machines from registered configs.
142    """
143
144    configs = [config for _, config in Config.get_configs()]
145
146    # set type to spawn
147    multiprocessing.set_start_method("spawn")
148
149    # run each state machine in a separate process
150    processes = [
151        multiprocessing.Process(target=run_state_machine, args=(config,))
152        for config in configs
153    ]
154
155    for process in processes:
156        process.start()
def get_state(config: retuve.keyphrases.config.Config) -> bool:
 29def get_state(config: Config) -> bool:
 30    """
 31    Get the state of the files and update the database.
 32
 33    :param config: The configuration.
 34    """
 35
 36    # turn the above into a list comprehension
 37    files = [
 38        os.path.join(dataset, file)
 39        for dataset in config.trak.datasets
 40        for file in os.listdir(dataset)
 41        if any(file.endswith(ext) for ext in config.batch.input_types)
 42    ]
 43    save_dir = config.api.savedir
 44
 45    new_states = {}
 46
 47    for file in files:
 48        file_id = file.split("/")[-1].split(".")[0]
 49
 50        updated = File(
 51            file_id=file_id,
 52            state=FileEnum.PENDING,
 53            metrics_url="N/A",
 54            video_url="N/A",
 55            img_url="N/A",
 56            figure_url="N/A",
 57            attempts=0,
 58        )
 59
 60        # Check if any case files exist
 61        output_paths = [os.path.join(save_dir, file_id)]
 62        any_case_files_exist = any(
 63            os.path.isfile(os.path.join(path, "metrics.json"))
 64            for path in output_paths
 65        )
 66
 67        url = config.api.url
 68        base_url = os.path.join(
 69            url, API_RESULTS_URL_ACCESS, config.name, file_id
 70        )
 71        updated.img_url = os.path.join(base_url, Outputs.IMAGE)
 72
 73        # If files exist, update URLs and set state to COMPLETED
 74        if any_case_files_exist:
 75            updated.video_url = os.path.join(base_url, Outputs.VIDEO_CLIP)
 76            updated.figure_url = os.path.join(base_url, Outputs.VISUAL3D)
 77            updated.metrics_url = os.path.join(base_url, Outputs.METRICS)
 78            updated.state = FileEnum.COMPLETED
 79
 80        else:
 81            os.makedirs(os.path.join(save_dir, file_id), exist_ok=True)
 82
 83            # Insert Empty Images automatically, if mode is not 3D
 84            if config.batch.hip_mode not in [HipMode.US3D, HipMode.US2DSW]:
 85                shutil.copyfile(
 86                    file, os.path.join(save_dir, file_id, "img.jpg")
 87                )
 88
 89            any_case_videos_exist = any(
 90                os.path.isfile(os.path.join(path, "video.mp4"))
 91                for path in output_paths
 92            )
 93
 94            if any_case_videos_exist:
 95                # Add the video URL
 96                updated.video_url = os.path.join(base_url, Outputs.VIDEO_CLIP)
 97                updated.state = FileEnum.FAILED
 98
 99        new_states[file_id] = updated
100
101    # get file_ids in cache, find the difference and insert the new states
102    cached_files = extract_files(config.api.db_path)
103    for cached_file in cached_files:
104        if cached_file.file_id not in new_states:
105            # and there is no video + metrics file
106
107            # Check if any case files exist
108            output_paths = [
109                os.path.join(save_dir, cached_file.file_id, output)
110                for output in config.batch.outputs
111            ]
112            any_case_files_exist = any(
113                os.path.exists(path) for path in output_paths
114            )
115
116            if any_case_files_exist:
117                # mark as dead
118                cached_file.state = FileEnum.DEAD_WITH_RESULTS
119                new_states[cached_file.file_id] = cached_file
120
121            else:
122                # mark as pending
123                cached_file.state = FileEnum.DEAD
124                new_states[cached_file.file_id] = cached_file
125
126    insert_files(new_states.values(), config.api.db_path)

Get the state of the files and update the database.

Parameters
  • config: The configuration.
def run_state_machine(config: retuve.keyphrases.config.Config):
129def run_state_machine(config: Config):
130    """
131    Continuously run the state machine.
132    """
133    ulogger.info(f"\nRunning state machine {config.name}!\n")
134
135    while True:
136        get_state(config)
137        time.sleep(5)

Continuously run the state machine.

def run_all_state_machines():
140def run_all_state_machines():
141    """
142    Run all state machines from registered configs.
143    """
144
145    configs = [config for _, config in Config.get_configs()]
146
147    # set type to spawn
148    multiprocessing.set_start_method("spawn")
149
150    # run each state machine in a separate process
151    processes = [
152        multiprocessing.Process(target=run_state_machine, args=(config,))
153        for config in configs
154    ]
155
156    for process in processes:
157        process.start()

Run all state machines from registered configs.