naritoブログ

このブログはDjangoとBootstrap4で作成されました
ソースコード

Djangoで、手書き数字の画像認識(Deep Learning)

プログラミング関連 Bootstrap4 Django 画像認識 ディープラーニング Python 約234日前
2017年2月26日3:58
「ゼロから作るDeep Learning」という書籍を今読んでいるのですが、これをDjangoに組み込みたくなりました。
これはChainerやTensorFlowなどのディープラーニングのフレームワークを使わず、0からディープラーニングを実装し仕組みを理解しよう、という趣旨の本です。お勧めです。
今回は書籍の7章、畳み込みニューラルネットワークの実装サンプルをDjangoに埋め込みます。
ネットワークの構成は以下の流れです。
Convolution
Relu
Pooling

Affine
Relu

Affine
Softmax


書籍のソースコードは以下の部分です。
https://github.com/oreilly-japan/deep-learning-from-scratch/tree/master/ch07
https://github.com/oreilly-japan/deep-learning-from-scratch/tree/master/common

まずはこのような画面です。


このように手書き数字(MNIST)をいくつか選択し...


送信を押すと、このような結果が表示されます。



環境は以下です。PYthon3.6とdjango、numpy、pillowをpipするだけ!
Python3.6
Django==1.10.5
numpy==1.12.0
olefile==0.44
Pillow==4.0.0


プロジェクト名は「tes」
アプリケーションは一つで「main」

ディレクトリは以下のように。tesとmainはプロジェクトとアプリケーションです。
deep-learning-from-scratch/ch07/の、params.pklをこのディレクトリ内に置きます。これは学習済みの重みパラメータです。
cnnにディープラーニング関連のプログラムを入れます。


cnnディレクトリの中身。今回使うのは以下だけです。
simple_convnet.pyは、deep-learning-from-scratch/ch07/
それ以外は、deep-learning-from-scratch/common/
にあります。



今回はDjangoに埋め込むため、importのパスを変更しています。変更はimport部分のみです。
common.layers等を.layersに変更するだけです。

tes/cnn/simple_convnet.py
import pickle
import numpy as np
from collections import OrderedDict
from .layers import *  # 変更
from .gradient import numerical_gradient  # 変更


tes/cnn/layers.py
import numpy as np
from .functions import *  # 変更
from .util import im2col, col2im  # 変更



tes/tes/settings.py
INSTALLED_APPSにmainを足しておけば、大丈夫です。
"""
Django settings for tes project.

Generated by 'django-admin startproject' using Django 1.10.5.

For more information on this file, see
https://docs.djangoproject.com/en/1.10/topics/settings/

For the full list of settings and their values, see
https://docs.djangoproject.com/en/1.10/ref/settings/
"""

import os

# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/1.10/howto/deployment/checklist/

# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = 'r&8x(hf_^x=0eeahvb5&*vyso!zhhh4vfi6x&ggohngdzg2+oq'

# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = True

ALLOWED_HOSTS = ['*']


# Application definition

INSTALLED_APPS = [
    'django.contrib.admin',
    'django.contrib.auth',
    'django.contrib.contenttypes',
    'django.contrib.sessions',
    'django.contrib.messages',
    'django.contrib.staticfiles',
    'main',
]

MIDDLEWARE = [
    'django.middleware.security.SecurityMiddleware',
    'django.contrib.sessions.middleware.SessionMiddleware',
    'django.middleware.common.CommonMiddleware',
    'django.middleware.csrf.CsrfViewMiddleware',
    'django.contrib.auth.middleware.AuthenticationMiddleware',
    'django.contrib.messages.middleware.MessageMiddleware',
    'django.middleware.clickjacking.XFrameOptionsMiddleware',
]

ROOT_URLCONF = 'tes.urls'

TEMPLATES = [
    {
        'BACKEND': 'django.template.backends.django.DjangoTemplates',
        'DIRS': [],
        'APP_DIRS': True,
        'OPTIONS': {
            'context_processors': [
                'django.template.context_processors.debug',
                'django.template.context_processors.request',
                'django.contrib.auth.context_processors.auth',
                'django.contrib.messages.context_processors.messages',
            ],
        },
    },
]

WSGI_APPLICATION = 'tes.wsgi.application'


