Skip to content

Commit

Permalink
#157 Use bidding network for bidding
Browse files Browse the repository at this point in the history
  • Loading branch information
b0n541 committed Jan 13, 2025
1 parent 412878b commit 942875a
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 148 deletions.
113 changes: 113 additions & 0 deletions jskat-base/src/main/kotlin/org/jskat/ai/deeplearning/AIPLayerDL.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package org.jskat.ai.deeplearning

import org.jskat.data.GameContract
import org.jskat.player.AbstractJSkatPlayer
import org.jskat.util.Card
import org.jskat.util.CardList
import org.jskat.util.GameType
import org.jskat.util.SkatConstants
import org.jskat.util.rule.GrandRule
import org.jskat.util.rule.SuitRule
import org.slf4j.LoggerFactory
import kotlin.random.Random

class AIPLayerDL(val name: String = "AIPlayerDL") : AbstractJSkatPlayer() {

val logger = LoggerFactory.getLogger(javaClass)

val random = Random.Default

val biddingModel = BiddingModel()

val grandRule = GrandRule()
val suitRule = SuitRule()

override fun prepareForNewGame() {
// nothing to do for AIPLayerDL
}

override fun isAIPlayer(): Boolean {
return true
}

override fun startGame() {
// nothing to do for AIPLayerDL
}

override fun bidMore(nextBidValue: Int): Int {
if (nextBidValue <= maxBid()) {
return nextBidValue
}

return 0
}

override fun holdBid(currBidValue: Int): Boolean {
return currBidValue <= maxBid()
}

private fun maxBid(): Int {
val gameType = biddingModel.predictGameType(knowledge.playerPosition, knowledge.ownCards)

logger.info("Bidding model prediction: $gameType")

if (gameType != GameType.PASSED_IN) {
var matadors = when (gameType) {
GameType.GRAND -> grandRule.getMatadors(knowledge.ownCards, gameType)
GameType.CLUBS, GameType.SPADES, GameType.HEARTS, GameType.DIAMONDS ->
suitRule.getMatadors(
knowledge.ownCards, gameType
)

else -> 0
}

logger.info("Matadors: $matadors")

// TODO use calculations from skat rules and SkatConstants
// TODO take hand and ouvert into account
return (matadors + 1) * SkatConstants.getGameBaseValue(gameType, false, false)
}

return 0
}

override fun pickUpSkat(): Boolean {
// TODO find better cards
return false
}

override fun getCardsToDiscard(): CardList {
TODO("Not yet implemented")
}

override fun announceGame(): GameContract {
val gameType = biddingModel.predictGameType(knowledge.playerPosition, knowledge.ownCards)
return GameContract(gameType)
}

override fun playGrandHand(): Boolean {
// TODO
return false
}

override fun callContra(): Boolean {
// TODO
return false
}

override fun callRe(): Boolean {
// TODO
return false
}

override fun playCard(): Card {
val possibleCards = getPlayableCards(knowledge.trickCards)
val index = random.nextInt(possibleCards.size())
return possibleCards[index]
}

override fun finalizeGame() {
// nothing to do for AIPLayerDL
}
}
147 changes: 0 additions & 147 deletions jskat-base/src/main/kotlin/org/jskat/ai/deeplearning/AIPlayerDL.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ fun main() {
"♦A", "♦T", "♦K", "♦Q", "♦J", "♦9", "♦8", "♦7",
"maxBidForehand", "maxBidMiddlehand", "maxBidRearhand",
"gameType", "hand", "ouvert", "annSchneider", "annSchwarz",
"won", "declarerScore", "schneider", "schwarz"
"won", "schneider", "schwarz", "declarerScore"
)
.setSkipHeaderRecord(true)
.setIgnoreHeaderCase(true)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package org.jskat.ai.deeplearning

import ai.djl.Model
import ai.djl.inference.Predictor
import ai.djl.modality.Classifications
import ai.djl.nn.Activation
import ai.djl.nn.SequentialBlock
import ai.djl.nn.core.Linear
import org.jskat.util.Card
import org.jskat.util.CardList
import org.jskat.util.GameType
import org.jskat.util.Player
import org.slf4j.LoggerFactory
import java.nio.file.Paths

class BiddingModel() {

private val logger = LoggerFactory.getLogger(BiddingModel::class.java)

private val block = SequentialBlock()
.add(Linear.builder().setUnits(33).build())
.add(Activation::relu)
.add(Linear.builder().setUnits(128).build())
.add(Activation::relu)
.add(Linear.builder().setUnits(128).build())
.add(Activation::relu)
.add(Linear.builder().setUnits(64).build())
.add(Activation::relu)
.add(Linear.builder().setUnits(64).build())
.add(Activation::relu)
.add(Linear.builder().setUnits(6).build())

private val model = Model.newInstance("bidnet")

private val gameTypeTranslator = GameTypeClassificationTranslator()

private val predictor: Predictor<FloatArray, Classifications>

init {
model.block = block
model.load(Paths.get(BiddingModel::class.java.classLoader.getResource("data/model").toURI()))
predictor = model.newPredictor(gameTypeTranslator)
}

fun predictGameType(position: Player, hand: CardList): GameType {
val classes = predictor.predict(toFloatArray(position) + toFloatArray(hand))
val best = classes.best<Classifications.Classification>()

logger.info("Best game type ${best.className} with probability of ${best.probability}")

if (best.probability > 0.5) {
return GameType.valueOf(best.className)
}
return GameType.PASSED_IN
}

private fun toFloatArray(postion: Player): FloatArray {
val result = floatArrayOf(0.0f, 0.0f, 0.0f)
when (postion) {
Player.FOREHAND -> result[0] = 1.0f
Player.MIDDLEHAND -> result[1] = 1.0f
Player.REARHAND -> result[2] = 1.0f
}
return result
}

private fun toFloatArray(hand: CardList): FloatArray {
val result = FloatArray(32) { 0.0f }

for (card in Card.values()) {
if (hand.contains(card)) {
result[card.ordinal] = 1.0f
}
}

return result
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package org.jskat.ai.deeplearning

import ai.djl.modality.Classifications
import ai.djl.ndarray.NDList
import ai.djl.translate.Translator
import ai.djl.translate.TranslatorContext

class GameTypeClassificationTranslator : Translator<FloatArray, Classifications> {
override fun processInput(
ctx: TranslatorContext,
input: FloatArray
): NDList? {
return NDList(ctx.ndManager?.create(input))
}

override fun processOutput(
ctx: TranslatorContext,
list: NDList
): Classifications {
return Classifications(
listOf("CLUBS", "DIAMONDS", "GRAND", "HEARTS", "NULL", "SPADES"),
list.singletonOrThrow().softmax(0)
)
}
}
Loading

0 comments on commit 942875a

Please sign in to comment.