Skip to content

Commit

Permalink
Merge pull request #15 from hjonnala/master
Browse files Browse the repository at this point in the history
convert from EdgeTPU API to PyCoral API
  • Loading branch information
scottamain authored Oct 14, 2021
2 parents d910773 + 4375a93 commit 3bcb9bf
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 52 deletions.
22 changes: 20 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,26 @@
# Coral Smart Birdfeeder
A smart birdfeeder that uses the Coral Enterprise Board + Camera module,
and identifies the birds that use the feeder. It also implements a deterrent
A smart birdfeeder that identifies the birds that use the feeder. It also implements a deterrent
for any visiting squirrels.

# Steps with Coral USB accelerator and linux machines

```
git clone https://github.com/google-coral/project-birdfeeder.git
cd project-birdfeeder
sh install_requirements.sh
sh birdfeeder.sh
```

# Steps with Coal Dev Board and USB camera

```
git clone https://github.com/google-coral/project-birdfeeder.git
cd project-birdfeeder
sh install_requirements.sh
sh birdfeeder_dev_board.sh
```

## License
Copyright 2019 Google LLC

Expand Down
107 changes: 63 additions & 44 deletions bird_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,58 +26,59 @@
model.
"""

import argparse
import time
import re
import imp
import logging
import gstreamer
from edgetpu.classification.engine import ClassificationEngine
from PIL import Image
from playsound import playsound

def save_data(image,results,path,ext='png'):
from pycoral.utils.dataset import read_label_file
from pycoral.utils.edgetpu import make_interpreter
from pycoral.adapters import common
from pycoral.adapters.classify import get_classes

import gstreamer


def save_data(image, results, path, ext='png'):
"""Saves camera frame and model inference results
to user-defined storage directory."""
tag = '%010d' % int(time.monotonic()*1000)
name = '%s/img-%s.%s' %(path,tag,ext)
name = '%s/img-%s.%s' % (path, tag, ext)
image.save(name)
print('Frame saved as: %s' %name)
logging.info('Image: %s Results: %s', tag,results)
print('Frame saved as: %s' % name)
logging.info('Image: %s Results: %s', tag, results)

def load_labels(path):
"""Parses provided label file for use in model inference."""
p = re.compile(r'\s*(\d+)(.+)')
with open(path, 'r', encoding='utf-8') as f:
lines = (p.match(line).groups() for line in f.readlines())
return {int(num): text.strip() for num, text in lines}

def print_results(start_time, last_time, end_time, results):
"""Print results to terminal for debugging."""
inference_rate = ((end_time - start_time) * 1000)
fps = (1.0/(end_time - last_time))
print('\nInference: %.2f ms, FPS: %.2f fps' % (inference_rate, fps))
for label, score in results:
print(' %s, score=%.2f' %(label, score))
print(' %s, score=%.2f' % (label, score))

def do_training(results,last_results,top_k):

def do_training(results, last_results, top_k):
"""Compares current model results to previous results and returns
true if at least one label difference is detected. Used to collect
images for training a custom model."""
new_labels = [label[0] for label in results]
old_labels = [label[0] for label in last_results]
shared_labels = set(new_labels).intersection(old_labels)
shared_labels = set(new_labels).intersection(old_labels)
if len(shared_labels) < top_k:
print('Difference detected')
return True
print('Difference detected')
return True
return False


def user_selections():
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True,
help='.tflite model path')
parser.add_argument('--labels', required=True,
help='label file path')
parser.add_argument('--videosrc', help='Which video source to use', default='/dev/video0')
parser.add_argument('--top_k', type=int, default=3,
help='number of classes with highest score to display')
parser.add_argument('--threshold', type=float, default=0.1,
Expand All @@ -88,7 +89,7 @@ def user_selections():
help='File path to deterrent sound')
parser.add_argument('--print', default=False, required=False,
help='Print inference results to terminal')
parser.add_argument('--training', default=False, required=False,
parser.add_argument('--training', action='store_true',
help='Training mode for image collection')
args = parser.parse_args()
return args
Expand All @@ -100,44 +101,62 @@ def main():
gather images for custom model creation or in deterrent mode that sounds an
'alarm' if a defined label is detected."""
args = user_selections()
print("Loading %s with %s labels."%(args.model, args.labels))
engine = ClassificationEngine(args.model)
labels = load_labels(args.labels)
print("Loading %s with %s labels." % (args.model, args.labels))
interpreter = make_interpreter(args.model)
interpreter.allocate_tensors()
labels = read_label_file(args.labels)
input_tensor_shape = interpreter.get_input_details()[0]['shape']
if (input_tensor_shape.size != 4 or
input_tensor_shape[0] != 1):
raise RuntimeError(
'Invalid input tensor shape! Expected: [1, height, width, channel]')

output_tensors = len(interpreter.get_output_details())
if output_tensors != 1:
raise ValueError(
('Classification model should have 1 output tensor only!'
'This model has {}.'.format(output_tensors)))
storage_dir = args.storage

