-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsemanticSegmentation.py
More file actions
64 lines (52 loc) · 2.46 KB
/
semanticSegmentation.py
File metadata and controls
64 lines (52 loc) · 2.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from matplotlib import pyplot as plt
from tensorflow.python.framework.ops import Tensor
import tensorflow as tf
import numpy as np
import network
slim = tf.contrib.slim
import os
import argparse
import json
from PIL import Image
class Dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
class deepLabV3_InferenceEngine:
def __init__(self,modelDirectory,sess):
self.modelDirectory=modelDirectory #Directory where necessary model checkpoints are located
self.sess=sess
with open(modelDirectory + '\data.json', 'r') as fp: #read configuration values from json file
self.modelConfig = json.load(fp)
self.modelConfig = Dotdict(self.modelConfig )
self.imageHolder = tf.placeholder(tf.float32,shape=[None,None,None,3],name='trueTarget')
self.logits_tf = network.deeplab_v3(self.imageHolder,
self.modelConfig,
is_training=False,
reuse=False)
self.predictions_tf = tf.argmax(self.logits_tf , axis=3)
saver = tf.train.Saver()
saver.restore(self.sess, os.path.join(self.modelDirectory, "model.ckpt"))
print('DeeplabV3 Model restored')
def segmentImage(self,inputImage):
self.testImage = inputImage
self.testImage = self.testImage.resize((self.modelConfig.crop_size,
self.modelConfig.crop_size),
Image.ANTIALIAS)
self.testImage = np.array(self.testImage)
self.testImage = np.expand_dims(self.testImage, axis=0) #This is necessary else Resnet starts bitching about dimensions
print('Segmentation In Progress..')
resultImg = self.sess.run([self.imageHolder,self.predictions_tf],{self.imageHolder:self.testImage})
print('Segmentation Done !')
return resultImg[1][0]
if __name__=="__main__":
deepLab = deepLabV3_InferenceEngine(os.getcwd()+'\model', tf.Session())
testImage = Image.open('C:\\DataSets'+'\\t140.jpg')
segmentedImage = deepLab.segmentImage(testImage)
plt.imshow(segmentedImage)
plt.show()
testImage = Image.open('C:\\DataSets'+'\\s382.jpg')
segmentedImage = deepLab.segmentImage(testImage)
plt.imshow(segmentedImage)
plt.show()