roboflow 是一个用于注释图像以用于对象检测 ai 的平台。
我将这个平台用于 c2smr c2smr.fr,我的海上救援计算机视觉协会。
在本文中,我将向您展示如何使用这个平台并使用 python 训练您的模型。
您可以在我的github上找到更多示例代码:https://github.com/c2smr/detector
i - 数据集
要创建数据集,请访问 https://app.roboflow.com/ 并开始注释您的图像,如下图所示。
立即学习“Python免费学习笔记(深入)”;
点击下载“修复网络工具,一键解决电脑无法上网”;
在这个例子中,我绕道所有游泳者来预测他们在未来图像中的位置。
为了获得良好的结果,请裁剪所有游泳者并将边界框放置在对象后面以正确包围它。
您已经可以使用公共 roboflow 数据集,为此检查 https://universe.roboflow.com/
二、培训
在训练阶段,你可以直接使用 roboflow,但是到了第三次你就需要付费了,这就是为什么我向你展示如何使用你的笔记本电脑进行操作。
第一步是导入数据集。为此,您可以导入 roboflow 库。
pip install roboflow
创建模型需要使用yolo算法,可以通过ultralytics库导入该算法。
pip install ultralytics
在我的脚本中,我使用以下命令:
py train.py api-key project-workspace project-name project-version nb-epoch size_model
您必须获得:
- 访问密钥
- 工作空间
- roboflow 项目名称
- 项目数据集版本
- 训练模型的纪元数
- 神经网络大小
最初,脚本下载 yolov8-obb.pt,默认的 yolo 权重和训练前数据,以方便训练。
import sys
import os
import random
from roboflow import roboflow
from ultralytics import yolo
import yaml
import time
class main:
rf: roboflow
project: object
dataset: object
model: object
results: object
model_size: str
def __init__(self):
self.model_size = sys.argv[6]
self.import_dataset()
self.train()
def import_dataset(self):
self.rf = roboflow(api_key=sys.argv[1])
self.project = self.rf.workspace(sys.argv[2]).project(sys.argv[3])
self.dataset = self.project.version(sys.argv[4]).download("yolov8-obb")
with open(f'{self.dataset.location}/data.yaml', 'r') as file:
data = yaml.safe_load(file)
data['path'] = self.dataset.location
with open(f'{self.dataset.location}/data.yaml', 'w') as file:
yaml.dump(data, file, sort_keys=false)
def train(self):
list_of_models = ["n", "s", "m", "l", "x"]
if self.model_size != "all" and self.model_size in list_of_models:
self.model = yolo(f"yolov8{self.model_size}-obb.pt")
self.results = self.model.train(data=f"{self.dataset.location}/"
f"yolov8-obb.yaml",
epochs=int(sys.argv[5]), imgsz=640)
elif self.model_size == "all":
for model_size in list_of_models:
self.model = yolo(f"yolov8{model_size}.pt")
self.results = self.model.train(data=f"{self.dataset.location}"
f"/yolov8-obb.yaml",
epochs=int(sys.argv[5]),
imgsz=640)
else:
print("invalid model size")
if __name__ == '__main__':
main()
三、显示
训练完模型后,得到文件best.py和last.py,它们对应的是权重。
使用ultralytics库,您还可以导入yolo并加载您的体重,然后加载您的测试视频。
在此示例中,我使用跟踪功能来获取每个游泳者的 id。
import cv2
from ultralytics import yolo
import sys
def main():
cap = cv2.videocapture(sys.argv[1])
model = yolo(sys.argv[2])
while true:
ret, frame = cap.read()
results = model.track(frame, persist=true)
res_plotted = results[0].plot()
cv2.imshow("frame", res_plotted)
if cv2.waitkey(1) == 27:
break
cap.release()
cv2.destroyallwindows()
if __name__ == "__main__":
main()
为了分析预测,可以获取模型json如下。
results = model.track(frame, persist=True)
results_json = json.loads(results[0].tojson())