手写数字识别(Android版)服务端

介绍

服务端实现的需求:

  • 接收到手机发来的图片并存储到本地
  • 识别本地存储的图片
  • 将识别结果传回手机

使用flask服务器和http协议

实现

训练和预测部分和之前的类似,可以参考之前写的文章

网络

识别网络可以自定义成识别效果更好的,文件名为:network.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
import torch.nn as nn

# 自定义手写数字识别网络
class net(nn.Module):
def __init__(self):
super(net, self).__init__()

self.Conn_layers = nn.Sequential(
nn.Linear(784, 100),
nn.Sigmoid(),
nn.Linear(100, 10),
nn.Sigmoid()
)

def forward(self, input):
output = self.Conn_layers(input)

return output

训练

和之前写的类似,超参数、损失函数和优化器可根据自己实际情况调整,文件名为:train.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader

from network import *

# 下载训练集
train_dataset = datasets.MNIST(root='./data/',
train=True,
transform=transforms.ToTensor(),
download=False)
# 下载测试集
test_dataset = datasets.MNIST(root='./data/',
train=False,
transform=transforms.ToTensor(),
download=False)

# 设置批次数
batch_size = 100

# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
batch_size = batch_size,
shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
batch_size = batch_size,
shuffle = True)

# 定义学习率
LR = 0.1

# 定义一个网络对象
net = net()

# 损失函数使用交叉熵
loss_function = nn.CrossEntropyLoss()

# 优化函数使用 SGD
optimizer = optim.SGD(
net.parameters(),
lr = LR,
momentum = 0.9,
weight_decay = 0.0005
)

# 定义迭代次数
epoch = 20

# 进行迭代训练
for epoch in range(epoch):
for i, data in enumerate(train_loader):
inputs, labels = data

# 转换下输入形状
inputs = inputs.reshape(batch_size, 784)

