How to test a Scala Play Framework websocket?
Adapted to Playframework 2.7
import java.util.concurrent.ExecutionException
import java.util.function.Consumer
import com.typesafe.scalalogging.StrictLogging
import play.shaded.ahc.org.asynchttpclient.AsyncHttpClient
import play.shaded.ahc.org.asynchttpclient.netty.ws.NettyWebSocket
import play.shaded.ahc.org.asynchttpclient.ws.{WebSocket, WebSocketListener, WebSocketUpgradeHandler}
import scala.compat.java8.FutureConverters
import scala.concurrent.Future
class LoggingListener(onMessageCallback: Consumer[String]) extends WebSocketListener with StrictLogging {
override def onOpen(websocket: WebSocket): Unit = {
logger.info("onClose: ")
websocket.sendTextFrame("hello")
}
override def onClose(webSocket: WebSocket, i: Int, s: String): Unit =
logger.info("onClose: ")
override def onError(t: Throwable): Unit =
logger.error("onError: ", t);
override def onTextFrame(payload: String, finalFragment: Boolean, rsv: Int): Unit = {
logger.debug(s"$payload $finalFragment $rsv")
onMessageCallback.accept(payload)
}
}
class WebSocketClient(client: AsyncHttpClient) {
@throws[ExecutionException]
@throws[InterruptedException]
def call(url: String, origin: String, listener: WebSocketListener): Future[NettyWebSocket] = {
val requestBuilder = client.prepareGet(url).addHeader("Origin", origin)
val handler = new WebSocketUpgradeHandler.Builder().addWebSocketListener(listener).build
val listenableFuture = requestBuilder.execute(handler)
FutureConverters.toScala(listenableFuture.toCompletableFuture)
}
}
And in test:
val myPublicAddress = s"localhost:$port"
val serverURL = s"ws://$myPublicAddress/api/alarm/ws"
val asyncHttpClient = client.underlying[AsyncHttpClient]
val webSocketClient = new WebSocketClient(asyncHttpClient)
val origin = "ws://example.com/ws"
val consumer: Consumer[String] = (message: String) => logger.debug(message)
val listener = new LoggingListener(consumer)
val f = webSocketClient.call(serverURL, origin, listener)
Await.result(f, atMost = 1000.millis)
Play 2.6
I followed this Example: play-scala-websocket-example
Main steps:
Create or provide a WebSocketClient that you can use in your
tests.
Create the client:
val asyncHttpClient: AsyncHttpClient = wsClient.underlying[AsyncHttpClient]
val webSocketClient = new WebSocketClient(asyncHttpClient)
Connect to the serverURL
:
val listener = new WebSocketClient.LoggingListener(message => queue.put(message))
val completionStage = webSocketClient.call(serverURL, origin, listener)
val f = FutureConverters.toScala(completionStage)
Test the Messages sent by the Server:
whenReady(f, timeout = Timeout(1.second)) { webSocket =>
await().until(() => webSocket.isOpen && queue.peek() != null)
checkMsg1(queue.take())
checkMsg2(queue.take())
assert(queue.isEmpty)
}
For example, like:
private def checkMsg1(msg: String) {
val json: JsValue = Json.parse(msg)
json.validate[AdapterMsg] match {
case JsSuccess(AdapterNotRunning(None), _) => // ok
case other => fail(s"Unexpected result: $other")
}
}
The whole example can be found here: scala-adapters (JobCockpitControllerSpec)
This is a complete example which uses the Akka Websocket Client to test a Websocket controller. There is some custom code, but it shows multiple test scenarios. This works for Play 2.7.
package controllers
import java.util.concurrent.{ LinkedBlockingDeque, TimeUnit }
import actors.WSBridge
import akka.Done
import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.model.headers.{ Origin, RawHeader }
import akka.http.scaladsl.model.ws.{ BinaryMessage, Message, TextMessage, WebSocketRequest }
import akka.http.scaladsl.model.{ HttpResponse, StatusCodes, Uri }
import akka.stream.scaladsl.{ Flow, Keep, Sink, Source, SourceQueueWithComplete }
import akka.stream.{ ActorMaterializer, OverflowStrategy }
import models.WSTopic
import org.specs2.matcher.JsonMatchers
import play.api.Logging
import play.api.inject.guice.GuiceApplicationBuilder
import play.api.test._
import scala.collection.immutable.Seq
import scala.concurrent.Future
/**
* Test case for the [[WSController]] actor.
*/
class WSControllerSpec extends ForServer with WSControllerSpecContext with JsonMatchers {
"The `socket` method" should {
"return a 403 status code if the origin doesn't match" >> { implicit rs: RunningServer =>
val maybeSocket = await(websocketClient.connect(WebSocketRequest(endpoint)))
maybeSocket must beLeft[HttpResponse].like { case response =>
response.status must be equalTo StatusCodes.Forbidden
}
}
"return a 400 status code if the topic cannot be found" >> { implicit rs: RunningServer =>
val headers = Seq(Origin("http://localhost:9443"))
val maybeSocket = await(websocketClient.connect(WebSocketRequest(endpoint, headers)))
maybeSocket must beLeft[HttpResponse].like { case response =>
response.status must be equalTo StatusCodes.BadRequest
}
}
"return a 400 status code if the topic syntax isn't valid in query param" >> { implicit rs: RunningServer =>
val headers = Seq(Origin("http://localhost:9443"))
val request = WebSocketRequest(endpoint.withRawQueryString("?topic=."), headers)
val maybeSocket = await(websocketClient.connect(request))
maybeSocket must beLeft[HttpResponse].like { case response =>
response.status must be equalTo StatusCodes.BadRequest
}
}
"return a 400 status code if the topic syntax isn't valid in header param" >> { implicit rs: RunningServer =>
val headers = Seq(Origin("http://localhost:9443"), RawHeader("X-TOPIC", "."))
val maybeSocket = await(websocketClient.connect(WebSocketRequest(endpoint, headers)))
maybeSocket must beLeft[HttpResponse].like { case response =>
response.status must be equalTo StatusCodes.BadRequest
}
}
"receive an acknowledge message when connecting to a topic via query param" >> { implicit rs: RunningServer =>
val headers = Seq(Origin("http://localhost:9443"))
val request = WebSocketRequest(endpoint.withRawQueryString("topic=%2Fflowers%2Frose"), headers)
val maybeSocket = await(websocketClient.connect(request))
maybeSocket must beRight[(SourceQueue, MessageQueue)].like { case (_, messages) =>
messages.poll(1000, TimeUnit.MILLISECONDS) must be equalTo
WSBridge.Ack(WSTopic("/flowers/rose")).message.toJson.toString()
}
}
"receive an acknowledge message when connecting to a topic via query param" >> { implicit rs: RunningServer =>
val headers = Seq(Origin("http://localhost:9443"), RawHeader("X-TOPIC", "/flowers/tulip"))
val maybeSocket = await(websocketClient.connect(WebSocketRequest(endpoint, headers)))
maybeSocket must beRight[(SourceQueue, MessageQueue)].like { case (_, messages) =>
messages.poll(1000, TimeUnit.MILLISECONDS) must be equalTo
WSBridge.Ack(WSTopic("/flowers/tulip")).message.toJson.toString()
}
}
"receive a pong message when sending a ping" >> { implicit rs: RunningServer =>
val headers = Seq(Origin("http://localhost:9443"), RawHeader("X-TOPIC", "/flowers/tulip"))
val maybeSocket = await(websocketClient.connect(WebSocketRequest(endpoint, headers)))
maybeSocket must beRight[(SourceQueue, MessageQueue)].like { case (queue, messages) =>
queue.offer(WSBridge.Ping.toJson.toString())
messages.poll(1000, TimeUnit.MILLISECONDS) must be equalTo
WSBridge.Ack(WSTopic("/flowers/tulip")).message.toJson.toString()
messages.poll(1000, TimeUnit.MILLISECONDS) must be equalTo
WSBridge.Pong.toJson.toString()
}
}
}
}
/**
* The context for the [[WSControllerSpec]].
*/
trait WSControllerSpecContext extends ForServer with PlaySpecification with ApplicationFactories {
type SourceQueue = SourceQueueWithComplete[String]
type MessageQueue = LinkedBlockingDeque[String]
/**
* Provides the application factory.
*/
protected def applicationFactory: ApplicationFactory = withGuiceApp(GuiceApplicationBuilder())
/**
* Gets the WebSocket endpoint.
*
* @param rs The running server.
* @return The WebSocket endpoint.
*/
protected def endpoint(implicit rs: RunningServer): Uri =
Uri(rs.endpoints.httpEndpoint.get.pathUrl("/ws").replace("http://", "ws://"))
/**
* Provides an instance of the WebSocket client.
*
* This should be a method to return a fresh client for every test.
*/
protected def websocketClient = new AkkaWebSocketClient
/**
* An Akka WebSocket client that is optimized for testing.
*/
class AkkaWebSocketClient extends Logging {
/**
* The queue of received messages.
*/
private val messageQueue = new LinkedBlockingDeque[String]()
/**
* Connect to the WebSocket.
*
* @param wsRequest The WebSocket request instance.
* @return Either an [[HttpResponse]] if the upgrade process wasn't successful or a source and a message queue
* to which new messages may be offered.
*/
def connect(wsRequest: WebSocketRequest): Future[Either[HttpResponse, (SourceQueue, MessageQueue)]] = {
implicit val system: ActorSystem = ActorSystem()
implicit val materializer: ActorMaterializer = ActorMaterializer()
import system.dispatcher
// Store each incoming message in the messages queue
val incoming: Sink[Message, Future[Done]] = Sink.foreach {
case TextMessage.Strict(s) => messageQueue.offer(s)
case TextMessage.Streamed(s) => s.runFold("")(_ + _).foreach(messageQueue.offer)
case BinaryMessage.Strict(s) => messageQueue.offer(s.utf8String)
case BinaryMessage.Streamed(s) => s.runFold("")(_ + _.utf8String).foreach(messageQueue.offer)
}
// Out source is a queue to which we can offer messages that will be sent to the WebSocket server.
// All offered messages will be transformed into WebSocket messages.
val sourceQueue = Source.queue[String](Int.MaxValue, OverflowStrategy.backpressure)
.map { msg => TextMessage.Strict(msg) }
val (sourceMat, source) = sourceQueue.preMaterialize()
// The outgoing flow sends all messages which are offered to the queue (our stream source) to the WebSocket
// server.
val flow: Flow[Message, Message, Future[Done]] = Flow.fromSinkAndSourceMat(incoming, source)(Keep.left)
// UpgradeResponse is a Future[WebSocketUpgradeResponse] that completes or fails when the connection succeeds
// or fails and closed is a Future[Done] representing the stream completion from above
val (upgradeResponse, closed) = Http().singleWebSocketRequest(wsRequest, flow)
closed.foreach(_ => logger.info("Channel closed"))
upgradeResponse.map { upgrade =>
if (upgrade.response.status == StatusCodes.SwitchingProtocols) {
Right((sourceMat, messageQueue))
} else {
Left(upgrade.response)
}
}
}
}
}