-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathorl_generate.py
76 lines (63 loc) · 2.58 KB
/
orl_generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# -*- coding:utf-8 -*-
import tensorflow as tf
from PIL import Image
import numpy as np
import os
import orl_inference
import cv2
# 生成整数型的属性
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# 生成字符串类型
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
train_path = "./train/"
test_path = "./test/"
classes = {i: i for i in range(1, 41)}
writer_train = tf.python_io.TFRecordWriter("orl_train.tfrecords")
writer_test = tf.python_io.TFRecordWriter("orl_test.tfrecords")
def generate():
# 遍历字典
for index, name in enumerate(classes):
train = train_path + str(name) + '/'
test = test_path + str(name) + '/'
for img_name in os.listdir(train):
img_path = train + img_name # 每一个图片的地址
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_raw = img.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
'label': _int64_feature(index + 1),
'img_raw': _bytes_feature(img_raw)
}))
writer_train.write(example.SerializeToString())
for img_name in os.listdir(test):
img_path = test + img_name # 每一个图片的地址
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_raw = img.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
'label': _int64_feature(index + 1),
'img_raw': _bytes_feature(img_raw)
}))
writer_test.write(example.SerializeToString())
writer_test.close()
writer_train.close()
generate()
# def read_and_decode(filename):
# # 生成一个队列
# filename_queue = tf.train.string_input_producer([filename])
#
# reader = tf.TFRecordReader()
# # 返回文件名和文件
# _, serialized_example = reader.read(filename_queue)
# features = tf.parse_single_example(serialized_example,
# features={
# 'label': tf.FixedLenFeature([], tf.int64),
# 'img_raw': tf.FixedLenFeature([], tf.string),
# })
# img = tf.decode_raw(features['img_raw'], tf.uint8)
# img = tf.reshape(img, [28, 28, 3])
# img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
# label = tf.cast(features['label'], tf.int32)
# return img, label