naritoブログ

【お知らせ】
新ブログができました。今後そちらで更新し、このサイトは更新されません(ウェブサイト自体は残しておきます)
このブログの内容に関してコメントしたい場合は、新ブログのフリースペースに書き込んでください

このブログの内容を新ブログに移行中です。このブログで見つからない記事は、新ブログにありま

Djangoで、手書き数字の画像認識(Deep Learning)

約596日前 2017年2月26日3:58
プログラミング関連
Bootstrap4 Django 画像認識 ディープラーニング Python
「ゼロから作るDeep Learning」という書籍を今読んでいるのですが、これをDjangoに組み込みたくなりました。
これはChainerやTensorFlowなどのディープラーニングのフレームワークを使わず、0からディープラーニングを実装し仕組みを理解しよう、という趣旨の本です。お勧めです。
今回は書籍の7章、畳み込みニューラルネットワークの実装サンプルをDjangoに埋め込みます。
ネットワークの構成は以下の流れです。

Convolution
Relu
Pooling

Affine
Relu

Affine
Softmax


書籍のソースコードは以下の部分です。
https://github.com/oreilly-japan/deep-learning-from-scratch/tree/master/ch07
https://github.com/oreilly-japan/deep-learning-from-scratch/tree/master/common

まずはこのような画面です。


このように手書き数字(MNIST)をいくつか選択し...


送信を押すと、このような結果が表示されます。



環境は以下です。PYthon3.6とdjango、numpy、pillowをpipするだけ!

Python3.6
Django==1.10.5
numpy==1.12.0
olefile==0.44
Pillow==4.0.0


プロジェクト名は「tes」
アプリケーションは一つで「main」

ディレクトリは以下のように。tesとmainはプロジェクトとアプリケーションです。
deep-learning-from-scratch/ch07/の、params.pklをこのディレクトリ内に置きます。これは学習済みの重みパラメータです。
cnnにディープラーニング関連のプログラムを入れます。


cnnディレクトリの中身。今回使うのは以下だけです。
simple_convnet.pyは、deep-learning-from-scratch/ch07/
それ以外は、deep-learning-from-scratch/common/
にあります。



今回はDjangoに埋め込むため、importのパスを変更しています。変更はimport部分のみです。
common.layers等を.layersに変更するだけです。

tes/cnn/simple_convnet.py

import pickle
import numpy as np
from collections import OrderedDict
from .layers import * # 変更
from .gradient import numerical_gradient # 変更


tes/cnn/layers.py

import numpy as np
from .functions import * # 変更
from .util import im2col, col2im # 変更



tes/tes/settings.py
INSTALLED_APPSにmainを足しておけば、大丈夫です。

"""
Django settings for tes project.

Generated by 'django-admin startproject' using Django 1.10.5.

For more information on this file, see
https://docs.djangoproject.com/en/1.10/topics/settings/

For the full list of settings and their values, see
https://docs.djangoproject.com/en/1.10/ref/settings/
"""

import os

# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/1.10/howto/deployment/checklist/

# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = 'r&8x(hf_^x=0eeahvb5&*vyso!zhhh4vfi6x&ggohngdzg2+oq'

# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = True

ALLOWED_HOSTS = ['*']


# Application definition

INSTALLED_APPS = [
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
'main',
]

MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
]

ROOT_URLCONF = 'tes.urls'

TEMPLATES = [
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': [],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.template.context_processors.debug',
'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
],
},
},
]

WSGI_APPLICATION = 'tes.wsgi.application'


# Database
# https://docs.djangoproject.com/en/1.10/ref/settings/#databases

DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
}
}


# Password validation
# https://docs.djangoproject.com/en/1.10/ref/settings/#auth-password-validators

AUTH_PASSWORD_VALIDATORS = [
{
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
},
]


# Internationalization
# https://docs.djangoproject.com/en/1.10/topics/i18n/

LANGUAGE_CODE = 'ja'

TIME_ZONE = 'Asia/Tokyo'

USE_I18N = True

USE_L10N = True

USE_TZ = True


# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/1.10/howto/static-files/

STATIC_URL = '/static/'


tes/tes/urls.py

from django.conf.urls import url
from django.contrib import admin
from main.views import Home, upload

urlpatterns = [
url(r'^admin/', admin.site.urls),
url(r'^upload/$', upload, name='upload'),
url(r'^$', Home.as_view(), name='home'),
]


tes/main/views.py

import base64

from django.shortcuts import render, redirect
from django.views import generic
import numpy as np
from PIL import Image

from cnn.simple_convnet import SimpleConvNet

network = SimpleConvNet(
input_dim=(1, 28, 28), hidden_size=100, output_size=10)
network.load_params('params.pkl')


class Home(generic.TemplateView):
template_name = 'main/home.html'


