Files
AI-Stock-Trader/WebServer/Controllers/Automation.cs
T

190 lines
7.9 KiB
C#

using Controllers.DataBase;
using Controllers.Payment;
using Controllers.PythonInterop;
using Entities;
using Newtonsoft.Json;
namespace Controllers.Automation {
public class AutomationController {
AIModule _aiModule;
DbDriver _dbDriver;
IPayment _paymentProcessor;
public AutomationController(AIModule aiModule, DbDriver dbDriver, IPayment PaymentProcessor) {
_aiModule = aiModule;
_dbDriver = dbDriver;
_paymentProcessor = PaymentProcessor;
}
public void GlobalPredictAI(int DaysBefore = 0, bool testmode = false) {
// Start this process on a background thread so its non-blocking
Task thread = new Task(() => {
// Load the userlist
List<string> VerifiedUserList = new List<string>(){ "TESTMODE" };
if (!testmode) {
List<string>? UserList = JsonConvert.DeserializeObject<List<string>>(_dbDriver.Get("Users"));
VerifiedUserList = UserList != null ? UserList : new List<string>();
}
// Process each request at the same time for speed improvement
Parallel.ForEach(VerifiedUserList, async (username) => {
string dbPrefix = $"[{username.ToLower()}]:";
// Load the Tracked stocks for each user
List<Stock>? TrackedStocks = JsonConvert.DeserializeObject<List<Stock>>( _dbDriver.Get( dbPrefix + "watched" ) );
List<Stock> VerifiedTrackedStocks = TrackedStocks != null ? TrackedStocks : new List<Stock>();
// Go through each stock
List<Task> threadpool = new List<Task>();
foreach(Stock cur in VerifiedTrackedStocks) {
// Predict the trend on a new thread
Task thread = new Task(() => {
(string, float)Result = _aiModule.PredictAI(cur.Symbol, DaysBefore);
// If error log it
if (!string.IsNullOrEmpty(Result.Item1)){
Console.WriteLine(Result.Item1);
}
// Write the score to the users tracked stocks
cur.Score = Result.Item2;
});
thread.Start();
threadpool.Add(thread);
}
// Wait for all the threads to finish
await Task.WhenAll(threadpool);
// Get the highest ranked
Stock HighestRanking = new Stock(){ Symbol="NVDA", Score = -400 }; // Just a placeholder incase an empty list comes through there is a fallback
foreach(Stock cur in VerifiedTrackedStocks) {
if (HighestRanking.Score < cur.Score) {
HighestRanking = cur;
}
}
// Get users money
string MoneyStr = _dbDriver.Get( dbPrefix + "money");
bool Money = float.TryParse( MoneyStr, out float VerifiedMoney );
if (!Money) {
Console.WriteLine( "Unable to load users money" );
VerifiedMoney = 0;
}
// Sell all stocks
VerifiedMoney = sellStock(username, VerifiedMoney);
if (VerifiedMoney == -1f) {
Console.WriteLine("Failed to sell stocks");
return;
}
// Buy predicted stock
VerifiedMoney = buyStock(username, VerifiedMoney, HighestRanking.Symbol);
if (VerifiedMoney == -1f) {
Console.WriteLine("Failed to buy stocks");
return;
}
// Save to the database
_dbDriver.Set( dbPrefix + "watched", JsonConvert.SerializeObject( VerifiedTrackedStocks ) );
_dbDriver.Set( dbPrefix + "money", VerifiedMoney.ToString());
});
});
thread.Start();
}
public void GlobalTrainAI(){
Task thread = new Task(() => {
_aiModule.PullAI();
_aiModule.TrainAI();
});
thread.Start();
}
float sellStock(string username, float Money){
string dbPrefix = $"[{username.ToLower()}]:";
// Get all stock history
List<PurchasedStock>? PurchaseHistory = JsonConvert.DeserializeObject<List<PurchasedStock>>( _dbDriver.Get( dbPrefix + "history" ) );
List<PurchasedStock> VerifiedPurchaseHistory = PurchaseHistory != null ? PurchaseHistory : new List<PurchasedStock>();
// Find the stocks that need to be sold
float totalSale = 0;
foreach(PurchasedStock cur in VerifiedPurchaseHistory) {
if (cur.Sold == false) {
// Get sell price
float sellPrice = cur.Quantity * _aiModule.GetCurrentPrice( cur.Symbol );
// Try create payment session
(bool, string) createResult = _paymentProcessor.CreatePayment(username);
if (!createResult.Item1) {
Console.WriteLine("Create Payment Failure: " + createResult.Item2);
return -1f;
}
// Try to sell the stock
(bool, string) paymentResult = _paymentProcessor.TrySell(createResult.Item2, sellPrice);
if (!paymentResult.Item1){
Console.WriteLine("Process Payment Failure: " + paymentResult.Item2);
return -1f;
}
// Add up the total sale
totalSale += sellPrice;
}
}
// Save the stock history
_dbDriver.Set( dbPrefix + "Stocks", JsonConvert.SerializeObject(VerifiedPurchaseHistory) );
// return the new calculated total
return Money + totalSale;
}
float buyStock(string username, float Money, string stockSymbol){
string dbPrefix = $"[{username.ToLower()}]:";
// Get all stock history
List<PurchasedStock>? PurchaseHistory = JsonConvert.DeserializeObject<List<PurchasedStock>>( _dbDriver.Get( dbPrefix + "history" ) );
List<PurchasedStock> VerifiedPurchaseHistory = PurchaseHistory != null ? PurchaseHistory : new List<PurchasedStock>();
// Get Stock Price
float stockPrice = _aiModule.GetCurrentPrice( stockSymbol );
// Get max stocks user can purchase [ int cast truncates the decimal ]
int MaxQty = (int)( Money / stockPrice );
// Try create payment session
(bool, string) createResult = _paymentProcessor.CreatePayment(username);
if (!createResult.Item1) {
Console.WriteLine("Create Payment Failure: " + createResult.Item2);
return -1f;
}
// Try Pay for the stock
(bool, string) result = _paymentProcessor.TryPayment(createResult.Item2, stockPrice * MaxQty);
if (!result.Item1){
Console.WriteLine("Process Payment Failure: " + result.Item2);
return -1f;
}
// Add the stock
VerifiedPurchaseHistory.Add( new PurchasedStock(){
Symbol = stockSymbol.ToUpper(),
PurchasePrice = stockPrice,
Quantity = MaxQty,
} );
_dbDriver.Set( dbPrefix + "Stocks", JsonConvert.SerializeObject(VerifiedPurchaseHistory) );
// Return the new money
return Money - ( stockPrice * MaxQty );
}
}
}