Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • m-team/ai/ai4os-yolov8-torch
  • valentin.kozlov/ai4os-yolov8-torch-cx
2 results
Show changes
Commits on Source (3)
...@@ -9,18 +9,10 @@ In this repository, we have integrated a DeepaaS API into the Ultralytics YOLOv ...@@ -9,18 +9,10 @@ In this repository, we have integrated a DeepaaS API into the Ultralytics YOLOv
# Install the API and the external submodule requirement # Install the API and the external submodule requirement
To launch the API, first, install the package, and then run DeepaaS: To launch the API, first, install the package, and then run DeepaaS:
``` ``` bash
git clone --depth 1 https://git.scc.kit.edu/m-team/ai/yolov8_api.git git clone --depth 1 https://git.scc.kit.edu/m-team/ai/yolov8_api.git
cd yolov8_api cd yolov8_api
pip install -e . pip install -e .
```
To launch it, first install the package then run [deepaas](https://github.com/indigo-dc/DEEPaaS):
```bash
git clone https://git.scc.kit.edu/m-team/ai/yolov8_api
cd yolov8_api
pip install -e .
deepaas-run --listen-ip 0.0.0.0 deepaas-run --listen-ip 0.0.0.0
``` ```
......
...@@ -137,24 +137,28 @@ def train(**args): ...@@ -137,24 +137,28 @@ def train(**args):
config.DATA_PATH, "raw", args["data"] config.DATA_PATH, "raw", args["data"]
) )
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
# should the project name
args["project"] = config.MODEL_NAME args["project"] = config.MODEL_NAME
# point to the model directory without root directory # point to the model directory without root directory
args["name"] = os.path.join("models", timestamp) args["name"] = os.path.join("models", timestamp)
model = YOLO(args["model"]) if args["weights"] is not None:
if os.path.isfile(args["weights"]):
path=args["weights"]
else:
path= os.path.join(
config.MODELS_PATH, args["weights"]
)
model=YOLO(path)
else:
model = YOLO(args["model"])
os.environ["WANDB_DISABLED"] = str(args["disable_wandb"]) os.environ["WANDB_DISABLED"] = str(args["disable_wandb"])
args.pop("disable_wandb", None) utils.pop_keys_from_dict(args, ["task_type", "disable_wandb", "weights"])
args.pop("task_type", None)
model.train(**args) model.train(**args)
return { return {
f'The model was trained successfully and was saved to: {os.path.join(args["project"], args["name"])}' f'The model was trained successfully and was saved to: {os.path.join(args["project"], args["name"])}'
} }
except Exception as err:
raise HTTPException(reason=err) from err
if __name__ == "__main__": if __name__ == "__main__":
fields = schemas.TrainArgsSchema().fields fields = schemas.TrainArgsSchema().fields
......
...@@ -219,3 +219,6 @@ class DotDict: ...@@ -219,3 +219,6 @@ class DotDict:
setattr(self, key, DotDict(value)) setattr(self, key, DotDict(value))
else: else:
setattr(self, key, value) setattr(self, key, value)
def pop_keys_from_dict(dictionary, keys_to_pop):
for key in keys_to_pop:
dictionary.pop(key, None)