Akka implementation of OAuth 2 service: access_token management

Posted by charlieholder on Fri, 10 Jan 2020 09:03:54 +0100

There are several core points to implement an OAuth 2 service:

  1. OAuth 2 protocol analysis
  2. There may be many connected users, and the system needs to support horizontal expansion
  3. State control of access_token of each connected user: validity control
  4. Services should support fault tolerance, recoverability, scalability, high concurrency and other features

Using Akka to implement OAuth 2 service will find that the logic is very clear and can well implement the above core points.

Each connected user or access_token can be sampled as an Actor, so that multiple connected users or access_token can access concurrently. You can manage expiration time and other states within the Actor.

Using Akka Cluster sharding, we can realize the cluster deployment and horizontal expansion of connected users. Akka persistence provides EventSourcedBehavior to add persistence capability to Actor, which realizes recoverability. By using Akka Cluster mechanism, we can reduce the dependence on external cache system.

Akka Actor provides a regulatory mechanism so that we can respond quickly to errors and achieve fault tolerance.

access_token

In Akka, there are two main schemes to design access Φ token through Actor model:

  1. Each access token has an Actor, which is horizontally extended by ClusterSharding, and Akka Actor is used as a stateful cache.
  2. Each User has one Actor, and multiple access? Tokens are saved inside the User Actor through the state.

One Actor for each access token

Each access token has an Actor which is relatively simple in design. You only need to pay attention to stopping the Actor at the expiration time. The sample code is as follows:

object AccessTokenEntity {
  final case class State(accessToken: String, expiresEpochMillis: Long = 0L, refreshToken: String = "")
      extends CborSerializable

  sealed trait Command extends CborSerializable
  case object StopSelf extends Command
  final case class Check(replyTo: ActorRef[Int]) extends Command
  final case class Create(refreshToken: String, replTo: ActorRef[AccessToken]) extends Command

  sealed trait Event extends CborSerializable
  final case class Created(expiresIn: FiniteDuration, refreshToken: String) extends Event

  private val DEVIATION = 5000L // The expiration time is 5 seconds longer than the actual time to ensure that the new access token and the old access token are valid for a certain period of time when the client refreshes at the expiration time point
  val TypeKey: EntityTypeKey[Command] = EntityTypeKey("AccessTokenEntity")

  def init(system: ActorSystem[_]): ActorRef[ShardingEnvelope[Command]] =
    ClusterSharding(system).init(
      Entity(TypeKey)(ec => apply(ec.entityId))
        .withSettings(ClusterShardingSettings(system).withPassivateIdleEntityAfter(Duration.Zero)))

  private def apply(accessToken: String): Behavior[Command] =
    Behaviors.setup(context =>
      Behaviors.withTimers { timers =>
        val behavior = EventSourcedBehavior[Command, Event, State](
          PersistenceId.of(TypeKey.name, accessToken),
          State(accessToken),
          (state, command) => commandHandler(timers, state, command),
          (state, event) => eventHandler(state, event))
        behavior.receiveSignal {
          case (state, RecoveryCompleted) =>
            val now = System.currentTimeMillis()
            if (state.expiresEpochMillis < now) context.self ! StopSelf
            else timers.startSingleTimer(StopSelf, (state.expiresEpochMillis - now).millis)
        }
      })

  private def commandHandler(timers: TimerScheduler[Command], state: State, command: Command): Effect[Event, State] =
    command match {
      case Check(replyTo) =>
        if (state.expiresEpochMillis == 0L) {
          Effect.stop().thenReply(replyTo)(_ => 401)
        } else {
          val status = if (System.currentTimeMillis() < state.expiresEpochMillis) 200 else 401
          Effect.reply(replyTo)(status)
        }

      case Create(refreshToken, replTo) =>
        if (state.expiresEpochMillis > 0L) // Return the existing AccessToken
          Effect.reply(replTo)(createAccessToken(state))
        else
          Effect.persist(Created(2.hours, refreshToken)).thenReply(replTo) { st =>
            timers.startSingleTimer(StopSelf, (st.expiresEpochMillis - System.currentTimeMillis()).millis)
            createAccessToken(st)
          }

      case StopSelf => Effect.stop()
    }