# Database
# https://docs.djangoproject.com/en/1.10/ref/settings/#databases

DATABASES = {
    'default': {
        'ENGINE': 'django.db.backends.sqlite3',
        'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
    }
}


# Password validation
# https://docs.djangoproject.com/en/1.10/ref/settings/#auth-password-validators

AUTH_PASSWORD_VALIDATORS = [
    {
        'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
    },
    {
        'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
    },
    {
        'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
    },
    {
        'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
    },
]


# Internationalization
# https://docs.djangoproject.com/en/1.10/topics/i18n/

LANGUAGE_CODE = 'ja'

TIME_ZONE = 'Asia/Tokyo'

USE_I18N = True

USE_L10N = True

USE_TZ = True


# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/1.10/howto/static-files/

STATIC_URL = '/static/'


tes/tes/urls.py
from django.conf.urls import url
from django.contrib import admin
from main.views import Home, upload

urlpatterns = [
    url(r'^admin/', admin.site.urls),
    url(r'^upload/$', upload, name='upload'),
    url(r'^$', Home.as_view(), name='home'),
]


tes/main/views.py
import base64

from django.shortcuts import render, redirect
from django.views import generic
import numpy as np
from PIL import Image

from cnn.simple_convnet import SimpleConvNet

network = SimpleConvNet(
    input_dim=(1, 28, 28), hidden_size=100, output_size=10)
network.load_params('params.pkl')


class Home(generic.TemplateView):
    template_name = 'main/home.html'


def upload(request):
    files = request.FILES.getlist("files[]")
    if request.method == 'POST' and files:
        array_list = []
        for file in files:
            img = Image.open(file)
            array = np.asarray(img)
            array_list.append(array)

        x = np.array(array_list).reshape(len(array_list), 1, 28, 28)
        labels = network.predict(x).argmax(axis=1)
        result = []
        for file, label in zip(files, labels):
            file.seek(0)
            src = base64.b64encode(file.read())
            result.append((src, label))
        context = {
            'result': result,
        }
        return render(request, 'main/result.html', context)
    else:
        return redirect('home')


モジュールがロードされたら、SimpleConvNetをインスタンス化し重みパラメータを読み込みます。
後は、好きなときにpredictなんかを呼んであげるだけ!
network = SimpleConvNet(
    input_dim=(1, 28, 28), hidden_size=100, output_size=10)
network.load_params('params.pkl')



これはトップページを表示するTemplateViewです。
class Home(generic.TemplateView):
    template_name = 'main/home.html'



送信ボタンでupload関数が呼ばれます。
複数ファイルのアップロードは、request.FILES.getlist("files[]") のようにします。
何かアップロードがあり、POSTメソッドならばその後の処理を、そうでなければトップページへリダイレクトです。
def upload(request):
    files = request.FILES.getlist("files[]")
    if request.method == 'POST' and files:
        ...
        ...
    else:
        return redirect('home')


複数ファイルのアップロードは、以前にも書きました。
Djangoで、複数ファイルをアップロードする
https://torina.top/detail/245/

filesには、アップロードされたファイルが入ります。リスト型です。
[<InMemoryUploadedFile: 11.png (image/png)>, <InMemoryUploadedFile: 19.png (image/png)>, <InMemoryUploadedFile: 36.png (image/png)>]


Image.open()は、ファイルのパスか、又はファイルっぽいオブジェクトを受け取ります。
InMemoryUploadedFileはファイルライクなオブジェクトで、無事に渡せます。
やろうと思えば、以下のようにshow()で画像を表示することもできます...
img = Image.open(file)
img.show()


np.assarayで、PILのImageオブジェクトをnumpy配列に変換できます。そして、array_listにappendします。
        array_list = []
        for file in files:
            img = Image.open(file)
            array = np.asarray(img)
            array_list.append(array)



numpy配列からPILのイメージオブジェクトに戻したいときは、例えば以下のように。
img = Image.open(file)
array = np.asarray(img)
img = Image.fromarray(np.uint8(array))  # これ
img.show()