#Initialize logging file
logging.basicConfig(filename='%s/results.log'%storage_dir,
# Initialize logging file
logging.basicConfig(filename='%s/results.log' % storage_dir,
format='%(asctime)s-%(message)s',
level=logging.DEBUG)

last_time = time.monotonic()
last_results = [('label', 0)]
def user_callback(image,svg_canvas):

def user_callback(image, svg_canvas):
nonlocal last_time
nonlocal last_results
start_time = time.monotonic()
results = engine.classify_with_image(image, threshold=args.threshold, top_k=args.top_k)
common.set_resized_input(
interpreter, image.size, lambda size: image.resize(size, Image.NEAREST))
interpreter.invoke()
results = get_classes(interpreter, args.top_k, args.threshold)
end_time = time.monotonic()
play_sounds = [labels[i] for i, score in results]
results = [(labels[i], score) for i, score in results]

if args.print:
print_results(start_time,last_time, end_time, results)
print_results(start_time, last_time, end_time, results)

if args.training:
if do_training(results,last_results,args.top_k):
save_data(image,results, storage_dir)
if do_training(results, last_results, args.top_k):
save_data(image, results, storage_dir)
else:
#Custom model mode:
#The labels can be modified to detect/deter user-selected items
if results[0][0] !='background':
save_data(image, storage_dir,results)
if 'fox squirrel, eastern fox squirrel, Sciurus niger' in results:
playsound(args.sound)
logging.info('Deterrent sounded')

last_results=results
# Custom model mode:
# The labels can be modified to detect/deter user-selected items
if len(results):
if results[0][0] != 'background':
save_data(image, results, storage_dir)

if FOX_SQUIRREL_LABEL in play_sounds:
playsound(args.sound)
logging.info('Deterrent sounded')

last_results = results
last_time = end_time
result = gstreamer.run_pipeline(user_callback)
gstreamer.run_pipeline(user_callback, videosrc=args.videosrc)


if __name__ == '__main__':
FOX_SQUIRREL_LABEL = 'fox squirrel, eastern fox squirrel, Sciurus niger'
main()
6 changes: 3 additions & 3 deletions birdfeeder.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
#!/bin/bash

python3 bird_classify.py \
--model mobilenet_v2_1.0_224_quant_edgetpu.tflite \
--labels imagenet_labels.txt \
--model models/mobilenet_v2_1.0_224_quant_edgetpu.tflite \
--labels labels/imagenet_labels.txt \
--videosrc /dev/video0 \
--storage sdcard_directory \
--sound sound_file.wav \
--training True \
--print True
26 changes: 26 additions & 0 deletions birdfeeder_dev_board.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Coral Smart Bird Feeder Script.
# Automates running the bird_classify code.

#!/bin/bash

python3 bird_classify.py \
--model models/mobilenet_v2_1.0_224_quant_edgetpu.tflite \
--labels labels/imagenet_labels.txt \
--videosrc /dev/video1 \
--storage sdcard_directory \
--sound sound_file.wav \
--print True
6 changes: 3 additions & 3 deletions gstreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def detectCoralDevBoard():

def run_pipeline(user_function,
src_size=(640,480),
appsink_size=(320, 180)):
PIPELINE = 'v4l2src device=/dev/video0 ! {src_caps} ! {leaky_q} ! tee name=t'
appsink_size=(320, 180), videosrc='/dev/video0'):
PIPELINE = 'v4l2src device={videosrc} ! {src_caps} ! {leaky_q} ! tee name=t'
if detectCoralDevBoard():
SRC_CAPS = 'video/x-raw,format=YUY2,width={width},height={height},framerate=30/1'
PIPELINE += """
Expand All @@ -86,7 +86,7 @@ def run_pipeline(user_function,
src_caps = SRC_CAPS.format(width=src_size[0], height=src_size[1])
dl_caps = DL_CAPS.format(width=appsink_size[0], height=appsink_size[1])
sink_caps = SINK_CAPS.format(width=appsink_size[0], height=appsink_size[1])
pipeline = PIPELINE.format(leaky_q=LEAKY_Q,
pipeline = PIPELINE.format(videosrc=videosrc, leaky_q=LEAKY_Q,
src_caps=src_caps, dl_caps=dl_caps, sink_caps=sink_caps,
sink_element=SINK_ELEMENT)

Expand Down
28 changes: 28 additions & 0 deletions install_requirements.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/bin/bash
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

python3 -m pip install svgwrite playsound PyGObject
python3 -m pip install --extra-index-url https://google-coral.github.io/py-repo/ pycoral~=2.0
mkdir -p "models"
cd models
rm mobilenet_v2_1.0_224_quant_edgetpu.tflite
wget https://dl.google.com/coral/canned_models/mobilenet_v2_1.0_224_quant_edgetpu.tflite
cd ..
mkdir -p "labels"
cd labels
rm imagenet_labels.txt
wget https://dl.google.com/coral/canned_models/imagenet_labels.txt
cd ..
mkdir -p "sdcard_directory"
Binary file added sound_file.wav
Binary file not shown.

0 comments on commit 3bcb9bf

Please sign in to comment.