簡単なチャットボットの作り方

概要

Bloomを使用して、日本語で入力できる簡単なチャットボットを作る方法。

環境

OS Windows11
GPU NVIDIA GeForce RTX 3080Ti

使用するモデル

bloomz-3b

言語生成モデルはbloomを使用。

bigscience/bloomz-3b · Hugging Face

fugumt

翻訳用のモデルにfugumtを使用。

staka/fugumt-en-ja · Hugging Face

staka/fugumt-ja-en · Hugging Face

ライブラリのインストール

以下のコマンドで必要なライブラリをインストール。

pip install transformers accelerate sentencepiece

実装

import tkinter as tk
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

checkpoint = "bigscience/bloomz-3b"

# Bloomのモデルをロード
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto", device_map="auto")

# fugumtのモデルをダウンロード
je_translator = pipeline("translation", model="staka/fugumt-ja-en")
ej_translator = pipeline("translation", model="staka/fugumt-en-ja")

# GUIのウィンドウを作成
window = tk.Tk()
window.title("チャットボットアプリ")
window.geometry("600x400")

# テキストエリアを作成
text_area = tk.Text(window, width=60, height=20)
text_area.pack()

# テキストボックスを作成
text_box = tk.Entry(window, width=60)
text_box.pack()

# チャットボットの応答を生成する関数
def generate_response():
    # ユーザーの入力を取得
    ja_input = text_box.get()
    en_input= je_translator(text_box.get())[0]['translation_text']
    print(en_input)
    user_input = en_input
    # テキストボックスをクリア
    text_box.delete(0, tk.END)
    # テキストエリアにユーザーの入力を表示
    text_area.insert(tk.END, "あなた: " + ja_input + "\n")
    input_text = user_input
    # テキストをトークン化
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
    # テキストを生成
    output_ids = model.generate(input_ids, max_length=32)
    # トークンをテキストに戻す
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).replace(input_text, "")
    # テキストエリアにチャットボットの応答を表示
    en_output = output_text
    print(en_output)
    ja_output = ej_translator(en_output)[0]['translation_text']
    text_area.insert(tk.END, "チャットボット: " + ja_output + "\n")

# ボタンを作成
button = tk.Button(window, text="送信", command=generate_response)
button.pack()

# GUIのウィンドウを表示
window.mainloop()