predictに渡せる形に変換しています。1, 28, 28の部分は(チャンネル, 高さ, 幅)で、SimpleConvNet(input_dim=(1, 28, 28),...
と一致させています。
labelsには、推論した結果が[5 7 8 9]のような形で入ります。
        x = np.array(array_list).reshape(len(array_list), 1, 28, 28)
        labels = network.predict(x).argmax(axis=1)


結果をHTMLに表示させたい訳ですが、今回はアップロードされた画像もテンプレートへ渡そうと思いました。
元画像があった方がわかりやすいですからね。
方法はいくつかあると思いますが、今回はアップロードされた画像データを、imgタグのsrcに埋め込む方法にします。
(画像データ, 推論した結果)というタプルをresult変数に追加していきます。
        result = []
        for file, label in zip(files, labels):
            file.seek(0)
            src = base64.b64encode(file.read())
            result.append((src, label))
        context = {
            'result': result,
        }
        return render(request, 'main/result.html', context)


画像データを取得する必要があります。
InMemoryUploadedFileはファイルライクなオブジェクトで、read()で読み込むことはできます。
ただ、そのままread()しても上手くいかないので、seek(0)を直前で呼び出しています。
更に、imgタグのsrcに直接読み込ませるためにはbase64エンコードも必要です。
        for file, label in zip(files, labels):
            file.seek(0)
            src = base64.b64encode(file.read())
            result.append((src, label))


tes/main/templates/base.html
bootstrap4を使っています。
<!DOCTYPE html>
<html lang="ja">
  <head>
    <!-- Required meta tags -->
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
 
    <!-- Bootstrap CSS -->
    <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0-alpha.6/css/bootstrap.min.css" integrity="sha384-rwoIResjU2yc3z8GV/NPeZWAv56rSmLldC3R/AZzGRnGxQQKnKkoFVhFQhNUwEyJ" crossorigin="anonymous">

  </head>
  <body>
      {% block content %}{% endblock %}
    <!-- jQuery first, then Tether, then Bootstrap JS. -->
    <script
      src="https://code.jquery.com/jquery-3.1.1.min.js"
      integrity="sha256-hVVnYaiADRTO2PzUGmuLJr8BLUSjGIZsDYGmIJLv2b8="
      crossorigin="anonymous"></script>
    <script src="https://cdnjs.cloudflare.com/ajax/libs/tether/1.4.0/js/tether.min.js" integrity="sha384-DztdAPBWPRXSA/3eYEEUWrWCy7G5KFbe8fFjk5JAIxUYHKkDx6Qin1DkWx51bBrb" crossorigin="anonymous"></script>
    <script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0-alpha.6/js/bootstrap.min.js" integrity="sha384-vBWWzlZJ8ea9aCX4pEW3rVHjgjt7zpkNpZk+02D9phzyeVkE+jo0ieGizqPLForn" crossorigin="anonymous"></script>

  </body>
</html>



tes/main/templates/home.html
formのenctype="multipart/form-data"、fileタグのfiles[]、multipleあたりを忘れないようにしましょう。
{% extends "main/base.html" %}
{% block content %}
<div class="mt-5 text-center">
  <h1 class="display-1">ディープラーニングデモ</h1>
  <p class="lead">手書きの画像をアップロードし、OKを押してください</p>
  <form action="{% url 'upload' %}" method="POST" enctype="multipart/form-data">
    <input type="file" name="files[]" multiple>
    <input type="submit" class="btn btn-lg btn-secondary" value="送信">
    {% csrf_token %}
  </form>
</div>
{% endblock %}


tes/main/templates/result.html
{{ data.0 }}や{{ data.1 }}は、data[0]、data[1]と同義です。テンプレートでの書き方ですね。
imgタグのsrcに直接値を書き込むときは、data:image/png;base64,を先頭につけます...
jpgならば、pngの部分をjpgに変更しましょう。
{% extends "main/base.html" %}
{% block content %}
<div class="mt-5 text-center">
  {% for data in result %}
    <img class="img-fluid" src="data:image/png;base64,{{ data.0 }}">
    <p class="lead">画像は{{ data.1 }}です。</p>
    <hr>
  {% endfor %}
  <a class="btn btn-lg btn-secondary"  href="{% url 'home' %}">戻る</a>
</div>

{% endblock %}
約23日前 名無し 2017年9月25日13:01 返信する
test