-
Notifications
You must be signed in to change notification settings - Fork 1
/
input_data.py
executable file
·59 lines (57 loc) · 2.04 KB
/
input_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
import numpy as np
import glob
from skimage import io,transform
import pickle
# ============================================================================
# -----------------生成图片路径和标签的List------------------------------------
# step1:获取所有的图片路径名,存放到
# 对应的列表中,同时贴上标签,存放到label列表中。
def read_img(file_dir):
# 将所有的图片resize成100*100
w = 100
h = 100
c = 3
catelist = os.listdir(file_dir) # 获取改目录下所有子目录
classes = []
for i in catelist:
if i == '.DS_Store':
continue
classes.append(i)
all = []
# step1:获取所有的图片路径名,存放到
# 对应的列表中,同时贴上标签,存放到label列表中。
# 遍历主文件夹下所有的类别文件夹
imgs = []
labels = []
dict_ =[]
for index, name in enumerate(classes):
path = file_dir + name + '/'
# 获取所有该类别文件夹下所有的图片路径
path_all = glob.glob(path + '*.jpg')
dict_.append((index,name))
for img_ in path_all:
print('reading the images:%s' % (img_))
img = io.imread(img_)
img = transform.resize(img, (w, h, c))
imgs.append(img)
labels.append(index)
with open('dict.pickle', 'wb') as file:
pickle.dump(dict(dict_), file)
return np.asarray(imgs,np.float32),np.asarray(labels,np.int32),len(classes)
# ============================================================================
# -----------------生成训练与测试集------------------------------------
def set_val(data,label,ratio):
# 打乱顺序
num_example = data.shape[0]
arr = np.arange(num_example)
np.random.shuffle(arr)
data = data[arr]
label = label[arr]
# 将所有数据分为训练集和验证集
s = np.int(num_example * ratio)
x_train = data[:s]
y_train = label[:s]
x_val = data[s:]
y_val = label[s:]
return x_train,y_train,x_val,y_val