Rust: Non-blocking I/O programming with mio::Poll

Takami Torao Rust 1.48 mio 0.7 #Rust #async #poll
  • このエントリーをはてなブックマークに追加

概要

Linux カーネルのシステムコール epoll や FreeBSD (Mac OS) の kqueue、またはより古典的な POSIX 準拠の select や poll システムコールは大量のクライアント接続を効率的に処理するノンブロッキング I/O プログラミングに必要な機能である。これらの類似機能は複数のソケット (ファイル記述子) をまとめて監視し、読み込み / 書き込み可能になったソケットに対して通知を受けることができる。つまり、1 つのスレッドで複数の接続に対する入出力処理に対応できることから、多くのクライアントと接続して通信を行うサーバサイドに置いて C10K 問題を解決するためにイベント駆動型フレームワークの低レベル層で使用されている。

この記事では Rust の mio 0.7 の Poll を使用した実装について説明する。なお Poll ライクなプログラミングモデルとしては Java 標準の Selector の説明も参考になる。

Table of Contents

  1. 概要
    1. edge trigger と level trigger
    2. 外部からポーリングへの介入
    3. EOF の判断
  2. Echo サーバ実装

edge trigger と level trigger

epoll や kqueue を使ったプログラミングでは 2 種類の動作に注意しなければならない。その一つは、通知対象の準備ができた時点で一回のみ通知を行う edge trigger という動作である。edge trigger の動作は特に readable イベントの時に注意する必要があり、イベント処理でデータを読み残してしまうと、次の通知が発生するまで残ったデータを読み込む機会を失ったり、それ以降のデータを読み込めなくなったりする可能性がある。したがって edge trigger に対処するプログラミングでは readable イベント処理で読み出し可能なデータを全て読み出さなければならない

もう一つは読み出し可能なデータが存在する限り準備通知を行う level trigger という動作である。この動作では edge trigger とは逆に writable イベントに注意する必要がある。level trigger 環境において書込み可能なデータが存在しないのに writable 通知を有効にしておくと、無駄な処理のループによって CPU リソースを食いつぶすことになる。このため level trigger に対処するプログラミングでは、書込み可能なデータが無くなった時に writable 通知を無効化し、新しい書き込みデータが発生した時に有効化するといった切り替えを行う必要がある

mio::Poll を使って edge trigger と level trigger のランタイム環境の互換性を考慮したプログラミングを行うには、readable 通知に対して WouldBlock が発生するまですべてのデータを読み込みを行い、書き込み可能なデータが存在するかどうかで writable 通知の切り替えを行う必要がある。

外部からポーリングへの介入

mio::Poll はスレッドセーフではない (mio に限らず Java の Selector でのキーセット操作もそうだが) ため、一般に poll を使用したプログラミングモデルは 1 つのスレッド内で準備通知のポーリングとイベント処理を行う。しかし、スレッドの外部から新しいソケットを登録したり、readable / writable を変更したり、またポーリングのイベントループを終了するようなケースでは、外部からポーリングの中断を行う必要がある。

このようなケースでは、外部スレッドから mio::Waker に対して wake() を呼び出すことで Waker の構築時に指定した mio::Poll の (多分別スレッドでの処理をブロックしている) ポーリングを中断することができる。また Waker のみでは単にポーリングを中断するだけであるため、中断後の処理にデータを渡すために channel (Sender / Receiver) を併用する必要があるだろう。

EOF の判断

poll を使ったノンブロッキングの読み込みでは、相手側からのクローズによる EOF は Read::read() が 0 を返すことで検出することができる (ただし読み込みバッファサイズが 0 でなければ; API リファレンス参照)。読み込み可能だが単にデータが到着していないだけであれば read() は WouldBlock が発生するため EOF と状況を区別することができる。

読み出し側で EOF を検出した場合、そのソケットはすでにピアによって正しい手続きでクローズされている可能性が高く、したがって以後ソケットが writable になることを期待することができず、送信バッファに残っているデータは破棄してソケットを poll の登録から外す必要がある。

mio には is_read_closed() が用意されているが、動作が保証されている準備通知は is_readable()is_writable() のみであることからヒントとしての使用に限定される。

Echo サーバ実装

以下は mio::Poll を使用した Echo サーバの実装サンプルである。サーバソケット (TcpListener) と接続したピアのソケット (TcpStream) を同じ Poll で処理処理している。

Show Source
use mio::event::Event;
use mio::net::{TcpListener, TcpStream};
use mio::{Events, Interest, Poll, Token, Waker};
use std::collections::HashMap;
use std::io::{ErrorKind, Read, Write};
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::thread::spawn;

