This case comes from the official example of ML agents, Github address: https://github.com/Unity-Technologies/ml-agents , this article is a detailed supporting explanation.
Based on my previous two articles, I need to know about ml agents. For details, see: The use of ML agents in Unity reinforcement learning,Ml agents command and configuration.
My previous related articles are:
Push box game in ML agents case
Ml agents case wall jumping game
Food collectors in ML agents case
Double football in ML agents case
Unity artificial intelligence: a self evolving five player football game
Dungeon Escape in ML agents case
Robot walking in ML agents case
Map matching of ML agents case
Environmental description
As shown in the figure, when the agent is in a circular room, the box with numbers will appear randomly on the wall. The agent needs to collide with the box from small to large, and the collided box will turn green with a score of + 1. Once the collision order is wrong, the game ends with a score of - 1.
The challenge of this case is that we will not tell the agent how to sort correctly. The agent needs to try and make mistakes in the environment, so as to learn the behavior mode of sorting from small to large and colliding with the corresponding blocks. At the same time, the number of digital blocks on the wall is uncertain, that is, we need to receive different numbers of inputs for each episode, How should this be handled?
Status input: a new sensor Buffer Sensor is used here.
The function of this sensor is to receive the status input of number change. We need to pass in one vector at a time, which can be represented by the array listObservation. Through m_BufferSensor.AppendObservation(listObservation) is passed into the buffersensor, which can receive countless such vector inputs, but the dimension of each vector must be the same. In other words, even if the number of vectors we input is different every time, we can train the network to produce the desired output. The specific implementation is not in the project code. It is integrated in the ML agents package. According to my experience, the network structure of self attention should be used to receive the input of different number vectors.
In addition to the input to the BufferSensor, four-dimensional vectors are also introduced, which are the components of the vector from the agent's position to the center of the site on the x-axis and z-axis, and the components of the agent's forward direction on the x-axis and z-axis.
Action output: three discrete values are output. Each discrete value contains three numbers of 0-2. The first discrete value determines the forward and backward, the second discrete value determines the left shift and right shift, and the third discrete value determines the left turn and right turn.
Code explanation
The scripts attached to the agent include the decision requester, model override, Behavior Parameters, and the Buffer Sensor described just now. Only the agent file sorteragent is left The following:
Header file:
using System.Collections.Generic; using UnityEngine; using Unity.MLAgents; using Unity.MLAgents.Actuators; using Unity.MLAgents.Sensors; using Random = UnityEngine.Random;
Define variables:
// The maximum number of default number squares can be adjusted by sliding in the editor. The adjustment range is 1 - 20 [Range(1, 20)] public int DefaultMaxNumTiles; // Maximum number of squares private const int k_HighestTileValue = 20; // Number of generated blocks int m_NumberOfTilesToSpawn; // Maximum number of squares int m_MaxNumberOfTiles; // rigid body Rigidbody m_AgentRb; // BufferSensorComponent is a sensor that allows you to observe a different number of inputs BufferSensorComponent m_BufferSensor; // List of numeric squares public List<NumberTile> NumberTilesList = new List<NumberTile>(); // List of squares that appear in the scene private List<NumberTile> CurrentlyVisibleTilesList = new List<NumberTile>(); // List of squares that have been touched private List<Transform> AlreadyTouchedList = new List<Transform>(); private List<int> m_UsedPositionsList = new List<int>(); // initial position private Vector3 m_StartingPos; // Whole scene GameObject m_Area; // Environment parameters, which can be obtained from the configuration file EnvironmentParameters m_ResetParams; // The index of the next digit box you want to collide with private int m_NextExpectedTileIndex;
Initialization method Initialize():
public override void Initialize() { // Get parent object m_Area = transform.parent.gameObject; // Gets the maximum number of squares m_MaxNumberOfTiles = k_HighestTileValue; // Get environment parameters from configuration file m_ResetParams = Academy.Instance.EnvironmentParameters; // Get sensor script m_BufferSensor = GetComponent<BufferSensorComponent>(); // Get rigid body m_AgentRb = GetComponent<Rigidbody>(); // Starting position m_StartingPos = transform.position; }
Status input method:
public override void CollectObservations(VectorSensor sensor) { // Get the distance from the agent to the center of the site on the x-axis and z-axis sensor.AddObservation((transform.position.x - m_Area.transform.position.x) / 20f); sensor.AddObservation((transform.position.z - m_Area.transform.position.z) / 20f); // Obtain the values of x-axis and z-axis of the forward direction of the agent sensor.AddObservation(transform.forward.x); sensor.AddObservation(transform.forward.z); foreach (var item in CurrentlyVisibleTilesList) { // Define an array to store a series of observations. The length of the array is the maximum number of digital squares + 3. The default initialization is 0 float[] listObservation = new float[k_HighestTileValue + 3]; // Get the number of blocks and set the corresponding one hot vector listObservation[item.NumberValue] = 1.0f; // Obtain the coordinates of the block (the sub object coordinates are the real coordinates, and the position of the transform itself is kept in the center of the scene to facilitate rotation) var tileTransform = item.transform.GetChild(1); // Enter the x and z components of the digital block and the agent listObservation[k_HighestTileValue] = (tileTransform.position.x - transform.position.x) / 20f; listObservation[k_HighestTileValue + 1] = (tileTransform.position.z - transform.position.z) / 20f; // Has the square been collided listObservation[k_HighestTileValue + 2] = item.IsVisited ? 1.0f : 0.0f; // Add the array to the Buffer Sensor (the reason why it is not directly input to the network is that the number of arrays to be added varies) m_BufferSensor.AppendObservation(listObservation); } }
Action output method OnActionReceived:
public override void OnActionReceived(ActionBuffers actionBuffers) { // Mobile agent MoveAgent(actionBuffers.DiscreteActions); // Time penalty, encourage the agent to complete as soon as possible AddReward(-1f / MaxStep); } public void MoveAgent(ActionSegment<int> act) { var dirToGo = Vector3.zero; var rotateDir = Vector3.zero; // Three discrete outputs of neural network are obtained var forwardAxis = act[0]; var rightAxis = act[1]; var rotateAxis = act[2]; // The first discrete output determines forward and backward switch (forwardAxis) { case 1: dirToGo = transform.forward * 1f; break; case 2: dirToGo = transform.forward * -1f; break; } // The second discrete output determines the shift from left to right switch (rightAxis) { case 1: dirToGo = transform.right * 1f; break; case 2: dirToGo = transform.right * -1f; break; } // The third discrete output determines the left turn and the right turn switch (rotateAxis) { case 1: rotateDir = transform.up * -1f; break; case 2: rotateDir = transform.up * 1f; break; } // Execute action transform.Rotate(rotateDir, Time.deltaTime * 200f); m_AgentRb.AddForce(dirToGo * 2, ForceMode.VelocityChange); }
The method OnEpisodeBegin executed at the beginning of each episode (turn):
public override void OnEpisodeBegin() { // Get the number of blocks from the configuration file. If not, set it to DefaultMaxNumTiles m_MaxNumberOfTiles = (int)m_ResetParams.GetWithDefault("num_tiles", DefaultMaxNumTiles); // Number of randomly generated blocks m_NumberOfTilesToSpawn = Random.Range(1, m_MaxNumberOfTiles + 1); // Select the corresponding box to be generated and add it to the list SelectTilesToShow(); // Generate blocks and adjust positions SetTilePositions(); transform.position = m_StartingPos; m_AgentRb.velocity = Vector3.zero; m_AgentRb.angularVelocity = Vector3.zero; } void SelectTilesToShow() { // Clear both lists CurrentlyVisibleTilesList.Clear(); AlreadyTouchedList.Clear(); // A total of nunLeft squares are generated int numLeft = m_NumberOfTilesToSpawn; while (numLeft > 0) { // Take random numbers in the range to generate corresponding blocks int rndInt = Random.Range(0, k_HighestTileValue); var tmp = NumberTilesList[rndInt]; // If the corresponding box is not in the list, add it if (!CurrentlyVisibleTilesList.Contains(tmp)) { CurrentlyVisibleTilesList.Add(tmp); numLeft--; } } // Sort the list of squares in ascending numerical order CurrentlyVisibleTilesList.Sort((x, y) => x.NumberValue.CompareTo(y.NumberValue)); m_NextExpectedTileIndex = 0; } void SetTilePositions() { // clear list m_UsedPositionsList.Clear(); // Reset the state of all blocks. The ResetTile method can be seen in the script of the number block foreach (var item in NumberTilesList) { item.ResetTile(); item.gameObject.SetActive(false); } foreach (var item in CurrentlyVisibleTilesList) { bool posChosen = false; // rndPosIndx determines the rotation angle of our square (i.e. where in the circular field) int rndPosIndx = 0; while (!posChosen) { rndPosIndx = Random.Range(0, k_HighestTileValue); // Is this rotation angle selected? If not, it will be added to the list if (!m_UsedPositionsList.Contains(rndPosIndx)) { m_UsedPositionsList.Add(rndPosIndx); posChosen = true; } } // Performs a rotation of the square angle and activates the object item.transform.localRotation = Quaternion.Euler(0, rndPosIndx * (360f / k_HighestTileValue), 0); item.gameObject.SetActive(true); } }
When collision with other objects begins, execute the method OnCollisionEnter:
private void OnCollisionEnter(Collision col) { // Only collision with the number box is detected if (!col.gameObject.CompareTag("tile")) { return; } // If the block has collided, it is also excluded from the collision object if (AlreadyTouchedList.Contains(col.transform)) { return; } // If the order of collisions is wrong, reward - 1 and end the game if (col.transform.parent != CurrentlyVisibleTilesList[m_NextExpectedTileIndex].transform) { AddReward(-1); EndEpisode(); } // Hit the right box else { // Reward + 1 AddReward(1); // Change the material of the box var tile = col.gameObject.GetComponentInParent<NumberTile>(); tile.VisitTile(); // Index + 1 m_NextExpectedTileIndex++; // Add the block to the touched list AlreadyTouchedList.Add(col.transform); // If all the tasks are completed, the game is over if (m_NextExpectedTileIndex == m_NumberOfTilesToSpawn) { EndEpisode(); } } }
When the agent has no model and people want to record examples manually, the Heuristic method can be used:
public override void Heuristic(in ActionBuffers actionsOut) { var discreteActionsOut = actionsOut.DiscreteActions; //forward if (Input.GetKey(KeyCode.W)) { discreteActionsOut[0] = 1; } if (Input.GetKey(KeyCode.S)) { discreteActionsOut[0] = 2; } //rotate if (Input.GetKey(KeyCode.A)) { discreteActionsOut[2] = 1; } if (Input.GetKey(KeyCode.D)) { discreteActionsOut[2] = 2; } //right if (Input.GetKey(KeyCode.E)) { discreteActionsOut[1] = 1; } if (Input.GetKey(KeyCode.Q)) { discreteActionsOut[1] = 2; } }
Script numbertile. Mounted on the number box cs:
using UnityEngine; public class NumberTile : MonoBehaviour { // The number on the square public int NumberValue; // Default material and material for successful conversion public Material DefaultMaterial; public Material SuccessMaterial; // Have you collided private bool m_Visited; // Rendering for converting materials private MeshRenderer m_Renderer; public bool IsVisited { get { return m_Visited; } } // Method for converting materials public void VisitTile() { m_Renderer.sharedMaterial = SuccessMaterial; m_Visited = true; } // Method of resetting the box, material restoration, m_Visited status restore public void ResetTile() { if (m_Renderer is null) { m_Renderer = GetComponentInChildren<MeshRenderer>(); } m_Renderer.sharedMaterial = DefaultMaterial; m_Visited = false; } }
configuration file
behaviors: Sorter: trainer_type: ppo hyperparameters: batch_size: 512 buffer_size: 40960 learning_rate: 0.0003 beta: 0.005 epsilon: 0.2 lambd: 0.95 num_epoch: 3 learning_rate_schedule: constant network_settings: normalize: False hidden_units: 128 num_layers: 2 vis_encode_type: simple reward_signals: extrinsic: gamma: 0.99 strength: 1.0 keep_checkpoints: 5 max_steps: 5000000 time_horizon: 256 summary_freq: 10000 environment_parameters: num_tiles: curriculum: - name: Lesson0 # The '-' is important as this is a list completion_criteria: measure: progress behavior: Sorter signal_smoothing: true min_lesson_length: 100 threshold: 0.3 value: 2.0 - name: Lesson1 completion_criteria: measure: progress behavior: Sorter signal_smoothing: true min_lesson_length: 100 threshold: 0.4 value: 4.0 - name: Lesson2 completion_criteria: measure: progress behavior: Sorter signal_smoothing: true min_lesson_length: 100 threshold: 0.45 value: 6.0 - name: Lesson3 completion_criteria: measure: progress behavior: Sorter signal_smoothing: true min_lesson_length: 100 threshold: 0.5 value: 8.0 - name: Lesson4 completion_criteria: measure: progress behavior: Sorter signal_smoothing: true min_lesson_length: 100 threshold: 0.55 value: 10.0 - name: Lesson5 completion_criteria: measure: progress behavior: Sorter signal_smoothing: true min_lesson_length: 100 threshold: 0.6 value: 12.0 - name: Lesson6 completion_criteria: measure: progress behavior: Sorter signal_smoothing: true min_lesson_length: 100 threshold: 0.65 value: 14.0 - name: Lesson7 completion_criteria: measure: progress behavior: Sorter signal_smoothing: true min_lesson_length: 100 threshold: 0.7 value: 16.0 - name: Lesson8 completion_criteria: measure: progress behavior: Sorter signal_smoothing: true min_lesson_length: 100 threshold: 0.75 value: 18.0 - name: Lesson9 value: 20.0
It can be seen that the configuration file adopts the most commonly used PPO algorithm, and it is an ordinary PPO without other "accessories" such as LSTM, internal reward mechanism and other modules, The only difference is that Curriculum Learning is added here (course study), that is to say, this kind of agent that can sort dozens of blocks is difficult to train at once, so we need to arrange tasks for it from easy to difficult. From the beginning, it can sort two blocks, gradually increase two blocks, and finally reach 20. For a detailed explanation of the parameters of current learning, please see my previous article Ml agents case wall jumping game.
Effect demonstration
Postscript
Compared with the previous case, the innovation of this case is the introduction of Buffer Sensor, which is used to receive the input of different number vectors, rather than hanging in the agent like the previous sensors. This is to deal with the situation that the amount of received information changes with the environment in the case of this case. There are many cases, For example, when an agent faces the enemy, the number of enemies is uncertain, and the number of bullets fired by the enemy is also uncertain. At this time, we need to use buffer sensors to accept different numbers of inputs. Of course, such training often requires more samples, and various numbers of inputs need to be covered, otherwise it will be over fitted. In order to achieve this goal, the previous Curriculum Learning is used to diversify the training samples, make the training from easy to difficult, and make the agent strategy robust.