  private def eventHandler(state: State, event: Event): State = event match {
    case Created(expiresIn, refreshToken) =>
      state.copy(
        expiresEpochMillis = System.currentTimeMillis() + expiresIn.toMillis + DEVIATION,
        refreshToken = refreshToken)
  }
}

The benefits of this approach are:

  1. Cluster sharding can be used to realize the cluster with infinite horizontal expansion in theory, and no matter how many access_token s can be saved
  2. Fault tolerant and recoverable. After the node is hung up, the token state can be recovered from other machines
  3. The generated access token does not need to contain business information, only needs to ensure uniqueness
  4. Intuitive code logic

The disadvantages of this scheme are as follows:

  1. Each access token has one Actor. The Actor does not have many functions and has no obvious advantages over the cache data storage system with TTL
  2. Each invalid access token can only be judged whether it is valid after an Actor is generated, which will cause many invalid actors to be created

One Actor per user

object UserEntity {
  final case class State(
      tokens: Map[String, DueEpochMillis] = Map(),
      refreshTokens: Map[String, DueEpochMillis] = Map())
      extends CborSerializable {
    def clear(clearTokens: IterableOnce[String], clearRefreshTokens: IterableOnce[String]): State =
      copy(tokens = tokens -- clearTokens, refreshTokens = refreshTokens -- clearRefreshTokens)

    def addToken(created: TokenCreated): State = {
      val tokenDue = OAuthUtils.expiresInToEpochMillis(created.accessTokenExpiresIn)
      val refreshTokenDue = OAuthUtils.expiresInToEpochMillis(created.refreshTokenExpiresIn)
      State(tokens + (created.accessToken -> tokenDue), refreshTokens + (created.refreshToken -> refreshTokenDue))
    }

    def addToken(accessToken: String, expiresIn: FiniteDuration): State =
      copy(tokens = tokens + (accessToken -> OAuthUtils.expiresInToEpochMillis(expiresIn)))
  }

  sealed trait Command extends CborSerializable

  final case class CreateToken(replyTo: ActorRef[AccessToken]) extends Command
  final case class CheckToken(accessToken: String, replyTo: ActorRef[Int]) extends Command
  final case class RefreshToken(refreshToken: String, replyTo: ActorRef[Option[AccessToken]]) extends Command
  final case object ClearTick extends Command

  sealed trait Event extends CborSerializable
  final case class TokenCreated(
      accessToken: String,
      accessTokenExpiresIn: FiniteDuration,
      refreshToken: String,
      refreshTokenExpiresIn: FiniteDuration)
      extends Event
  final case class TokenRefreshed(accessToken: String, expiresIn: FiniteDuration) extends Event
  final case class ClearEvent(clearTokens: Set[String], clearRefreshTokens: Set[String]) extends Event

  val TypeKey: EntityTypeKey[Command] = EntityTypeKey("UserEntity")

  def init(system: ActorSystem[_]): ActorRef[ShardingEnvelope[Command]] =
    ClusterSharding(system).init(
      Entity(TypeKey)(ec => apply(ec))
        .withSettings(ClusterShardingSettings(system).withPassivateIdleEntityAfter(Duration.Zero)))

  private def apply(ec: EntityContext[Command]): Behavior[Command] = {
    val userId = ec.entityId
    Behaviors.setup(
      context =>
        Behaviors.withTimers(timers =>
          new UserEntity(PersistenceId.of(ec.entityTypeKey.name, ec.entityId), userId, timers, context)
            .eventSourcedBehavior()))
  }
}