const WAKER: Token = Token(0);
const SERVER: Token = Token(1);

/// TCP 接続を受け付けて受信したデータをそのまま返す Echo サーバです。これは [mio::Poll] のサンプル実装を目的として
/// います。サーバスレッドは [EchoServer] のスコープが終了すると停止します。
///
/// ```
/// use echo_server::EchoServer;
/// use std::io::{Read, Write};
/// use std::net::TcpStream;
///
/// let expected = "hello, world".as_bytes();
/// let mut actual = (0..expected.len()).map(|_| 0u8).collect::<Vec<u8>>();
///
/// let mut server = EchoServer::new("127.0.0.1:0");
/// let mut socket = TcpStream::connect(server.local_address()).unwrap();
/// socket.write(expected).unwrap();
/// socket.read_exact(&mut actual).unwrap();
/// assert_eq!(expected, actual);
/// ```
///
pub struct EchoServer {
  server: Arc<Mutex<TcpListener>>,
  waker: Waker,
}

/// 別スレッドで実行されるサーバ処理。[EchoServer] とは waker のみで接続されており、[EchoServer] 側で
/// `waker.wake()` が実行されると poll のイベントループが終了する。
///
/// `mio::Poll` 機能はソケットに対して [Token] と呼ばれる usize 値のみしか保持しないため、[Token] 値と
/// 対象のソケットや入出力バッファを保持する [HashMap] が必要。また、すでに登録されているソケットと同じ Token
/// を割り当てないように番号を管理する必要がある。`Token(std::usize::Max)` は [Poll] によって予約されている。
///
struct EchoServant {
  poll: Poll,
  server: Arc<Mutex<TcpListener>>,
  clients: HashMap<usize, (TcpStream, Vec<u8>)>,
  sequence: usize,
}

impl EchoServer {
  /// 指定されたアドレスで Echo サーバを起動して [EchoServer] を返します。
  pub fn new(address: &str) -> EchoServer {
    let poll = Poll::new().unwrap();
    let clients = HashMap::new();
    let waker = Waker::new(poll.registry(), Token(0)).unwrap();

    let bind_address = address.parse().unwrap();
    let mut server = mio::net::TcpListener::bind(bind_address).unwrap();
    poll.registry().register(&mut server, SERVER, Interest::READABLE).unwrap();

    let server = Arc::new(Mutex::new(server));
    let mut state = EchoServant { poll, server: server.clone(), clients, sequence: 0 };
    spawn(move || state.start());
    EchoServer { server, waker }
  }

  /// この [EchoServer] の bind アドレスを参照します。
  pub fn local_address(&self) -> SocketAddr {
    self.server.clone().lock().unwrap().local_addr().unwrap()
  }
}

impl Drop for EchoServer {
  fn drop(&mut self) {
    self.waker.wake().unwrap();
  }
}

