diff options
author | 2024-11-05 18:17:33 -0500 | |
---|---|---|
committer | 2025-01-03 15:52:15 -0500 | |
commit | 02946ae243e89bef2bae8087e91262e2ae8f499b (patch) | |
tree | a16025b2be173f8d2f649a54bdee9a37bc8a582d | |
parent | 7a10baee303d855ee04179698f680627705a7693 (diff) |
[kairos] generalize node storage
Rather than using (Mutable)Maps for all internal storage, provide a
mechanism by which custom Map impls can be used.
Flag: EXEMPT unused
Test: atest kairos-tests
Change-Id: Ic75286db27426cda7f41d1f2fdba138b1ce66e2f
21 files changed, 1195 insertions, 623 deletions
diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/FrpStateScope.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/FrpStateScope.kt index c7ea6808a53e..058fc1037e58 100644 --- a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/FrpStateScope.kt +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/FrpStateScope.kt @@ -85,7 +85,7 @@ interface FrpStateScope : FrpTransactionScope { * @see merge */ @ExperimentalFrpApi - fun <K : Any, V> TFlow<Map<K, Maybe<TFlow<V>>>>.mergeIncrementally( + fun <K, V> TFlow<Map<K, Maybe<TFlow<V>>>>.mergeIncrementally( initialTFlows: FrpDeferredValue<Map<K, TFlow<V>>> ): TFlow<Map<K, V>> @@ -107,7 +107,7 @@ interface FrpStateScope : FrpTransactionScope { * @see merge */ @ExperimentalFrpApi - fun <K : Any, V> TFlow<Map<K, Maybe<TFlow<V>>>>.mergeIncrementallyPromptly( + fun <K, V> TFlow<Map<K, Maybe<TFlow<V>>>>.mergeIncrementallyPromptly( initialTFlows: FrpDeferredValue<Map<K, TFlow<V>>> ): TFlow<Map<K, V>> @@ -131,7 +131,7 @@ interface FrpStateScope : FrpTransactionScope { * @see merge */ @ExperimentalFrpApi - fun <K : Any, V> TFlow<Map<K, Maybe<TFlow<V>>>>.mergeIncrementally( + fun <K, V> TFlow<Map<K, Maybe<TFlow<V>>>>.mergeIncrementally( initialTFlows: Map<K, TFlow<V>> = emptyMap() ): TFlow<Map<K, V>> = mergeIncrementally(deferredOf(initialTFlows)) @@ -153,7 +153,7 @@ interface FrpStateScope : FrpTransactionScope { * @see merge */ @ExperimentalFrpApi - fun <K : Any, V> TFlow<Map<K, Maybe<TFlow<V>>>>.mergeIncrementallyPromptly( + fun <K, V> TFlow<Map<K, Maybe<TFlow<V>>>>.mergeIncrementallyPromptly( initialTFlows: Map<K, TFlow<V>> = emptyMap() ): TFlow<Map<K, V>> = mergeIncrementallyPromptly(deferredOf(initialTFlows)) diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/TFlow.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/TFlow.kt index 1d8fe116d57b..362a890f44e2 100644 --- a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/TFlow.kt +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/TFlow.kt @@ -26,6 +26,7 @@ import com.android.systemui.kairos.internal.TFlowImpl import com.android.systemui.kairos.internal.activated import com.android.systemui.kairos.internal.cached import com.android.systemui.kairos.internal.constInit +import com.android.systemui.kairos.internal.demuxMap import com.android.systemui.kairos.internal.filterImpl import com.android.systemui.kairos.internal.filterJustImpl import com.android.systemui.kairos.internal.init @@ -35,7 +36,7 @@ import com.android.systemui.kairos.internal.mergeNodes import com.android.systemui.kairos.internal.mergeNodesLeft import com.android.systemui.kairos.internal.neverImpl import com.android.systemui.kairos.internal.switchDeferredImplSingle -import com.android.systemui.kairos.internal.switchPromptImpl +import com.android.systemui.kairos.internal.switchPromptImplSingle import com.android.systemui.kairos.internal.util.hashString import com.android.systemui.kairos.util.Either import com.android.systemui.kairos.util.Left @@ -344,7 +345,7 @@ fun <K, A> Map<K, TFlow<A>>.merge(): TFlow<Map<K, A>> = */ @ExperimentalFrpApi fun <K, A> TFlow<Map<K, A>>.groupByKey(numKeys: Int? = null): GroupedTFlow<K, A> = - GroupedTFlow(DemuxImpl({ init.connect(this) }, numKeys)) + GroupedTFlow(demuxMap({ init.connect(this) }, numKeys)) /** * Shorthand for `map { mapOf(extractKey(it) to it) }.groupByKey()` @@ -417,8 +418,8 @@ class GroupedTFlow<in K, out A> internal constructor(internal val impl: DemuxImp * that takes effect immediately, see [switchPromptly]. */ @ExperimentalFrpApi -fun <A> TState<TFlow<A>>.switch(): TFlow<A> { - return TFlowInit( +fun <A> TState<TFlow<A>>.switch(): TFlow<A> = + TFlowInit( constInit( name = null, switchDeferredImplSingle( @@ -433,7 +434,6 @@ fun <A> TState<TFlow<A>>.switch(): TFlow<A> { ), ) ) -} /** * Returns a [TFlow] that switches to the [TFlow] contained within this [TState] whenever it @@ -444,21 +444,22 @@ fun <A> TState<TFlow<A>>.switch(): TFlow<A> { */ // TODO: parameter to handle coincidental emission from both old and new @ExperimentalFrpApi -fun <A> TState<TFlow<A>>.switchPromptly(): TFlow<A> { - val switchNode = - switchPromptImpl( - getStorage = { - mapOf(Unit to init.connect(this).getCurrentWithEpoch(this).first.init.connect(this)) - }, - getPatches = { - val patches = init.connect(this).changes - mapImpl({ patches }) { newFlow -> mapOf(Unit to just(newFlow.init.connect(this))) } - }, +fun <A> TState<TFlow<A>>.switchPromptly(): TFlow<A> = + TFlowInit( + constInit( + name = null, + switchPromptImplSingle( + getStorage = { + init.connect(this).getCurrentWithEpoch(this).first.init.connect(this) + }, + getPatches = { + mapImpl({ init.connect(this).changes }) { newFlow -> + newFlow.init.connect(this) + } + }, + ), ) - return TFlowInit( - constInit(name = null, mapImpl({ switchNode }) { it.getValue(Unit).getPushEvent(this) }) ) -} /** * A mutable [TFlow] that provides the ability to [emit] values to the flow, handling backpressure diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/TState.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/TState.kt index 8ad5f55adca3..66aa2a950fcf 100644 --- a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/TState.kt +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/TState.kt @@ -36,6 +36,7 @@ import com.android.systemui.kairos.internal.map import com.android.systemui.kairos.internal.mapCheap import com.android.systemui.kairos.internal.mapImpl import com.android.systemui.kairos.internal.util.hashString +import com.android.systemui.kairos.internal.zipStateMap import com.android.systemui.kairos.internal.zipStates import kotlin.reflect.KProperty import kotlinx.coroutines.CompletableDeferred @@ -159,12 +160,12 @@ fun <A> Iterable<TState<A>>.combine(): TState<List<A>> { * @see TState.combineWith */ @ExperimentalFrpApi -fun <K : Any, A> Map<K, TState<A>>.combine(): TState<Map<K, A>> { +fun <K, A> Map<K, TState<A>>.combine(): TState<Map<K, A>> { val operatorName = "combine" val name = operatorName return TStateInit( init(name) { - zipStates( + zipStateMap( name, operatorName, states = mapValues { it.value.init.connect(evalScope = this) }, diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/debug/Debug.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/debug/Debug.kt index 0674a2e75659..6f9612fab70a 100644 --- a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/debug/Debug.kt +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/debug/Debug.kt @@ -133,7 +133,7 @@ internal fun TStateImpl<*>.dump(infoById: MutableMap<Any, InitInfo>, edges: Muta edges.add(Edge(upstream = state.upstream, downstream = state)) Mapped(cheap = false) } - is DerivedZipped<*, *> -> { + is DerivedZipped<*, *, *> -> { state.upstream.forEach { (key, upstream) -> edges.add( Edge(upstream = upstream, downstream = state, tag = "key=$key") diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Demux.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Demux.kt index dd46fe202413..5f652525f036 100644 --- a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Demux.kt +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Demux.kt @@ -16,17 +16,20 @@ package com.android.systemui.kairos.internal +import com.android.systemui.kairos.internal.store.ConcurrentHashMapK +import com.android.systemui.kairos.internal.store.MapHolder +import com.android.systemui.kairos.internal.store.MapK +import com.android.systemui.kairos.internal.store.MutableMapK import com.android.systemui.kairos.internal.util.hashString -import java.util.concurrent.ConcurrentHashMap import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.launch import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock -internal class DemuxNode<K, A>( - private val branchNodeByKey: ConcurrentHashMap<K, DemuxNode<K, A>.BranchNode>, +internal class DemuxNode<W, K, A>( + private val branchNodeByKey: MutableMapK<W, K, DemuxNode<W, K, A>.BranchNode>, val lifecycle: DemuxLifecycle<K, A>, - private val spec: DemuxActivator<K, A>, + private val spec: DemuxActivator<W, K, A>, ) : SchedulableNode { val schedulable = Schedulable.N(this) @@ -34,7 +37,7 @@ internal class DemuxNode<K, A>( inline val mutex get() = lifecycle.mutex - lateinit var upstreamConnection: NodeConnection<Map<K, A>> + lateinit var upstreamConnection: NodeConnection<MapK<W, K, A>> @Volatile private var epoch: Long = Long.MIN_VALUE @@ -52,7 +55,10 @@ internal class DemuxNode<K, A>( mutex.withLock { updateEpoch(evalScope) for ((key, _) in upstreamResult) { - branchNodeByKey[key]?.let { branch -> launch { branch.schedule(evalScope) } } + if (key !in branchNodeByKey) continue + val branch = branchNodeByKey.getValue(key) + // TODO: launchImmediate? + launch { branch.schedule(evalScope) } } } } @@ -75,7 +81,7 @@ internal class DemuxNode<K, A>( override suspend fun moveIndirectUpstreamToDirect( scheduler: Scheduler, oldIndirectDepth: Int, - oldIndirectSet: Set<MuxDeferredNode<*, *>>, + oldIndirectSet: Set<MuxDeferredNode<*, *, *>>, newDirectDepth: Int, ) { coroutineScope { @@ -97,8 +103,8 @@ internal class DemuxNode<K, A>( scheduler: Scheduler, oldDepth: Int, newDepth: Int, - removals: Set<MuxDeferredNode<*, *>>, - additions: Set<MuxDeferredNode<*, *>>, + removals: Set<MuxDeferredNode<*, *, *>>, + additions: Set<MuxDeferredNode<*, *, *>>, ) { coroutineScope { mutex.withLock { @@ -120,7 +126,7 @@ internal class DemuxNode<K, A>( scheduler: Scheduler, oldDirectDepth: Int, newIndirectDepth: Int, - newIndirectSet: Set<MuxDeferredNode<*, *>>, + newIndirectSet: Set<MuxDeferredNode<*, *, *>>, ) { coroutineScope { mutex.withLock { @@ -140,7 +146,7 @@ internal class DemuxNode<K, A>( override suspend fun removeIndirectUpstream( scheduler: Scheduler, depth: Int, - indirectSet: Set<MuxDeferredNode<*, *>>, + indirectSet: Set<MuxDeferredNode<*, *, *>>, ) { coroutineScope { mutex.withLock { @@ -245,35 +251,45 @@ internal class DemuxNode<K, A>( } } -internal fun <K, A> DemuxImpl( +internal fun <W, K, A> DemuxImpl( + upstream: TFlowImpl<MapK<W, K, A>>, + numKeys: Int?, + storeFactory: MutableMapK.Factory<W, K>, +): DemuxImpl<K, A> = + DemuxImpl( + DemuxLifecycle( + DemuxLifecycleState.Inactive(DemuxActivator(numKeys, upstream, storeFactory)) + ) + ) + +internal fun <K, A> demuxMap( upstream: suspend EvalScope.() -> TFlowImpl<Map<K, A>>, numKeys: Int?, ): DemuxImpl<K, A> = - DemuxImpl(DemuxLifecycle(DemuxLifecycleState.Inactive(DemuxActivator(numKeys, upstream)))) + DemuxImpl(mapImpl(upstream) { MapHolder(it) }, numKeys, ConcurrentHashMapK.Factory()) -internal class DemuxActivator<K, A>( +internal class DemuxActivator<W, K, A>( private val numKeys: Int?, - private val upstream: suspend EvalScope.() -> TFlowImpl<Map<K, A>>, + private val upstream: TFlowImpl<MapK<W, K, A>>, + private val storeFactory: MutableMapK.Factory<W, K>, ) { suspend fun activate( evalScope: EvalScope, lifecycle: DemuxLifecycle<K, A>, - ): Pair<DemuxNode<K, A>, Set<K>>? { - val demux = DemuxNode(ConcurrentHashMap(numKeys ?: 16), lifecycle, this) - return upstream - .invoke(evalScope) - .activate(evalScope, downstream = demux.schedulable) - ?.let { (conn, needsEval) -> - Pair( - demux.apply { upstreamConnection = conn }, - if (needsEval) { - demux.updateEpoch(evalScope) - conn.getPushEvent(evalScope).keys - } else { - emptySet() - }, - ) - } + ): Pair<DemuxNode<W, K, A>, Set<K>>? { + val demux = DemuxNode(storeFactory.create(numKeys), lifecycle, this) + return upstream.activate(evalScope, downstream = demux.schedulable)?.let { (conn, needsEval) + -> + Pair( + demux.apply { upstreamConnection = conn }, + if (needsEval) { + demux.updateEpoch(evalScope) + conn.getPushEvent(evalScope).keys + } else { + emptySet() + }, + ) + } } } @@ -295,7 +311,10 @@ internal class DemuxLifecycle<K, A>(@Volatile var lifecycleState: DemuxLifecycle override fun toString(): String = "TFlowDmuxState[$hashString][$lifecycleState][$mutex]" - suspend fun activate(evalScope: EvalScope, key: K): Pair<DemuxNode<K, A>.BranchNode, Boolean>? = + suspend fun activate( + evalScope: EvalScope, + key: K, + ): Pair<DemuxNode<*, K, A>.BranchNode, Boolean>? = mutex.withLock { when (val state = lifecycleState) { is DemuxLifecycleState.Dead -> null @@ -322,11 +341,11 @@ internal class DemuxLifecycle<K, A>(@Volatile var lifecycleState: DemuxLifecycle } internal sealed interface DemuxLifecycleState<out K, out A> { - class Inactive<K, A>(val spec: DemuxActivator<K, A>) : DemuxLifecycleState<K, A> { + class Inactive<K, A>(val spec: DemuxActivator<*, K, A>) : DemuxLifecycleState<K, A> { override fun toString(): String = "Inactive" } - class Active<K, A>(val node: DemuxNode<K, A>) : DemuxLifecycleState<K, A> { + class Active<K, A>(val node: DemuxNode<*, K, A>) : DemuxLifecycleState<K, A> { override fun toString(): String = "Active(node=$node)" } diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/FilterNode.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/FilterNode.kt index 030119394ac0..b60c227bcfbe 100644 --- a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/FilterNode.kt +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/FilterNode.kt @@ -16,6 +16,8 @@ package com.android.systemui.kairos.internal +import com.android.systemui.kairos.internal.store.Single +import com.android.systemui.kairos.internal.store.SingletonMapK import com.android.systemui.kairos.util.Just import com.android.systemui.kairos.util.Maybe import com.android.systemui.kairos.util.just @@ -25,16 +27,15 @@ internal inline fun <A> filterJustImpl( crossinline getPulse: suspend EvalScope.() -> TFlowImpl<Maybe<A>> ): TFlowImpl<A> = DemuxImpl( - { - mapImpl(getPulse) { maybeResult -> - if (maybeResult is Just) { - mapOf(Unit to maybeResult.value) - } else { - emptyMap() - } + mapImpl(getPulse) { maybeResult -> + if (maybeResult is Just) { + Single(maybeResult.value) + } else { + Single<A>() } }, numKeys = 1, + storeFactory = SingletonMapK.Factory(), ) .eventsForKey(Unit) diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Graph.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Graph.kt index 04ce5b6d8785..828f13b026d3 100644 --- a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Graph.kt +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Graph.kt @@ -72,21 +72,21 @@ internal class DepthTracker { @Volatile var snapshotIndirectDepth: Int = 0 @Volatile var snapshotDirectDepth: Int = 0 - private val _snapshotIndirectRoots = HashSet<MuxDeferredNode<*, *>>() + private val _snapshotIndirectRoots = HashSet<MuxDeferredNode<*, *, *>>() val snapshotIndirectRoots get() = _snapshotIndirectRoots.toSet() - private val indirectAdditions = HashSet<MuxDeferredNode<*, *>>() - private val indirectRemovals = HashSet<MuxDeferredNode<*, *>>() + private val indirectAdditions = HashSet<MuxDeferredNode<*, *, *>>() + private val indirectRemovals = HashSet<MuxDeferredNode<*, *, *>>() private val dirty_directUpstreamDepths = TreeMap<Int, Int>() private val dirty_indirectUpstreamDepths = TreeMap<Int, Int>() - private val dirty_indirectUpstreamRoots = Bag<MuxDeferredNode<*, *>>() + private val dirty_indirectUpstreamRoots = Bag<MuxDeferredNode<*, *, *>>() @Volatile var dirty_directDepth = 0 @Volatile private var dirty_indirectDepth = 0 @Volatile private var dirty_depthIsDirect = true @Volatile private var dirty_isIndirectRoot = false - fun schedule(scheduler: Scheduler, node: MuxNode<*, *, *>) { + fun schedule(scheduler: Scheduler, node: MuxNode<*, *, *, *>) { if (dirty_depthIsDirect) { scheduler.schedule(dirty_directDepth, node) } else { @@ -161,9 +161,9 @@ internal class DepthTracker { } fun updateIndirectRoots( - additions: Set<MuxDeferredNode<*, *>>? = null, - removals: Set<MuxDeferredNode<*, *>>? = null, - butNot: MuxDeferredNode<*, *>? = null, + additions: Set<MuxDeferredNode<*, *, *>>? = null, + removals: Set<MuxDeferredNode<*, *, *>>? = null, + butNot: MuxDeferredNode<*, *, *>? = null, ): Boolean { val addsChanged = additions @@ -192,7 +192,7 @@ internal class DepthTracker { return remainder } - suspend fun propagateChanges(scheduler: Scheduler, muxNode: MuxNode<*, *, *>) { + suspend fun propagateChanges(scheduler: Scheduler, muxNode: MuxNode<*, *, *, *>) { if (isDirty()) { schedule(scheduler, muxNode) } @@ -202,7 +202,7 @@ internal class DepthTracker { coroutineScope: CoroutineScope, scheduler: Scheduler, downstreamSet: DownstreamSet, - muxNode: MuxNode<*, *, *>, + muxNode: MuxNode<*, *, *, *>, ) { when { dirty_depthIsDirect -> { @@ -222,7 +222,7 @@ internal class DepthTracker { buildSet { addAll(snapshotIndirectRoots) if (snapshotIsIndirectRoot) { - add(muxNode as MuxDeferredNode<*, *>) + add(muxNode as MuxDeferredNode<*, *, *>) } }, newDirectDepth = dirty_directDepth, @@ -241,7 +241,7 @@ internal class DepthTracker { buildSet { addAll(dirty_indirectUpstreamRoots) if (dirty_isIndirectRoot) { - add(muxNode as MuxDeferredNode<*, *>) + add(muxNode as MuxDeferredNode<*, *, *>) } }, ) @@ -255,14 +255,14 @@ internal class DepthTracker { buildSet { addAll(indirectRemovals) if (snapshotIsIndirectRoot && !dirty_isIndirectRoot) { - add(muxNode as MuxDeferredNode<*, *>) + add(muxNode as MuxDeferredNode<*, *, *>) } }, additions = buildSet { addAll(indirectAdditions) if (!snapshotIsIndirectRoot && dirty_isIndirectRoot) { - add(muxNode as MuxDeferredNode<*, *>) + add(muxNode as MuxDeferredNode<*, *, *>) } }, ) @@ -288,7 +288,7 @@ internal class DepthTracker { buildSet { addAll(snapshotIndirectRoots) if (snapshotIsIndirectRoot) { - add(muxNode as MuxDeferredNode<*, *>) + add(muxNode as MuxDeferredNode<*, *, *>) } }, ) @@ -353,7 +353,7 @@ internal class DownstreamSet { val outputs = HashSet<Output<*>>() val stateWriters = mutableListOf<TStateSource<*>>() - val muxMovers = HashSet<MuxDeferredNode<*, *>>() + val muxMovers = HashSet<MuxDeferredNode<*, *, *>>() val nodes = HashSet<SchedulableNode>() fun add(schedulable: Schedulable) { @@ -390,7 +390,7 @@ internal class DownstreamSet { coroutineScope: CoroutineScope, scheduler: Scheduler, oldIndirectDepth: Int, - oldIndirectSet: Set<MuxDeferredNode<*, *>>, + oldIndirectSet: Set<MuxDeferredNode<*, *, *>>, newDirectDepth: Int, ) = coroutineScope.run { @@ -416,8 +416,8 @@ internal class DownstreamSet { scheduler: Scheduler, oldDepth: Int, newDepth: Int, - removals: Set<MuxDeferredNode<*, *>>, - additions: Set<MuxDeferredNode<*, *>>, + removals: Set<MuxDeferredNode<*, *, *>>, + additions: Set<MuxDeferredNode<*, *, *>>, ) = coroutineScope.run { for (node in nodes) { @@ -443,7 +443,7 @@ internal class DownstreamSet { scheduler: Scheduler, oldDirectDepth: Int, newIndirectDepth: Int, - newIndirectSet: Set<MuxDeferredNode<*, *>>, + newIndirectSet: Set<MuxDeferredNode<*, *, *>>, ) = coroutineScope.run { for (node in nodes) { @@ -467,7 +467,7 @@ internal class DownstreamSet { coroutineScope: CoroutineScope, scheduler: Scheduler, depth: Int, - indirectSet: Set<MuxDeferredNode<*, *>>, + indirectSet: Set<MuxDeferredNode<*, *, *>>, ) = coroutineScope.run { for (node in nodes) { @@ -506,7 +506,7 @@ internal class DownstreamSet { internal sealed interface Schedulable { data class S constructor(val state: TStateSource<*>) : Schedulable - data class M constructor(val muxMover: MuxDeferredNode<*, *>) : Schedulable + data class M constructor(val muxMover: MuxDeferredNode<*, *, *>) : Schedulable data class N constructor(val node: SchedulableNode) : Schedulable diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/InternalScopes.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/InternalScopes.kt index 69ecafd26ba2..80c40ba740a5 100644 --- a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/InternalScopes.kt +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/InternalScopes.kt @@ -61,7 +61,7 @@ internal interface NetworkScope : InitScope { fun scheduleOutput(output: Output<*>) - fun scheduleMuxMover(muxMover: MuxDeferredNode<*, *>) + fun scheduleMuxMover(muxMover: MuxDeferredNode<*, *, *>) fun schedule(state: TStateSource<*>) diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Mux.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Mux.kt index 1fc5470ef354..a479c90cc4de 100644 --- a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Mux.kt +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Mux.kt @@ -18,25 +18,32 @@ package com.android.systemui.kairos.internal -import com.android.systemui.kairos.internal.util.ConcurrentNullableHashMap +import com.android.systemui.kairos.internal.store.MapHolder +import com.android.systemui.kairos.internal.store.MapK +import com.android.systemui.kairos.internal.store.MutableMapK +import com.android.systemui.kairos.internal.store.asMapHolder +import com.android.systemui.kairos.internal.util.asyncImmediate import com.android.systemui.kairos.internal.util.hashString -import java.util.concurrent.ConcurrentHashMap +import kotlinx.coroutines.CoroutineStart +import kotlinx.coroutines.awaitAll import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock -internal typealias MuxResult<K, V> = Map<K, PullNode<V>> +internal typealias MuxResult<W, K, V> = MapK<W, K, PullNode<V>> /** Base class for muxing nodes, which have a (potentially dynamic) collection of upstream nodes. */ -internal sealed class MuxNode<K : Any, V, Output>(val lifecycle: MuxLifecycle<Output>) : - PushNode<Output> { +internal sealed class MuxNode<W, K, V, Output>( + val lifecycle: MuxLifecycle<Output>, + protected val storeFactory: MutableMapK.Factory<W, K>, +) : PushNode<Output> { inline val mutex get() = lifecycle.mutex - // TODO: preserve insertion order? - val upstreamData = ConcurrentNullableHashMap<K, PullNode<V>>() - val switchedIn = ConcurrentHashMap<K, MuxBranchNode<K, V>>() + @Volatile lateinit var upstreamData: MutableMapK<W, K, PullNode<V>> + @Volatile lateinit var switchedIn: MutableMapK<W, K, BranchNode> + val downstreamSet: DownstreamSet = DownstreamSet() // TODO: inline DepthTracker? would need to be added to PushNode signature @@ -110,7 +117,7 @@ internal sealed class MuxNode<K : Any, V, Output>(val lifecycle: MuxLifecycle<Ou suspend fun moveIndirectUpstreamToDirect( scheduler: Scheduler, oldIndirectDepth: Int, - oldIndirectRoots: Set<MuxDeferredNode<*, *>>, + oldIndirectRoots: Set<MuxDeferredNode<*, *, *>>, newDepth: Int, ) { mutex.withLock { @@ -128,8 +135,8 @@ internal sealed class MuxNode<K : Any, V, Output>(val lifecycle: MuxLifecycle<Ou scheduler: Scheduler, oldDepth: Int, newDepth: Int, - removals: Set<MuxDeferredNode<*, *>>, - additions: Set<MuxDeferredNode<*, *>>, + removals: Set<MuxDeferredNode<*, *, *>>, + additions: Set<MuxDeferredNode<*, *, *>>, ) { mutex.withLock { if ( @@ -137,7 +144,7 @@ internal sealed class MuxNode<K : Any, V, Output>(val lifecycle: MuxLifecycle<Ou depthTracker.updateIndirectRoots( additions, removals, - butNot = this as? MuxDeferredNode<*, *>, + butNot = this as? MuxDeferredNode<*, *, *>, ) ) { depthTracker.schedule(scheduler, this) @@ -149,7 +156,7 @@ internal sealed class MuxNode<K : Any, V, Output>(val lifecycle: MuxLifecycle<Ou scheduler: Scheduler, oldDepth: Int, newDepth: Int, - newIndirectSet: Set<MuxDeferredNode<*, *>>, + newIndirectSet: Set<MuxDeferredNode<*, *, *>>, ) { mutex.withLock { if ( @@ -157,7 +164,7 @@ internal sealed class MuxNode<K : Any, V, Output>(val lifecycle: MuxLifecycle<Ou depthTracker.removeDirectUpstream(oldDepth) or depthTracker.updateIndirectRoots( additions = newIndirectSet, - butNot = this as? MuxDeferredNode<*, *>, + butNot = this as? MuxDeferredNode<*, *, *>, ) ) { depthTracker.schedule(scheduler, this) @@ -177,7 +184,7 @@ internal sealed class MuxNode<K : Any, V, Output>(val lifecycle: MuxLifecycle<Ou suspend fun removeIndirectUpstream( scheduler: Scheduler, oldDepth: Int, - indirectSet: Set<MuxDeferredNode<*, *>>, + indirectSet: Set<MuxDeferredNode<*, *, *>>, key: K, ) { mutex.withLock { @@ -204,78 +211,83 @@ internal sealed class MuxNode<K : Any, V, Output>(val lifecycle: MuxLifecycle<Ou // MuxNode as a Pull (effectively making it a mapCheap). depthTracker.schedule(evalScope.scheduler, this) } -} -/** An input branch of a mux node, associated with a key. */ -internal class MuxBranchNode<K : Any, V>(private val muxNode: MuxNode<K, V, *>, val key: K) : - SchedulableNode { + /** An input branch of a mux node, associated with a key. */ + inner class BranchNode(val key: K) : SchedulableNode { - val schedulable = Schedulable.N(this) + val schedulable = Schedulable.N(this) - @Volatile lateinit var upstream: NodeConnection<V> + @Volatile lateinit var upstream: NodeConnection<V> - override suspend fun schedule(evalScope: EvalScope) { - muxNode.upstreamData[key] = upstream.directUpstream - muxNode.schedule(evalScope) - } + override suspend fun schedule(evalScope: EvalScope) { + upstreamData[key] = upstream.directUpstream + this@MuxNode.schedule(evalScope) + } - override suspend fun adjustDirectUpstream(scheduler: Scheduler, oldDepth: Int, newDepth: Int) { - muxNode.adjustDirectUpstream(scheduler, oldDepth, newDepth) - } + override suspend fun adjustDirectUpstream( + scheduler: Scheduler, + oldDepth: Int, + newDepth: Int, + ) { + this@MuxNode.adjustDirectUpstream(scheduler, oldDepth, newDepth) + } - override suspend fun moveIndirectUpstreamToDirect( - scheduler: Scheduler, - oldIndirectDepth: Int, - oldIndirectSet: Set<MuxDeferredNode<*, *>>, - newDirectDepth: Int, - ) { - muxNode.moveIndirectUpstreamToDirect( - scheduler, - oldIndirectDepth, - oldIndirectSet, - newDirectDepth, - ) - } + override suspend fun moveIndirectUpstreamToDirect( + scheduler: Scheduler, + oldIndirectDepth: Int, + oldIndirectSet: Set<MuxDeferredNode<*, *, *>>, + newDirectDepth: Int, + ) { + this@MuxNode.moveIndirectUpstreamToDirect( + scheduler, + oldIndirectDepth, + oldIndirectSet, + newDirectDepth, + ) + } - override suspend fun adjustIndirectUpstream( - scheduler: Scheduler, - oldDepth: Int, - newDepth: Int, - removals: Set<MuxDeferredNode<*, *>>, - additions: Set<MuxDeferredNode<*, *>>, - ) { - muxNode.adjustIndirectUpstream(scheduler, oldDepth, newDepth, removals, additions) - } + override suspend fun adjustIndirectUpstream( + scheduler: Scheduler, + oldDepth: Int, + newDepth: Int, + removals: Set<MuxDeferredNode<*, *, *>>, + additions: Set<MuxDeferredNode<*, *, *>>, + ) { + this@MuxNode.adjustIndirectUpstream(scheduler, oldDepth, newDepth, removals, additions) + } - override suspend fun moveDirectUpstreamToIndirect( - scheduler: Scheduler, - oldDirectDepth: Int, - newIndirectDepth: Int, - newIndirectSet: Set<MuxDeferredNode<*, *>>, - ) { - muxNode.moveDirectUpstreamToIndirect( - scheduler, - oldDirectDepth, - newIndirectDepth, - newIndirectSet, - ) - } + override suspend fun moveDirectUpstreamToIndirect( + scheduler: Scheduler, + oldDirectDepth: Int, + newIndirectDepth: Int, + newIndirectSet: Set<MuxDeferredNode<*, *, *>>, + ) { + this@MuxNode.moveDirectUpstreamToIndirect( + scheduler, + oldDirectDepth, + newIndirectDepth, + newIndirectSet, + ) + } - override suspend fun removeDirectUpstream(scheduler: Scheduler, depth: Int) { - muxNode.removeDirectUpstream(scheduler, depth, key) - } + override suspend fun removeDirectUpstream(scheduler: Scheduler, depth: Int) { + removeDirectUpstream(scheduler, depth, key) + } - override suspend fun removeIndirectUpstream( - scheduler: Scheduler, - depth: Int, - indirectSet: Set<MuxDeferredNode<*, *>>, - ) { - muxNode.removeIndirectUpstream(scheduler, depth, indirectSet, key) - } + override suspend fun removeIndirectUpstream( + scheduler: Scheduler, + depth: Int, + indirectSet: Set<MuxDeferredNode<*, *, *>>, + ) { + removeIndirectUpstream(scheduler, depth, indirectSet, key) + } - override fun toString(): String = "MuxBranchNode(key=$key, mux=$muxNode)" + override fun toString(): String = "MuxBranchNode(key=$key, mux=${this@MuxNode})" + } } +internal typealias BranchNode<W, K, V> = MuxNode<W, K, V, *>.BranchNode + /** Tracks lifecycle of MuxNode in the network. Essentially a mutable ref for MuxLifecycleState. */ internal class MuxLifecycle<A>(@Volatile var lifecycleState: MuxLifecycleState<A>) : TFlowImpl<A> { val mutex = Mutex() @@ -324,7 +336,7 @@ internal sealed interface MuxLifecycleState<out A> { override fun toString(): String = "Inactive" } - class Active<A>(val node: MuxNode<*, *, A>) : MuxLifecycleState<A> { + class Active<A>(val node: MuxNode<*, *, *, A>) : MuxLifecycleState<A> { override fun toString(): String = "Active(node=$node)" } @@ -332,11 +344,71 @@ internal sealed interface MuxLifecycleState<out A> { } internal interface MuxActivator<A> { - suspend fun activate(evalScope: EvalScope, lifecycle: MuxLifecycle<A>): MuxNode<*, *, A>? + suspend fun activate(evalScope: EvalScope, lifecycle: MuxLifecycle<A>): MuxNode<*, *, *, A>? } internal inline fun <A> MuxLifecycle(onSubscribe: MuxActivator<A>): TFlowImpl<A> = MuxLifecycle(MuxLifecycleState.Inactive(onSubscribe)) -internal fun <K, V> TFlowImpl<MuxResult<K, V>>.awaitValues(): TFlowImpl<Map<K, V>> = - mapImpl({ this@awaitValues }) { results -> results.mapValues { it.value.getPushEvent(this) } } +internal fun <K, V> TFlowImpl<MuxResult<MapHolder.W, K, V>>.awaitValues(): TFlowImpl<Map<K, V>> = + mapImpl({ this@awaitValues }) { results -> + results.asMapHolder().unwrapped.mapValues { it.value.getPushEvent(this) } + } + +// activation logic + +internal suspend fun <W, K, V, O> MuxNode<W, K, V, O>.initializeUpstream( + evalScope: EvalScope, + getStorage: suspend EvalScope.() -> Iterable<Map.Entry<K, TFlowImpl<V>>>, + storeFactory: MutableMapK.Factory<W, K>, +) { + val storage = getStorage(evalScope) + coroutineScope { + val initUpstream = buildList { + storage.forEach { (key, flow) -> + val branchNode = BranchNode(key) + add( + asyncImmediate(start = CoroutineStart.LAZY) { + flow.activate(evalScope, branchNode.schedulable)?.let { (conn, needsEval) -> + Triple( + key, + branchNode.apply { upstream = conn }, + if (needsEval) conn.directUpstream else null, + ) + } + } + ) + } + } + val results = initUpstream.awaitAll() + switchedIn = storeFactory.create(initUpstream.size) + upstreamData = storeFactory.create(initUpstream.size) + for (triple in results) { + triple?.let { (key, branch, upstream) -> + switchedIn[key] = branch + upstream?.let { upstreamData[key] = upstream } + } + } + } +} + +internal fun <W, K, V, O> MuxNode<W, K, V, O>.initializeDepth() { + switchedIn.forEach { (_, branch) -> + val conn = branch.upstream + if (conn.depthTracker.snapshotIsDirect) { + depthTracker.addDirectUpstream( + oldDepth = null, + newDepth = conn.depthTracker.snapshotDirectDepth, + ) + } else { + depthTracker.addIndirectUpstream( + oldDepth = null, + newDepth = conn.depthTracker.snapshotIndirectDepth, + ) + depthTracker.updateIndirectRoots( + additions = conn.depthTracker.snapshotIndirectRoots, + butNot = null, + ) + } + } +} diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/MuxDeferred.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/MuxDeferred.kt index 338ee0145530..7f40df508fb1 100644 --- a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/MuxDeferred.kt +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/MuxDeferred.kt @@ -16,16 +16,18 @@ package com.android.systemui.kairos.internal +import com.android.systemui.kairos.internal.store.MutableArrayMapK +import com.android.systemui.kairos.internal.store.MutableMapK +import com.android.systemui.kairos.internal.store.SingletonMapK +import com.android.systemui.kairos.internal.store.StoreEntry +import com.android.systemui.kairos.internal.store.asArrayHolder +import com.android.systemui.kairos.internal.store.singleOf import com.android.systemui.kairos.internal.util.Key -import com.android.systemui.kairos.internal.util.associateByIndexTo import com.android.systemui.kairos.internal.util.hashString import com.android.systemui.kairos.internal.util.mapParallel -import com.android.systemui.kairos.internal.util.mapValuesNotNullParallelTo import com.android.systemui.kairos.util.Just -import com.android.systemui.kairos.util.Left import com.android.systemui.kairos.util.Maybe import com.android.systemui.kairos.util.None -import com.android.systemui.kairos.util.Right import com.android.systemui.kairos.util.These import com.android.systemui.kairos.util.flatMap import com.android.systemui.kairos.util.getMaybe @@ -34,27 +36,26 @@ import com.android.systemui.kairos.util.maybeThat import com.android.systemui.kairos.util.maybeThis import com.android.systemui.kairos.util.merge import com.android.systemui.kairos.util.orError -import com.android.systemui.kairos.util.partitionEithers import com.android.systemui.kairos.util.these -import java.util.TreeMap import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.launch import kotlinx.coroutines.sync.withLock -internal class MuxDeferredNode<K : Any, V>( - lifecycle: MuxLifecycle<Map<K, PullNode<V>>>, - val spec: MuxActivator<Map<K, PullNode<V>>>, -) : MuxNode<K, V, Map<K, PullNode<V>>>(lifecycle), Key<Map<K, PullNode<V>>> { +internal class MuxDeferredNode<W, K, V>( + lifecycle: MuxLifecycle<MuxResult<W, K, V>>, + val spec: MuxActivator<MuxResult<W, K, V>>, + factory: MutableMapK.Factory<W, K>, +) : MuxNode<W, K, V, MuxResult<W, K, V>>(lifecycle, factory), Key<MuxResult<W, K, V>> { val schedulable = Schedulable.M(this) - @Volatile var patches: NodeConnection<Map<K, Maybe<TFlowImpl<V>>>>? = null - @Volatile var patchData: Map<K, Maybe<TFlowImpl<V>>>? = null + @Volatile var patches: NodeConnection<Iterable<Map.Entry<K, Maybe<TFlowImpl<V>>>>>? = null + @Volatile var patchData: Iterable<Map.Entry<K, Maybe<TFlowImpl<V>>>>? = null override suspend fun visit(evalScope: EvalScope) { - val result = upstreamData.toMap() + val scheduleDownstream = upstreamData.isNotEmpty() + val result = upstreamData.readOnlyCopy() upstreamData.clear() - val scheduleDownstream = result.isNotEmpty() val compactDownstream = depthTracker.isDirty() if (scheduleDownstream || compactDownstream) { coroutineScope { @@ -79,7 +80,7 @@ internal class MuxDeferredNode<K : Any, V>( } } - override suspend fun getPushEvent(evalScope: EvalScope): Map<K, PullNode<V>> = + override suspend fun getPushEvent(evalScope: EvalScope): MuxResult<W, K, V> = evalScope.getCurrentValue(key = this) private suspend fun compactIfNeeded(evalScope: EvalScope) { @@ -94,7 +95,7 @@ internal class MuxDeferredNode<K : Any, V>( } // Process branch nodes coroutineScope { - switchedIn.values.forEach { branchNode -> + switchedIn.forEach { (_, branchNode) -> branchNode.upstream.let { launch { it.removeDownstreamAndDeactivateIfNeeded(branchNode.schedulable) } } @@ -115,23 +116,21 @@ internal class MuxDeferredNode<K : Any, V>( // fun? // We have a patch, process additions/updates and removals - val (adds, removes) = - patch - .asSequence() - .map { (k, newUpstream: Maybe<TFlowImpl<V>>) -> - when (newUpstream) { - is Just -> Left(k to newUpstream.value) - None -> Right(k) - } - } - .partitionEithers() + val adds = mutableListOf<Pair<K, TFlowImpl<V>>>() + val removes = mutableListOf<K>() + patch.forEach { (k, newUpstream) -> + when (newUpstream) { + is Just -> adds.add(k to newUpstream.value) + None -> removes.add(k) + } + } val severed = mutableListOf<NodeConnection<*>>() coroutineScope { // remove and sever removes.forEach { k -> - switchedIn.remove(k)?.let { branchNode: MuxBranchNode<K, V> -> + switchedIn.remove(k)?.let { branchNode: BranchNode -> val conn = branchNode.upstream severed.add(conn) launch { conn.removeDownstream(downstream = branchNode.schedulable) } @@ -142,13 +141,13 @@ internal class MuxDeferredNode<K : Any, V>( // add or replace adds .mapParallel { (k, newUpstream: TFlowImpl<V>) -> - val branchNode = MuxBranchNode(this@MuxDeferredNode, k) + val branchNode = BranchNode(k) k to newUpstream.activate(evalScope, branchNode.schedulable)?.let { (conn, _) -> branchNode.apply { upstream = conn } } } - .forEach { (k, newBranch: MuxBranchNode<K, V>?) -> + .forEach { (k, newBranch: BranchNode?) -> // remove old and sever, if present switchedIn.remove(k)?.let { branchNode -> val conn = branchNode.upstream @@ -204,7 +203,7 @@ internal class MuxDeferredNode<K : Any, V>( suspend fun removeIndirectPatchNode( scheduler: Scheduler, depth: Int, - indirectSet: Set<MuxDeferredNode<*, *>>, + indirectSet: Set<MuxDeferredNode<*, *, *>>, ) { // indirectly connected patches forward the indirectSet mutex.withLock { @@ -221,7 +220,7 @@ internal class MuxDeferredNode<K : Any, V>( suspend fun moveIndirectPatchNodeToDirect( scheduler: Scheduler, oldIndirectDepth: Int, - oldIndirectSet: Set<MuxDeferredNode<*, *>>, + oldIndirectSet: Set<MuxDeferredNode<*, *, *>>, ) { // directly connected patches are stored as an indirect singleton set of the patchNode mutex.withLock { @@ -238,7 +237,7 @@ internal class MuxDeferredNode<K : Any, V>( suspend fun moveDirectPatchNodeToIndirect( scheduler: Scheduler, newIndirectDepth: Int, - newIndirectSet: Set<MuxDeferredNode<*, *>>, + newIndirectSet: Set<MuxDeferredNode<*, *, *>>, ) { // indirectly connected patches forward the indirectSet mutex.withLock { @@ -256,8 +255,8 @@ internal class MuxDeferredNode<K : Any, V>( scheduler: Scheduler, oldDepth: Int, newDepth: Int, - removals: Set<MuxDeferredNode<*, *>>, - additions: Set<MuxDeferredNode<*, *>>, + removals: Set<MuxDeferredNode<*, *, *>>, + additions: Set<MuxDeferredNode<*, *, *>>, ) { // indirectly connected patches forward the indirectSet mutex.withLock { @@ -289,120 +288,98 @@ internal inline fun <A> switchDeferredImplSingle( ): TFlowImpl<A> = mapImpl({ switchDeferredImpl( - getStorage = { mapOf(Unit to getStorage()) }, - getPatches = { mapImpl(getPatches) { newFlow -> mapOf(Unit to just(newFlow)) } }, + getStorage = { singleOf(getStorage()).asIterable() }, + getPatches = { + mapImpl(getPatches) { newFlow -> singleOf(just(newFlow)).asIterable() } + }, + storeFactory = SingletonMapK.Factory(), ) }) { map -> map.getValue(Unit).getPushEvent(this) } -internal fun <K : Any, V> switchDeferredImpl( - getStorage: suspend EvalScope.() -> Map<K, TFlowImpl<V>>, - getPatches: suspend EvalScope.() -> TFlowImpl<Map<K, Maybe<TFlowImpl<V>>>>, -): TFlowImpl<Map<K, PullNode<V>>> = - MuxLifecycle( - object : MuxActivator<Map<K, PullNode<V>>> { - override suspend fun activate( - evalScope: EvalScope, - lifecycle: MuxLifecycle<Map<K, PullNode<V>>>, - ): MuxNode<*, *, Map<K, PullNode<V>>>? { - val storage: Map<K, TFlowImpl<V>> = getStorage(evalScope) - // Initialize mux node and switched-in connections. - val muxNode = - MuxDeferredNode(lifecycle, this).apply { - storage.mapValuesNotNullParallelTo(switchedIn) { (key, flow) -> - val branchNode = MuxBranchNode(this@apply, key) - flow.activate(evalScope, branchNode.schedulable)?.let { - (conn, needsEval) -> - branchNode - .apply { upstream = conn } - .also { - if (needsEval) { - upstreamData[key] = conn.directUpstream - } - } - } - } - } +internal fun <W, K, V> switchDeferredImpl( + getStorage: suspend EvalScope.() -> Iterable<Map.Entry<K, TFlowImpl<V>>>, + getPatches: suspend EvalScope.() -> TFlowImpl<Iterable<Map.Entry<K, Maybe<TFlowImpl<V>>>>>, + storeFactory: MutableMapK.Factory<W, K>, +): TFlowImpl<MuxResult<W, K, V>> = + MuxLifecycle(MuxDeferredActivator(getStorage, storeFactory, getPatches)) + +private class MuxDeferredActivator<W, K, V>( + private val getStorage: suspend EvalScope.() -> Iterable<Map.Entry<K, TFlowImpl<V>>>, + private val storeFactory: MutableMapK.Factory<W, K>, + private val getPatches: + suspend EvalScope.() -> TFlowImpl<Iterable<Map.Entry<K, Maybe<TFlowImpl<V>>>>>, +) : MuxActivator<MuxResult<W, K, V>> { + override suspend fun activate( + evalScope: EvalScope, + lifecycle: MuxLifecycle<MuxResult<W, K, V>>, + ): MuxNode<W, *, *, MuxResult<W, K, V>>? { + // Initialize mux node and switched-in connections. + val muxNode = + MuxDeferredNode(lifecycle, this, storeFactory).apply { + initializeUpstream(evalScope, getStorage, storeFactory) // Update depth based on all initial switched-in nodes. - muxNode.switchedIn.values.forEach { branch -> - val conn = branch.upstream - if (conn.depthTracker.snapshotIsDirect) { - muxNode.depthTracker.addDirectUpstream( - oldDepth = null, - newDepth = conn.depthTracker.snapshotDirectDepth, - ) - } else { - muxNode.depthTracker.addIndirectUpstream( - oldDepth = null, - newDepth = conn.depthTracker.snapshotIndirectDepth, - ) - muxNode.depthTracker.updateIndirectRoots( - additions = conn.depthTracker.snapshotIndirectRoots, - butNot = muxNode, - ) - } - } + initializeDepth() // We don't have our patches connection established yet, so for now pretend we have // a direct connection to patches. We will update downstream nodes later if this // turns out to be a lie. - muxNode.depthTracker.setIsIndirectRoot(true) - muxNode.depthTracker.reset() - - // Setup patches connection; deferring allows for a recursive connection, where - // muxNode is downstream of itself via patches. - var isIndirect = true - evalScope.deferAction { - val (patchesConn, needsEval) = - getPatches(evalScope).activate(evalScope, downstream = muxNode.schedulable) - ?: run { - isIndirect = false - // Turns out we can't connect to patches, so update our depth and - // propagate - muxNode.mutex.withLock { - if (muxNode.depthTracker.setIsIndirectRoot(false)) { - muxNode.depthTracker.schedule(evalScope.scheduler, muxNode) - } - } - return@deferAction - } - muxNode.patches = patchesConn - - if (!patchesConn.schedulerUpstream.depthTracker.snapshotIsDirect) { - // Turns out patches is indirect, so we are not a root. Update depth and - // propagate. + depthTracker.setIsIndirectRoot(true) + depthTracker.reset() + } + // Setup patches connection; deferring allows for a recursive connection, where + // muxNode is downstream of itself via patches. + var isIndirect = true + evalScope.deferAction { + val (patchesConn, needsEval) = + getPatches(evalScope).activate(evalScope, downstream = muxNode.schedulable) + ?: run { + isIndirect = false + // Turns out we can't connect to patches, so update our depth and + // propagate muxNode.mutex.withLock { - if ( - muxNode.depthTracker.setIsIndirectRoot(false) or - muxNode.depthTracker.addIndirectUpstream( - oldDepth = null, - newDepth = patchesConn.depthTracker.snapshotIndirectDepth, - ) or - muxNode.depthTracker.updateIndirectRoots( - additions = patchesConn.depthTracker.snapshotIndirectRoots - ) - ) { + if (muxNode.depthTracker.setIsIndirectRoot(false)) { muxNode.depthTracker.schedule(evalScope.scheduler, muxNode) } } + return@deferAction } - // Schedule mover to process patch emission at the end of this transaction, if - // needed. - if (needsEval) { - muxNode.patchData = patchesConn.getPushEvent(evalScope) - evalScope.scheduleMuxMover(muxNode) + muxNode.patches = patchesConn + + if (!patchesConn.schedulerUpstream.depthTracker.snapshotIsDirect) { + // Turns out patches is indirect, so we are not a root. Update depth and + // propagate. + muxNode.mutex.withLock { + if ( + muxNode.depthTracker.setIsIndirectRoot(false) or + muxNode.depthTracker.addIndirectUpstream( + oldDepth = null, + newDepth = patchesConn.depthTracker.snapshotIndirectDepth, + ) or + muxNode.depthTracker.updateIndirectRoots( + additions = patchesConn.depthTracker.snapshotIndirectRoots + ) + ) { + muxNode.depthTracker.schedule(evalScope.scheduler, muxNode) } } - - // Schedule for evaluation if any switched-in nodes have already emitted within - // this transaction. - if (muxNode.upstreamData.isNotEmpty()) { - muxNode.schedule(evalScope) - } - return muxNode.takeUnless { muxNode.switchedIn.isEmpty() && !isIndirect } + } + // Schedule mover to process patch emission at the end of this transaction, if + // needed. + if (needsEval) { + muxNode.patchData = patchesConn.getPushEvent(evalScope) + evalScope.scheduleMuxMover(muxNode) } } - ) + + // Schedule for evaluation if any switched-in nodes have already emitted within + // this transaction. + if (muxNode.upstreamData.isNotEmpty()) { + muxNode.schedule(evalScope) + } + return muxNode.takeUnless { muxNode.switchedIn.isEmpty() && !isIndirect } + } +} internal inline fun <A> mergeNodes( crossinline getPulse: suspend EvalScope.() -> TFlowImpl<A>, @@ -416,16 +393,22 @@ internal inline fun <A> mergeNodes( return merged.cached() } +internal fun <T> Iterable<T>.asIterableWithIndex(): Iterable<StoreEntry<Int, T>> = + asSequence().mapIndexed { i, t -> StoreEntry(i, t) }.asIterable() + internal inline fun <A, B> mergeNodes( crossinline getPulse: suspend EvalScope.() -> TFlowImpl<A>, crossinline getOther: suspend EvalScope.() -> TFlowImpl<B>, ): TFlowImpl<These<A, B>> { val storage = - mapOf( - 0 to mapImpl(getPulse) { These.thiz<A, B>(it) }, - 1 to mapImpl(getOther) { These.that(it) }, + listOf(mapImpl(getPulse) { These.thiz<A, B>(it) }, mapImpl(getOther) { These.that(it) }) + .asIterableWithIndex() + val switchNode = + switchDeferredImpl( + getStorage = { storage }, + getPatches = { neverImpl }, + storeFactory = MutableArrayMapK.Factory(), ) - val switchNode = switchDeferredImpl(getStorage = { storage }, getPatches = { neverImpl }) val merged = mapImpl({ switchNode }) { mergeResults -> val first = mergeResults.getMaybe(0).flatMap { it.getPushEvent(this).maybeThis() } @@ -440,12 +423,14 @@ internal inline fun <A> mergeNodes( ): TFlowImpl<List<A>> { val switchNode = switchDeferredImpl( - getStorage = { getPulses().associateByIndexTo(TreeMap()) }, + getStorage = { getPulses().asIterableWithIndex() }, getPatches = { neverImpl }, + storeFactory = MutableArrayMapK.Factory(), ) val merged = - mapImpl({ switchNode }) { mergeResults -> - mergeResults.values.map { it.getPushEvent(this) } + mapImpl({ switchNode }) { + val mergeResults = it.asArrayHolder() + mergeResults.map { (_, node) -> node.getPushEvent(this) } } return merged.cached() } @@ -455,12 +440,11 @@ internal inline fun <A> mergeNodesLeft( ): TFlowImpl<A> { val switchNode = switchDeferredImpl( - getStorage = { getPulses().associateByIndexTo(TreeMap()) }, + getStorage = { getPulses().asIterableWithIndex() }, getPatches = { neverImpl }, + storeFactory = MutableArrayMapK.Factory(), ) val merged = - mapImpl({ switchNode }) { mergeResults: Map<Int, PullNode<A>> -> - mergeResults.values.first().getPushEvent(this) - } + mapImpl({ switchNode }) { mergeResults -> mergeResults.values.first().getPushEvent(this) } return merged.cached() } diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/MuxPrompt.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/MuxPrompt.kt index dd0357b0413d..839c5a64272a 100644 --- a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/MuxPrompt.kt +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/MuxPrompt.kt @@ -16,44 +16,48 @@ package com.android.systemui.kairos.internal +import com.android.systemui.kairos.internal.store.MutableMapK +import com.android.systemui.kairos.internal.store.SingletonMapK +import com.android.systemui.kairos.internal.store.singleOf import com.android.systemui.kairos.internal.util.Key import com.android.systemui.kairos.internal.util.launchImmediate import com.android.systemui.kairos.internal.util.mapParallel -import com.android.systemui.kairos.internal.util.mapValuesNotNullParallelTo import com.android.systemui.kairos.util.Just -import com.android.systemui.kairos.util.Left import com.android.systemui.kairos.util.Maybe import com.android.systemui.kairos.util.None -import com.android.systemui.kairos.util.Right -import com.android.systemui.kairos.util.partitionEithers +import com.android.systemui.kairos.util.just import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.launch import kotlinx.coroutines.sync.withLock -private typealias MuxPromptMovingResult<K, V> = Pair<MuxResult<K, V>, MuxResult<K, V>?> +private typealias MuxPromptMovingResult<W, K, V> = Pair<MuxResult<W, K, V>, MuxResult<W, K, V>?> -internal class MuxPromptMovingNode<K : Any, V>( - lifecycle: MuxLifecycle<MuxPromptMovingResult<K, V>>, - private val spec: MuxActivator<MuxPromptMovingResult<K, V>>, -) : MuxNode<K, V, MuxPromptMovingResult<K, V>>(lifecycle), Key<MuxPromptMovingResult<K, V>> { +internal class MuxPromptMovingNode<W, K, V>( + lifecycle: MuxLifecycle<MuxPromptMovingResult<W, K, V>>, + private val spec: MuxActivator<MuxPromptMovingResult<W, K, V>>, + factory: MutableMapK.Factory<W, K>, +) : + MuxNode<W, K, V, MuxPromptMovingResult<W, K, V>>(lifecycle, factory), + Key<MuxPromptMovingResult<W, K, V>> { - @Volatile var patchData: Map<K, Maybe<TFlowImpl<V>>>? = null - @Volatile var patches: MuxPromptPatchNode<K, V>? = null + @Volatile var patchData: Iterable<Map.Entry<K, Maybe<TFlowImpl<V>>>>? = null + @Volatile var patches: PatchNode? = null - @Volatile private var reEval: MuxPromptMovingResult<K, V>? = null + @Volatile private var reEval: MuxPromptMovingResult<W, K, V>? = null override suspend fun visit(evalScope: EvalScope) { - val preSwitchResults: MuxResult<K, V> = upstreamData.toMap() + val preSwitchNotEmpty = upstreamData.isNotEmpty() + val preSwitchResults: MuxResult<W, K, V> = upstreamData.readOnlyCopy() upstreamData.clear() - val patch: Map<K, Maybe<TFlowImpl<V>>>? = patchData + val patch: Iterable<Map.Entry<K, Maybe<TFlowImpl<V>>>>? = patchData patchData = null val (reschedule, evalResult) = reEval?.let { false to it } - ?: if (preSwitchResults.isNotEmpty() || patch?.isNotEmpty() == true) { - doEval(preSwitchResults, patch, evalScope) + ?: if (preSwitchNotEmpty || patch != null) { + doEval(preSwitchNotEmpty, preSwitchResults, patch, evalScope) } else { false to null } @@ -88,31 +92,30 @@ internal class MuxPromptMovingNode<K : Any, V>( } private suspend fun doEval( - preSwitchResults: MuxResult<K, V>, - patch: Map<K, Maybe<TFlowImpl<V>>>?, + preSwitchNotEmpty: Boolean, + preSwitchResults: MuxResult<W, K, V>, + patch: Iterable<Map.Entry<K, Maybe<TFlowImpl<V>>>>?, evalScope: EvalScope, - ): Pair<Boolean, MuxPromptMovingResult<K, V>?> { - val newlySwitchedIn: MuxResult<K, V>? = + ): Pair<Boolean, MuxPromptMovingResult<W, K, V>?> { + val newlySwitchedIn: MuxResult<W, K, V>? = patch?.let { // We have a patch, process additions/updates and removals - val (adds, removes) = - patch - .asSequence() - .map { (k, newUpstream: Maybe<TFlowImpl<V>>) -> - when (newUpstream) { - is Just -> Left(k to newUpstream.value) - None -> Right(k) - } - } - .partitionEithers() + val adds = mutableListOf<Pair<K, TFlowImpl<V>>>() + val removes = mutableListOf<K>() + patch.forEach { (k, newUpstream) -> + when (newUpstream) { + is Just -> adds.add(k to newUpstream.value) + None -> removes.add(k) + } + } - val additionsAndUpdates = mutableMapOf<K, PullNode<V>>() + val additionsAndUpdates = mutableListOf<Pair<K, PullNode<V>>>() val severed = mutableListOf<NodeConnection<*>>() coroutineScope { // remove and sever removes.forEach { k -> - switchedIn.remove(k)?.let { branchNode: MuxBranchNode<K, V> -> + switchedIn.remove(k)?.let { branchNode: BranchNode -> val conn: NodeConnection<V> = branchNode.upstream severed.add(conn) launchImmediate { @@ -125,16 +128,16 @@ internal class MuxPromptMovingNode<K : Any, V>( // add or replace adds .mapParallel { (k, newUpstream: TFlowImpl<V>) -> - val branchNode = MuxBranchNode(this@MuxPromptMovingNode, k) + val branchNode = BranchNode(k) k to newUpstream.activate(evalScope, branchNode.schedulable)?.let { (conn, _) -> branchNode.apply { upstream = conn } } } - .forEach { (k, newBranch: MuxBranchNode<K, V>?) -> + .forEach { (k, newBranch: BranchNode?) -> // remove old and sever, if present - switchedIn.remove(k)?.let { oldBranch: MuxBranchNode<K, V> -> + switchedIn.remove(k)?.let { oldBranch: BranchNode -> val conn: NodeConnection<V> = oldBranch.upstream severed.add(conn) launchImmediate { @@ -148,7 +151,7 @@ internal class MuxPromptMovingNode<K : Any, V>( // add new newBranch?.let { switchedIn[k] = newBranch - additionsAndUpdates[k] = newBranch.upstream.directUpstream + additionsAndUpdates.add(k to newBranch.upstream.directUpstream) val branchDepthTracker = newBranch.upstream.depthTracker if (branchDepthTracker.snapshotIsDirect) { depthTracker.addDirectUpstream( @@ -175,10 +178,14 @@ internal class MuxPromptMovingNode<K : Any, V>( } } - additionsAndUpdates.takeIf { it.isNotEmpty() } + val resultStore = storeFactory.create<PullNode<V>>(additionsAndUpdates.size) + for ((k, node) in additionsAndUpdates) { + resultStore[k] = node + } + resultStore.takeIf { it.isNotEmpty() }?.asReadOnly() } - return if (preSwitchResults.isNotEmpty() || newlySwitchedIn != null) { + return if (preSwitchNotEmpty || newlySwitchedIn != null) { (newlySwitchedIn != null) to (preSwitchResults to newlySwitchedIn) } else { false to null @@ -203,7 +210,7 @@ internal class MuxPromptMovingNode<K : Any, V>( } } - override suspend fun getPushEvent(evalScope: EvalScope): MuxPromptMovingResult<K, V> = + override suspend fun getPushEvent(evalScope: EvalScope): MuxPromptMovingResult<W, K, V> = evalScope.getCurrentValue(key = this) override suspend fun doDeactivate() { @@ -213,7 +220,7 @@ internal class MuxPromptMovingNode<K : Any, V>( lifecycle.lifecycleState = MuxLifecycleState.Inactive(spec) } // Process branch nodes - switchedIn.values.forEach { branchNode -> + switchedIn.forEach { (_, branchNode) -> branchNode.upstream.removeDownstreamAndDeactivateIfNeeded( downstream = branchNode.schedulable ) @@ -227,7 +234,7 @@ internal class MuxPromptMovingNode<K : Any, V>( suspend fun removeIndirectPatchNode( scheduler: Scheduler, oldDepth: Int, - indirectSet: Set<MuxDeferredNode<*, *>>, + indirectSet: Set<MuxDeferredNode<*, *, *>>, ) { mutex.withLock { patches = null @@ -248,190 +255,126 @@ internal class MuxPromptMovingNode<K : Any, V>( } } } -} -internal class MuxPromptEvalNode<K, V>( - private val movingNode: PullNode<MuxPromptMovingResult<K, V>> -) : PullNode<MuxResult<K, V>> { - override suspend fun getPushEvent(evalScope: EvalScope): MuxResult<K, V> = - movingNode.getPushEvent(evalScope).let { (preSwitchResults, newlySwitchedIn) -> - newlySwitchedIn?.toMap(preSwitchResults.toMutableMap()) ?: preSwitchResults - } -} + inner class PatchNode : SchedulableNode { -// TODO: inner class? -internal class MuxPromptPatchNode<K : Any, V>(private val muxNode: MuxPromptMovingNode<K, V>) : - SchedulableNode { + val schedulable = Schedulable.N(this) - val schedulable = Schedulable.N(this) + lateinit var upstream: NodeConnection<Iterable<Map.Entry<K, Maybe<TFlowImpl<V>>>>> - lateinit var upstream: NodeConnection<Map<K, Maybe<TFlowImpl<V>>>> - - override suspend fun schedule(evalScope: EvalScope) { - muxNode.patchData = upstream.getPushEvent(evalScope) - muxNode.schedule(evalScope) - } + override suspend fun schedule(evalScope: EvalScope) { + patchData = upstream.getPushEvent(evalScope) + this@MuxPromptMovingNode.schedule(evalScope) + } - override suspend fun adjustDirectUpstream(scheduler: Scheduler, oldDepth: Int, newDepth: Int) { - muxNode.adjustDirectUpstream(scheduler, oldDepth, newDepth) - } + override suspend fun adjustDirectUpstream( + scheduler: Scheduler, + oldDepth: Int, + newDepth: Int, + ) { + this@MuxPromptMovingNode.adjustDirectUpstream(scheduler, oldDepth, newDepth) + } - override suspend fun moveIndirectUpstreamToDirect( - scheduler: Scheduler, - oldIndirectDepth: Int, - oldIndirectSet: Set<MuxDeferredNode<*, *>>, - newDirectDepth: Int, - ) { - muxNode.moveIndirectUpstreamToDirect( - scheduler, - oldIndirectDepth, - oldIndirectSet, - newDirectDepth, - ) - } + override suspend fun moveIndirectUpstreamToDirect( + scheduler: Scheduler, + oldIndirectDepth: Int, + oldIndirectSet: Set<MuxDeferredNode<*, *, *>>, + newDirectDepth: Int, + ) { + this@MuxPromptMovingNode.moveIndirectUpstreamToDirect( + scheduler, + oldIndirectDepth, + oldIndirectSet, + newDirectDepth, + ) + } - override suspend fun adjustIndirectUpstream( - scheduler: Scheduler, - oldDepth: Int, - newDepth: Int, - removals: Set<MuxDeferredNode<*, *>>, - additions: Set<MuxDeferredNode<*, *>>, - ) { - muxNode.adjustIndirectUpstream(scheduler, oldDepth, newDepth, removals, additions) - } + override suspend fun adjustIndirectUpstream( + scheduler: Scheduler, + oldDepth: Int, + newDepth: Int, + removals: Set<MuxDeferredNode<*, *, *>>, + additions: Set<MuxDeferredNode<*, *, *>>, + ) { + this@MuxPromptMovingNode.adjustIndirectUpstream( + scheduler, + oldDepth, + newDepth, + removals, + additions, + ) + } - override suspend fun moveDirectUpstreamToIndirect( - scheduler: Scheduler, - oldDirectDepth: Int, - newIndirectDepth: Int, - newIndirectSet: Set<MuxDeferredNode<*, *>>, - ) { - muxNode.moveDirectUpstreamToIndirect( - scheduler, - oldDirectDepth, - newIndirectDepth, - newIndirectSet, - ) - } + override suspend fun moveDirectUpstreamToIndirect( + scheduler: Scheduler, + oldDirectDepth: Int, + newIndirectDepth: Int, + newIndirectSet: Set<MuxDeferredNode<*, *, *>>, + ) { + this@MuxPromptMovingNode.moveDirectUpstreamToIndirect( + scheduler, + oldDirectDepth, + newIndirectDepth, + newIndirectSet, + ) + } - override suspend fun removeDirectUpstream(scheduler: Scheduler, depth: Int) { - muxNode.removeDirectPatchNode(scheduler, depth) - } + override suspend fun removeDirectUpstream(scheduler: Scheduler, depth: Int) { + this@MuxPromptMovingNode.removeDirectPatchNode(scheduler, depth) + } - override suspend fun removeIndirectUpstream( - scheduler: Scheduler, - depth: Int, - indirectSet: Set<MuxDeferredNode<*, *>>, - ) { - muxNode.removeIndirectPatchNode(scheduler, depth, indirectSet) + override suspend fun removeIndirectUpstream( + scheduler: Scheduler, + depth: Int, + indirectSet: Set<MuxDeferredNode<*, *, *>>, + ) { + this@MuxPromptMovingNode.removeIndirectPatchNode(scheduler, depth, indirectSet) + } } } -internal fun <K : Any, V> switchPromptImpl( - getStorage: suspend EvalScope.() -> Map<K, TFlowImpl<V>>, - getPatches: suspend EvalScope.() -> TFlowImpl<Map<K, Maybe<TFlowImpl<V>>>>, -): TFlowImpl<MuxResult<K, V>> { - val moving = - MuxLifecycle( - object : MuxActivator<MuxPromptMovingResult<K, V>> { - override suspend fun activate( - evalScope: EvalScope, - lifecycle: MuxLifecycle<MuxPromptMovingResult<K, V>>, - ): MuxNode<*, *, MuxPromptMovingResult<K, V>>? { - val storage: Map<K, TFlowImpl<V>> = getStorage(evalScope) - // Initialize mux node and switched-in connections. - val movingNode = - MuxPromptMovingNode(lifecycle, this).apply { - coroutineScope { - launch { - storage.mapValuesNotNullParallelTo(switchedIn) { (key, flow) -> - val branchNode = MuxBranchNode(this@apply, key) - flow - .activate( - evalScope = evalScope, - downstream = branchNode.schedulable, - ) - ?.let { (conn, needsEval) -> - branchNode - .apply { upstream = conn } - .also { - if (needsEval) { - upstreamData[key] = conn.directUpstream - } - } - } - } - } - // Setup patches connection - val patchNode = MuxPromptPatchNode(this@apply) - getPatches(evalScope) - .activate( - evalScope = evalScope, - downstream = patchNode.schedulable, - ) - ?.let { (conn, needsEval) -> - patchNode.upstream = conn - patches = patchNode - - if (needsEval) { - patchData = conn.getPushEvent(evalScope) - } - } - } - } - // Update depth based on all initial switched-in nodes. - movingNode.switchedIn.values.forEach { branch -> - val conn = branch.upstream - if (conn.depthTracker.snapshotIsDirect) { - movingNode.depthTracker.addDirectUpstream( - oldDepth = null, - newDepth = conn.depthTracker.snapshotDirectDepth, - ) - } else { - movingNode.depthTracker.addIndirectUpstream( - oldDepth = null, - newDepth = conn.depthTracker.snapshotIndirectDepth, - ) - movingNode.depthTracker.updateIndirectRoots( - additions = conn.depthTracker.snapshotIndirectRoots, - butNot = null, - ) - } - } - // Update depth based on patches node. - movingNode.patches?.upstream?.let { conn -> - if (conn.depthTracker.snapshotIsDirect) { - movingNode.depthTracker.addDirectUpstream( - oldDepth = null, - newDepth = conn.depthTracker.snapshotDirectDepth, - ) - } else { - movingNode.depthTracker.addIndirectUpstream( - oldDepth = null, - newDepth = conn.depthTracker.snapshotIndirectDepth, - ) - movingNode.depthTracker.updateIndirectRoots( - additions = conn.depthTracker.snapshotIndirectRoots, - butNot = null, - ) - } - } - movingNode.depthTracker.reset() - - // Schedule for evaluation if any switched-in nodes or the patches node have - // already emitted within this transaction. - if (movingNode.patchData != null || movingNode.upstreamData.isNotEmpty()) { - movingNode.schedule(evalScope) +internal class MuxPromptEvalNode<W, K, V>( + private val movingNode: PullNode<MuxPromptMovingResult<W, K, V>>, + private val factory: MutableMapK.Factory<W, K>, +) : PullNode<MuxResult<W, K, V>> { + override suspend fun getPushEvent(evalScope: EvalScope): MuxResult<W, K, V> = + movingNode.getPushEvent(evalScope).let { (preSwitchResults, newlySwitchedIn) -> + newlySwitchedIn?.let { + factory + .create(preSwitchResults) + .also { store -> + newlySwitchedIn.forEach { k, pullNode -> store[k] = pullNode } } + .asReadOnly() + } ?: preSwitchResults + } +} - return movingNode.takeUnless { it.patches == null && it.switchedIn.isEmpty() } - } - } +internal inline fun <A> switchPromptImplSingle( + crossinline getStorage: suspend EvalScope.() -> TFlowImpl<A>, + crossinline getPatches: suspend EvalScope.() -> TFlowImpl<TFlowImpl<A>>, +): TFlowImpl<A> = + mapImpl({ + switchPromptImpl( + getStorage = { singleOf(getStorage()).asIterable() }, + getPatches = { + mapImpl(getPatches) { newFlow -> singleOf(just(newFlow)).asIterable() } + }, + storeFactory = SingletonMapK.Factory(), ) + }) { map -> + map.getValue(Unit).getPushEvent(this) + } +internal fun <W, K, V> switchPromptImpl( + getStorage: suspend EvalScope.() -> Iterable<Map.Entry<K, TFlowImpl<V>>>, + getPatches: suspend EvalScope.() -> TFlowImpl<Iterable<Map.Entry<K, Maybe<TFlowImpl<V>>>>>, + storeFactory: MutableMapK.Factory<W, K>, +): TFlowImpl<MuxResult<W, K, V>> { + val moving = MuxLifecycle(MuxPromptActivator(getStorage, storeFactory, getPatches)) val eval = TFlowCheap { downstream -> moving.activate(evalScope = this, downstream)?.let { (connection, needsEval) -> - val evalNode = MuxPromptEvalNode(connection.directUpstream) + val evalNode = MuxPromptEvalNode(connection.directUpstream, storeFactory) ActivationResult( connection = NodeConnection(evalNode, connection.schedulerUpstream), needsEval = needsEval, @@ -440,3 +383,64 @@ internal fun <K : Any, V> switchPromptImpl( } return eval.cached() } + +private class MuxPromptActivator<W, K, V>( + private val getStorage: suspend EvalScope.() -> Iterable<Map.Entry<K, TFlowImpl<V>>>, + private val storeFactory: MutableMapK.Factory<W, K>, + private val getPatches: + suspend EvalScope.() -> TFlowImpl<Iterable<Map.Entry<K, Maybe<TFlowImpl<V>>>>>, +) : MuxActivator<MuxPromptMovingResult<W, K, V>> { + override suspend fun activate( + evalScope: EvalScope, + lifecycle: MuxLifecycle<MuxPromptMovingResult<W, K, V>>, + ): MuxNode<W, *, *, MuxPromptMovingResult<W, K, V>>? { + // Initialize mux node and switched-in connections. + val movingNode = + MuxPromptMovingNode(lifecycle, this, storeFactory).apply { + coroutineScope { + launch { initializeUpstream(evalScope, getStorage, storeFactory) } + // Setup patches connection + val patchNode = PatchNode() + getPatches(evalScope) + .activate(evalScope = evalScope, downstream = patchNode.schedulable) + ?.let { (conn, needsEval) -> + patchNode.upstream = conn + patches = patchNode + if (needsEval) { + patchData = conn.getPushEvent(evalScope) + } + } + } + // Update depth based on all initial switched-in nodes. + initializeDepth() + // Update depth based on patches node. + patches?.upstream?.let { conn -> + if (conn.depthTracker.snapshotIsDirect) { + depthTracker.addDirectUpstream( + oldDepth = null, + newDepth = conn.depthTracker.snapshotDirectDepth, + ) + } else { + depthTracker.addIndirectUpstream( + oldDepth = null, + newDepth = conn.depthTracker.snapshotIndirectDepth, + ) + depthTracker.updateIndirectRoots( + additions = conn.depthTracker.snapshotIndirectRoots, + butNot = null, + ) + } + } + // Reset all depth adjustments, since no downstream has been notified + depthTracker.reset() + } + + // Schedule for evaluation if any switched-in nodes or the patches node have + // already emitted within this transaction. + if (movingNode.patchData != null || movingNode.upstreamData.isNotEmpty()) { + movingNode.schedule(evalScope) + } + + return movingNode.takeUnless { it.patches == null && it.switchedIn.isEmpty() } + } +} diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Network.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Network.kt index b2b3ca3001ad..79d4b7a843ac 100644 --- a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Network.kt +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Network.kt @@ -59,7 +59,7 @@ internal class Network(val coroutineScope: CoroutineScope) : NetworkScope { private val stateWrites = ConcurrentLinkedQueue<TStateSource<*>>() private val outputsByDispatcher = ConcurrentHashMap<ContinuationInterceptor, ConcurrentLinkedQueue<Output<*>>>() - private val muxMovers = ConcurrentLinkedQueue<MuxDeferredNode<*, *>>() + private val muxMovers = ConcurrentLinkedQueue<MuxDeferredNode<*, *, *>>() private val deactivations = ConcurrentLinkedDeque<PushNode<*>>() private val outputDeactivations = ConcurrentLinkedQueue<Output<*>>() private val transactionMutex = Mutex() @@ -73,7 +73,7 @@ internal class Network(val coroutineScope: CoroutineScope) : NetworkScope { .add(output) } - override fun scheduleMuxMover(muxMover: MuxDeferredNode<*, *>) { + override fun scheduleMuxMover(muxMover: MuxDeferredNode<*, *, *>) { muxMovers.add(muxMover) } diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/NodeTypes.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/NodeTypes.kt index b9bef059d4b0..7a015d8ca1f6 100644 --- a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/NodeTypes.kt +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/NodeTypes.kt @@ -29,7 +29,7 @@ internal sealed interface SchedulableNode { suspend fun moveIndirectUpstreamToDirect( scheduler: Scheduler, oldIndirectDepth: Int, - oldIndirectSet: Set<MuxDeferredNode<*, *>>, + oldIndirectSet: Set<MuxDeferredNode<*, *, *>>, newDirectDepth: Int, ) @@ -37,21 +37,21 @@ internal sealed interface SchedulableNode { scheduler: Scheduler, oldDepth: Int, newDepth: Int, - removals: Set<MuxDeferredNode<*, *>>, - additions: Set<MuxDeferredNode<*, *>>, + removals: Set<MuxDeferredNode<*, *, *>>, + additions: Set<MuxDeferredNode<*, *, *>>, ) suspend fun moveDirectUpstreamToIndirect( scheduler: Scheduler, oldDirectDepth: Int, newIndirectDepth: Int, - newIndirectSet: Set<MuxDeferredNode<*, *>>, + newIndirectSet: Set<MuxDeferredNode<*, *, *>>, ) suspend fun removeIndirectUpstream( scheduler: Scheduler, depth: Int, - indirectSet: Set<MuxDeferredNode<*, *>>, + indirectSet: Set<MuxDeferredNode<*, *, *>>, ) suspend fun removeDirectUpstream(scheduler: Scheduler, depth: Int) diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Scheduler.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Scheduler.kt index c12ef6ae6a5d..d046420517fe 100644 --- a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Scheduler.kt +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Scheduler.kt @@ -25,22 +25,23 @@ import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.launch internal interface Scheduler { - fun schedule(depth: Int, node: MuxNode<*, *, *>) + fun schedule(depth: Int, node: MuxNode<*, *, *, *>) - fun scheduleIndirect(indirectDepth: Int, node: MuxNode<*, *, *>) + fun scheduleIndirect(indirectDepth: Int, node: MuxNode<*, *, *, *>) } internal class SchedulerImpl : Scheduler { - val enqueued = ConcurrentHashMap<MuxNode<*, *, *>, Any>() - val scheduledQ = PriorityBlockingQueue<Pair<Int, MuxNode<*, *, *>>>(16, compareBy { it.first }) + val enqueued = ConcurrentHashMap<MuxNode<*, *, *, *>, Any>() + val scheduledQ = + PriorityBlockingQueue<Pair<Int, MuxNode<*, *, *, *>>>(16, compareBy { it.first }) - override fun schedule(depth: Int, node: MuxNode<*, *, *>) { + override fun schedule(depth: Int, node: MuxNode<*, *, *, *>) { if (enqueued.putIfAbsent(node, node) == null) { scheduledQ.add(Pair(depth, node)) } } - override fun scheduleIndirect(indirectDepth: Int, node: MuxNode<*, *, *>) { + override fun scheduleIndirect(indirectDepth: Int, node: MuxNode<*, *, *, *>) { schedule(Int.MIN_VALUE + indirectDepth, node) } @@ -59,7 +60,9 @@ internal class SchedulerImpl : Scheduler { private suspend inline fun drain( crossinline onStep: - suspend (runStep: suspend (visit: suspend (MuxNode<*, *, *>) -> Unit) -> Unit) -> Unit + suspend ( + runStep: suspend (visit: suspend (MuxNode<*, *, *, *>) -> Unit) -> Unit + ) -> Unit ): Unit = coroutineScope { while (scheduledQ.isNotEmpty()) { val maxDepth = scheduledQ.peek()?.first ?: error("Unexpected empty scheduler") @@ -69,7 +72,7 @@ internal class SchedulerImpl : Scheduler { private suspend inline fun runStep( maxDepth: Int, - crossinline visit: suspend (MuxNode<*, *, *>) -> Unit, + crossinline visit: suspend (MuxNode<*, *, *, *>) -> Unit, ) = coroutineScope { while (scheduledQ.peek()?.first?.let { it <= maxDepth } == true) { val (d, node) = scheduledQ.remove() diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/StateScopeImpl.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/StateScopeImpl.kt index 06b5b1690391..94f94f510d48 100644 --- a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/StateScopeImpl.kt +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/StateScopeImpl.kt @@ -29,6 +29,7 @@ import com.android.systemui.kairos.TStateInit import com.android.systemui.kairos.emptyTFlow import com.android.systemui.kairos.groupByKey import com.android.systemui.kairos.init +import com.android.systemui.kairos.internal.store.ConcurrentHashMapK import com.android.systemui.kairos.internal.util.mapValuesParallel import com.android.systemui.kairos.mapCheap import com.android.systemui.kairos.merge @@ -100,7 +101,7 @@ internal class StateScopeImpl(val evalScope: EvalScope, override val endSignal: .toTStateInternalDeferred(operatorName, initialValue.unwrapped) } - private fun <K : Any, V> TFlow<Map<K, Maybe<TFlow<V>>>>.mergeIncrementallyInternal( + private fun <K, V> TFlow<Map<K, Maybe<TFlow<V>>>>.mergeIncrementallyInternal( storage: TState<Map<K, TFlow<V>>> ): TFlow<Map<K, V>> { val name = "mergeIncrementally" @@ -114,21 +115,25 @@ internal class StateScopeImpl(val evalScope: EvalScope, override val endSignal: .getCurrentWithEpoch(this) .first .mapValuesParallel { (_, flow) -> flow.init.connect(this) } + .asIterable() }, getPatches = { mapImpl({ init.connect(this) }) { patch -> - patch.mapValuesParallel { (_, m) -> - m.map { flow -> flow.init.connect(this) } - } + patch + .mapValuesParallel { (_, m) -> + m.map { flow -> flow.init.connect(this) } + } + .asIterable() } }, + storeFactory = ConcurrentHashMapK.Factory(), ) .awaitValues(), ) ) } - private fun <K : Any, V> TFlow<Map<K, Maybe<TFlow<V>>>>.mergeIncrementallyPromptInternal( + private fun <K, V> TFlow<Map<K, Maybe<TFlow<V>>>>.mergeIncrementallyPromptInternal( storage: TState<Map<K, TFlow<V>>> ): TFlow<Map<K, V>> { val name = "mergeIncrementallyPrompt" @@ -142,14 +147,18 @@ internal class StateScopeImpl(val evalScope: EvalScope, override val endSignal: .getCurrentWithEpoch(this) .first .mapValuesParallel { (_, flow) -> flow.init.connect(this) } + .asIterable() }, getPatches = { mapImpl({ init.connect(this) }) { patch -> - patch.mapValuesParallel { (_, m) -> - m.map { flow -> flow.init.connect(this) } - } + patch + .mapValuesParallel { (_, m) -> + m.map { flow -> flow.init.connect(this) } + } + .asIterable() } }, + storeFactory = ConcurrentHashMapK.Factory(), ) .awaitValues(), ) @@ -217,14 +226,14 @@ internal class StateScopeImpl(val evalScope: EvalScope, override val endSignal: override fun <A> TFlow<A>.holdDeferred(initialValue: FrpDeferredValue<A>): TState<A> = toTStateDeferredInternal(initialValue) - override fun <K : Any, V> TFlow<Map<K, Maybe<TFlow<V>>>>.mergeIncrementally( + override fun <K, V> TFlow<Map<K, Maybe<TFlow<V>>>>.mergeIncrementally( initialTFlows: FrpDeferredValue<Map<K, TFlow<V>>> ): TFlow<Map<K, V>> { val storage: TState<Map<K, TFlow<V>>> = foldMapIncrementally(initialTFlows) return mergeIncrementallyInternal(storage) } - override fun <K : Any, V> TFlow<Map<K, Maybe<TFlow<V>>>>.mergeIncrementallyPromptly( + override fun <K, V> TFlow<Map<K, Maybe<TFlow<V>>>>.mergeIncrementallyPromptly( initialTFlows: FrpDeferredValue<Map<K, TFlow<V>>> ): TFlow<Map<K, V>> { val storage: TState<Map<K, TFlow<V>>> = foldMapIncrementally(initialTFlows) diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/TStateImpl.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/TStateImpl.kt index c4a26a33e24d..916f22575b0c 100644 --- a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/TStateImpl.kt +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/TStateImpl.kt @@ -16,10 +16,13 @@ package com.android.systemui.kairos.internal +import com.android.systemui.kairos.internal.store.ConcurrentHashMapK +import com.android.systemui.kairos.internal.store.MutableArrayMapK +import com.android.systemui.kairos.internal.store.MutableMapK +import com.android.systemui.kairos.internal.store.StoreEntry import com.android.systemui.kairos.internal.util.Key -import com.android.systemui.kairos.internal.util.associateByIndex import com.android.systemui.kairos.internal.util.hashString -import com.android.systemui.kairos.internal.util.mapValuesParallel +import com.android.systemui.kairos.internal.util.launchImmediate import com.android.systemui.kairos.util.Maybe import com.android.systemui.kairos.util.just import com.android.systemui.kairos.util.none @@ -28,6 +31,7 @@ import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.CoroutineStart import kotlinx.coroutines.Deferred import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.coroutineScope internal sealed interface TStateImpl<out A> { val name: String? @@ -224,16 +228,10 @@ internal fun <A> TStateImpl<TStateImpl<A>>.flatten(name: String?, operator: Stri mergeNodes({ switchEvents }, { newInner.changes }) { _, new -> new } } val switchedChanges: TFlowImpl<A> = - mapImpl({ - switchPromptImpl( - getStorage = { - mapOf(Unit to this@flatten.getCurrentWithEpoch(evalScope = this).first.changes) - }, - getPatches = { mapImpl({ innerChanges }) { new -> mapOf(Unit to just(new)) } }, - ) - }) { map -> - map.getValue(Unit).getPushEvent(this) - } + switchPromptImplSingle( + getStorage = { this@flatten.getCurrentWithEpoch(evalScope = this).first.changes }, + getPatches = { innerChanges }, + ) lateinit var state: DerivedFlatten<A> state = DerivedFlatten(name, operator, this, switchedChanges.calm { state }) return state @@ -268,10 +266,8 @@ internal fun <A, B, Z> zipStates( l2: TStateImpl<B>, transform: suspend EvalScope.(A, B) -> Z, ): TStateImpl<Z> = - zipStates(null, operatorName, mapOf(0 to l1, 1 to l2)).map(name, operatorName) { - val a = it.getValue(0) - val b = it.getValue(1) - @Suppress("UNCHECKED_CAST") transform(a as A, b as B) + zipStateList(null, operatorName, listOf(l1, l2)).map(name, operatorName) { + @Suppress("UNCHECKED_CAST") transform(it[0] as A, it[1] as B) } internal fun <A, B, C, Z> zipStates( @@ -282,11 +278,8 @@ internal fun <A, B, C, Z> zipStates( l3: TStateImpl<C>, transform: suspend EvalScope.(A, B, C) -> Z, ): TStateImpl<Z> = - zipStates(null, operatorName, mapOf(0 to l1, 1 to l2, 2 to l3)).map(name, operatorName) { - val a = it.getValue(0) - val b = it.getValue(1) - val c = it.getValue(2) - @Suppress("UNCHECKED_CAST") transform(a as A, b as B, c as C) + zipStateList(null, operatorName, listOf(l1, l2, l3)).map(name, operatorName) { + @Suppress("UNCHECKED_CAST") transform(it[0] as A, it[1] as B, it[2] as C) } internal fun <A, B, C, D, Z> zipStates( @@ -298,15 +291,8 @@ internal fun <A, B, C, D, Z> zipStates( l4: TStateImpl<D>, transform: suspend EvalScope.(A, B, C, D) -> Z, ): TStateImpl<Z> = - zipStates(null, operatorName, mapOf(0 to l1, 1 to l2, 2 to l3, 3 to l4)).map( - name, - operatorName, - ) { - val a = it.getValue(0) - val b = it.getValue(1) - val c = it.getValue(2) - val d = it.getValue(3) - @Suppress("UNCHECKED_CAST") transform(a as A, b as B, c as C, d as D) + zipStateList(null, operatorName, listOf(l1, l2, l3, l4)).map(name, operatorName) { + @Suppress("UNCHECKED_CAST") transform(it[0] as A, it[1] as B, it[2] as C, it[3] as D) } internal fun <A, B, C, D, E, Z> zipStates( @@ -319,58 +305,122 @@ internal fun <A, B, C, D, E, Z> zipStates( l5: TStateImpl<E>, transform: suspend EvalScope.(A, B, C, D, E) -> Z, ): TStateImpl<Z> = - zipStates(null, operatorName, mapOf(0 to l1, 1 to l2, 2 to l3, 3 to l4, 4 to l5)).map( - name, - operatorName, - ) { - val a = it.getValue(0) - val b = it.getValue(1) - val c = it.getValue(2) - val d = it.getValue(3) - val e = it.getValue(4) - @Suppress("UNCHECKED_CAST") transform(a as A, b as B, c as C, d as D, e as E) + zipStateList(null, operatorName, listOf(l1, l2, l3, l4, l5)).map(name, operatorName) { + @Suppress("UNCHECKED_CAST") + transform(it[0] as A, it[1] as B, it[2] as C, it[3] as D, it[4] as E) } -internal fun <K : Any, A> zipStates( +internal fun <K, V> zipStateMap( + name: String?, + operatorName: String, + states: Map<K, TStateImpl<V>>, +): TStateImpl<Map<K, V>> = + zipStates( + name = name, + operatorName = operatorName, + numStates = states.size, + states = states.asIterable(), + storeFactory = ConcurrentHashMapK.Factory(), + ) + +internal fun <V> zipStateList( name: String?, operatorName: String, - states: Map<K, TStateImpl<A>>, -): TStateImpl<Map<K, A>> { - if (states.isEmpty()) return constS(name, operatorName, emptyMap()) - val stateChanges: Map<K, TFlowImpl<A>> = states.mapValues { it.value.changes } - lateinit var state: DerivedZipped<K, A> + states: List<TStateImpl<V>>, +): TStateImpl<List<V>> { + val zipped = + zipStates( + name = name, + operatorName = operatorName, + numStates = states.size, + states = + states + .asSequence() + .mapIndexed { index, tStateImpl -> StoreEntry(index, tStateImpl) } + .asIterable(), + storeFactory = MutableArrayMapK.Factory(), + ) + // Like mapCheap, but with caching (or like map, but without the calm changes, as they are not + // necessary). + return DerivedMap( + name = name, + operatorName = operatorName, + transform = { arrayStore -> arrayStore.values.toList() }, + upstream = zipped, + changes = mapImpl({ zipped.changes }) { arrayStore -> arrayStore.values.toList() }, + ) +} + +internal fun <W, K, A> zipStates( + name: String?, + operatorName: String, + numStates: Int, + states: Iterable<Map.Entry<K, TStateImpl<A>>>, + storeFactory: MutableMapK.Factory<W, K>, +): TStateImpl<MutableMapK<W, K, A>> { + if (numStates == 0) { + return constS(name, operatorName, storeFactory.create(0)) + } + val stateChanges = states.asSequence().map { (k, v) -> StoreEntry(k, v.changes) }.asIterable() + lateinit var state: DerivedZipped<W, K, A> // No need for calm; invariant ensures that changes will only emit when there's a difference - val changes: TFlowImpl<Map<K, A>> = + val changes = mapImpl({ - switchDeferredImpl(getStorage = { stateChanges }, getPatches = { neverImpl }) - }) { patch -> - states - .mapValues { (k, v) -> - if (k in patch) { - patch.getValue(k).getPushEvent(this) - } else { - v.getCurrentWithEpoch(evalScope = this).first + switchDeferredImpl( + getStorage = { stateChanges }, + getPatches = { neverImpl }, + storeFactory = storeFactory, + ) + }) { patch -> + val store = storeFactory.create<A>(numStates) + coroutineScope { + states.forEach { (k, state) -> + launchImmediate { + store[k] = + if (patch.contains(k)) { + patch.getValue(k).getPushEvent(evalScope = this@mapImpl) + } else { + state.getCurrentWithEpoch(evalScope = this@mapImpl).first + } + } } } - .also { state.setCache(it, epoch) } - } - state = DerivedZipped(name, operatorName, states, changes) + store.also { state.setCache(it, epoch) } + } + .cached() + state = + DerivedZipped( + name = name, + operatorName = operatorName, + upstreamSize = numStates, + upstream = states, + changes = changes, + storeFactory = storeFactory, + ) return state } -internal class DerivedZipped<K : Any, A>( +internal class DerivedZipped<W, K, A>( override val name: String?, override val operatorName: String, - val upstream: Map<K, TStateImpl<A>>, - changes: TFlowImpl<Map<K, A>>, -) : TStateDerived<Map<K, A>>(changes) { - override suspend fun recalc(evalScope: EvalScope): Pair<Map<K, A>, Long> { + private val upstreamSize: Int, + val upstream: Iterable<Map.Entry<K, TStateImpl<A>>>, + changes: TFlowImpl<MutableMapK<W, K, A>>, + private val storeFactory: MutableMapK.Factory<W, K>, +) : TStateDerived<MutableMapK<W, K, A>>(changes) { + override suspend fun recalc(evalScope: EvalScope): Pair<MutableMapK<W, K, A>, Long> { val newEpoch = AtomicLong() - return upstream.mapValuesParallel { - val (a, epoch) = it.value.getCurrentWithEpoch(evalScope) - newEpoch.accumulateAndGet(epoch, ::maxOf) - a - } to newEpoch.get() + val store = storeFactory.create<A>(upstreamSize) + coroutineScope { + for ((key, value) in upstream) { + launchImmediate { + val (a, epoch) = value.getCurrentWithEpoch(evalScope) + newEpoch.accumulateAndGet(epoch, ::maxOf) + store[key] = a + } + } + } + return store to newEpoch.get() } override fun toString(): String = "${this::class.simpleName}@$hashString" @@ -385,10 +435,5 @@ internal inline fun <A> zipStates( if (states.isEmpty()) { constS(name, operatorName, emptyList()) } else { - zipStates(null, operatorName, states.asIterable().associateByIndex()).mapCheap( - name, - operatorName, - ) { - it.values.toList() - } + zipStateList(null, operatorName, states) } diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/store/ArrayMapK.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/store/ArrayMapK.kt new file mode 100644 index 000000000000..f0c2f346e9b7 --- /dev/null +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/store/ArrayMapK.kt @@ -0,0 +1,129 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.systemui.kairos.internal.store + +import java.util.concurrent.atomic.AtomicReferenceArray + +/** A [Map] backed by a flat array. */ +internal class ArrayMapK<V>( + val unwrapped: List<MutableMap.MutableEntry<Int, V>>, + val originalCapacity: Int, +) : MapK<ArrayMapK.W, Int, V>, AbstractMap<Int, V>() { + object W + + override val entries: Set<Map.Entry<Int, V>> = + object : AbstractSet<Map.Entry<Int, V>>() { + override val size: Int + get() = unwrapped.size + + override fun iterator(): Iterator<Map.Entry<Int, V>> = unwrapped.iterator() + } +} + +@Suppress("NOTHING_TO_INLINE") +internal inline fun <V> MapK<ArrayMapK.W, Int, V>.asArrayHolder(): ArrayMapK<V> = + this as ArrayMapK<V> + +internal class MutableArrayMapK<V> +private constructor(private val storage: AtomicReferenceArray<MutableMap.MutableEntry<Int, V>?>) : + MutableMapK<ArrayMapK.W, Int, V>, AbstractMutableMap<Int, V>() { + + constructor(length: Int) : this(AtomicReferenceArray<MutableMap.MutableEntry<Int, V>?>(length)) + + override fun readOnlyCopy(): ArrayMapK<V> { + val size1 = storage.length() + return ArrayMapK( + buildList { + for (i in 0 until size1) { + storage.get(i)?.let { entry -> add(StoreEntry(entry.key, entry.value)) } + } + }, + size1, + ) + } + + override fun asReadOnly(): MapK<ArrayMapK.W, Int, V> = readOnlyCopy() + + private fun getNumEntries(): Int { + val capacity = storage.length() + var total = 0 + for (i in 0 until capacity) { + storage.get(i)?.let { total++ } + } + return total + } + + override fun put(key: Int, value: V): V? = + storage.get(key)?.value.also { storage.set(key, StoreEntry(key, value)) } + + override val entries: MutableSet<MutableMap.MutableEntry<Int, V>> = + object : AbstractMutableSet<MutableMap.MutableEntry<Int, V>>() { + override val size: Int + get() = getNumEntries() + + override fun add(element: MutableMap.MutableEntry<Int, V>): Boolean = + (storage.get(element.key) is MutableMap.MutableEntry<*, *>).also { + storage.set(element.key, element) + } + + override fun iterator(): MutableIterator<MutableMap.MutableEntry<Int, V>> = + object : MutableIterator<MutableMap.MutableEntry<Int, V>> { + + var cursor = -1 + var nextIndex = -1 + + override fun hasNext(): Boolean { + val capacity = storage.length() + if (nextIndex >= capacity) return false + if (nextIndex != cursor) return true + while (++nextIndex < capacity) { + if (storage.get(nextIndex) != null) { + return true + } + } + return false + } + + override fun next(): MutableMap.MutableEntry<Int, V> { + if (!hasNext()) throw NoSuchElementException() + cursor = nextIndex + return storage.get(cursor)!! + } + + override fun remove() { + check( + cursor >= 0 && + cursor < storage.length() && + storage.getAndSet(cursor, null) != null + ) + } + } + } + + class Factory : MutableMapK.Factory<ArrayMapK.W, Int> { + override fun <V> create(capacity: Int?) = + MutableArrayMapK<V>(checkNotNull(capacity) { "Cannot use ArrayMapK with null capacity." }) + + override fun <V> create(input: MapK<ArrayMapK.W, Int, V>): MutableArrayMapK<V> { + val holder = input.asArrayHolder() + return MutableArrayMapK( + AtomicReferenceArray<MutableMap.MutableEntry<Int, V>?>(holder.originalCapacity) + .apply { holder.unwrapped.forEach { set(it.key, it) } } + ) + } + } +} diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/store/MapHolder.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/store/MapHolder.kt new file mode 100644 index 000000000000..db2dde00f17a --- /dev/null +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/store/MapHolder.kt @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.systemui.kairos.internal.store + +import com.android.systemui.kairos.internal.util.ConcurrentNullableHashMap + +@JvmInline +internal value class MapHolder<K, V>(val unwrapped: Map<K, V>) : + MapK<MapHolder.W, K, V>, Map<K, V> by unwrapped { + object W +} + +@Suppress("NOTHING_TO_INLINE") +internal inline fun <K, V> MapK<MapHolder.W, K, V>.asMapHolder(): MapHolder<K, V> = + this as MapHolder<K, V> + +// TODO: preserve insertion order? +internal class ConcurrentHashMapK<K, V>(private val storage: ConcurrentNullableHashMap<K, V>) : + MutableMapK<MapHolder.W, K, V>, MutableMap<K, V> by storage { + + override fun readOnlyCopy() = MapHolder(storage.toMap()) + + override fun asReadOnly(): MapK<MapHolder.W, K, V> = MapHolder(storage) + + class Factory<K> : MutableMapK.Factory<MapHolder.W, K> { + override fun <V> create(capacity: Int?) = + ConcurrentHashMapK<K, V>( + capacity?.let { ConcurrentNullableHashMap(capacity) } ?: ConcurrentNullableHashMap() + ) + + override fun <V> create(input: MapK<MapHolder.W, K, V>) = + ConcurrentHashMapK( + ConcurrentNullableHashMap<K, V>().apply { + input.asMapHolder().unwrapped.forEach { (k, v) -> set(k, v) } + } + ) + } +} diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/store/MapK.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/store/MapK.kt new file mode 100644 index 000000000000..e193a4957bd0 --- /dev/null +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/store/MapK.kt @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.systemui.kairos.internal.store + +/** + * Higher-kinded encoding for [Map]. + * + * Let's say you want to write a class that is generic over both a map, and the type of data within + * the map: + * ``` kotlin + * class Foo<TMap, TKey, TValue> { + * val container: TMap<TKey, TElement> // disallowed! + * } + * ``` + * + * You can use `MapK` to represent the "higher-kinded" type variable `TMap`: + * ``` kotlin + * class Foo<TMap, TKey, TValue> { + * val container: MapK<TMap, TKey, TValue> // OK! + * } + * ``` + * + * Note that Kotlin will not let you use the generic type without parameters as `TMap`: + * ``` kotlin + * val fooHk: MapK<HashMap, Int, String> // not allowed: HashMap requires two type parameters + * ``` + * + * To work around this, you need to declare a special type-witness object. This object is only used + * at compile time and can be stripped out by a minifier because it's never used at runtime. + * + * ``` kotlin + * class Foo<A, B> : MapK<FooWitness, A, B> { ... } + * object FooWitness + * + * // safe, as long as Foo is the only implementor of MapK<FooWitness, *, *> + * fun <A, B> MapK<FooWitness, A, B>.asFoo(): Foo<A, B> = this as Foo<A, B> + * + * val fooStore: MapK<FooWitness, Int, String> = Foo() + * val foo: Foo<Int, String> = fooStore.asFoo() + * ``` + */ +internal interface MapK<W, K, V> : Map<K, V> + +internal interface MutableMapK<W, K, V> : MutableMap<K, V> { + + fun readOnlyCopy(): MapK<W, K, V> + + fun asReadOnly(): MapK<W, K, V> + + interface Factory<W, K> { + fun <V> create(capacity: Int?): MutableMapK<W, K, V> + + fun <V> create(input: MapK<W, K, V>): MutableMapK<W, K, V> + } +} + +internal object NoValue + +internal data class StoreEntry<K, V>(override var key: K, override var value: V) : + MutableMap.MutableEntry<K, V> { + override fun setValue(newValue: V): V = value.also { value = newValue } +} diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/store/Single.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/store/Single.kt new file mode 100644 index 000000000000..2d0894884a4c --- /dev/null +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/store/Single.kt @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.systemui.kairos.internal.store + +@Suppress("NOTHING_TO_INLINE") internal inline fun <V> singleOf(value: V) = Single<V>(value) + +/** A [Map] with a single element that has key [Unit]. */ +internal class Single<V>(val unwrapped: Any?) : MapK<Single.W, Unit, V>, AbstractMap<Unit, V>() { + + constructor() : this(NoValue) + + @Suppress("UNCHECKED_CAST") + override val entries: Set<Map.Entry<Unit, V>> = + if (unwrapped === NoValue) emptySet() else setOf(StoreEntry(Unit, unwrapped as V)) + + object W +} + +@Suppress("NOTHING_TO_INLINE") +internal inline fun <V> MapK<Single.W, Unit, V>.asSingle(): Single<V> = this as Single<V> + +internal class SingletonMapK<V>(@Volatile private var value: Any?) : + MutableMapK<Single.W, Unit, V>, AbstractMutableMap<Unit, V>() { + + constructor() : this(NoValue) + + override fun readOnlyCopy() = + Single<V>(if (value === NoValue) value else (value as MutableMap.MutableEntry<*, *>).value) + + override fun asReadOnly(): MapK<Single.W, Unit, V> = readOnlyCopy() + + @Suppress("UNCHECKED_CAST") + override fun put(key: Unit, value: V): V? = + (this.value as? MutableMap.MutableEntry<Unit, V>)?.value.also { + this.value = StoreEntry(Unit, value) + } + + override val entries: MutableSet<MutableMap.MutableEntry<Unit, V>> = + object : AbstractMutableSet<MutableMap.MutableEntry<Unit, V>>() { + override fun add(element: MutableMap.MutableEntry<Unit, V>): Boolean = + (value !== NoValue).also { value = element } + + override val size: Int + get() = if (value === NoValue) 0 else 1 + + override fun iterator(): MutableIterator<MutableMap.MutableEntry<Unit, V>> { + return object : MutableIterator<MutableMap.MutableEntry<Unit, V>> { + + var done = false + + override fun hasNext(): Boolean = value !== NoValue && !done + + override fun next(): MutableMap.MutableEntry<Unit, V> { + if (!hasNext()) throw NoSuchElementException() + done = true + @Suppress("UNCHECKED_CAST") + return value as MutableMap.MutableEntry<Unit, V> + } + + override fun remove() { + if (!done || value === NoValue) throw IllegalStateException() + value = NoValue + } + } + } + } + + internal class Factory : MutableMapK.Factory<Single.W, Unit> { + override fun <V> create(capacity: Int?): SingletonMapK<V> { + check(capacity == null || capacity == 0 || capacity == 1) { + "Can't use singleton store with capacity > 1. Got: $capacity" + } + return SingletonMapK() + } + + override fun <V> create(input: MapK<Single.W, Unit, V>): SingletonMapK<V> = + SingletonMapK(input.asSingle().unwrapped) + } +} diff --git a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/util/ConcurrentNullableHashMap.kt b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/util/ConcurrentNullableHashMap.kt index 6c8ae7cf6436..afeb0679fe12 100644 --- a/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/util/ConcurrentNullableHashMap.kt +++ b/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/util/ConcurrentNullableHashMap.kt @@ -16,31 +16,114 @@ package com.android.systemui.kairos.internal.util +import com.android.systemui.kairos.internal.store.NoValue import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.ConcurrentMap + +internal class ConcurrentNullableHashMap<K, V> +private constructor(private val inner: ConcurrentHashMap<Any, Any>) : + ConcurrentMap<K, V>, AbstractMutableMap<K, V>() { -internal class ConcurrentNullableHashMap<K : Any, V> -private constructor(private val inner: ConcurrentHashMap<K, Any>) { constructor() : this(ConcurrentHashMap()) - @Suppress("UNCHECKED_CAST") - operator fun get(key: K): V? = inner[key]?.takeIf { it !== NullValue } as V? + constructor(capacity: Int) : this(ConcurrentHashMap(capacity)) + + override fun get(key: K): V? = inner[key ?: NullValue]?.let { toNullable<V>(it) } + + fun getValue(key: K): V = toNullable(inner.getValue(key ?: NullValue)) @Suppress("UNCHECKED_CAST") - fun put(key: K, value: V?): V? = - inner.put(key, value ?: NullValue)?.takeIf { it !== NullValue } as V? + override fun put(key: K, value: V): V? = + inner.put(key ?: NullValue, value ?: NullValue)?.takeIf { it !== NullValue } as V? - operator fun set(key: K, value: V?) { + operator fun set(key: K, value: V) { put(key, value) } - @Suppress("UNCHECKED_CAST") - fun toMap(): Map<K, V> = inner.mapValues { (_, v) -> v.takeIf { it !== NullValue } as V } + fun toMap(): Map<K, V> = + inner.asSequence().associate { (k, v) -> toNullable<K>(k) to toNullable(v) } - fun clear() { + override fun clear() { inner.clear() } + override fun remove(key: K, value: V): Boolean = inner.remove(key ?: NoValue, value ?: NoValue) + + override val entries: MutableSet<MutableMap.MutableEntry<K, V>> = + object : AbstractMutableSet<MutableMap.MutableEntry<K, V>>() { + val wrapped = inner.entries + + override fun add(element: MutableMap.MutableEntry<K, V>): Boolean { + val e = + object : MutableMap.MutableEntry<Any, Any> { + override val key: Any + get() = element.key ?: NullValue + + override val value: Any + get() = element.value ?: NullValue + + override fun setValue(newValue: Any): Any = + element.setValue(toNullable(newValue)) ?: NullValue + } + return wrapped.add(e) + } + + override val size: Int + get() = wrapped.size + + override fun iterator(): MutableIterator<MutableMap.MutableEntry<K, V>> { + val iter = wrapped.iterator() + return object : MutableIterator<MutableMap.MutableEntry<K, V>> { + override fun hasNext(): Boolean = iter.hasNext() + + override fun next(): MutableMap.MutableEntry<K, V> { + val element = iter.next() + return object : MutableMap.MutableEntry<K, V> { + override val key: K + get() = toNullable(element.key) + + override val value: V + get() = toNullable(element.value) + + override fun setValue(newValue: V): V = + toNullable(element.setValue(newValue ?: NullValue)) + } + } + + override fun remove() { + iter.remove() + } + } + } + } + + override fun replace(key: K, oldValue: V, newValue: V): Boolean = + inner.replace(key ?: NullValue, oldValue ?: NullValue, newValue ?: NullValue) + + override fun replace(key: K, value: V): V? = + inner.replace(key ?: NullValue, value ?: NullValue)?.let { toNullable<V>(it) } + + override fun putIfAbsent(key: K, value: V): V? = + inner.putIfAbsent(key ?: NullValue, value ?: NullValue)?.let { toNullable<V>(it) } + + @Suppress("UNCHECKED_CAST", "NOTHING_TO_INLINE") + private inline fun <T> toNullable(value: Any): T = value.takeIf { it !== NullValue } as T + fun isNotEmpty(): Boolean = inner.isNotEmpty() + + @Suppress("UNCHECKED_CAST") + override fun remove(key: K): V? = + inner.remove(key ?: NullValue)?.takeIf { it !== NullValue } as V? + + fun asSequence(): Sequence<Pair<K, V>> = + inner.asSequence().map { (key, value) -> toNullable<K>(key) to toNullable(value) } + + override fun isEmpty(): Boolean = inner.isEmpty() + + override fun containsKey(key: K): Boolean = inner.containsKey(key ?: NullValue) + + fun getOrPut(key: K, defaultValue: () -> V): V = + toNullable(inner.getOrPut(key) { defaultValue() ?: NullValue }) } private object NullValue |