import { InMemoryStorage } from '@st/util/key-value-storage'
import { Channel, Socket } from 'phoenix'

export type StateLoad = {
  state: any
  vsn: number
}

export type StateDiff = {
  diff: any
  fromVsn: number
  toVsn: number
}

type InnerSocket = Socket
type RemoteSocketOpts = {
  endpoint: string
}
export class RemoteSocket {
  private innerSocket: InnerSocket

  /**
   * A reference counted map of topic -> count
   * When it goes to zero, we delete it from this map
   *
   * There is logic to declaratively maintain a list of channels we want to be connected to
   */
  private topics: Record<string, number> = {}

  private channels: Record<string, Channel> = {}

  constructor(opts: RemoteSocketOpts) {
    this.innerSocket = new Socket(opts.endpoint, {
      sessionStorage: new InMemoryStorage({ syncTabs: false }),
      logger(kind, message, data) {
        // console.log(kind, message, data)
      }
    }) as InnerSocket
  }

  connect({ token }: { token: string }) {
    this.innerSocket.connect({ token })
  }

  disconnect() {
    this.innerSocket.disconnect()
    this.channels = {}
  }

  getTopics() {
    return this.topics
  }

  subscribe(topic: string) {
    this.topics[topic] = this.topics[topic] ?? 0
    this.topics[topic]++

    this.syncChannelStatesWithTopics()

    // unsubscribe
    return () => {
      this.topics[topic]--
      if (this.topics[topic] == 0) {
        delete this.topics[topic]
      }
      this.syncChannelStatesWithTopics()
    }
  }

  on<T = any>(topic: string, eventName: string, callback: (msg: T) => void): () => void {
    if (!(topic in this.channels)) {
      throw `No channel named ${topic} is connected to`
    }

    const channel = this.channels[topic]
    const ref = channel.on(eventName, callback)

    // unsubscribe
    return () => {
      channel.off(eventName, ref)
    }
  }

  private syncChannelStatesWithTopics() {
    const topicToChannelState = getChannelStatesByTopic(Object.values(this.channels))
    const ops = getChannelOps(topicToChannelState, Object.keys(this.topics))

    for (const op of ops) {
      switch (op.type) {
        case 'join':
          const channel = this.innerSocket.channel(op.topic)
          this.channels[channel.topic] = channel
          channel.join()
          console.log('channel.join', channel.topic)
          break
        case 'leave':
          this.channels[op.topic].leave()
          delete this.channels[op.topic]
          console.log('channel.leave', op.topic)
          break
      }
    }
  }
}

type TopicToChannelState = Record<string, Channel['state']>
function getChannelStatesByTopic(channels: Channel[]): TopicToChannelState {
  const states: TopicToChannelState = {}
  for (const c of channels) {
    states[c.topic] = c.state
  }
  return states
}

type ChannelOp = { type: 'join'; topic: string } | { type: 'leave'; topic: string }
function getChannelOps(channelStates: TopicToChannelState, desiredTopics: string[]): ChannelOp[] {
  const ops: ChannelOp[] = []
  // add join ops - desiredTopics that are not yet joining
  for (const topic of desiredTopics) {
    const currentState = channelStates[topic] ?? 'closed'
    switch (currentState) {
      case 'closed':
      case 'errored':
      case 'leaving':
        ops.push({ type: 'join', topic })
        break
      case 'joining':
      case 'joined':
        // already joining
        break
    }
  }
  // add remove ops - joined/joining but it's not a desiredTopic
  for (const [topic, state] of Object.entries(channelStates)) {
    if ((state == 'joined' || state == 'joining') && !desiredTopics.includes(topic)) {
      ops.push({ type: 'leave', topic })
    }
  }
  return ops
}
