fastAPIでwebsocketを使ったpub/sub機能の実装方法

こんにちは!seiです。

 

業務でpub/sub機能を作成する必要があり、初めてwewbsocketを使ったので備忘録です。

 

fastAPIでwebsocketを使ったpub/sub機能の実装手順

 

Graphql ライブラリはariadneを使用しています。
ariadneを使っている理由は生graphql定義ファイルから定義をインポートしやすいためです。
ariadneはstarletteと統合することができます。
FastAPI内部ではstarletteが使われているので、FastAPIとも統合することができます。

 

pub/sub機能の実装手順としては以下です。

  • websocketプロトコルによる通信ができるようにする
  • pubしたメッセージを貯めるqueueを用意する
  • いい感じにsubsccriptionを書く!

 

websocketプロトコルによる通信ができるようにする

websocketとは何ぞや?

クライアントサイドからhttpプロトコルを使って通信する場合、サーバーからの通信はレスポンスとして送信されます。

しかし、pub/subの機能を考えたときサーバはレスポンスとしてではなく、サーバ発の通信を送る必要があります。この時どのクライアント(ブラウザ)に対し送信するかの特定が必要です。でもクライアントの情報(IP等)をサーバ側に保存しておくとスケールしたときに大変です。

そこで、クライアントとサーバ間で持続的に通信しようというのがwebsocketの仕組みです。内部的にはping-pongでパケットをやり取りして接続を維持しています。

 

長くなっちゃったので、実装を見ていきましょう。

以下ariadneの公式ドキュメントからの引用です。

from ariadne import QueryType, make_executable_schema
from ariadne.asgi import GraphQL
from ariadne.asgi.handlers import GraphQLTransportWSHandler
from fastapi import FastAPI, Depends, Request
from fastapi.websockets import WebSocket
from myapp.database import get_database_session

type_defs = """
    type Query {
        hello: String!
    }
"""

query = QueryType()


@query.field("hello")
def resolve_hello(*_):
    return "Hello world!"


# Create executable schema instance
schema = make_executable_schema(type_defs, query)

# Custom context setup method
def get_context_value(request_or_ws: Request | WebSocket, _data) -> dict:
    return {
        "request": request_or_ws,
        "db": request_or_ws.scope["db"],
    }

# Create GraphQL App instance
# ここが大事です
<strong>graphql_app = GraphQL(
    schema,
    debug=True,
    context_value=get_context_value,
    websocket_handler=GraphQLTransportWSHandler(),
)</strong>

# Create FastAPI instance
app = FastAPI()


# Handle GET requests to serve GraphQL explorer
# Handle OPTIONS requests for CORS
@app.get("/graphql/")
@app.options("/graphql/")
async def handle_graphql_explorer(request: Request):
    return await graphql_app.handle_request(request)

# Handle POST requests to execute GraphQL queries
@app.post("/graphql/")
async def handle_graphql_query(
    request: Request,
    db = Depends(get_database_session),
):
    # Expose database connection to the GraphQL through request's scope
    request.scope["db"] = db
    return await graphql_app.handle_request(request)


# Handle GraphQL subscriptions over websocket
@app.websocket("/graphql")
async def graphql_subscriptions(
    websocket: WebSocket,
    db = Depends(get_database_session),
):
    # Expose database connection to the GraphQL through request's scope
    websocket.scope["db"] = db
    await graphql_app.handle_websocket(websocket)

 

これだけでwebsocketが使えるようになります!

pubしたメッセージを貯めるqueueを用意する

pubsubを使いたい場合は、publishしたメッセージを格納するqueueが必要です。
queueにはApache KafkaやRabbitMQなどの高機能なメッセージブローカーやRedisやPostgreが使えるようです。

 

pubsubを使った全体の処理の流れは以下のようになります。

  • クライアントがsubscriptionAにリクエストを送ることでsubscriptionAのsubscriberとなる
  • API(pubsub)サーバはRedis内のチャネルAに向けてメッセージを送信する
  • Redisがメッセージを受け取る
  • チャネルAを持続的に監視しているライブラリが、Redisからメッセージを取り出す
  • 取り出したメッセージをsubscriptionAが受け取り、subscriptionAはsubscriber全員にメッセージを送信する

 

ここではariadneのドキュメントに載っているBroadcaster(pubsubのライブラリ)を使ってRedisサーバをqueueとして用います。

 

Redisサーバの作成

redisサーバを作成します。
僕は同一インスタンス内にdockerコンテナを作成して、FastAPIコンテナとredhisコンテナを通信させました。(FastAPIサーバが死んだときにメッセージを永続化させる必要がないため、それ以外の場合は別インスタンス推奨)

追加設定をしない場合は、公式イメージからひっぱってくるだけです。

docker-compose.yaml

version: '3.8'
services:
  fastapi-app:
    build: ./fastapi
    ports:
      - "8000:8000"
    depends_on:
      - redis

  redis:
 #好きなバージョンを選ぶ
    image: redis:6.2
    ports:
      - "6379:6379"

broadcasterの設定

ますはインストール

pip install broadcaster

 

FastAPIの起動時に接続する関数、停止時に切断する関数を呼び出します。
starletteを内部に用いているので、on_startupとon_shutdownに関数を定義してあげればよいです。

