what.examples.mobilenet_ssd_demo
1import cv2 2import torch 3 4from what.cli.model import * 5from what.utils.file import get_file 6 7from what.models.detection.ssd.mobilenet_v1_ssd import MobileNetV1SSD 8from what.models.detection.ssd.mobilenet_v2_ssd_lite import MobileNetV2SSDLite 9from what.models.detection.datasets.voc import VOC_CLASS_NAMES 10 11from what.models.detection.utils.box_utils import draw_bounding_boxes 12 13device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 14 15what_ssd_model_list = what_model_list[6:8] 16 17def mobilenet_ssd_inference_demo(): 18 19 max_len = max([len(x[WHAT_MODEL_NAME_INDEX]) for x in what_ssd_model_list]) 20 for i, model in enumerate(what_ssd_model_list, start=1): 21 if os.path.isfile(os.path.join(WHAT_MODEL_PATH, model[WHAT_MODEL_FILE_INDEX])): 22 downloaded = 'x' 23 else: 24 downloaded = ' ' 25 print('[{}] {} : {:<{w}s}\t{}\t{}'.format(downloaded, i, model[WHAT_MODEL_NAME_INDEX], model[WHAT_MODEL_TYPE_INDEX], model[WHAT_MODEL_DESC_INDEX], w=max_len)) 26 27 index = input(f"Please input the model index: ") 28 while not index.isdigit() or int(index) > len(what_ssd_model_list): 29 index = input(f"Model [{index}] does not exist. Please try again: ") 30 31 index = int(index) - 1 32 33 # Download the model first if not exists 34 # Check what_model_list for all available models 35 if not os.path.isfile(os.path.join(WHAT_MODEL_PATH, what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX])): 36 get_file(what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX], 37 WHAT_MODEL_PATH, 38 what_ssd_model_list[index][WHAT_MODEL_URL_INDEX], 39 what_ssd_model_list[index][WHAT_MODEL_HASH_INDEX]) 40 41 if index == 0: 42 # Initialize the model 43 model = MobileNetV1SSD(os.path.join(WHAT_MODEL_PATH, what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX]), 44 VOC_CLASS_NAMES, 45 is_test=True, 46 device=device) 47 48 if index == 1: 49 # Initialize the model 50 model = MobileNetV2SSDLite(os.path.join(WHAT_MODEL_PATH, what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX]), 51 VOC_CLASS_NAMES, 52 is_test=True, 53 device=device) 54 55 video = input(f"Please input the OpenCV capture device (e.g. 0, 1, 2): ") 56 57 while not video.isdigit(): 58 video = input(f"Please input the OpenCV capture device (e.g. 0, 1, 2): ") 59 60 # Capture from camera 61 cap = cv2.VideoCapture(int(video)) 62 #cap.set(3, 1920) 63 #cap.set(4, 1080) 64 65 try: 66 while True: 67 _, orig_image = cap.read() 68 if orig_image is None: 69 continue 70 71 # Image preprocessing 72 image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB) 73 74 # Run inference 75 images, boxes, labels, probs = model.predict(image, 10, 0.4) 76 image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 77 78 # Draw bounding boxes onto the image 79 height, width, _ = image.shape 80 81 if len(boxes) > 0: 82 output = draw_bounding_boxes(image, boxes, labels, model.class_names, probs); 83 84 cv2.imshow('MobileNet SSD Demo', output) 85 86 if cv2.waitKey(1) & 0xFF == ord('q'): 87 break 88 89 cap.release() 90 cv2.destroyAllWindows() 91 92 except Exception as e: 93 print(enumerate)
def
mobilenet_ssd_inference_demo():
18def mobilenet_ssd_inference_demo(): 19 20 max_len = max([len(x[WHAT_MODEL_NAME_INDEX]) for x in what_ssd_model_list]) 21 for i, model in enumerate(what_ssd_model_list, start=1): 22 if os.path.isfile(os.path.join(WHAT_MODEL_PATH, model[WHAT_MODEL_FILE_INDEX])): 23 downloaded = 'x' 24 else: 25 downloaded = ' ' 26 print('[{}] {} : {:<{w}s}\t{}\t{}'.format(downloaded, i, model[WHAT_MODEL_NAME_INDEX], model[WHAT_MODEL_TYPE_INDEX], model[WHAT_MODEL_DESC_INDEX], w=max_len)) 27 28 index = input(f"Please input the model index: ") 29 while not index.isdigit() or int(index) > len(what_ssd_model_list): 30 index = input(f"Model [{index}] does not exist. Please try again: ") 31 32 index = int(index) - 1 33 34 # Download the model first if not exists 35 # Check what_model_list for all available models 36 if not os.path.isfile(os.path.join(WHAT_MODEL_PATH, what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX])): 37 get_file(what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX], 38 WHAT_MODEL_PATH, 39 what_ssd_model_list[index][WHAT_MODEL_URL_INDEX], 40 what_ssd_model_list[index][WHAT_MODEL_HASH_INDEX]) 41 42 if index == 0: 43 # Initialize the model 44 model = MobileNetV1SSD(os.path.join(WHAT_MODEL_PATH, what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX]), 45 VOC_CLASS_NAMES, 46 is_test=True, 47 device=device) 48 49 if index == 1: 50 # Initialize the model 51 model = MobileNetV2SSDLite(os.path.join(WHAT_MODEL_PATH, what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX]), 52 VOC_CLASS_NAMES, 53 is_test=True, 54 device=device) 55 56 video = input(f"Please input the OpenCV capture device (e.g. 0, 1, 2): ") 57 58 while not video.isdigit(): 59 video = input(f"Please input the OpenCV capture device (e.g. 0, 1, 2): ") 60 61 # Capture from camera 62 cap = cv2.VideoCapture(int(video)) 63 #cap.set(3, 1920) 64 #cap.set(4, 1080) 65 66 try: 67 while True: 68 _, orig_image = cap.read() 69 if orig_image is None: 70 continue 71 72 # Image preprocessing 73 image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB) 74 75 # Run inference 76 images, boxes, labels, probs = model.predict(image, 10, 0.4) 77 image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 78 79 # Draw bounding boxes onto the image 80 height, width, _ = image.shape 81 82 if len(boxes) > 0: 83 output = draw_bounding_boxes(image, boxes, labels, model.class_names, probs); 84 85 cv2.imshow('MobileNet SSD Demo', output) 86 87 if cv2.waitKey(1) & 0xFF == ord('q'): 88 break 89 90 cap.release() 91 cv2.destroyAllWindows() 92 93 except Exception as e: 94 print(enumerate)