def upload(request):
files = request.FILES.getlist("files[]")
if request.method == 'POST' and files:
array_list = []
for file in files:
img = Image.open(file)
array = np.asarray(img)
array_list.append(array)

x = np.array(array_list).reshape(len(array_list), 1, 28, 28)
labels = network.predict(x).argmax(axis=1)
result = []
for file, label in zip(files, labels):
file.seek(0)
src = base64.b64encode(file.read())
# django2.0からは src = base64.b64encode(file.read()).decode()
result.append((src, label))
context = {
'result': result,
}
return render(request, 'main/result.html', context)
else:
return redirect('home')


モジュールがロードされたら、SimpleConvNetをインスタンス化し重みパラメータを読み込みます。
後は、好きなときにpredictなんかを呼んであげるだけ!

network = SimpleConvNet(
input_dim=(1, 28, 28), hidden_size=100, output_size=10)
network.load_params('params.pkl')



これはトップページを表示するTemplateViewです。

class Home(generic.TemplateView):
template_name = 'main/home.html'



送信ボタンでupload関数が呼ばれます。
複数ファイルのアップロードは、request.FILES.getlist("files[]") のようにします。
何かアップロードがあり、POSTメソッドならばその後の処理を、そうでなければトップページへリダイレクトです。

def upload(request):
files = request.FILES.getlist("files[]")
if request.method == 'POST' and files:
...
...
else:
return redirect('home')


複数ファイルのアップロードは、以前にも書きました。
Djangoで、複数ファイルをアップロードする
https://torina.top/detail/245/

filesには、アップロードされたファイルが入ります。リスト型です。

[<InMemoryUploadedFile: 11.png (image/png)>, <InMemoryUploadedFile: 19.png (image/png)>, <InMemoryUploadedFile: 36.png (image/png)>]


Image.open()は、ファイルのパスか、又はファイルっぽいオブジェクトを受け取ります。
InMemoryUploadedFileはファイルライクなオブジェクトで、無事に渡せます。
やろうと思えば、以下のようにshow()で画像を表示することもできます...

img = Image.open(file)
img.show()


np.assarayで、PILのImageオブジェクトをnumpy配列に変換できます。そして、array_listにappendします。

array_list = []
for file in files:
img = Image.open(file)
array = np.asarray(img)
array_list.append(array)



numpy配列からPILのイメージオブジェクトに戻したいときは、例えば以下のように。

img = Image.open(file)
array = np.asarray(img)
img = Image.fromarray(np.uint8(array)) # これ
img.show()