inputs, labels = Variable(inputs), Variable(labels)
outputs = net(inputs)
loss = loss_function(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()

# 初始化正确结果数为0
test_result = 0

# 用测试数据进行测试
for data_test in test_loader:
images, labels = data_test

# 转换下输入形状
images = images.reshape(batch_size, 784)

images, labels = Variable(images), Variable(labels)
output_test = net(images)

# 对一个批次的数据的准确性进行判断
for i in range(len(labels)):

# 如果输出结果的最大值的索引与标签内正确数据相等,准确个数累加
if torch.argmax(output_test[i]) == labels[i]:
test_result += 1

# 打印每次迭代后正确的结果数
print("Epoch {} : {} / {}".format(epoch, test_result, len(test_dataset)))

# 保存权重模型
torch.save(net.state_dict(), 'weight/test.pth')

图像预处理

文件名为:pretreatment.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import cv2
import numpy as np

def image_preprocessing():

# 读取图片
img = cv2.imread("getImage/image.jpg")

# =====================图像处理======================== #

# 转换成灰度图像
gray_img = cv2.cvtColor(img , cv2.COLOR_BGR2GRAY)

# 进行高斯滤波
gauss_img = cv2.GaussianBlur(gray_img, (5,5), 0, 0, cv2.BORDER_DEFAULT)

# 边缘检测
img_edge1 = cv2.Canny(gauss_img, 100, 200)

# =====================图像分割======================== #

# 获取原始图像的宽和高
high = img.shape[0]
width = img.shape[1]

# 分别初始化高和宽的和
add_width = np.zeros(high, dtype = int)
add_high = np.zeros(width, dtype = int)

# 计算每一行的灰度图的值的和
for h in range(high):
for w in range(width):
add_width[h] = add_width[h] + img_edge1[h][w]

# 计算每一列的值的和
for w in range(width):
for h in range(high):
add_high[w] = add_high[w] + img_edge1[h][w]

# 初始化上下边界为宽度总值最大的值的索引
acount_high_up = np.argmax(add_width)
acount_high_down = np.argmax(add_width)

# 将上边界坐标值上移,直到没有遇到白色点停止,此为数字的上边界
while add_width[acount_high_up] != 0:
acount_high_up = acount_high_up + 1

# 将下边界坐标值下移,直到没有遇到白色点停止,此为数字的下边界
while add_width[acount_high_down] != 0:
acount_high_down = acount_high_down - 1

# 初始化左右边界为宽度总值最大的值的索引
acount_width_left = np.argmax(add_high)
acount_width_right = np.argmax(add_high)

# 将左边界坐标值左移,直到没有遇到白色点停止,此为数字的左边界
while add_high[acount_width_left] != 0:
acount_width_left = acount_width_left - 1

# 将右边界坐标值右移,直到没有遇到白色点停止,此为数字的右边界
while add_high[acount_width_right] != 0:
acount_width_right = acount_width_right + 1

# 求出宽和高的间距
width_spacing = acount_width_right - acount_width_left
high_spacing = acount_high_up - acount_high_down

# 求出宽和高的间距差
poor = width_spacing - high_spacing

# 将数字进行正方形分割,目的是方便之后进行图像压缩
if poor > 0:
tailor_image = img[acount_high_down - poor \
// 2 - 5:acount_high_up + poor - poor \
// 2 + 5, acount_width_left - 5:acount_width_right + 5]
else:
tailor_image = img[acount_high_down - 5:acount_high_up + 5, \
acount_width_left + poor // \
2 - 5:acount_width_right - poor + poor // 2 + 5]

# ======================小图处理======================= #

# 将裁剪后的图片进行灰度化
gray_img = cv2.cvtColor(tailor_image , cv2.COLOR_BGR2GRAY)

# 高斯去噪
gauss_img = cv2.GaussianBlur(gray_img, (5,5), 0, 0, cv2.BORDER_DEFAULT)

# 将图像形状调整到28*28大小
zoom_image = cv2.resize(gauss_img, (28, 28))

# 获取图像的高和宽
high = zoom_image.shape[0]
wide = zoom_image.shape[1]

# 将图像每个点的灰度值进行阈值比较
for h in range(high):
for w in range(wide):

# 若灰度值大于100,则判断为背景并赋值0,否则将深灰度值变白处理
if zoom_image[h][w] > 100:
zoom_image[h][w] = 0
else:
zoom_image[h][w] = 255 - zoom_image[h][w]

return zoom_image

预测

这里和之前写的不同的是把预测部分封装成一个函数,这样方便传递结果,文件名为:predict.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from network import *

# pretreatment.py为上面图片预处理的文件名,导入图片预处理文件
import pretreatment as PRE

net = net()
net.load_state_dict(torch.load('weight/test.pth'))

def predict_number():

# 得到返回的待预测图片值,就是pretreatment.py中的zoom_image
img = PRE.image_preprocessing()

# 将待预测图片转换形状
inputs = img.reshape(-1, 784)

# 输入数据转换成tensor张量类型,并转换成浮点类型
inputs = torch.from_numpy(inputs)
inputs = inputs.float()

# 丢入网络进行预测,得到预测数据
predict = net(inputs)

# # 打印对应的最后的预测结果
# print("The number in this picture is {}".format(torch.argmax(predict).detach().numpy()))

# 返回预测结果
return torch.argmax(predict).detach().numpy()

服务器

文件名为:my_server.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from flask import Flask
from flask import request
import os
from werkzeug.utils import secure_filename
from predict import *

app = Flask(__name__)

@app.route('/')
def test():
return '服务器正常运行'

# 此方法接收图片
@app.route('/upload', methods=['POST'])
def upload():

f = request.files['file']
print('连接成功')

# 当前文件所在路径
basepath = os.path.dirname(__file__)
upload_path = os.path.join(basepath, 'getImage', secure_filename(f.filename))

# 保存文件
f.save(upload_path)

# 放入预测函数得到结果
my_result = predict_number()
print(my_result)

# 将结果返回给手机
return str(my_result)

if __name__ == '__main__':
app.run(host='0.0.0.0', port=5555)

效果