impl EchoServant {
  /// サーバ処理の開始
  pub fn start(&mut self) {
    let mut events = Events::with_capacity(1024);
    let mut buffer = (0..1024).map(|_| 0u8).collect::<Vec<u8>>();
    'polling: loop {
      self.poll.poll(&mut events, None).unwrap();
      for event in events.iter() {
        match event.token() {
          WAKER => break 'polling,
          SERVER => self.server_acceptable(),
          Token(id) => {
            if event.is_readable() {
              self.client_readable(id, event, &mut buffer);
            }
            if event.is_writable() {
              self.client_writable(id, event);
            }
          }
        }
      }
    }
  }

  /// サーバソケットが accept 可能になった時に実行する処理。
  fn server_acceptable(&mut self) {
    let server = self.server.lock().unwrap();
    loop {
      // NOTE: edge trigger の動作になる環境があるため readiness 通知のあったソケットは WouldBlock が発生する
      // まですべて読み込まなければならない。
      match server.accept() {
        Ok((mut socket, _)) => {
          // 0=Waker, 1=TcpListener, MAX=mio reserved のため TcpStream で使用できる数は MAX-3 個まで
          let max = std::usize::MAX - 3;
          let id = self.sequence + 2;
          self.sequence = (self.sequence + 1) % max;
          assert!(!self.clients.contains_key(&id));

          // 接続したソケットの登録
          self.poll.registry().register(&mut socket, Token(id), Interest::READABLE).unwrap();
          self.clients.insert(id, (socket, Vec::new()));
        }
        Err(ref err) if err.kind() == ErrorKind::WouldBlock => break,
        Err(err) => panic!("read failure: {}", err),
      }
    }
  }

  /// クライアントソケットが readable になったときに実行する処理。
  fn client_readable(&mut self, id: usize, event: &Event, read_buffer: &mut [u8]) {
    let (socket, buffer) = self.clients.get_mut(&id).unwrap();
    let mut length = 0usize;
    let mut eof = false;
    loop {
      // NOTE: edge trigger の動作になる環境があるため readable 通知のあったソケットは WouldBlock が発生する
      // まですべて読み込まなければならない。また readable 通知がありながら読み込み可能なデータがないこともある点に
      // も注意。
      match socket.read(read_buffer) {
        Ok(len) if len == 0 => {
          eof = true;
          break;
        }
        Ok(len) => {
          // TODO unsafe を使って効率化すべき処理
          for i in 0..len {
            buffer.push(read_buffer[i]);
          }
          length += len;
        }
        Err(ref err) if err.kind() == ErrorKind::WouldBlock => break,
        Err(err) => panic!("read failure: {}", err),
      }
    }
    if eof {
      // then socket.is_read_closed() is probably true
      let (mut socket, _) = self.clients.remove(&id).unwrap();
      self.poll.registry().deregister(&mut socket).unwrap();
    } else if !eof && buffer.len() == length && length != 0 {
      // この readable 処理の前にバッファが空であったなら WRITABLE 通知が設定されていない。この読み込みで
      // echo back するデータが発生したなら WRITABLE 通知を有効にする。
      self
        .poll
        .registry()
        .reregister(socket, event.token(), Interest::READABLE | Interest::WRITABLE)
        .unwrap();
    }
  }

  /// クライアントソケットが writable になった時に実行する処理。ソケットに対する送信バッファにデータが存在する場合に
  /// それを出力します。
  fn client_writable(&mut self, id: usize, event: &Event) {
    let (socket, buffer) = self.clients.get_mut(&id).unwrap();
    let len = socket.write(&buffer).unwrap();
    // TODO unsafe を使って効率化すべき処理
    for _ in 0..len {
      buffer.pop();
    }
    if buffer.is_empty() {
      // 送信バッファが空になったら WRITABLE 通知を受け取らない (level trigger 環境の動作では送信データがない
      // にもかかわらず WRITABLE 通知が繰り返し呼び出されるため)
      self.poll.registry().reregister(socket, event.token(), Interest::READABLE).unwrap();
    }
  }
}

#[cfg(test)]
mod test {
  use crate::EchoServer;
  use rand;
  use rand::RngCore;
  use std::io::{Read, Write};
  use std::thread::{spawn, JoinHandle};

  #[test]
  fn test_echo_server_from_single_client() {
    let server = EchoServer::new("127.0.0.1:0");
    let address = server.local_address();
    let mut socket = std::net::TcpStream::connect(address).unwrap();

    let expected = "hello, world".as_bytes();
    let mut actual = (0..expected.len()).map(|_| 0u8).collect::<Vec<u8>>();
    socket.write(expected).unwrap();
    socket.read_exact(&mut actual).unwrap();
    assert_eq!(expected, actual);
  }

  #[test]
  fn test_echo_server_from_several_clients() {
    let server = EchoServer::new("127.0.0.1:0");
    let address = server.local_address();

    // 並行で同期/ブロッキング echo クライアントを起動
    let mut client_joins = (0..25)
      .map(|_| {
        let addr = address.clone();
        spawn(move || {
          let mut socket = std::net::TcpStream::connect(addr).unwrap();
          for _ in 0..10 {
            // ランダムなバイト列を送信
            let mut send_buffer = [0u8; 512 * 1024];
            rand::thread_rng().fill_bytes(&mut send_buffer);
            let mut written_length = 0;
            while written_length < send_buffer.len() {
              written_length += socket.write(&send_buffer[written_length..]).unwrap();
            }

            // 受信したバイト列が送信したものと一致していることを確認
            let mut byte = [0u8; 4 * 1024];
            let mut read_length = 0;
            while read_length < send_buffer.len() {
              let len = socket.read(&mut byte).unwrap();
              for i in 0..len {
                assert_eq!(byte[i], send_buffer[read_length + i]);
              }
              read_length += len;
            }
          }
        })
      })
      .collect::<Vec<JoinHandle<()>>>();

    // すべてのテスト Echo クライアントが終了するまで待機
    while !client_joins.is_empty() {
      client_joins.pop().unwrap().join().unwrap();
    }
  }
}
[package]
name = "echo-server"
version = "0.1.0"
authors = ["TAKAMI Torao <koiroha@gmail.com>"]
edition = "2018"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
mio = { version = "0.7", features = ["os-poll", "net"] }

[dev-dependencies]
rand = "0.7"