predictに渡せる形に変換しています。1, 28, 28の部分は(チャンネル, 高さ, 幅)で、SimpleConvNet(input_dim=(1, 28, 28),...
と一致させています。
labelsには、推論した結果が[5 7 8 9]のような形で入ります。

x = np.array(array_list).reshape(len(array_list), 1, 28, 28)
labels = network.predict(x).argmax(axis=1)


結果をHTMLに表示させたい訳ですが、今回はアップロードされた画像もテンプレートへ渡そうと思いました。
元画像があった方がわかりやすいですからね。
方法はいくつかあると思いますが、今回はアップロードされた画像データを、imgタグのsrcに埋め込む方法にします。
(画像データ, 推論した結果)というタプルをresult変数に追加していきます。

result = []
for file, label in zip(files, labels):
file.seek(0)
src = base64.b64encode(file.read())
# django2.0からは src = base64.b64encode(file.read()).decode()
result.append((src, label))
context = {
'result': result,
}
return render(request, 'main/result.html', context)


画像データを取得する必要があります。
InMemoryUploadedFileはファイルライクなオブジェクトで、read()で読み込むことはできます。
ただ、そのままread()しても上手くいかないので、seek(0)を直前で呼び出しています。
更に、imgタグのsrcに直接読み込ませるためにはbase64エンコードも必要です。
これはDjango1.11の例で、Django2.0からはテンプレートでのバイト列の扱いが変わりました。このまま渡してもb'hogehoge'...と出力されてしまうので、decode()が必要となります。

for file, label in zip(files, labels):
file.seek(0)
src = base64.b64encode(file.read())
# django2.0からは src = base64.b64encode(file.read()).decode()
result.append((src, label))


tes/main/templates/base.html
bootstrap4を使っています。

<!DOCTYPE html>
<html lang="ja">
<head>
<!-- Required meta tags -->
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">

<!-- Bootstrap CSS -->
<link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0-alpha.6/css/bootstrap.min.css" integrity="sha384-rwoIResjU2yc3z8GV/NPeZWAv56rSmLldC3R/AZzGRnGxQQKnKkoFVhFQhNUwEyJ" crossorigin="anonymous">

</head>
<body>
{% block content %}{% endblock %}
<!-- jQuery first, then Tether, then Bootstrap JS. -->
<script
src="https://code.jquery.com/jquery-3.1.1.min.js"
integrity="sha256-hVVnYaiADRTO2PzUGmuLJr8BLUSjGIZsDYGmIJLv2b8="
crossorigin="anonymous"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/tether/1.4.0/js/tether.min.js" integrity="sha384-DztdAPBWPRXSA/3eYEEUWrWCy7G5KFbe8fFjk5JAIxUYHKkDx6Qin1DkWx51bBrb" crossorigin="anonymous"></script>
<script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0-alpha.6/js/bootstrap.min.js" integrity="sha384-vBWWzlZJ8ea9aCX4pEW3rVHjgjt7zpkNpZk+02D9phzyeVkE+jo0ieGizqPLForn" crossorigin="anonymous"></script>

</body>
</html>



tes/main/templates/home.html
formのenctype="multipart/form-data"、fileタグのfiles[]、multipleあたりを忘れないようにしましょう。

{% extends "main/base.html" %}
{% block content %}
<div class="mt-5 text-center">
<h1 class="display-1">ディープラーニングデモ</h1>
<p class="lead">手書きの画像をアップロードし、OKを押してください</p>
<form action="{% url 'upload' %}" method="POST" enctype="multipart/form-data">
<input type="file" name="files[]" multiple>
<input type="submit" class="btn btn-lg btn-secondary" value="送信">
{% csrf_token %}
</form>
</div>
{% endblock %}


tes/main/templates/result.html
{{ data.0 }}や{{ data.1 }}は、data[0]、data[1]と同義です。テンプレートでの書き方ですね。
imgタグのsrcに直接値を書き込むときは、data:image/png;base64,を先頭につけます...
jpgならば、pngの部分をjpgに変更しましょう。

{% extends "main/base.html" %}
{% block content %}
<div class="mt-5 text-center">
{% for data in result %}
<img class="img-fluid" src="data:image/png;base64,{{ data.0 }}">
<p class="lead">画像は{{ data.1 }}です。</p>
<hr>
{% endfor %}
<a class="btn btn-lg btn-secondary" href="{% url 'home' %}">戻る</a>
</div>

{% endblock %}
名無し 約385日前 2017年9月25日13:01 返信する
test
名無し 約35日前 2018年9月10日13:55 返信する
ブログに記載され構成でviews.pyとresult.htmlを実装しましたが、
result.htmlの「<img class="img-fluid" src="data:image/png;base64,{{ data.0 }}">」の部分が

うまく動かず、画像が表示されない状態です。(米粒マークだけが表示されるイメージ)
base64エンコード処理でうまく画像をHTML上に表示できていない状態だと思われますが、何か原因となりそうな要因はありますでしょうか?
名無し 約35日前 2018年9月10日14:08
補足ですが、アップロードしたファイルはpng形式です。
テンプレート側で単純に{{ data.0}}とすると、以下の様なエンコードされた値がわたってきているので、イメージデータはテンプレート側まで来ているようです。
{{data.0}}の表示結果
b'Qk1mCQAAAAAAADYAAAAoAAAAHAAAABwAAAABABgAAAAAADAJAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA  以下省略

ちなみにブラウザはChromeですが、IEでもだめでした。「画像表示部分がxで表示される」
なりと 約34日前 2018年9月11日13:53
確認しますので、プロジェクトを添付していただけますか。
名無し 約34日前 2018年9月11日14:03
1つ原因箇所がわかりました。

imgデータが以下の様な形にバイト型?でテンプレートにわたっていました。

b'Qk1mCQAAAAAAADYA ***省略*** AAAAAAA'

先頭の「b'」と後ろの「'」を省いた部分のコードをhtmlの下記の{{data.0}}部分にコピペしたところ画像が表示できました。
そこで、テンプレートに渡す前にsrcを文字列型str(src)に変換して渡してもNGでした。

<img src="data:image/png;base64,{{data.0}" />

ここからどのような対応策が考えられますでしょうか?
名無し 約34日前 2018年9月11日14:35
度々申し訳ないです・・・
再確認したところ、views.pyで以下のようにsrcを文字列型に変換してから不要な部分を削除してテンプレートに渡すことで解消しました。


src = base64.b64encode(file.read())
src = str(src)[2:-1]
なりと 約34日前 2018年9月11日15:14
Django2.0から、テンプレートへ渡したバイト列はb'hogehoge'と表示されます。Django1.11までは、単純な文字列として出力されていました。

対策はその方法でも良いですし、
src = base64.b64encode(file.read()).decode()
とビュー側で対処してもOKです。

記事にも補足しておきました。
名無し 約33日前 2018年9月12日6:35
ご回答ありがとうございます!
勉強になりました。