from fastapi import FastAPI
from broadcaster import Broadcast
// 先ほどのファイルからapiをインポート
from some import app as graphql_api

app = FastAPI()

// 作成したredisサーバのホストを指定
pubsub = Broadcast("redis://somehost:6379")

@app.on_event("startup")
async def startup_event():
  await pupsub.connect()

@app.on_event("shutdown"):
async def shutdown_event():
  await pubsub.disconnect()


app.mount("/",graphql_api)
    

 

※僕のように最上位のファイルで複数のappをmountしている場合最上位のファイルのstartup,shutdownイベントに設定してください。よくわからないエラーで原因を突き止めるのに2日苦しみました(笑)

 

いい感じにsubsccriptionを書く!

次にSubscriptionのリゾルバを定義します。
queryやmutationと違って@subscription.fieldとは別に@subscription.sourceが必要です。

軽く役割を説明すると、以下のようになります。

source:監視しているチャネル「chatroom」にメッセージが来ると、メッセージを生成してsubscription.fieldに渡す
field:sourceから受け取ったメッセージを自身のsubscriber全員に送信する
subscription = SubscriptionType()

@subscription.source("messageReceived")
async def message_received_generator(obj, info):
   // 先ほど定義したpubsub
 async with pubsub.subscribe(channel="chatroom") as subscriber:
   async for event in subscriber:
        yield json.loads(event.message)

@subscription.field("messageReceived")
def message_received_resolver(message, info):
    return message

メッセージをpublishする

あとは任意の場所でメッセージをpublishするだけです。
先ほど作成したリゾルバが監視しているチャネルに対してメッセージを送りましょう。

 

お疲れ様でした!

@mutation.field("some_mutation")
async def some_func(_,info,input):
  await broadcast.publish(channel="chatroom", message="Hello world!")

 

 

Broadcasterはどうやってqueueを持続的に監視しているのか?

Redisにメッセージが格納されたときに、自動でメッセージを取り出している仕組みが気になったので、ソースコードを見てみました。

 


class Broadcast:
    def __init__(self, url: str):
        from broadcaster._backends.base import BroadcastBackend

        parsed_url = urlparse(url)
        self._backend: BroadcastBackend
        self._subscribers: Dict[str, Any] = {}
        if parsed_url.scheme in ("redis", "rediss"):
            from broadcaster._backends.redis import RedisBackend

            self._backend = RedisBackend(url)

        elif parsed_url.scheme in ("postgres", "postgresql"):
            from broadcaster._backends.postgres import PostgresBackend

            self._backend = PostgresBackend(url)

        if parsed_url.scheme == "kafka":
            from broadcaster._backends.kafka import KafkaBackend

            self._backend = KafkaBackend(url)

        elif parsed_url.scheme == "memory":
            from broadcaster._backends.memory import MemoryBackend

            self._backend = MemoryBackend(url)

... 省略

    async def connect(self) -> None:
        await self._backend.connect()
        self._listener_task = asyncio.create_task(self._listener())

    async def disconnect(self) -> None:
        if self._listener_task.done():
            self._listener_task.result()
        else:
            self._listener_task.cancel()
        await self._backend.disconnect()

    async def _listener(self) -> None:
        while True:
          event = await self._backend.next_published()
          for queue in list(self._subscribers.get(event.channel, [])):
            await queue.put(event)

先ほどのconnect関数でasyncioにタスクを登録しています。
connect関数の2行目を見てみましょう。

self._listener_task = asyncio.create_task(self._listener())

これでasyincioイベントループにself._lisner()のタスクが登録されますね。
では_listener関数を見てみます。

async def _listener(self) -> None:
      while True:
          event = await self._backend.next_published()
          for queue in list(self._subscribers.get(event.channel, [])):
              await queue.put(event)

while Trueで無限ループする処理ですね。メッセージがpublishされるまでawaitで待っていることが分かります。

awaitを使っているので、待ちが発生している間はこのスレッドは他の処理を行うことができます。

 

まとめると、「無限ループ内でawaitする処理をasyncイベントループに登録する」ことで、持続的に監視していることが分かりました!

フロントのgraphQLのwebsocketライブラリはどうすればよい?

フロントエンド(graphqlのリクエストを送る側)のgraphQLライブラリはariadneを使用している場合、graphql-wsを使えばよいです。

サーバー側とフロント側で同じライブラリを使用する必要がありそうです。(websocketだからだと思われます)

まとめ

今回はFastAPIを使ったpub/sub機能の実装を紹介しました。graphqlにはもともとpub/sub機能が想定されているので、pub/subを使いたいならgraphqlという選択肢もありなのかなと感じました。

websocketを利用する場合の認証の方法は今回紹介しませんでしたが、余裕があれば追記しようと思っています。websocketの場合、ブラウザとサーバで最初の接続(hand shake)の際に認証をかけてあげる必要があります。websocketにはHTTPヘッダーがないので認証情報をパラメータに含めて送信、受信する処理を記載してあげる必要があります。

pub/sub機能はチャット機能を実装する際によく使われています。

プログラミング学習方法を発信してます!