LangchainとGeminiを用いたSQLクエリ実行エージェントの作成

LLMTech

本記事では、LangchainとGemini (GoogleのLLM) を用いて、SQLiteデータベースに対して自然言語でクエリを実行するエージェントを作成する方法を紹介します。

必要なライブラリ

まず、必要なライブラリをインストールします。

Bash

pip install langchain langchain-google-genai sqlite3

SQLiteデータベースの準備

SQLiteデータベース sample.db を作成し、products テーブルを作成します。

SQL

CREATE TABLE IF NOT EXISTS products (
    id INTEGER PRIMARY KEY,
    name TEXT,
    price REAL
);

INSERT INTO products (name, price) VALUES ('Product A', 10.0);
INSERT INTO products (name, price) VALUES ('Product B', 20.0);
INSERT INTO products (name, price) VALUES ('Product C', 30.0);

Pythonスクリプト

以下のPythonスクリプトを実行することで、自然言語による質問に回答することができます。

import sqlite3
from langchain.tools import Tool
from langchain.agents import initialize_agent, AgentType
import logging
import os
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.prompts import PromptTemplate

# ロギングの設定 (必要に応じて調整)
logging.basicConfig(level=logging.DEBUG)

# データベースパス
DATABASE_PATH = 'Your/Database/Path/sample.db'  # 実際のパスはここに記述

# マスクされたデータベースパス
MASKED_DATABASE_PATH = 'your_masked_database_path.db'  # 任意のマスクされたパス

def execute_sql(query, db_path=DATABASE_PATH): # データベースパスを引数に
    query = query.replace("`sql", "").replace("`", "").strip()
    conn = sqlite3.connect(db_path) # マスクされたパスを使用
    cursor = conn.cursor()
    try:
        cursor.execute(query)
        result = cursor.fetchall()
    except sqlite3.OperationalError as e:
        return f"SQL Error: {e}"
    finally:
        conn.close()
    return result

sql_tool = Tool(
    name="SQL",
    func=execute_sql,
    description="Useful for querying a SQL database. Input should be a valid SQL query."
)

def get_product_price(product_name, db_path=DATABASE_PATH): # データベースパスを引数に
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute("SELECT price FROM products WHERE name = ?", (product_name,))
    result = cursor.fetchone()
    conn.close()
    if result:
        return result[0]
    else:
        return "Product not found"

product_price_tool = Tool(
    name="ProductPrice",
    func=get_product_price,
    description="Useful for getting the price of a specific product. Input should be the name of the product."
)

def get_average_price(tool_input, db_path=DATABASE_PATH): # データベースパスを引数に
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    try:
        cursor.execute("SELECT AVG(price) FROM products")
        result = cursor.fetchone()
    except sqlite3.OperationalError as e:
        return f"SQL Error: {e}"
    finally:
        conn.close()
    if result:
        return result[0]
    else:
        return "No products found"

average_price_tool = Tool(
    name="AveragePrice",
    func=get_average_price,
    description="Useful for getting the average price of all products. Input should be empty."
)

def get_max_price_product(tool_input, db_path=DATABASE_PATH): # データベースパスを引数に
    logging.debug("get_max_price_product called")
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    try:
        cursor.execute("SELECT name, price FROM products ORDER BY price DESC LIMIT 1")
        result = cursor.fetchone()
        logging.debug(f"get_max_price_product result: {result}")
    except sqlite3.OperationalError as e:
        return f"SQL Error: {e}"
    finally:
        conn.close()
    if result:
        return f"The product with the maximum price is {result[0]} at ${result[1]}."
    else:
        return "No products found"

max_price_product_tool = Tool(
    name="MaxPriceProduct",
    func=get_max_price_product,
    description="Useful for getting the product with the maximum price. Input should be empty."
)


# LLMの初期化 (APIキーは環境変数から取得)
api_key = os.getenv("GOOGLEAI_API_KEY")
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=api_key)

# プロンプトテンプレート
template = """
あなたはデータベースに関する質問に答えるエージェントです。
質問:{question}
必要であれば、以下のツールを使ってください。
{tools}
回答:
"""
prompt = PromptTemplate(template=template, input_variables=["question", "tools"])

# エージェントの作成
agent = initialize_agent(
    [product_price_tool, average_price_tool, max_price_product_tool, sql_tool], # sql_toolを追加
    llm,
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    verbose=True,
    prompt=prompt
)

# 質問を受け付ける
question = input("質問を入力してください:")

# 質問に回答
response = agent.invoke(question)

# 回答を表示
print(response)
            

実行結果

質問を入力してください:Product A の価格は?
> Entering new AgentExecutor chain...
> Thought: I need to find the price of Product A. I can use the ProductPrice tool for this.
> Action: ProductPrice
> Action Input: Product A
> Observation: 10.0
> Thought: The price of Product A is 10.0.
> Final Answer: 10.0

> ... Finished chain.
10.0

解説

  • 各ツールは、langchain.tools.Tool クラスを使って定義されています。
  • func 引数には、ツールが実行する関数を指定します。
  • description 引数には、ツールがどのようなタスクを実行するのかを記述します。これは、エージェントがツールを選択する際に役立ちます。
  • initialize_agent 関数を使って、ツールとLLMを組み合わせたエージェントを作成します。
  • エージェントは、run メソッドを使って質問に答えます。

まとめ

本記事では、LangchainとGeminiを用いて、SQLiteデータベースに対して自然言語でクエリを実行するエージェントを作成する方法を紹介しました。

コメント

タイトルとURLをコピーしました