Nenhuma Descrição

yolo.py 4.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import cv2
  2. import numpy as np
  3. import argparse
  4. import time
  5. parser = argparse.ArgumentParser()
  6. parser.add_argument('--webcam', help="True/False", default=False)
  7. parser.add_argument('--play_video', help="Tue/False", default=False)
  8. parser.add_argument('--image', help="Tue/False", default=False)
  9. parser.add_argument('--video_path', help="Path of video file", default="videos/car_on_road.mp4")
  10. parser.add_argument('--image_path', help="Path of image to detect objects", default="Images/bicycle.jpg")
  11. parser.add_argument('--verbose', help="To print statements", default=True)
  12. args = parser.parse_args()
  13. #Load yolo
  14. def load_yolo():
  15. #net = cv2.dnn.readNet("yolov3-tiny.weights", "yolov3-tiny.cfg")
  16. net = cv2.dnn.readNet("yolov3.weights", "yolov3.cfg")
  17. classes = []
  18. with open("coco.names", "r") as f:
  19. classes = [line.strip() for line in f.readlines()]
  20. output_layers = [layer_name for layer_name in net.getUnconnectedOutLayersNames()]
  21. colors = np.random.uniform(0, 255, size=(len(classes), 3))
  22. return net, classes, colors, output_layers
  23. def load_image(img_path):
  24. # image loading
  25. img = cv2.imread(img_path)
  26. img = cv2.resize(img, None, fx=0.4, fy=0.4)
  27. height, width, channels = img.shape
  28. return img, height, width, channels
  29. def start_webcam():
  30. cap = cv2.VideoCapture(0)
  31. return cap
  32. def display_blob(blob):
  33. '''
  34. Three images each for RED, GREEN, BLUE channel
  35. '''
  36. for b in blob:
  37. for n, imgb in enumerate(b):
  38. cv2.imshow(str(n), imgb)
  39. def detect_objects(img, net, outputLayers):
  40. blob = cv2.dnn.blobFromImage(img, scalefactor=0.00392, size=(320, 320), mean=(0, 0, 0), swapRB=True, crop=False)
  41. net.setInput(blob)
  42. outputs = net.forward(outputLayers)
  43. return blob, outputs
  44. def get_box_dimensions(outputs, height, width):
  45. boxes = []
  46. confs = []
  47. class_ids = []
  48. for output in outputs:
  49. for detect in output:
  50. scores = detect[5:]
  51. class_id = np.argmax(scores)
  52. conf = scores[class_id]
  53. if conf > 0.3:
  54. center_x = int(detect[0] * width)
  55. center_y = int(detect[1] * height)
  56. w = int(detect[2] * width)
  57. h = int(detect[3] * height)
  58. x = int(center_x - w/2)
  59. y = int(center_y - h / 2)
  60. boxes.append([x, y, w, h])
  61. confs.append(float(conf))
  62. class_ids.append(class_id)
  63. return boxes, confs, class_ids
  64. def draw_labels(boxes, confs, colors, class_ids, classes, img):
  65. indexes = cv2.dnn.NMSBoxes(boxes, confs, 0.5, 0.4)
  66. font = cv2.FONT_HERSHEY_PLAIN
  67. for i in range(len(boxes)):
  68. if i in indexes:
  69. x, y, w, h = boxes[i]
  70. label = str(classes[class_ids[i]])
  71. color = colors[i]
  72. cv2.rectangle(img, (x,y), (x+w, y+h), color, 2)
  73. cv2.putText(img, label, (x, y - 5), font, 3, color, 3)
  74. cv2.imshow("Image", img)
  75. def image_detect(img_path):
  76. model, classes, colors, output_layers = load_yolo()
  77. image, height, width, channels = load_image(img_path)
  78. blob, outputs = detect_objects(image, model, output_layers)
  79. boxes, confs, class_ids = get_box_dimensions(outputs, height, width)
  80. draw_labels(boxes, confs, colors, class_ids, classes, image)
  81. while True:
  82. key = cv2.waitKey(1)
  83. if key == 27:
  84. break
  85. def webcam_detect():
  86. model, classes, colors, output_layers = load_yolo()
  87. cap = start_webcam()
  88. while True:
  89. _, frame = cap.read()
  90. height, width, channels = frame.shape
  91. blob, outputs = detect_objects(frame, model, output_layers)
  92. boxes, confs, class_ids = get_box_dimensions(outputs, height, width)
  93. draw_labels(boxes, confs, colors, class_ids, classes, frame)
  94. key = cv2.waitKey(1)
  95. if key == 27:
  96. break
  97. cap.release()
  98. def start_video(video_path):
  99. model, classes, colors, output_layers = load_yolo()
  100. cap = cv2.VideoCapture(video_path)
  101. while True:
  102. _, frame = cap.read()
  103. height, width, channels = frame.shape
  104. blob, outputs = detect_objects(frame, model, output_layers)
  105. boxes, confs, class_ids = get_box_dimensions(outputs, height, width)
  106. draw_labels(boxes, confs, colors, class_ids, classes, frame)
  107. k = cv2.waitKey(1) & 0xFF
  108. if k == 27:
  109. cv2.destroyAllWindows()
  110. break
  111. cap.release()
  112. if __name__ == '__main__':
  113. webcam = args.webcam
  114. video_play = args.play_video
  115. image = args.image
  116. if webcam:
  117. if args.verbose:
  118. print('---- Starting Web Cam object detection ----')
  119. webcam_detect()
  120. if video_play:
  121. video_path = args.video_path
  122. if args.verbose:
  123. print('Opening '+video_path+" .... ")
  124. start_video(video_path)
  125. if image:
  126. image_path = args.image_path
  127. if args.verbose:
  128. print("Opening "+image_path+" .... ")
  129. image_detect(image_path)
  130. cv2.destroyAllWindows()