1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
| import torch import random import matplotlib.pyplot as plt
x = X = torch.arange(0,20,0.005).reshape(-1,1) Y = 3*X+2.5
for i in range(len(X)): X[i] += random.gauss(0, 0.35) Y[i] += random.gauss(0, 0.75) plt.plot(X.tolist(), Y.tolist(), linestyle='', marker='.',zorder=10)
W = torch.nn.Linear(1,1)
if torch.cuda.is_available(): W.cuda(0) X = X.cuda(0) x = x.cuda(0) Y = Y.cuda(0) optimizer = torch.optim.Adam(W.parameters()) criterion = torch.nn.MSELoss() log = open('log.txt', 'a') for step in range(10000): optimizer.zero_grad() loss = criterion(W(X),Y) loss.backward() optimizer.step() if step % 1000 == 0: txt = 'step:{}, = {}, loss = {}'.format(step,list(W.parameters()),loss) log.write(txt + '\n') print(txt) log.close() y = W(x) plt.plot(x.tolist(),y.tolist(),color = 'y', linewidth=1,zorder=30) plt.savefig('save.png', dpi=500)
|