what.examples.mobilenet_ssd_demo

 1import cv2
 2import torch
 3
 4from what.cli.model import *
 5from what.utils.file import get_file
 6
 7from what.models.detection.ssd.mobilenet_v1_ssd import MobileNetV1SSD
 8from what.models.detection.ssd.mobilenet_v2_ssd_lite import MobileNetV2SSDLite
 9from what.models.detection.datasets.voc import VOC_CLASS_NAMES
10
11from what.models.detection.utils.box_utils import draw_bounding_boxes
12
13device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
15what_ssd_model_list = what_model_list[6:8]
16
17def mobilenet_ssd_inference_demo():
18
19    max_len = max([len(x[WHAT_MODEL_NAME_INDEX]) for x in what_ssd_model_list])
20    for i, model in enumerate(what_ssd_model_list, start=1):
21        if os.path.isfile(os.path.join(WHAT_MODEL_PATH, model[WHAT_MODEL_FILE_INDEX])):
22            downloaded = 'x'
23        else:
24            downloaded = ' '
25        print('[{}] {} : {:<{w}s}\t{}\t{}'.format(downloaded, i, model[WHAT_MODEL_NAME_INDEX], model[WHAT_MODEL_TYPE_INDEX], model[WHAT_MODEL_DESC_INDEX], w=max_len))
26
27    index = input(f"Please input the model index: ")
28    while not index.isdigit() or int(index) > len(what_ssd_model_list):
29        index = input(f"Model [{index}] does not exist. Please try again: ")
30
31    index = int(index) - 1
32
33    # Download the model first if not exists
34    # Check what_model_list for all available models
35    if not os.path.isfile(os.path.join(WHAT_MODEL_PATH, what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX])):
36        get_file(what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX],
37                    WHAT_MODEL_PATH,
38                    what_ssd_model_list[index][WHAT_MODEL_URL_INDEX],
39                    what_ssd_model_list[index][WHAT_MODEL_HASH_INDEX])
40
41    if index == 0:
42        # Initialize the model
43        model = MobileNetV1SSD(os.path.join(WHAT_MODEL_PATH, what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX]),
44                               VOC_CLASS_NAMES,
45                               is_test=True,
46                               device=device)
47
48    if index == 1:
49        # Initialize the model
50        model = MobileNetV2SSDLite(os.path.join(WHAT_MODEL_PATH, what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX]),
51                                   VOC_CLASS_NAMES,
52                                   is_test=True,
53                                   device=device)
54
55    video = input(f"Please input the OpenCV capture device (e.g. 0, 1, 2): ")
56
57    while not video.isdigit():
58        video = input(f"Please input the OpenCV capture device (e.g. 0, 1, 2): ")
59
60    # Capture from camera
61    cap = cv2.VideoCapture(int(video))
62    #cap.set(3, 1920)
63    #cap.set(4, 1080)
64
65    try:
66        while True:
67            _, orig_image = cap.read()
68            if orig_image is None:
69                continue
70
71            # Image preprocessing
72            image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
73
74            # Run inference
75            images, boxes, labels, probs = model.predict(image, 10, 0.4)
76            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
77
78            # Draw bounding boxes onto the image
79            height, width, _ = image.shape
80
81            if len(boxes) > 0:
82                output = draw_bounding_boxes(image, boxes, labels, model.class_names, probs);
83
84            cv2.imshow('MobileNet SSD Demo', output)
85
86            if cv2.waitKey(1) & 0xFF == ord('q'):
87                break
88
89        cap.release()
90        cv2.destroyAllWindows()
91
92    except Exception as e:
93        print(enumerate)
def mobilenet_ssd_inference_demo():
18def mobilenet_ssd_inference_demo():
19
20    max_len = max([len(x[WHAT_MODEL_NAME_INDEX]) for x in what_ssd_model_list])
21    for i, model in enumerate(what_ssd_model_list, start=1):
22        if os.path.isfile(os.path.join(WHAT_MODEL_PATH, model[WHAT_MODEL_FILE_INDEX])):
23            downloaded = 'x'
24        else:
25            downloaded = ' '
26        print('[{}] {} : {:<{w}s}\t{}\t{}'.format(downloaded, i, model[WHAT_MODEL_NAME_INDEX], model[WHAT_MODEL_TYPE_INDEX], model[WHAT_MODEL_DESC_INDEX], w=max_len))
27
28    index = input(f"Please input the model index: ")
29    while not index.isdigit() or int(index) > len(what_ssd_model_list):
30        index = input(f"Model [{index}] does not exist. Please try again: ")
31
32    index = int(index) - 1
33
34    # Download the model first if not exists
35    # Check what_model_list for all available models
36    if not os.path.isfile(os.path.join(WHAT_MODEL_PATH, what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX])):
37        get_file(what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX],
38                    WHAT_MODEL_PATH,
39                    what_ssd_model_list[index][WHAT_MODEL_URL_INDEX],
40                    what_ssd_model_list[index][WHAT_MODEL_HASH_INDEX])
41
42    if index == 0:
43        # Initialize the model
44        model = MobileNetV1SSD(os.path.join(WHAT_MODEL_PATH, what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX]),
45                               VOC_CLASS_NAMES,
46                               is_test=True,
47                               device=device)
48
49    if index == 1:
50        # Initialize the model
51        model = MobileNetV2SSDLite(os.path.join(WHAT_MODEL_PATH, what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX]),
52                                   VOC_CLASS_NAMES,
53                                   is_test=True,
54                                   device=device)
55
56    video = input(f"Please input the OpenCV capture device (e.g. 0, 1, 2): ")
57
58    while not video.isdigit():
59        video = input(f"Please input the OpenCV capture device (e.g. 0, 1, 2): ")
60
61    # Capture from camera
62    cap = cv2.VideoCapture(int(video))
63    #cap.set(3, 1920)
64    #cap.set(4, 1080)
65
66    try:
67        while True:
68            _, orig_image = cap.read()
69            if orig_image is None:
70                continue
71
72            # Image preprocessing
73            image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
74
75            # Run inference
76            images, boxes, labels, probs = model.predict(image, 10, 0.4)
77            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
78
79            # Draw bounding boxes onto the image
80            height, width, _ = image.shape
81
82            if len(boxes) > 0:
83                output = draw_bounding_boxes(image, boxes, labels, model.class_names, probs);
84
85            cv2.imshow('MobileNet SSD Demo', output)
86
87            if cv2.waitKey(1) & 0xFF == ord('q'):
88                break
89
90        cap.release()
91        cv2.destroyAllWindows()
92
93    except Exception as e:
94        print(enumerate)