import blog.oauth2.peruser.UserEntity._
class UserEntity private (
    persistenceId: PersistenceId,
    userId: String,
    timers: TimerScheduler[Command],
    context: ActorContext[Command]) {
  timers.startTimerWithFixedDelay(ClearTick, 2.hours)

  def eventSourcedBehavior(): EventSourcedBehavior[Command, Event, State] =
    EventSourcedBehavior(
      persistenceId,
      State(),
      (state, command) =>
        command match {
          case CheckToken(accessToken, replyTo)    => processCheckToken(state, accessToken, replyTo)
          case RefreshToken(refreshToken, replyTo) => processRefreshToken(state, refreshToken, replyTo)
          case CreateToken(replyTo)                => processCreateToken(replyTo)
          case ClearTick                           => processClear(state)
        },
      (state, event) =>
        event match {
          case TokenRefreshed(accessToken, expiresIn)      => state.addToken(accessToken, expiresIn)
          case created: TokenCreated                       => state.addToken(created)
          case ClearEvent(clearTokens, clearRefreshTokens) => state.clear(clearTokens, clearRefreshTokens)
        })

  private def processRefreshToken(
      state: State,
      refreshToken: String,
      replyTo: ActorRef[Option[AccessToken]]): Effect[Event, State] = {
    if (state.refreshTokens.get(refreshToken).exists(due => System.currentTimeMillis() < due)) {
      val refreshed = TokenRefreshed(OAuthUtils.generateToken(userId), 2.hours)
      Effect
        .persist(refreshed)
        .thenReply(replyTo)(_ => Some(AccessToken(refreshed.accessToken, refreshed.expiresIn.toSeconds, refreshToken)))
    } else {
      Effect.reply(replyTo)(None)
    }
  }

  private def processCheckToken(state: State, accessToken: String, replyTo: ActorRef[Int]): Effect[Event, State] = {
    val status = state.tokens.get(accessToken) match {
      case Some(dueTimestamp) => if (System.currentTimeMillis() < dueTimestamp) 200 else 401
      case None               => 401
    }
    Effect.reply(replyTo)(status)
  }

  private def processCreateToken(replyTo: ActorRef[AccessToken]): Effect[Event, State] = {
    val createdEvent =
      TokenCreated(OAuthUtils.generateToken(userId), 2.hours, OAuthUtils.generateToken(userId), 30.days)
    Effect.persist(createdEvent).thenReply(replyTo) { _ =>
      AccessToken(createdEvent.accessToken, createdEvent.accessTokenExpiresIn.toSeconds, createdEvent.refreshToken)
    }
  }

  private def processClear(state: State): Effect[Event, State] = {
    if (state.tokens.isEmpty && state.refreshTokens.isEmpty) {
      Effect.stop()
    } else {
      val now = System.currentTimeMillis()
      val clearTokens = state.tokens.view.filterNot { case (_, due)               => now < due }.keys.toSet
      val clearRefreshTokens = state.refreshTokens.view.filterNot { case (_, due) => now < due }.keys.toSet
      Effect.persist(ClearEvent(clearTokens, clearRefreshTokens))
    }
  }
}

The advantages of one Actor per user are:

  1. Cluster sharding can be used to realize the cluster with infinite horizontal expansion in theory, and no matter how many access_token s can be saved
  2. Fault tolerant and recoverable. After the node is hung up, the token state can be recovered from other machines
  3. Compared with one Actor per access token, this scheme can significantly reduce the number of system actors
  4. Compared with the cache data storage system with TTL, Actor is more flexible and can be used in the same cluster for business (user) systems
  5. Invalid access? Token will not generate more than actors

The disadvantages of one Actor per user are as follows:

  1. The generated access token does not need to contain business information, such as user ID
  2. It can be seen from the number of lines of code that the code logic of this scheme is relatively complex, but its function is more powerful than that of one Actor per access_token!

Summary

The full source code can be found in Github https://github.com/yangbajing/yangbajing-blog/tree/master/src/main/scala/blog/oauth2 .

Topics: Programming github Scala