| import os | |
| import sys | |
| with open(sys.argv[0]) as f: | |
| code = f.read() # read the code of this file ASAP, for logging | |
| import uuid | |
| import time | |
| import copy | |
| import glob | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from pathlib import Path | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| import torch | |
| torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems | |
| from torch import Tensor, nn | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| # use of FlexAttention contributed by @KoszarskyB | |
| from torch.nn.attention.flex_attention import BlockMask, flex_attention | |
| #torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min | |
| # ----------------------------------------------------------------------------- | |
| # Custom operators: FP8 matmul by @YouJiacheng | |
| @torch.library.custom_op("nanogpt::mm", mutates_args=()) | |
| def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: | |
| @torch.compile | |
| def impl(x: Tensor, w: Tensor): | |
| assert x.is_contiguous() and w.is_contiguous() | |
| x_f8 = x.div(x_s).to(torch.float8_e4m3fn) | |
| w_f8 = w.div(w_s).to(torch.float8_e4m3fn) | |
| out = torch._scaled_mm( | |
| x_f8, | |
| w_f8.T, | |
| out_dtype=torch.bfloat16, | |
| scale_a=x.new_tensor(x_s, dtype=torch.float32), | |
| scale_b=x.new_tensor(w_s, dtype=torch.float32), | |
| use_fast_accum=True, | |
| ) | |
| return out, x_f8, w_f8 | |
| return impl(x, w) | |
| @mm_op.register_fake | |
| def _(x: Tensor, w: Tensor, *_): | |
| assert x.ndim == w.ndim == 2 | |
| assert x.shape[1] == w.shape[1] | |
| assert x.device == w.device | |
| assert x.is_contiguous() and w.is_contiguous() | |
| return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) | |
| @torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) | |
| def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: | |
| @torch.compile | |
| def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): | |
| assert grad.is_contiguous() | |
| x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) | |
| w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) | |
| grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) | |
| grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) | |
| grad_x = torch._scaled_mm( | |
| grad_f8, | |
| w_f8.T.contiguous().T, | |
| out_dtype=torch.bfloat16, | |
| scale_a=grad_inv_s, | |
| scale_b=w_inv_s, | |
| use_fast_accum=False, | |
| ) | |
| # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) | |
| grad_w = torch._scaled_mm( | |
| x_f8.T.contiguous(), | |
| grad_f8.T.contiguous().T, | |
| out_dtype=torch.float32, | |
| scale_a=x_inv_s, | |
| scale_b=grad_inv_s, | |
| use_fast_accum=False, | |
| ).T | |
| return grad_x, grad_w | |
| return impl(g, x_f8, w_f8) | |
| @mm_backward_op.register_fake | |
| def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): | |
| return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) | |
| def backward(ctx, grad_out: Tensor, *_): | |
| x_f8, w_f8 = ctx.saved_tensors | |
| x_s, w_s, grad_s = ctx.scales | |
| grad_x, grad_w = torch.ops.nanogpt.mm_backward( | |
| grad_out, x_f8, w_f8, x_s, w_s, grad_s | |
| ) | |
| return grad_x, grad_w, None, None, None | |
| def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): | |
| *_, x_s, w_s, grad_s = inputs | |
| _, x_f8, w_f8 = output | |
| ctx.save_for_backward(x_f8, w_f8) | |
| ctx.scales = x_s, w_s, grad_s | |
| ctx.set_materialize_grads(False) | |
| mm_op.register_autograd(backward, setup_context=setup_context) | |
| # ----------------------------------------------------------------------------- | |
| # Muon optimizer | |
| @torch.compile | |
| def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: | |
| """ | |
| Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a | |
| quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose | |
| of minimizing steps, it turns out to be empirically effective to keep increasing the slope at | |
| zero even beyond the point where the iteration no longer converges all the way to one everywhere | |
| on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T | |
| where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model | |
| performance at all relative to UV^T, where USV^T = G is the SVD. | |
| """ | |
| assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng | |
| a, b, c = (3.4445, -4.7750, 2.0315) | |
| X = G.bfloat16() | |
| if G.size(-2) > G.size(-1): | |
| X = X.mT | |
| # Ensure spectral norm is at most 1 | |
| X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) | |
| # Perform the NS iterations | |
| for _ in range(steps): | |
| A = X @ X.mT | |
| B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng | |
| X = a * X + B @ X | |
| if G.size(-2) > G.size(-1): | |
| X = X.mT | |
| return X | |
| class Muon(torch.optim.Optimizer): | |
| """ | |
| Muon - MomentUm Orthogonalized by Newton-schulz | |
| https://kellerjordan.github.io/posts/muon/ | |
| Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- | |
| processing step, in which each 2D parameter's update is replaced with the nearest orthogonal | |
| matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has | |
| the advantage that it can be stably run in bfloat16 on the GPU. | |
| Some warnings: | |
| - This optimizer should not be used for the embedding layer, the final fully connected layer, | |
| or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). | |
| - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. | |
| Arguments: | |
| lr: The learning rate used by the internal SGD. | |
| momentum: The momentum used by the internal SGD. | |
| nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) | |
| ns_steps: The number of Newton-Schulz iteration steps to use. | |
| """ | |
| def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5, rank=0, world_size=1): | |
| self.rank = rank | |
| self.world_size = world_size | |
| defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) | |
| params: list[Tensor] = [*params] | |
| param_groups = [] | |
| for size in {p.numel() for p in params}: | |
| b = torch.empty(world_size, size, dtype=torch.bfloat16, device="cuda") | |
| group = dict(params=[p for p in params if p.numel() == size], | |
| update_buffer=b, update_buffer_views=[b[i] for i in range(world_size)]) | |
| param_groups.append(group) | |
| super().__init__(param_groups, defaults) | |
| @torch.no_grad() | |
| def step(self): | |
| for group in self.param_groups: | |
| update_buffer: Tensor = group["update_buffer"] | |
| update_buffer_views: list[Tensor] = group["update_buffer_views"] | |
| # generate weight updates in distributed fashion | |
| params: list[Tensor] = group["params"] | |
| handle = None | |
| params_world = None | |
| def update_prev(): # optimized Muon implementation contributed by @YouJiacheng | |
| handle.wait() | |
| for p_world, g_world in zip(params_world, update_buffer_views): | |
| p_world.add_(g_world.view_as(p_world), | |
| alpha=-group["lr"] * max(1, p_world.size(-2) / p_world.size(-1))**0.5) | |
| for base_i in range(len(params))[::self.world_size]: | |
| if base_i + self.rank < len(params): | |
| p = params[base_i + self.rank] | |
| g = p.grad | |
| assert g is not None | |
| state = self.state[p] | |
| if "momentum_buffer" not in state: | |
| state["momentum_buffer"] = torch.zeros_like(g) | |
| buf: Tensor = state["momentum_buffer"] | |
| buf.lerp_(g, 1 - group["momentum"]) | |
| g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf | |
| g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]).flatten() | |
| else: | |
| g = update_buffer_views[self.rank] | |
| if base_i > 0: | |
| update_prev() # async all_gather instead of sync all_reduce by @YouJiacheng | |
| handle = dist.all_gather_into_tensor(update_buffer, g, async_op=True) | |
| params_world = params[base_i : base_i + self.world_size] | |
| update_prev() | |
| # ----------------------------------------------------------------------------- | |
| # PyTorch nn.Module definitions for the model | |
| def norm(x: Tensor): | |
| return F.rms_norm(x, (x.size(-1),)) | |
| class CastedLinear(nn.Linear): | |
| def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): | |
| super().__init__(in_features, out_features, bias=False) | |
| self.use_fp8 = use_fp8 | |
| self.x_s = x_s | |
| self.w_s = w_s | |
| self.grad_s = grad_s | |
| def reset_parameters(self) -> None: | |
| std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) | |
| bound = (3 ** 0.5) * std | |
| with torch.no_grad(): | |
| self.weight.uniform_(-bound, bound) | |
| def forward(self, x: Tensor): | |
| if self.use_fp8 and self.training: | |
| _x = x.flatten(0, -2) | |
| out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] | |
| return out.reshape(*x.shape[:-1], -1) | |
| else: | |
| return F.linear(x, self.weight.type_as(x)) | |
| class Rotary(nn.Module): | |
| def __init__(self, dim: int, max_seq_len: int): | |
| super().__init__() | |
| # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) | |
| angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) | |
| angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) | |
| t = torch.arange(max_seq_len, dtype=torch.float32) | |
| theta = torch.einsum("i,j -> ij", t, angular_freq) | |
| self.cos = nn.Buffer(theta.cos(), persistent=False) | |
| self.sin = nn.Buffer(theta.sin(), persistent=False) | |
| def forward(self, x_BTHD: Tensor): | |
| assert self.cos.size(0) >= x_BTHD.size(-3) | |
| cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] | |
| x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) | |
| y1 = x1 * cos + x2 * sin | |
| y2 = x1 * (-sin) + x2 * cos | |
| return torch.cat((y1, y2), 3).type_as(x_BTHD) | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.head_dim = head_dim | |
| hdim = num_heads * head_dim | |
| std = 0.5 * (dim ** -0.5) | |
| bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng | |
| # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng | |
| # https://x.com/hi_tysam/status/1879699187107033311 | |
| self.qkv_w = nn.Parameter(torch.empty(3, hdim, dim).uniform_(-bound, bound)) | |
| self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5])) | |
| self.rotary = Rotary(head_dim, max_seq_len) | |
| self.c_proj = CastedLinear(hdim, dim) | |
| self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977 | |
| def forward(self, x: Tensor, ve: Tensor | None, block_mask: BlockMask): | |
| B, T = x.size(0), x.size(1) # batch size, sequence length | |
| assert B == 1, "Must use batch size = 1 for FlexAttention" | |
| q, k, v = F.linear(x, self.qkv_w.flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) | |
| q, k = norm(q), norm(k) # QK norm @Grad62304977 | |
| q, k = self.rotary(q), self.rotary(k) | |
| if ve is not None: | |
| v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 | |
| else: # skip mid-layers token value embeddings by @YouJiacheng | |
| v = self.lambdas[0] * v | |
| # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun | |
| # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 | |
| y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2) | |
| y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side | |
| y = self.c_proj(y) | |
| return y | |
| class MLP(nn.Module): | |
| def __init__(self, dim: int): | |
| super().__init__() | |
| hdim = 4 * dim | |
| self.c_fc = CastedLinear(dim, hdim) | |
| self.c_proj = CastedLinear(hdim, dim) | |
| self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977 | |
| def forward(self, x: Tensor): | |
| x = self.c_fc(x) | |
| x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 | |
| x = self.c_proj(x) | |
| return x | |
| class Block(nn.Module): | |
| def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): | |
| super().__init__() | |
| # skip attention of blocks.7 (the 8th layer) by @YouJiacheng | |
| self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None | |
| self.mlp = MLP(dim) | |
| self.lambdas = nn.Parameter(torch.tensor([1., 0.])) | |
| def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask): | |
| x = self.lambdas[0] * x + self.lambdas[1] * x0 | |
| if self.attn is not None: | |
| x = x + self.attn(norm(x), ve, block_mask) | |
| x = x + self.mlp(norm(x)) | |
| return x | |
| # ----------------------------------------------------------------------------- | |
| # custom buckets for efficient allreduce | |
| def initialize_buckets(parameters, bucket_size_bytes): | |
| buckets, current_bucket, current_size = [], [], 0 | |
| for param in parameters: | |
| if param.requires_grad and param.grad is None: | |
| param.grad = torch.zeros_like(param) | |
| param_size = param.numel() * param.element_size() | |
| if current_size + param_size > bucket_size_bytes and current_bucket: | |
| buckets.append(current_bucket) | |
| current_bucket = [] | |
| current_size = 0 | |
| current_bucket.append(param) | |
| current_size += param_size | |
| if current_bucket: | |
| buckets.append(current_bucket) | |
| flat_buffers, bucket_info = [], [] | |
| for bucket in buckets: | |
| grad_shapes = [param.grad.shape for param in bucket] | |
| total_elements = sum(param.grad.numel() for param in bucket) | |
| device, dtype = bucket[0].device, bucket[0].dtype | |
| flat_buffer = torch.zeros(total_elements, device=device, dtype=dtype) | |
| flat_buffers.append(flat_buffer) | |
| offsets, offset = [], 0 | |
| for param in bucket: | |
| numel = param.grad.numel() | |
| offsets.append((offset, offset + numel)) | |
| offset += numel | |
| bucket_info.append({'params': bucket,'shapes': grad_shapes,'offsets': offsets}) | |
| return {'bucket_info': bucket_info, 'flat_buffers': flat_buffers} | |
| def reduce_gradients(bucket_data): | |
| bucket_info, flat_buffers = bucket_data['bucket_info'], bucket_data['flat_buffers'] | |
| handles = [] | |
| for i, (info, flat_buffer) in enumerate(zip(bucket_info, flat_buffers)): | |
| for param_idx, param in enumerate(info['params']): | |
| if param.grad is not None: | |
| start, end = info['offsets'][param_idx] | |
| flat_buffer[start:end].copy_(param.grad.view(-1)) | |
| handle = dist.all_reduce(flat_buffer, op=dist.ReduceOp.AVG, async_op=True) | |
| handles.append((handle, i)) | |
| return handles | |
| def unpack_gradients(bucket_data, handles): | |
| bucket_info, flat_buffers = bucket_data['bucket_info'], bucket_data['flat_buffers'] | |
| for handle, bucket_idx in handles: | |
| handle.wait() | |
| info, flat_buffer = bucket_info[bucket_idx], flat_buffers[bucket_idx] | |
| for param_idx, param in enumerate(info['params']): | |
| if param.grad is not None: | |
| start, end = info['offsets'][param_idx] | |
| param.grad.copy_(flat_buffer[start:end].view(info['shapes'][param_idx])) | |
| # ----------------------------------------------------------------------------- | |
| # The main model | |
| def next_multiple_of_n(v: float | int, *, n: int): | |
| return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) | |
| class GPT(nn.Module): | |
| def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): | |
| super().__init__() | |
| self.embed = nn.Embedding(vocab_size, model_dim) | |
| # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 | |
| # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 | |
| self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) | |
| self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) | |
| # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. | |
| # suggested to me by @Grad62304977. this originates from Karpathy's experiments. | |
| self.lm_head = CastedLinear(model_dim, next_multiple_of_n(vocab_size, n=128), | |
| use_fp8=True, x_s=(model_dim**0.5)/448, w_s=24/448, grad_s=1/448) | |
| self.lm_head.weight.detach().zero_() # @Grad62304977 | |
| # Add learnable skip connection weights for decoder layers | |
| assert num_layers % 2 == 0 | |
| self.skip_weights = nn.Parameter(torch.ones(num_layers//2)) | |
| def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): | |
| BLOCK_SIZE = 128 | |
| docs = (input_seq == 50256).cumsum(0) | |
| def document_causal(b, h, q_idx, kv_idx): | |
| causal_mask = q_idx >= kv_idx | |
| document_mask = docs[q_idx] == docs[kv_idx] | |
| return causal_mask & document_mask | |
| def dense_to_ordered(dense_blockmask: Tensor): | |
| num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) | |
| indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) | |
| return num_blocks[None, None].contiguous(), indices[None, None].contiguous() | |
| # manual block mask creation by @YouJiacheng | |
| assert len(input_seq) % BLOCK_SIZE == 0 | |
| NUM_BLOCKS = len(input_seq) // BLOCK_SIZE | |
| block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") | |
| causal_blockmask_any = block_idx[:, None] >= block_idx | |
| causal_blockmask_all = block_idx[:, None] > block_idx | |
| docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() | |
| docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() | |
| document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) | |
| document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) | |
| blockmask_any = causal_blockmask_any & document_blockmask_any | |
| blockmask_all = causal_blockmask_all & document_blockmask_all | |
| partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) | |
| full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) | |
| def build_bm(window_size_blocks: Tensor) -> BlockMask: | |
| return BlockMask.from_kv_blocks( | |
| torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), | |
| partial_kv_indices, | |
| torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), | |
| full_kv_indices, | |
| BLOCK_SIZE=BLOCK_SIZE, | |
| mask_mod=document_causal, | |
| ) | |
| # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper | |
| return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) | |
| def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): | |
| assert input_seq.ndim == 1 | |
| ve = [value_embed(input_seq) for value_embed in self.value_embeds] | |
| # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure | |
| ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] | |
| assert len(ve) == len(self.blocks) | |
| long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) | |
| block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] | |
| assert len(block_masks) == len(self.blocks) | |
| x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 | |
| # U-net design by @brendanh0gan | |
| skip_connections = [] | |
| n = len(self.skip_weights) | |
| for i in range(len(self.blocks)): | |
| if i >= n: | |
| x = x + self.skip_weights[i - n] * skip_connections.pop() | |
| x = self.blocks[i](x, ve[i], x0, block_masks[i]) | |
| if i < n: | |
| skip_connections.append(x) | |
| x = norm(x) | |
| logits = self.lm_head(x).float() | |
| # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) | |
| logits = 30 * torch.sigmoid(logits / (7.5 * x.size(-1)**0.5)) | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction='sum' if self.training else 'mean') | |
| return loss | |
| # ----------------------------------------------------------------------------- | |
| # Our own simple Distributed Data Loader | |
| def _load_data_shard(file: Path): | |
| header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 | |
| assert header[0] == 20240520, "magic number mismatch in the data .bin file" | |
| assert header[1] == 1, "unsupported version" | |
| num_tokens = int(header[2]) # number of tokens (claimed) | |
| with file.open("rb", buffering=0) as f: | |
| tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng | |
| f.seek(256 * 4) | |
| nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng | |
| assert nbytes == 2 * num_tokens, "number of tokens read does not match header" | |
| return tokens | |
| def distributed_data_generator(filename_pattern: str, batch_size: int, rank : int, world_size : int): | |
| files = [Path(file) for file in sorted(glob.glob(filename_pattern))] | |
| assert batch_size % world_size == 0 | |
| local_batch_size = batch_size // world_size | |
| file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training | |
| tokens, pos = _load_data_shard(next(file_iter)), 0 | |
| while True: | |
| if pos + batch_size + 1 >= len(tokens): | |
| tokens, pos = _load_data_shard(next(file_iter)), 0 | |
| buf = tokens[pos + rank * local_batch_size:][:local_batch_size + 1] | |
| inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; | |
| targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. | |
| pos += batch_size | |
| yield inputs, targets | |
| # ----------------------------------------------------------------------------- | |
| # int main | |
| @dataclass | |
| class Hyperparameters: | |
| # data | |
| train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on | |
| val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on | |
| val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons | |
| train_seq_len = 48*1024 # FlexAttention sequence length | |
| val_seq_len = 4*64*1024 # FlexAttention sequence length for validation | |
| # optimization | |
| num_iterations = 1770 # number of iterations to run | |
| cooldown_frac = 0.4 # fraction of training spent cooling down the learning rate | |
| # architecture | |
| vocab_size = 50257 | |
| # evaluation and logging | |
| val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end | |
| save_checkpoint = False | |
| args = Hyperparameters() | |
| # torchrun sets these env variables | |
| rank = int(os.environ["RANK"]) | |
| world_size = int(os.environ["WORLD_SIZE"]) | |
| assert world_size == 8 # this code is designed for 8xH100 | |
| assert torch.cuda.is_available() | |
| device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) | |
| torch.cuda.set_device(device) | |
| dist.init_process_group(backend="nccl", device_id=device) | |
| dist.barrier() | |
| master_process = (rank == 0) # this process will do logging, checkpointing etc. | |
| # begin logging | |
| logfile = None | |
| if master_process: | |
| run_id = uuid.uuid4() | |
| os.makedirs("logs", exist_ok=True) | |
| logfile = f"logs/{run_id}.txt" | |
| print(logfile) | |
| def print0(s, console=False): | |
| if master_process: | |
| with open(logfile, "a") as f: | |
| if console: | |
| print(s) | |
| print(s, file=f) | |
| # begin by printing this file (the Python code) | |
| print0(code) | |
| print0("="*100) | |
| # log information about the hardware/software environment this is running on | |
| print0(f"Running Python {sys.version}") | |
| print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") | |
| def nvidia_smi(): | |
| import subprocess # avoid top level import | |
| return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout | |
| print0(nvidia_smi()) | |
| print0("="*100) | |
| ######################################## | |
| # Construct model and optimizer # | |
| ######################################## | |
| model: nn.Module = GPT(vocab_size=args.vocab_size, num_layers=12, num_heads=6, model_dim=768, | |
| max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() | |
| for m in model.modules(): | |
| if isinstance(m, nn.Embedding): | |
| m.bfloat16() | |
| for param in model.parameters(): | |
| dist.broadcast(param.detach(), 0) | |
| # collect the parameters to optimize | |
| hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] | |
| embed_params = [p for n, p in model.named_parameters() if "embed" in n] | |
| scalar_params = [p for p in model.parameters() if p.ndim < 2] | |
| head_params = [model.lm_head.weight] | |
| # init the optimizer(s) | |
| adam_params = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)] | |
| # small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence | |
| # discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 | |
| optimizer1 = torch.optim.Adam(adam_params, betas=(0.8, 0.95), eps=1e-10, fused=True) | |
| optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, rank=rank, world_size=world_size) | |
| optimizers = [optimizer1, optimizer2] | |
| for opt in optimizers: | |
| for group in opt.param_groups: | |
| group["initial_lr"] = group["lr"] | |
| # init the gradient buckets | |
| gradient_buckets = initialize_buckets(model.parameters(), 64.0 * 1024**2) # 128MB buckets | |
| # learning rate schedule: stable then decay | |
| def get_lr(step: int): | |
| x = step / args.num_iterations # progress in training | |
| assert 0 <= x < 1 | |
| if x < 1 - args.cooldown_frac: | |
| return 1.0 | |
| else: | |
| w = (1 - x) / args.cooldown_frac | |
| return w * 1.0 + (1 - w) * 0.1 | |
| # attention window size schedule: linearly increase | |
| @lru_cache(1) | |
| def get_window_size_blocks_helper(window_size: int): | |
| return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) | |
| def get_window_size_blocks(step: int): | |
| x = step / args.num_iterations # progress in training | |
| assert 0 <= x <= 1 | |
| # Linearly increase the block-wise sliding window size over training 128 -> 1792 | |
| # increase by @fernbear.bsky.social; block-wise by @YouJiacheng | |
| window_size = next_multiple_of_n(1728 * x, n=128) | |
| return get_window_size_blocks_helper(window_size) | |
| model: nn.Module = torch.compile(model, dynamic=False) | |
| ######################################## | |
| # Warmup kernels # | |
| ######################################## | |
| # Warmup the training kernels, then re-initialize the state so we aren't cheating | |
| warmup_steps = 10 | |
| initial_state = dict(model=copy.deepcopy(model.state_dict()), | |
| optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state | |
| for _ in range(warmup_steps): | |
| inputs = targets = torch.randint(0, args.vocab_size, size=(args.train_seq_len,), device="cuda") | |
| model(inputs.to(torch.int32), targets, get_window_size_blocks(0)).backward() | |
| for param in model.parameters(): | |
| dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) | |
| for opt in optimizers: | |
| opt.step() | |
| model.zero_grad(set_to_none=True) | |
| model.load_state_dict(initial_state["model"]) | |
| for opt, opt_state in zip(optimizers, initial_state["optimizers"]): | |
| opt.load_state_dict(opt_state) | |
| del initial_state | |
| ######################################## | |
| # Training and validation # | |
| ######################################## | |
| train_loader = distributed_data_generator(args.train_files, world_size * args.train_seq_len, rank, world_size) | |
| training_time_ms = 0 | |
| # start the clock | |
| torch.cuda.synchronize() | |
| t0 = time.perf_counter() | |
| # begin training | |
| train_steps = args.num_iterations | |
| for step in range(train_steps + 1): | |
| last_step = (step == train_steps) | |
| # --------------- VALIDATION SECTION ----------------- | |
| if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): | |
| # stop the clock | |
| torch.cuda.synchronize() | |
| training_time_ms += 1000 * (time.perf_counter() - t0) | |
| model.eval() | |
| val_batch_size = world_size * args.val_seq_len | |
| assert args.val_tokens % val_batch_size == 0 | |
| val_steps = args.val_tokens // val_batch_size | |
| val_loader = distributed_data_generator(args.val_files, val_batch_size, rank, world_size) | |
| val_loss = 0 | |
| with torch.no_grad(): | |
| for _ in range(val_steps): | |
| inputs, targets = next(val_loader) | |
| val_loss += model(inputs, targets, get_window_size_blocks(step)) | |
| val_loss /= val_steps | |
| del val_loader | |
| dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) | |
| print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) | |
| model.train() | |
| # start the clock again | |
| torch.cuda.synchronize() | |
| t0 = time.perf_counter() | |
| if last_step: | |
| if master_process and args.save_checkpoint: | |
| log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) | |
| os.makedirs(f"logs/{run_id}", exist_ok=True) | |
| torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") | |
| # the last step only has the validation loop, so break to avoid training | |
| break | |
| # --------------- TRAINING SECTION ----------------- | |
| inputs, targets = next(train_loader) | |
| model(inputs, targets, get_window_size_blocks(step)).backward() | |
| #for param in model.parameters(): | |
| # dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) | |
| handles = reduce_gradients(gradient_buckets) # does the same thing as commented two lines above, but faster | |
| # set optimization hyperparameters | |
| for opt in optimizers: | |
| for group in opt.param_groups: | |
| group["lr"] = group["initial_lr"] * get_lr(step) | |
| for group in optimizer2.param_groups: | |
| frac = min(step / 300, 1) # momentum warmup for muon | |
| group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 | |
| # step the optimizers | |
| unpack_gradients(gradient_buckets, handles) | |
| for opt in optimizers: | |
| opt.step() | |
| # null the gradients | |
| model.zero_grad(set_to_none=True) | |
| # logging | |
| approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) | |
| print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) | |
| print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " | |
| f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) | |
| dist.destroy_process_group() | |
| ==================================================================================================== | |
| Running Python 3.12.7 (main, May 24 2025, 20:59:58) [GCC 13.2.0] | |
| Running PyTorch 2.8.0.dev20250524+cu126 compiled for CUDA 12.6 | |
| Sat May 24 21:45:46 2025 | |
| +-----------------------------------------------------------------------------------------+ | |
| | NVIDIA-SMI 570.124.06 Driver Version: 570.124.06 CUDA Version: 12.8 | | |
| |-----------------------------------------+------------------------+----------------------+ | |
| | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | |
| | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | |
| | | | MIG M. | | |
| |=========================================+========================+======================| | |
| | 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | | |
| | N/A 31C P0 124W / 700W | 5856MiB / 81559MiB | 0% Default | | |
| | | | Disabled | | |
| +-----------------------------------------+------------------------+----------------------+ | |
| | 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | | |
| | N/A 32C P0 123W / 700W | 1518MiB / 81559MiB | 0% Default | | |
| | | | Disabled | | |
| +-----------------------------------------+------------------------+----------------------+ | |
| | 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | | |
| | N/A 32C P0 121W / 700W | 1518MiB / 81559MiB | 0% Default | | |
| | | | Disabled | | |
| +-----------------------------------------+------------------------+----------------------+ | |
| | 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | | |
| | N/A 28C P0 111W / 700W | 1518MiB / 81559MiB | 0% Default | | |
| | | | Disabled | | |
| +-----------------------------------------+------------------------+----------------------+ | |
| | 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | | |
| | N/A 28C P0 117W / 700W | 1518MiB / 81559MiB | 0% Default | | |
| | | | Disabled | | |
| +-----------------------------------------+------------------------+----------------------+ | |
| | 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | | |
| | N/A 32C P0 117W / 700W | 1518MiB / 81559MiB | 0% Default | | |
| | | | Disabled | | |
| +-----------------------------------------+------------------------+----------------------+ | |
| | 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | | |
| | N/A 31C P0 118W / 700W | 1518MiB / 81559MiB | 0% Default | | |
| | | | Disabled | | |
| +-----------------------------------------+------------------------+----------------------+ | |
| | 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | | |
| | N/A 28C P0 117W / 700W | 1518MiB / 81559MiB | 0% Default | | |
| | | | Disabled | | |
| +-----------------------------------------+------------------------+----------------------+ | |
| +-----------------------------------------------------------------------------------------+ | |
| | Processes: | | |
| | GPU GI CI PID Type Process name GPU Memory | | |
| | ID ID Usage | | |
| |=========================================================================================| | |
| | 0 N/A N/A 7975 C /usr/local/bin/python 1508MiB | | |
| | 0 N/A N/A 7976 C /usr/local/bin/python 614MiB | | |
| | 0 N/A N/A 7977 C /usr/local/bin/python 614MiB | | |
| | 0 N/A N/A 7978 C /usr/local/bin/python 614MiB | | |
| | 0 N/A N/A 7979 C /usr/local/bin/python 614MiB | | |
| | 0 N/A N/A 7980 C /usr/local/bin/python 614MiB | | |
| | 0 N/A N/A 7981 C /usr/local/bin/python 614MiB | | |
| | 0 N/A N/A 7982 C /usr/local/bin/python 614MiB | | |
| | 1 N/A N/A 7976 C /usr/local/bin/python 1508MiB | | |
| | 2 N/A N/A 7977 C /usr/local/bin/python 1508MiB | | |
| | 3 N/A N/A 7978 C /usr/local/bin/python 1508MiB | | |
| | 4 N/A N/A 7979 C /usr/local/bin/python 1508MiB | | |
| | 5 N/A N/A 7980 C /usr/local/bin/python 1508MiB | | |
| | 6 N/A N/A 7981 C /usr/local/bin/python 1508MiB | | |
| | 7 N/A N/A 7982 C /usr/local/bin/python 1508MiB | | |
| +-----------------------------------------------------------------------------------------+ | |
| ==================================================================================================== | |
| step:0/1770 val_loss:10.8258 train_time:0ms step_avg:0.04ms | |
| step:1/1770 train_time:94ms step_avg:94.43ms | |
| step:2/1770 train_time:160ms step_avg:80.24ms | |
| step:3/1770 train_time:252ms step_avg:84.15ms | |
| step:4/1770 train_time:347ms step_avg:86.65ms | |
| step:5/1770 train_time:441ms step_avg:88.26ms | |
| step:6/1770 train_time:535ms step_avg:89.25ms | |
| step:7/1770 train_time:630ms step_avg:90.07ms | |
| step:8/1770 train_time:725ms step_avg:90.59ms | |
| step:9/1770 train_time:819ms step_avg:91.03ms | |
| step:10/1770 train_time:914ms step_avg:91.41ms | |
| step:11/1770 train_time:1009ms step_avg:91.71ms | |
| step:12/1770 train_time:1106ms step_avg:92.18ms | |
| step:13/1770 train_time:1204ms step_avg:92.59ms | |
| step:14/1770 train_time:1300ms step_avg:92.83ms | |
| step:15/1770 train_time:1395ms step_avg:92.98ms | |
| step:16/1770 train_time:1490ms step_avg:93.10ms | |
| step:17/1770 train_time:1585ms step_avg:93.24ms | |
| step:18/1770 train_time:1680ms step_avg:93.32ms | |
| step:19/1770 train_time:1774ms step_avg:93.38ms | |
| step:20/1770 train_time:1868ms step_avg:93.42ms | |
| step:21/1770 train_time:1964ms step_avg:93.50ms | |
| step:22/1770 train_time:2059ms step_avg:93.60ms | |
| step:23/1770 train_time:2154ms step_avg:93.66ms | |
| step:24/1770 train_time:2249ms step_avg:93.71ms | |
| step:25/1770 train_time:2344ms step_avg:93.78ms | |
| step:26/1770 train_time:2439ms step_avg:93.82ms | |
| step:27/1770 train_time:2535ms step_avg:93.87ms | |
| step:28/1770 train_time:2630ms step_avg:93.92ms | |
| step:29/1770 train_time:2725ms step_avg:93.95ms | |
| step:30/1770 train_time:2820ms step_avg:93.99ms | |
| step:31/1770 train_time:2914ms step_avg:94.00ms | |
| step:32/1770 train_time:3009ms step_avg:94.02ms | |
| step:33/1770 train_time:3105ms step_avg:94.10ms | |
| step:34/1770 train_time:3201ms step_avg:94.13ms | |
| step:35/1770 train_time:3296ms step_avg:94.16ms | |
| step:36/1770 train_time:3391ms step_avg:94.20ms | |
| step:37/1770 train_time:3488ms step_avg:94.26ms | |
| step:38/1770 train_time:3584ms step_avg:94.30ms | |
| step:39/1770 train_time:3679ms step_avg:94.34ms | |
| step:40/1770 train_time:3773ms step_avg:94.33ms | |
| step:41/1770 train_time:3869ms step_avg:94.35ms | |
| step:42/1770 train_time:3964ms step_avg:94.39ms | |
| step:43/1770 train_time:4060ms step_avg:94.42ms | |
| step:44/1770 train_time:4155ms step_avg:94.42ms | |
| step:45/1770 train_time:4249ms step_avg:94.43ms | |
| step:46/1770 train_time:4344ms step_avg:94.43ms | |
| step:47/1770 train_time:4440ms step_avg:94.47ms | |
| step:48/1770 train_time:4535ms step_avg:94.48ms | |
| step:49/1770 train_time:4631ms step_avg:94.50ms | |
| step:50/1770 train_time:4726ms step_avg:94.51ms | |
| step:51/1770 train_time:4822ms step_avg:94.54ms | |
| step:52/1770 train_time:4916ms step_avg:94.54ms | |
| step:53/1770 train_time:5011ms step_avg:94.54ms | |
| step:54/1770 train_time:5106ms step_avg:94.55ms | |
| step:55/1770 train_time:5202ms step_avg:94.59ms | |
| step:56/1770 train_time:5298ms step_avg:94.60ms | |
| step:57/1770 train_time:5392ms step_avg:94.60ms | |
| step:58/1770 train_time:5487ms step_avg:94.60ms | |
| step:59/1770 train_time:5583ms step_avg:94.63ms | |
| step:60/1770 train_time:5679ms step_avg:94.64ms | |
| step:61/1770 train_time:5773ms step_avg:94.64ms | |
| step:62/1770 train_time:5867ms step_avg:94.63ms | |
| step:63/1770 train_time:5962ms step_avg:94.64ms | |
| step:64/1770 train_time:6058ms step_avg:94.65ms | |
| step:65/1770 train_time:6152ms step_avg:94.65ms | |
| step:66/1770 train_time:6247ms step_avg:94.65ms | |
| step:67/1770 train_time:6344ms step_avg:94.69ms | |
| step:68/1770 train_time:6439ms step_avg:94.69ms | |
| step:69/1770 train_time:6535ms step_avg:94.70ms | |
| step:70/1770 train_time:6629ms step_avg:94.70ms | |
| step:71/1770 train_time:6724ms step_avg:94.70ms | |
| step:72/1770 train_time:6818ms step_avg:94.70ms | |
| step:73/1770 train_time:6913ms step_avg:94.70ms | |
| step:74/1770 train_time:7008ms step_avg:94.70ms | |
| step:75/1770 train_time:7103ms step_avg:94.70ms | |
| step:76/1770 train_time:7197ms step_avg:94.70ms | |
| step:77/1770 train_time:7293ms step_avg:94.72ms | |
| step:78/1770 train_time:7389ms step_avg:94.73ms | |
| step:79/1770 train_time:7487ms step_avg:94.77ms | |
| step:80/1770 train_time:7584ms step_avg:94.80ms | |
| step:81/1770 train_time:7679ms step_avg:94.80ms | |
| step:82/1770 train_time:7773ms step_avg:94.80ms | |
| step:83/1770 train_time:7868ms step_avg:94.79ms | |
| step:84/1770 train_time:7962ms step_avg:94.79ms | |
| step:85/1770 train_time:8057ms step_avg:94.79ms | |
| step:86/1770 train_time:8152ms step_avg:94.79ms | |
| step:87/1770 train_time:8247ms step_avg:94.79ms | |
| step:88/1770 train_time:8343ms step_avg:94.81ms | |
| step:89/1770 train_time:8438ms step_avg:94.81ms | |
| step:90/1770 train_time:8534ms step_avg:94.82ms | |
| step:91/1770 train_time:8629ms step_avg:94.82ms | |
| step:92/1770 train_time:8724ms step_avg:94.82ms | |
| step:93/1770 train_time:8819ms step_avg:94.82ms | |
| step:94/1770 train_time:8914ms step_avg:94.83ms | |
| step:95/1770 train_time:9008ms step_avg:94.82ms | |
| step:96/1770 train_time:9104ms step_avg:94.83ms | |
| step:97/1770 train_time:9199ms step_avg:94.83ms | |
| step:98/1770 train_time:9293ms step_avg:94.83ms | |
| step:99/1770 train_time:9388ms step_avg:94.83ms | |
| step:100/1770 train_time:9485ms step_avg:94.85ms | |
| step:101/1770 train_time:9581ms step_avg:94.86ms | |
| step:102/1770 train_time:9676ms step_avg:94.86ms | |
| step:103/1770 train_time:9771ms step_avg:94.86ms | |
| step:104/1770 train_time:9866ms step_avg:94.87ms | |
| step:105/1770 train_time:9962ms step_avg:94.87ms | |
| step:106/1770 train_time:10056ms step_avg:94.87ms | |
| step:107/1770 train_time:10151ms step_avg:94.87ms | |
| step:108/1770 train_time:10246ms step_avg:94.87ms | |
| step:109/1770 train_time:10341ms step_avg:94.87ms | |
| step:110/1770 train_time:10436ms step_avg:94.88ms | |
| step:111/1770 train_time:10531ms step_avg:94.88ms | |
| step:112/1770 train_time:10626ms step_avg:94.88ms | |
| step:113/1770 train_time:10722ms step_avg:94.88ms | |
| step:114/1770 train_time:10817ms step_avg:94.88ms | |
| step:115/1770 train_time:10912ms step_avg:94.89ms | |
| step:116/1770 train_time:11007ms step_avg:94.89ms | |
| step:117/1770 train_time:11103ms step_avg:94.90ms | |
| step:118/1770 train_time:11198ms step_avg:94.90ms | |
| step:119/1770 train_time:11292ms step_avg:94.89ms | |
| step:120/1770 train_time:11387ms step_avg:94.89ms | |
| step:121/1770 train_time:11482ms step_avg:94.89ms | |
| step:122/1770 train_time:11577ms step_avg:94.90ms | |
| step:123/1770 train_time:11672ms step_avg:94.89ms | |
| step:124/1770 train_time:11766ms step_avg:94.89ms | |
| step:125/1770 train_time:11862ms step_avg:94.90ms | |
| step:125/1770 val_loss:4.6494 train_time:11943ms step_avg:95.54ms | |
| step:126/1770 train_time:11975ms step_avg:95.04ms | |
| step:127/1770 train_time:12059ms step_avg:94.95ms | |
| step:128/1770 train_time:12161ms step_avg:95.00ms | |
| step:129/1770 train_time:12257ms step_avg:95.02ms | |
| step:130/1770 train_time:12352ms step_avg:95.02ms | |
| step:131/1770 train_time:12447ms step_avg:95.01ms | |
| step:132/1770 train_time:12541ms step_avg:95.01ms | |
| step:133/1770 train_time:12636ms step_avg:95.00ms | |
| step:134/1770 train_time:12731ms step_avg:95.01ms | |
| step:135/1770 train_time:12826ms step_avg:95.01ms | |
| step:136/1770 train_time:12921ms step_avg:95.01ms | |
| step:137/1770 train_time:13016ms step_avg:95.01ms | |
| step:138/1770 train_time:13113ms step_avg:95.02ms | |
| step:139/1770 train_time:13209ms step_avg:95.03ms | |
| step:140/1770 train_time:13306ms step_avg:95.04ms | |
| step:141/1770 train_time:13401ms step_avg:95.05ms | |
| step:142/1770 train_time:13498ms step_avg:95.05ms | |
| step:143/1770 train_time:13592ms step_avg:95.05ms | |
| step:144/1770 train_time:13688ms step_avg:95.05ms | |
| step:145/1770 train_time:13783ms step_avg:95.05ms | |
| step:146/1770 train_time:13878ms step_avg:95.06ms | |
| step:147/1770 train_time:13974ms step_avg:95.06ms | |
| step:148/1770 train_time:14070ms step_avg:95.07ms | |
| step:149/1770 train_time:14166ms step_avg:95.07ms | |
| step:150/1770 train_time:14262ms step_avg:95.08ms | |
| step:151/1770 train_time:14358ms step_avg:95.09ms | |
| step:152/1770 train_time:14454ms step_avg:95.09ms | |
| step:153/1770 train_time:14549ms step_avg:95.09ms | |
| step:154/1770 train_time:14645ms step_avg:95.10ms | |
| step:155/1770 train_time:14740ms step_avg:95.10ms | |
| step:156/1770 train_time:14835ms step_avg:95.10ms | |
| step:157/1770 train_time:14931ms step_avg:95.10ms | |
| step:158/1770 train_time:15027ms step_avg:95.11ms | |
| step:159/1770 train_time:15123ms step_avg:95.11ms | |
| step:160/1770 train_time:15219ms step_avg:95.12ms | |
| step:161/1770 train_time:15315ms step_avg:95.13ms | |
| step:162/1770 train_time:15411ms step_avg:95.13ms | |
| step:163/1770 train_time:15506ms step_avg:95.13ms | |
| step:164/1770 train_time:15602ms step_avg:95.14ms | |
| step:165/1770 train_time:15699ms step_avg:95.14ms | |
| step:166/1770 train_time:15794ms step_avg:95.14ms | |
| step:167/1770 train_time:15889ms step_avg:95.14ms | |
| step:168/1770 train_time:15985ms step_avg:95.15ms | |
| step:169/1770 train_time:16081ms step_avg:95.15ms | |
| step:170/1770 train_time:16177ms step_avg:95.16ms | |
| step:171/1770 train_time:16273ms step_avg:95.16ms | |
| step:172/1770 train_time:16369ms step_avg:95.17ms | |
| step:173/1770 train_time:16464ms step_avg:95.17ms | |
| step:174/1770 train_time:16560ms step_avg:95.17ms | |
| step:175/1770 train_time:16656ms step_avg:95.18ms | |
| step:176/1770 train_time:16752ms step_avg:95.18ms | |
| step:177/1770 train_time:16848ms step_avg:95.19ms | |
| step:178/1770 train_time:16943ms step_avg:95.18ms | |
| step:179/1770 train_time:17039ms step_avg:95.19ms | |
| step:180/1770 train_time:17135ms step_avg:95.20ms | |
| step:181/1770 train_time:17232ms step_avg:95.20ms | |
| step:182/1770 train_time:17327ms step_avg:95.20ms | |
| step:183/1770 train_time:17422ms step_avg:95.20ms | |
| step:184/1770 train_time:17518ms step_avg:95.21ms | |
| step:185/1770 train_time:17614ms step_avg:95.21ms | |
| step:186/1770 train_time:17709ms step_avg:95.21ms | |
| step:187/1770 train_time:17805ms step_avg:95.21ms | |
| step:188/1770 train_time:17900ms step_avg:95.21ms | |
| step:189/1770 train_time:17996ms step_avg:95.22ms | |
| step:190/1770 train_time:18091ms step_avg:95.22ms | |
| step:191/1770 train_time:18187ms step_avg:95.22ms | |
| step:192/1770 train_time:18282ms step_avg:95.22ms | |
| step:193/1770 train_time:18378ms step_avg:95.22ms | |
| step:194/1770 train_time:18474ms step_avg:95.23ms | |
| step:195/1770 train_time:18569ms step_avg:95.23ms | |
| step:196/1770 train_time:18665ms step_avg:95.23ms | |
| step:197/1770 train_time:18762ms step_avg:95.24ms | |
| step:198/1770 train_time:18857ms step_avg:95.24ms | |
| step:199/1770 train_time:18953ms step_avg:95.24ms | |
| step:200/1770 train_time:19048ms step_avg:95.24ms | |
| step:201/1770 train_time:19143ms step_avg:95.24ms | |
| step:202/1770 train_time:19240ms step_avg:95.25ms | |
| step:203/1770 train_time:19336ms step_avg:95.25ms | |
| step:204/1770 train_time:19432ms step_avg:95.25ms | |
| step:205/1770 train_time:19528ms step_avg:95.26ms | |
| step:206/1770 train_time:19623ms step_avg:95.26ms | |
| step:207/1770 train_time:19719ms step_avg:95.26ms | |
| step:208/1770 train_time:19815ms step_avg:95.26ms | |
| step:209/1770 train_time:19910ms step_avg:95.26ms | |
| step:210/1770 train_time:20005ms step_avg:95.26ms | |
| step:211/1770 train_time:20101ms step_avg:95.26ms | |
| step:212/1770 train_time:20196ms step_avg:95.26ms | |
| step:213/1770 train_time:20291ms step_avg:95.26ms | |
| step:214/1770 train_time:20387ms step_avg:95.27ms | |
| step:215/1770 train_time:20482ms step_avg:95.27ms | |
| step:216/1770 train_time:20578ms step_avg:95.27ms | |
| step:217/1770 train_time:20673ms step_avg:95.27ms | |
| step:218/1770 train_time:20768ms step_avg:95.27ms | |
| step:219/1770 train_time:20863ms step_avg:95.27ms | |
| step:220/1770 train_time:20960ms step_avg:95.27ms | |
| step:221/1770 train_time:21056ms step_avg:95.27ms | |
| step:222/1770 train_time:21152ms step_avg:95.28ms | |
| step:223/1770 train_time:21247ms step_avg:95.28ms | |
| step:224/1770 train_time:21343ms step_avg:95.28ms | |
| step:225/1770 train_time:21439ms step_avg:95.28ms | |
| step:226/1770 train_time:21535ms step_avg:95.29ms | |
| step:227/1770 train_time:21630ms step_avg:95.29ms | |
| step:228/1770 train_time:21725ms step_avg:95.28ms | |
| step:229/1770 train_time:21821ms step_avg:95.29ms | |
| step:230/1770 train_time:21918ms step_avg:95.29ms | |
| step:231/1770 train_time:22013ms step_avg:95.29ms | |
| step:232/1770 train_time:22108ms step_avg:95.29ms | |
| step:233/1770 train_time:22203ms step_avg:95.29ms | |
| step:234/1770 train_time:22299ms step_avg:95.30ms | |
| step:235/1770 train_time:22395ms step_avg:95.30ms | |
| step:236/1770 train_time:22490ms step_avg:95.30ms | |
| step:237/1770 train_time:22586ms step_avg:95.30ms | |
| step:238/1770 train_time:22681ms step_avg:95.30ms | |
| step:239/1770 train_time:22778ms step_avg:95.30ms | |
| step:240/1770 train_time:22873ms step_avg:95.31ms | |
| step:241/1770 train_time:22969ms step_avg:95.31ms | |
| step:242/1770 train_time:23065ms step_avg:95.31ms | |
| step:243/1770 train_time:23161ms step_avg:95.31ms | |
| step:244/1770 train_time:23257ms step_avg:95.31ms | |
| step:245/1770 train_time:23353ms step_avg:95.32ms | |
| step:246/1770 train_time:23448ms step_avg:95.32ms | |
| step:247/1770 train_time:23543ms step_avg:95.32ms | |
| step:248/1770 train_time:23639ms step_avg:95.32ms | |
| step:249/1770 train_time:23735ms step_avg:95.32ms | |
| step:250/1770 train_time:23831ms step_avg:95.32ms | |
| step:250/1770 val_loss:4.1101 train_time:23911ms step_avg:95.64ms | |
| step:251/1770 train_time:23940ms step_avg:95.38ms | |
| step:252/1770 train_time:24026ms step_avg:95.34ms | |
| step:253/1770 train_time:24127ms step_avg:95.36ms | |
| step:254/1770 train_time:24224ms step_avg:95.37ms | |
| step:255/1770 train_time:24320ms step_avg:95.37ms | |
| step:256/1770 train_time:24415ms step_avg:95.37ms | |
| step:257/1770 train_time:24510ms step_avg:95.37ms | |
| step:258/1770 train_time:24605ms step_avg:95.37ms | |
| step:259/1770 train_time:24701ms step_avg:95.37ms | |
| step:260/1770 train_time:24795ms step_avg:95.37ms | |
| step:261/1770 train_time:24890ms step_avg:95.37ms | |
| step:262/1770 train_time:24987ms step_avg:95.37ms | |
| step:263/1770 train_time:25084ms step_avg:95.38ms | |
| step:264/1770 train_time:25181ms step_avg:95.38ms | |
| step:265/1770 train_time:25278ms step_avg:95.39ms | |
| step:266/1770 train_time:25373ms step_avg:95.39ms | |
| step:267/1770 train_time:25470ms step_avg:95.40ms | |
| step:268/1770 train_time:25566ms step_avg:95.40ms | |
| step:269/1770 train_time:25662ms step_avg:95.40ms | |
| step:270/1770 train_time:25757ms step_avg:95.40ms | |
| step:271/1770 train_time:25853ms step_avg:95.40ms | |
| step:272/1770 train_time:25949ms step_avg:95.40ms | |
| step:273/1770 train_time:26046ms step_avg:95.41ms | |
| step:274/1770 train_time:26143ms step_avg:95.41ms | |
| step:275/1770 train_time:26240ms step_avg:95.42ms | |
| step:276/1770 train_time:26337ms step_avg:95.42ms | |
| step:277/1770 train_time:26433ms step_avg:95.42ms | |
| step:278/1770 train_time:26529ms step_avg:95.43ms | |
| step:279/1770 train_time:26625ms step_avg:95.43ms | |
| step:280/1770 train_time:26720ms step_avg:95.43ms | |
| step:281/1770 train_time:26816ms step_avg:95.43ms | |
| step:282/1770 train_time:26912ms step_avg:95.43ms | |
| step:283/1770 train_time:27008ms step_avg:95.43ms | |
| step:284/1770 train_time:27104ms step_avg:95.44ms | |
| step:285/1770 train_time:27200ms step_avg:95.44ms | |
| step:286/1770 train_time:27296ms step_avg:95.44ms | |
| step:287/1770 train_time:27393ms step_avg:95.44ms | |
| step:288/1770 train_time:27489ms step_avg:95.45ms | |
| step:289/1770 train_time:27586ms step_avg:95.45ms | |
| step:290/1770 train_time:27682ms step_avg:95.45ms | |
| step:291/1770 train_time:27778ms step_avg:95.46ms | |
| step:292/1770 train_time:27874ms step_avg:95.46ms | |
| step:293/1770 train_time:27970ms step_avg:95.46ms | |
| step:294/1770 train_time:28066ms step_avg:95.46ms | |
| step:295/1770 train_time:28163ms step_avg:95.47ms | |
| step:296/1770 train_time:28260ms step_avg:95.47ms | |
| step:297/1770 train_time:28355ms step_avg:95.47ms | |
| step:298/1770 train_time:28451ms step_avg:95.47ms | |
| step:299/1770 train_time:28547ms step_avg:95.48ms | |
| step:300/1770 train_time:28643ms step_avg:95.48ms | |
| step:301/1770 train_time:28739ms step_avg:95.48ms | |
| step:302/1770 train_time:28835ms step_avg:95.48ms | |
| step:303/1770 train_time:28930ms step_avg:95.48ms | |
| step:304/1770 train_time:29026ms step_avg:95.48ms | |
| step:305/1770 train_time:29123ms step_avg:95.48ms | |
| step:306/1770 train_time:29220ms step_avg:95.49ms | |
| step:307/1770 train_time:29316ms step_avg:95.49ms | |
| step:308/1770 train_time:29413ms step_avg:95.50ms | |
| step:309/1770 train_time:29509ms step_avg:95.50ms | |
| step:310/1770 train_time:29605ms step_avg:95.50ms | |
| step:311/1770 train_time:29701ms step_avg:95.50ms | |
| step:312/1770 train_time:29797ms step_avg:95.50ms | |
| step:313/1770 train_time:29893ms step_avg:95.50ms | |
| step:314/1770 train_time:29988ms step_avg:95.50ms | |
| step:315/1770 train_time:30085ms step_avg:95.51ms | |
| step:316/1770 train_time:30181ms step_avg:95.51ms | |
| step:317/1770 train_time:30277ms step_avg:95.51ms | |
| step:318/1770 train_time:30374ms step_avg:95.52ms | |
| step:319/1770 train_time:30471ms step_avg:95.52ms | |
| step:320/1770 train_time:30568ms step_avg:95.52ms | |
| step:321/1770 train_time:30664ms step_avg:95.53ms | |
| step:322/1770 train_time:30761ms step_avg:95.53ms | |
| step:323/1770 train_time:30857ms step_avg:95.53ms | |
| step:324/1770 train_time:30954ms step_avg:95.54ms | |
| step:325/1770 train_time:31049ms step_avg:95.54ms | |
| step:326/1770 train_time:31145ms step_avg:95.54ms | |
| step:327/1770 train_time:31242ms step_avg:95.54ms | |
| step:328/1770 train_time:31338ms step_avg:95.54ms | |
| step:329/1770 train_time:31434ms step_avg:95.54ms | |
| step:330/1770 train_time:31529ms step_avg:95.54ms | |
| step:331/1770 train_time:31625ms step_avg:95.54ms | |
| step:332/1770 train_time:31721ms step_avg:95.55ms | |
| step:333/1770 train_time:31818ms step_avg:95.55ms | |
| step:334/1770 train_time:31914ms step_avg:95.55ms | |
| step:335/1770 train_time:32011ms step_avg:95.55ms | |
| step:336/1770 train_time:32107ms step_avg:95.56ms | |
| step:337/1770 train_time:32204ms step_avg:95.56ms | |
| step:338/1770 train_time:32301ms step_avg:95.56ms | |
| step:339/1770 train_time:32397ms step_avg:95.57ms | |
| step:340/1770 train_time:32492ms step_avg:95.57ms | |
| step:341/1770 train_time:32588ms step_avg:95.57ms | |
| step:342/1770 train_time:32685ms step_avg:95.57ms | |
| step:343/1770 train_time:32782ms step_avg:95.57ms | |
| step:344/1770 train_time:32878ms step_avg:95.58ms | |
| step:345/1770 train_time:32974ms step_avg:95.58ms | |
| step:346/1770 train_time:33070ms step_avg:95.58ms | |
| step:347/1770 train_time:33167ms step_avg:95.58ms | |
| step:348/1770 train_time:33263ms step_avg:95.58ms | |
| step:349/1770 train_time:33360ms step_avg:95.59ms | |
| step:350/1770 train_time:33456ms step_avg:95.59ms | |
| step:351/1770 train_time:33552ms step_avg:95.59ms | |
| step:352/1770 train_time:33648ms step_avg:95.59ms | |
| step:353/1770 train_time:33745ms step_avg:95.60ms | |
| step:354/1770 train_time:33841ms step_avg:95.60ms | |
| step:355/1770 train_time:33937ms step_avg:95.60ms | |
| step:356/1770 train_time:34033ms step_avg:95.60ms | |
| step:357/1770 train_time:34129ms step_avg:95.60ms | |
| step:358/1770 train_time:34226ms step_avg:95.60ms | |
| step:359/1770 train_time:34323ms step_avg:95.61ms | |
| step:360/1770 train_time:34420ms step_avg:95.61ms | |
| step:361/1770 train_time:34516ms step_avg:95.61ms | |
| step:362/1770 train_time:34612ms step_avg:95.61ms | |
| step:363/1770 train_time:34708ms step_avg:95.61ms | |
| step:364/1770 train_time:34805ms step_avg:95.62ms | |
| step:365/1770 train_time:34902ms step_avg:95.62ms | |
| step:366/1770 train_time:34998ms step_avg:95.62ms | |
| step:367/1770 train_time:35094ms step_avg:95.62ms | |
| step:368/1770 train_time:35189ms step_avg:95.62ms | |
| step:369/1770 train_time:35286ms step_avg:95.63ms | |
| step:370/1770 train_time:35383ms step_avg:95.63ms | |
| step:371/1770 train_time:35480ms step_avg:95.63ms | |
| step:372/1770 train_time:35576ms step_avg:95.63ms | |
| step:373/1770 train_time:35672ms step_avg:95.64ms | |
| step:374/1770 train_time:35768ms step_avg:95.64ms | |
| step:375/1770 train_time:35866ms step_avg:95.64ms | |
| step:375/1770 val_loss:3.9086 train_time:35947ms step_avg:95.86ms | |
| step:376/1770 train_time:35976ms step_avg:95.68ms | |
| step:377/1770 train_time:36066ms step_avg:95.67ms | |
| step:378/1770 train_time:36165ms step_avg:95.67ms | |
| step:379/1770 train_time:36262ms step_avg:95.68ms | |
| step:380/1770 train_time:36357ms step_avg:95.68ms | |
| step:381/1770 train_time:36453ms step_avg:95.68ms | |
| step:382/1770 train_time:36548ms step_avg:95.68ms | |
| step:383/1770 train_time:36644ms step_avg:95.68ms | |
| step:384/1770 train_time:36740ms step_avg:95.68ms | |
| step:385/1770 train_time:36835ms step_avg:95.68ms | |
| step:386/1770 train_time:36931ms step_avg:95.68ms | |
| step:387/1770 train_time:37030ms step_avg:95.68ms | |
| step:388/1770 train_time:37129ms step_avg:95.69ms | |
| step:389/1770 train_time:37227ms step_avg:95.70ms | |
| step:390/1770 train_time:37323ms step_avg:95.70ms | |
| step:391/1770 train_time:37419ms step_avg:95.70ms | |
| step:392/1770 train_time:37515ms step_avg:95.70ms | |
| step:393/1770 train_time:37610ms step_avg:95.70ms | |
| step:394/1770 train_time:37707ms step_avg:95.70ms | |
| step:395/1770 train_time:37803ms step_avg:95.70ms | |
| step:396/1770 train_time:37901ms step_avg:95.71ms | |
| step:397/1770 train_time:37999ms step_avg:95.72ms | |
| step:398/1770 train_time:38098ms step_avg:95.72ms | |
| step:399/1770 train_time:38197ms step_avg:95.73ms | |
| step:400/1770 train_time:38296ms step_avg:95.74ms | |
| step:401/1770 train_time:38394ms step_avg:95.75ms | |
| step:402/1770 train_time:38492ms step_avg:95.75ms | |
| step:403/1770 train_time:38589ms step_avg:95.75ms | |
| step:404/1770 train_time:38688ms step_avg:95.76ms | |
| step:405/1770 train_time:38786ms step_avg:95.77ms | |
| step:406/1770 train_time:38884ms step_avg:95.77ms | |
| step:407/1770 train_time:38982ms step_avg:95.78ms | |
| step:408/1770 train_time:39080ms step_avg:95.79ms | |
| step:409/1770 train_time:39180ms step_avg:95.79ms | |
| step:410/1770 train_time:39278ms step_avg:95.80ms | |
| step:411/1770 train_time:39377ms step_avg:95.81ms | |
| step:412/1770 train_time:39475ms step_avg:95.81ms | |
| step:413/1770 train_time:39574ms step_avg:95.82ms | |
| step:414/1770 train_time:39672ms step_avg:95.83ms | |
| step:415/1770 train_time:39771ms step_avg:95.83ms | |
| step:416/1770 train_time:39869ms step_avg:95.84ms | |
| step:417/1770 train_time:39967ms step_avg:95.84ms | |
| step:418/1770 train_time:40067ms step_avg:95.85ms | |
| step:419/1770 train_time:40167ms step_avg:95.86ms | |
| step:420/1770 train_time:40266ms step_avg:95.87ms | |
| step:421/1770 train_time:40366ms step_avg:95.88ms | |
| step:422/1770 train_time:40464ms step_avg:95.89ms | |
| step:423/1770 train_time:40563ms step_avg:95.89ms | |
| step:424/1770 train_time:40662ms step_avg:95.90ms | |
| step:425/1770 train_time:40760ms step_avg:95.91ms | |
| step:426/1770 train_time:40858ms step_avg:95.91ms | |
| step:427/1770 train_time:40956ms step_avg:95.91ms | |
| step:428/1770 train_time:41054ms step_avg:95.92ms | |
| step:429/1770 train_time:41153ms step_avg:95.93ms | |
| step:430/1770 train_time:41252ms step_avg:95.94ms | |
| step:431/1770 train_time:41351ms step_avg:95.94ms | |
| step:432/1770 train_time:41450ms step_avg:95.95ms | |
| step:433/1770 train_time:41550ms step_avg:95.96ms | |
| step:434/1770 train_time:41648ms step_avg:95.96ms | |
| step:435/1770 train_time:41748ms step_avg:95.97ms | |
| step:436/1770 train_time:41848ms step_avg:95.98ms | |
| step:437/1770 train_time:41947ms step_avg:95.99ms | |
| step:438/1770 train_time:42046ms step_avg:96.00ms | |
| step:439/1770 train_time:42145ms step_avg:96.00ms | |
| step:440/1770 train_time:42243ms step_avg:96.01ms | |
| step:441/1770 train_time:42341ms step_avg:96.01ms | |
| step:442/1770 train_time:42440ms step_avg:96.02ms | |
| step:443/1770 train_time:42538ms step_avg:96.02ms | |
| step:444/1770 train_time:42636ms step_avg:96.03ms | |
| step:445/1770 train_time:42734ms step_avg:96.03ms | |
| step:446/1770 train_time:42832ms step_avg:96.04ms | |
| step:447/1770 train_time:42930ms step_avg:96.04ms | |
| step:448/1770 train_time:43028ms step_avg:96.05ms | |
| step:449/1770 train_time:43127ms step_avg:96.05ms | |
| step:450/1770 train_time:43225ms step_avg:96.06ms | |
| step:451/1770 train_time:43324ms step_avg:96.06ms | |
| step:452/1770 train_time:43423ms step_avg:96.07ms | |
| step:453/1770 train_time:43521ms step_avg:96.07ms | |
| step:454/1770 train_time:43619ms step_avg:96.08ms | |
| step:455/1770 train_time:43717ms step_avg:96.08ms | |
| step:456/1770 train_time:43815ms step_avg:96.09ms | |
| step:457/1770 train_time:43914ms step_avg:96.09ms | |
| step:458/1770 train_time:44013ms step_avg:96.10ms | |
| step:459/1770 train_time:44112ms step_avg:96.10ms | |
| step:460/1770 train_time:44211ms step_avg:96.11ms | |
| step:461/1770 train_time:44310ms step_avg:96.12ms | |
| step:462/1770 train_time:44410ms step_avg:96.13ms | |
| step:463/1770 train_time:44508ms step_avg:96.13ms | |
| step:464/1770 train_time:44607ms step_avg:96.14ms | |
| step:465/1770 train_time:44706ms step_avg:96.14ms | |
| step:466/1770 train_time:44804ms step_avg:96.15ms | |
| step:467/1770 train_time:44903ms step_avg:96.15ms | |
| step:468/1770 train_time:45001ms step_avg:96.16ms | |
| step:469/1770 train_time:45100ms step_avg:96.16ms | |
| step:470/1770 train_time:45198ms step_avg:96.17ms | |
| step:471/1770 train_time:45297ms step_avg:96.17ms | |
| step:472/1770 train_time:45395ms step_avg:96.17ms | |
| step:473/1770 train_time:45493ms step_avg:96.18ms | |
| step:474/1770 train_time:45591ms step_avg:96.18ms | |
| step:475/1770 train_time:45690ms step_avg:96.19ms | |
| step:476/1770 train_time:45789ms step_avg:96.20ms | |
| step:477/1770 train_time:45888ms step_avg:96.20ms | |
| step:478/1770 train_time:45988ms step_avg:96.21ms | |
| step:479/1770 train_time:46087ms step_avg:96.21ms | |
| step:480/1770 train_time:46186ms step_avg:96.22ms | |
| step:481/1770 train_time:46286ms step_avg:96.23ms | |
| step:482/1770 train_time:46385ms step_avg:96.23ms | |
| step:483/1770 train_time:46484ms step_avg:96.24ms | |
| step:484/1770 train_time:46583ms step_avg:96.25ms | |
| step:485/1770 train_time:46682ms step_avg:96.25ms | |
| step:486/1770 train_time:46780ms step_avg:96.25ms | |
| step:487/1770 train_time:46878ms step_avg:96.26ms | |
| step:488/1770 train_time:46976ms step_avg:96.26ms | |
| step:489/1770 train_time:47075ms step_avg:96.27ms | |
| step:490/1770 train_time:47174ms step_avg:96.27ms | |
| step:491/1770 train_time:47272ms step_avg:96.28ms | |
| step:492/1770 train_time:47371ms step_avg:96.28ms | |
| step:493/1770 train_time:47470ms step_avg:96.29ms | |
| step:494/1770 train_time:47569ms step_avg:96.29ms | |
| step:495/1770 train_time:47668ms step_avg:96.30ms | |
| step:496/1770 train_time:47768ms step_avg:96.31ms | |
| step:497/1770 train_time:47868ms step_avg:96.31ms | |
| step:498/1770 train_time:47968ms step_avg:96.32ms | |
| step:499/1770 train_time:48067ms step_avg:96.33ms | |
| step:500/1770 train_time:48166ms step_avg:96.33ms | |
| step:500/1770 val_loss:3.7533 train_time:48249ms step_avg:96.50ms | |
| step:501/1770 train_time:48282ms step_avg:96.37ms | |
| step:502/1770 train_time:48373ms step_avg:96.36ms | |
| step:503/1770 train_time:48473ms step_avg:96.37ms | |
| step:504/1770 train_time:48573ms step_avg:96.38ms | |
| step:505/1770 train_time:48672ms step_avg:96.38ms | |
| step:506/1770 train_time:48770ms step_avg:96.38ms | |
| step:507/1770 train_time:48868ms step_avg:96.39ms | |
| step:508/1770 train_time:48966ms step_avg:96.39ms | |
| step:509/1770 train_time:49064ms step_avg:96.39ms | |
| step:510/1770 train_time:49162ms step_avg:96.40ms | |
| step:511/1770 train_time:49260ms step_avg:96.40ms | |
| step:512/1770 train_time:49361ms step_avg:96.41ms | |
| step:513/1770 train_time:49461ms step_avg:96.42ms | |
| step:514/1770 train_time:49560ms step_avg:96.42ms | |
| step:515/1770 train_time:49659ms step_avg:96.43ms | |
| step:516/1770 train_time:49758ms step_avg:96.43ms | |
| step:517/1770 train_time:49857ms step_avg:96.44ms | |
| step:518/1770 train_time:49955ms step_avg:96.44ms | |
| step:519/1770 train_time:50053ms step_avg:96.44ms | |
| step:520/1770 train_time:50151ms step_avg:96.44ms | |
| step:521/1770 train_time:50250ms step_avg:96.45ms | |
| step:522/1770 train_time:50351ms step_avg:96.46ms | |
| step:523/1770 train_time:50451ms step_avg:96.46ms | |
| step:524/1770 train_time:50550ms step_avg:96.47ms | |
| step:525/1770 train_time:50650ms step_avg:96.48ms | |
| step:526/1770 train_time:50748ms step_avg:96.48ms | |
| step:527/1770 train_time:50847ms step_avg:96.48ms | |
| step:528/1770 train_time:50945ms step_avg:96.49ms | |
| step:529/1770 train_time:51044ms step_avg:96.49ms | |
| step:530/1770 train_time:51143ms step_avg:96.50ms | |
| step:531/1770 train_time:51242ms step_avg:96.50ms | |
| step:532/1770 train_time:51341ms step_avg:96.51ms | |
| step:533/1770 train_time:51440ms step_avg:96.51ms | |
| step:534/1770 train_time:51539ms step_avg:96.52ms | |
| step:535/1770 train_time:51639ms step_avg:96.52ms | |
| step:536/1770 train_time:51739ms step_avg:96.53ms | |
| step:537/1770 train_time:51838ms step_avg:96.53ms | |
| step:538/1770 train_time:51936ms step_avg:96.54ms | |
| step:539/1770 train_time:52035ms step_avg:96.54ms | |
| step:540/1770 train_time:52134ms step_avg:96.54ms | |
| step:541/1770 train_time:52233ms step_avg:96.55ms | |
| step:542/1770 train_time:52333ms step_avg:96.55ms | |
| step:543/1770 train_time:52433ms step_avg:96.56ms | |
| step:544/1770 train_time:52533ms step_avg:96.57ms | |
| step:545/1770 train_time:52633ms step_avg:96.57ms | |
| step:546/1770 train_time:52734ms step_avg:96.58ms | |
| step:547/1770 train_time:52834ms step_avg:96.59ms | |
| step:548/1770 train_time:52933ms step_avg:96.59ms | |
| step:549/1770 train_time:53033ms step_avg:96.60ms | |
| step:550/1770 train_time:53132ms step_avg:96.60ms | |
| step:551/1770 train_time:53232ms step_avg:96.61ms | |
| step:552/1770 train_time:53332ms step_avg:96.62ms | |
| step:553/1770 train_time:53432ms step_avg:96.62ms | |
| step:554/1770 train_time:53531ms step_avg:96.63ms | |
| step:555/1770 train_time:53630ms step_avg:96.63ms | |
| step:556/1770 train_time:53729ms step_avg:96.63ms | |
| step:557/1770 train_time:53828ms step_avg:96.64ms | |
| step:558/1770 train_time:53928ms step_avg:96.64ms | |
| step:559/1770 train_time:54027ms step_avg:96.65ms | |
| step:560/1770 train_time:54126ms step_avg:96.65ms | |
| step:561/1770 train_time:54224ms step_avg:96.66ms | |
| step:562/1770 train_time:54323ms step_avg:96.66ms | |
| step:563/1770 train_time:54422ms step_avg:96.66ms | |
| step:564/1770 train_time:54521ms step_avg:96.67ms | |
| step:565/1770 train_time:54619ms step_avg:96.67ms | |
| step:566/1770 train_time:54718ms step_avg:96.68ms | |
| step:567/1770 train_time:54817ms step_avg:96.68ms | |
| step:568/1770 train_time:54917ms step_avg:96.68ms | |
| step:569/1770 train_time:55017ms step_avg:96.69ms | |
| step:570/1770 train_time:55115ms step_avg:96.69ms | |
| step:571/1770 train_time:55215ms step_avg:96.70ms | |
| step:572/1770 train_time:55315ms step_avg:96.70ms | |
| step:573/1770 train_time:55415ms step_avg:96.71ms | |
| step:574/1770 train_time:55515ms step_avg:96.72ms | |
| step:575/1770 train_time:55614ms step_avg:96.72ms | |
| step:576/1770 train_time:55714ms step_avg:96.73ms | |
| step:577/1770 train_time:55814ms step_avg:96.73ms | |
| step:578/1770 train_time:55913ms step_avg:96.73ms | |
| step:579/1770 train_time:56012ms step_avg:96.74ms | |
| step:580/1770 train_time:56112ms step_avg:96.74ms | |
| step:581/1770 train_time:56212ms step_avg:96.75ms | |
| step:582/1770 train_time:56313ms step_avg:96.76ms | |
| step:583/1770 train_time:56412ms step_avg:96.76ms | |
| step:584/1770 train_time:56511ms step_avg:96.77ms | |
| step:585/1770 train_time:56612ms step_avg:96.77ms | |
| step:586/1770 train_time:56710ms step_avg:96.78ms | |
| step:587/1770 train_time:56809ms step_avg:96.78ms | |
| step:588/1770 train_time:56909ms step_avg:96.78ms | |
| step:589/1770 train_time:57007ms step_avg:96.79ms | |
| step:590/1770 train_time:57106ms step_avg:96.79ms | |
| step:591/1770 train_time:57205ms step_avg:96.79ms | |
| step:592/1770 train_time:57304ms step_avg:96.80ms | |
| step:593/1770 train_time:57403ms step_avg:96.80ms | |
| step:594/1770 train_time:57502ms step_avg:96.80ms | |
| step:595/1770 train_time:57601ms step_avg:96.81ms | |
| step:596/1770 train_time:57699ms step_avg:96.81ms | |
| step:597/1770 train_time:57799ms step_avg:96.82ms | |
| step:598/1770 train_time:57897ms step_avg:96.82ms | |
| step:599/1770 train_time:57997ms step_avg:96.82ms | |
| step:600/1770 train_time:58096ms step_avg:96.83ms | |
| step:601/1770 train_time:58195ms step_avg:96.83ms | |
| step:602/1770 train_time:58295ms step_avg:96.84ms | |
| step:603/1770 train_time:58395ms step_avg:96.84ms | |
| step:604/1770 train_time:58495ms step_avg:96.85ms | |
| step:605/1770 train_time:58594ms step_avg:96.85ms | |
| step:606/1770 train_time:58694ms step_avg:96.85ms | |
| step:607/1770 train_time:58794ms step_avg:96.86ms | |
| step:608/1770 train_time:58893ms step_avg:96.86ms | |
| step:609/1770 train_time:58993ms step_avg:96.87ms | |
| step:610/1770 train_time:59093ms step_avg:96.87ms | |
| step:611/1770 train_time:59192ms step_avg:96.88ms | |
| step:612/1770 train_time:59292ms step_avg:96.88ms | |
| step:613/1770 train_time:59392ms step_avg:96.89ms | |
| step:614/1770 train_time:59491ms step_avg:96.89ms | |
| step:615/1770 train_time:59591ms step_avg:96.90ms | |
| step:616/1770 train_time:59690ms step_avg:96.90ms | |
| step:617/1770 train_time:59789ms step_avg:96.90ms | |
| step:618/1770 train_time:59889ms step_avg:96.91ms | |
| step:619/1770 train_time:59988ms step_avg:96.91ms | |
| step:620/1770 train_time:60087ms step_avg:96.92ms | |
| step:621/1770 train_time:60187ms step_avg:96.92ms | |
| step:622/1770 train_time:60286ms step_avg:96.92ms | |
| step:623/1770 train_time:60385ms step_avg:96.93ms | |
| step:624/1770 train_time:60484ms step_avg:96.93ms | |
| step:625/1770 train_time:60582ms step_avg:96.93ms | |
| step:625/1770 val_loss:3.6686 train_time:60665ms step_avg:97.06ms | |
| step:626/1770 train_time:60695ms step_avg:96.96ms | |
| step:627/1770 train_time:60787ms step_avg:96.95ms | |
| step:628/1770 train_time:60889ms step_avg:96.96ms | |
| step:629/1770 train_time:60988ms step_avg:96.96ms | |
| step:630/1770 train_time:61087ms step_avg:96.96ms | |
| step:631/1770 train_time:61185ms step_avg:96.97ms | |
| step:632/1770 train_time:61284ms step_avg:96.97ms | |
| step:633/1770 train_time:61382ms step_avg:96.97ms | |
| step:634/1770 train_time:61481ms step_avg:96.97ms | |
| step:635/1770 train_time:61578ms step_avg:96.97ms | |
| step:636/1770 train_time:61677ms step_avg:96.98ms | |
| step:637/1770 train_time:61777ms step_avg:96.98ms | |
| step:638/1770 train_time:61879ms step_avg:96.99ms | |
| step:639/1770 train_time:61978ms step_avg:96.99ms | |
| step:640/1770 train_time:62078ms step_avg:97.00ms | |
| step:641/1770 train_time:62176ms step_avg:97.00ms | |
| step:642/1770 train_time:62275ms step_avg:97.00ms | |
| step:643/1770 train_time:62375ms step_avg:97.01ms | |
| step:644/1770 train_time:62475ms step_avg:97.01ms | |
| step:645/1770 train_time:62574ms step_avg:97.01ms | |
| step:646/1770 train_time:62673ms step_avg:97.02ms | |
| step:647/1770 train_time:62772ms step_avg:97.02ms | |
| step:648/1770 train_time:62872ms step_avg:97.03ms | |
| step:649/1770 train_time:62973ms step_avg:97.03ms | |
| step:650/1770 train_time:63074ms step_avg:97.04ms | |
| step:651/1770 train_time:63174ms step_avg:97.04ms | |
| step:652/1770 train_time:63274ms step_avg:97.05ms | |
| step:653/1770 train_time:63373ms step_avg:97.05ms | |
| step:654/1770 train_time:63473ms step_avg:97.05ms | |
| step:655/1770 train_time:63572ms step_avg:97.06ms | |
| step:656/1770 train_time:63671ms step_avg:97.06ms | |
| step:657/1770 train_time:63771ms step_avg:97.06ms | |
| step:658/1770 train_time:63872ms step_avg:97.07ms | |
| step:659/1770 train_time:63975ms step_avg:97.08ms | |
| step:660/1770 train_time:64077ms step_avg:97.09ms | |
| step:661/1770 train_time:64178ms step_avg:97.09ms | |
| step:662/1770 train_time:64279ms step_avg:97.10ms | |
| step:663/1770 train_time:64379ms step_avg:97.10ms | |
| step:664/1770 train_time:64480ms step_avg:97.11ms | |
| step:665/1770 train_time:64581ms step_avg:97.11ms | |
| step:666/1770 train_time:64682ms step_avg:97.12ms | |
| step:667/1770 train_time:64782ms step_avg:97.13ms | |
| step:668/1770 train_time:64883ms step_avg:97.13ms | |
| step:669/1770 train_time:64984ms step_avg:97.14ms | |
| step:670/1770 train_time:65085ms step_avg:97.14ms | |
| step:671/1770 train_time:65187ms step_avg:97.15ms | |
| step:672/1770 train_time:65288ms step_avg:97.15ms | |
| step:673/1770 train_time:65389ms step_avg:97.16ms | |
| step:674/1770 train_time:65490ms step_avg:97.17ms | |
| step:675/1770 train_time:65590ms step_avg:97.17ms | |
| step:676/1770 train_time:65692ms step_avg:97.18ms | |
| step:677/1770 train_time:65793ms step_avg:97.18ms | |
| step:678/1770 train_time:65894ms step_avg:97.19ms | |
| step:679/1770 train_time:65996ms step_avg:97.20ms | |
| step:680/1770 train_time:66097ms step_avg:97.20ms | |
| step:681/1770 train_time:66199ms step_avg:97.21ms | |
| step:682/1770 train_time:66301ms step_avg:97.22ms | |
| step:683/1770 train_time:66401ms step_avg:97.22ms | |
| step:684/1770 train_time:66502ms step_avg:97.22ms | |
| step:685/1770 train_time:66602ms step_avg:97.23ms | |
| step:686/1770 train_time:66703ms step_avg:97.23ms | |
| step:687/1770 train_time:66803ms step_avg:97.24ms | |
| step:688/1770 train_time:66903ms step_avg:97.24ms | |
| step:689/1770 train_time:67004ms step_avg:97.25ms | |
| step:690/1770 train_time:67104ms step_avg:97.25ms | |
| step:691/1770 train_time:67205ms step_avg:97.26ms | |
| step:692/1770 train_time:67306ms step_avg:97.26ms | |
| step:693/1770 train_time:67407ms step_avg:97.27ms | |
| step:694/1770 train_time:67507ms step_avg:97.27ms | |
| step:695/1770 train_time:67608ms step_avg:97.28ms | |
| step:696/1770 train_time:67709ms step_avg:97.28ms | |
| step:697/1770 train_time:67809ms step_avg:97.29ms | |
| step:698/1770 train_time:67910ms step_avg:97.29ms | |
| step:699/1770 train_time:68011ms step_avg:97.30ms | |
| step:700/1770 train_time:68113ms step_avg:97.30ms | |
| step:701/1770 train_time:68215ms step_avg:97.31ms | |
| step:702/1770 train_time:68316ms step_avg:97.32ms | |
| step:703/1770 train_time:68418ms step_avg:97.32ms | |
| step:704/1770 train_time:68519ms step_avg:97.33ms | |
| step:705/1770 train_time:68619ms step_avg:97.33ms | |
| step:706/1770 train_time:68719ms step_avg:97.34ms | |
| step:707/1770 train_time:68820ms step_avg:97.34ms | |
| step:708/1770 train_time:68921ms step_avg:97.35ms | |
| step:709/1770 train_time:69022ms step_avg:97.35ms | |
| step:710/1770 train_time:69123ms step_avg:97.36ms | |
| step:711/1770 train_time:69223ms step_avg:97.36ms | |
| step:712/1770 train_time:69323ms step_avg:97.36ms | |
| step:713/1770 train_time:69424ms step_avg:97.37ms | |
| step:714/1770 train_time:69525ms step_avg:97.37ms | |
| step:715/1770 train_time:69626ms step_avg:97.38ms | |
| step:716/1770 train_time:69726ms step_avg:97.38ms | |
| step:717/1770 train_time:69827ms step_avg:97.39ms | |
| step:718/1770 train_time:69928ms step_avg:97.39ms | |
| step:719/1770 train_time:70029ms step_avg:97.40ms | |
| step:720/1770 train_time:70129ms step_avg:97.40ms | |
| step:721/1770 train_time:70231ms step_avg:97.41ms | |
| step:722/1770 train_time:70332ms step_avg:97.41ms | |
| step:723/1770 train_time:70434ms step_avg:97.42ms | |
| step:724/1770 train_time:70537ms step_avg:97.43ms | |
| step:725/1770 train_time:70638ms step_avg:97.43ms | |
| step:726/1770 train_time:70740ms step_avg:97.44ms | |
| step:727/1770 train_time:70841ms step_avg:97.44ms | |
| step:728/1770 train_time:70941ms step_avg:97.45ms | |
| step:729/1770 train_time:71042ms step_avg:97.45ms | |
| step:730/1770 train_time:71142ms step_avg:97.46ms | |
| step:731/1770 train_time:71242ms step_avg:97.46ms | |
| step:732/1770 train_time:71343ms step_avg:97.46ms | |
| step:733/1770 train_time:71443ms step_avg:97.47ms | |
| step:734/1770 train_time:71543ms step_avg:97.47ms | |
| step:735/1770 train_time:71643ms step_avg:97.47ms | |
| step:736/1770 train_time:71746ms step_avg:97.48ms | |
| step:737/1770 train_time:71847ms step_avg:97.49ms | |
| step:738/1770 train_time:71948ms step_avg:97.49ms | |
| step:739/1770 train_time:72049ms step_avg:97.49ms | |
| step:740/1770 train_time:72150ms step_avg:97.50ms | |
| step:741/1770 train_time:72250ms step_avg:97.50ms | |
| step:742/1770 train_time:72352ms step_avg:97.51ms | |
| step:743/1770 train_time:72453ms step_avg:97.51ms | |
| step:744/1770 train_time:72555ms step_avg:97.52ms | |
| step:745/1770 train_time:72657ms step_avg:97.53ms | |
| step:746/1770 train_time:72760ms step_avg:97.53ms | |
| step:747/1770 train_time:72862ms step_avg:97.54ms | |
| step:748/1770 train_time:72962ms step_avg:97.54ms | |
| step:749/1770 train_time:73062ms step_avg:97.55ms | |
| step:750/1770 train_time:73162ms step_avg:97.55ms | |
| step:750/1770 val_loss:3.6018 train_time:73246ms step_avg:97.66ms | |
| step:751/1770 train_time:73275ms step_avg:97.57ms | |
| step:752/1770 train_time:73372ms step_avg:97.57ms | |
| step:753/1770 train_time:73474ms step_avg:97.58ms | |
| step:754/1770 train_time:73575ms step_avg:97.58ms | |
| step:755/1770 train_time:73676ms step_avg:97.58ms | |
| step:756/1770 train_time:73777ms step_avg:97.59ms | |
| step:757/1770 train_time:73877ms step_avg:97.59ms | |
| step:758/1770 train_time:73977ms step_avg:97.60ms | |
| step:759/1770 train_time:74077ms step_avg:97.60ms | |
| step:760/1770 train_time:74178ms step_avg:97.60ms | |
| step:761/1770 train_time:74283ms step_avg:97.61ms | |
| step:762/1770 train_time:74386ms step_avg:97.62ms | |
| step:763/1770 train_time:74489ms step_avg:97.63ms | |
| step:764/1770 train_time:74589ms step_avg:97.63ms | |
| step:765/1770 train_time:74689ms step_avg:97.63ms | |
| step:766/1770 train_time:74789ms step_avg:97.64ms | |
| step:767/1770 train_time:74889ms step_avg:97.64ms | |
| step:768/1770 train_time:74989ms step_avg:97.64ms | |
| step:769/1770 train_time:75089ms step_avg:97.65ms | |
| step:770/1770 train_time:75189ms step_avg:97.65ms | |
| step:771/1770 train_time:75290ms step_avg:97.65ms | |
| step:772/1770 train_time:75391ms step_avg:97.66ms | |
| step:773/1770 train_time:75492ms step_avg:97.66ms | |
| step:774/1770 train_time:75593ms step_avg:97.67ms | |
| step:775/1770 train_time:75693ms step_avg:97.67ms | |
| step:776/1770 train_time:75794ms step_avg:97.67ms | |
| step:777/1770 train_time:75894ms step_avg:97.68ms | |
| step:778/1770 train_time:75995ms step_avg:97.68ms | |
| step:779/1770 train_time:76096ms step_avg:97.68ms | |
| step:780/1770 train_time:76196ms step_avg:97.69ms | |
| step:781/1770 train_time:76297ms step_avg:97.69ms | |
| step:782/1770 train_time:76398ms step_avg:97.70ms | |
| step:783/1770 train_time:76499ms step_avg:97.70ms | |
| step:784/1770 train_time:76600ms step_avg:97.70ms | |
| step:785/1770 train_time:76702ms step_avg:97.71ms | |
| step:786/1770 train_time:76805ms step_avg:97.72ms | |
| step:787/1770 train_time:76906ms step_avg:97.72ms | |
| step:788/1770 train_time:77007ms step_avg:97.72ms | |
| step:789/1770 train_time:77108ms step_avg:97.73ms | |
| step:790/1770 train_time:77209ms step_avg:97.73ms | |
| step:791/1770 train_time:77310ms step_avg:97.74ms | |
| step:792/1770 train_time:77411ms step_avg:97.74ms | |
| step:793/1770 train_time:77513ms step_avg:97.75ms | |
| step:794/1770 train_time:77614ms step_avg:97.75ms | |
| step:795/1770 train_time:77715ms step_avg:97.76ms | |
| step:796/1770 train_time:77816ms step_avg:97.76ms | |
| step:797/1770 train_time:77917ms step_avg:97.76ms | |
| step:798/1770 train_time:78019ms step_avg:97.77ms | |
| step:799/1770 train_time:78121ms step_avg:97.77ms | |
| step:800/1770 train_time:78223ms step_avg:97.78ms | |
| step:801/1770 train_time:78325ms step_avg:97.78ms | |
| step:802/1770 train_time:78427ms step_avg:97.79ms | |
| step:803/1770 train_time:78529ms step_avg:97.79ms | |
| step:804/1770 train_time:78630ms step_avg:97.80ms | |
| step:805/1770 train_time:78730ms step_avg:97.80ms | |
| step:806/1770 train_time:78831ms step_avg:97.80ms | |
| step:807/1770 train_time:78931ms step_avg:97.81ms | |
| step:808/1770 train_time:79032ms step_avg:97.81ms | |
| step:809/1770 train_time:79133ms step_avg:97.82ms | |
| step:810/1770 train_time:79234ms step_avg:97.82ms | |
| step:811/1770 train_time:79335ms step_avg:97.82ms | |
| step:812/1770 train_time:79437ms step_avg:97.83ms | |
| step:813/1770 train_time:79538ms step_avg:97.83ms | |
| step:814/1770 train_time:79640ms step_avg:97.84ms | |
| step:815/1770 train_time:79742ms step_avg:97.84ms | |
| step:816/1770 train_time:79843ms step_avg:97.85ms | |
| step:817/1770 train_time:79945ms step_avg:97.85ms | |
| step:818/1770 train_time:80047ms step_avg:97.86ms | |
| step:819/1770 train_time:80150ms step_avg:97.86ms | |
| step:820/1770 train_time:80250ms step_avg:97.87ms | |
| step:821/1770 train_time:80351ms step_avg:97.87ms | |
| step:822/1770 train_time:80451ms step_avg:97.87ms | |
| step:823/1770 train_time:80552ms step_avg:97.88ms | |
| step:824/1770 train_time:80653ms step_avg:97.88ms | |
| step:825/1770 train_time:80754ms step_avg:97.88ms | |
| step:826/1770 train_time:80856ms step_avg:97.89ms | |
| step:827/1770 train_time:80957ms step_avg:97.89ms | |
| step:828/1770 train_time:81059ms step_avg:97.90ms | |
| step:829/1770 train_time:81160ms step_avg:97.90ms | |
| step:830/1770 train_time:81262ms step_avg:97.91ms | |
| step:831/1770 train_time:81364ms step_avg:97.91ms | |
| step:832/1770 train_time:81465ms step_avg:97.91ms | |
| step:833/1770 train_time:81566ms step_avg:97.92ms | |
| step:834/1770 train_time:81668ms step_avg:97.92ms | |
| step:835/1770 train_time:81769ms step_avg:97.93ms | |
| step:836/1770 train_time:81870ms step_avg:97.93ms | |
| step:837/1770 train_time:81970ms step_avg:97.93ms | |
| step:838/1770 train_time:82071ms step_avg:97.94ms | |
| step:839/1770 train_time:82171ms step_avg:97.94ms | |
| step:840/1770 train_time:82273ms step_avg:97.94ms | |
| step:841/1770 train_time:82374ms step_avg:97.95ms | |
| step:842/1770 train_time:82476ms step_avg:97.95ms | |
| step:843/1770 train_time:82577ms step_avg:97.96ms | |
| step:844/1770 train_time:82677ms step_avg:97.96ms | |
| step:845/1770 train_time:82779ms step_avg:97.96ms | |
| step:846/1770 train_time:82881ms step_avg:97.97ms | |
| step:847/1770 train_time:82983ms step_avg:97.97ms | |
| step:848/1770 train_time:83085ms step_avg:97.98ms | |
| step:849/1770 train_time:83187ms step_avg:97.98ms | |
| step:850/1770 train_time:83287ms step_avg:97.99ms | |
| step:851/1770 train_time:83388ms step_avg:97.99ms | |
| step:852/1770 train_time:83489ms step_avg:97.99ms | |
| step:853/1770 train_time:83590ms step_avg:98.00ms | |
| step:854/1770 train_time:83691ms step_avg:98.00ms | |
| step:855/1770 train_time:83792ms step_avg:98.00ms | |
| step:856/1770 train_time:83894ms step_avg:98.01ms | |
| step:857/1770 train_time:83996ms step_avg:98.01ms | |
| step:858/1770 train_time:84098ms step_avg:98.02ms | |
| step:859/1770 train_time:84198ms step_avg:98.02ms | |
| step:860/1770 train_time:84299ms step_avg:98.02ms | |
| step:861/1770 train_time:84400ms step_avg:98.03ms | |
| step:862/1770 train_time:84502ms step_avg:98.03ms | |
| step:863/1770 train_time:84604ms step_avg:98.03ms | |
| step:864/1770 train_time:84707ms step_avg:98.04ms | |
| step:865/1770 train_time:84808ms step_avg:98.04ms | |
| step:866/1770 train_time:84908ms step_avg:98.05ms | |
| step:867/1770 train_time:85010ms step_avg:98.05ms | |
| step:868/1770 train_time:85111ms step_avg:98.05ms | |
| step:869/1770 train_time:85211ms step_avg:98.06ms | |
| step:870/1770 train_time:85312ms step_avg:98.06ms | |
| step:871/1770 train_time:85413ms step_avg:98.06ms | |
| step:872/1770 train_time:85514ms step_avg:98.07ms | |
| step:873/1770 train_time:85615ms step_avg:98.07ms | |
| step:874/1770 train_time:85717ms step_avg:98.07ms | |
| step:875/1770 train_time:85818ms step_avg:98.08ms | |
| step:875/1770 val_loss:3.5522 train_time:85903ms step_avg:98.18ms | |
| step:876/1770 train_time:85933ms step_avg:98.10ms | |
| step:877/1770 train_time:86029ms step_avg:98.09ms | |
| step:878/1770 train_time:86134ms step_avg:98.10ms | |
| step:879/1770 train_time:86235ms step_avg:98.11ms | |
| step:880/1770 train_time:86336ms step_avg:98.11ms | |
| step:881/1770 train_time:86435ms step_avg:98.11ms | |
| step:882/1770 train_time:86535ms step_avg:98.11ms | |
| step:883/1770 train_time:86635ms step_avg:98.11ms | |
| step:884/1770 train_time:86735ms step_avg:98.12ms | |
| step:885/1770 train_time:86835ms step_avg:98.12ms | |
| step:886/1770 train_time:86936ms step_avg:98.12ms | |
| step:887/1770 train_time:87038ms step_avg:98.13ms | |
| step:888/1770 train_time:87140ms step_avg:98.13ms | |
| step:889/1770 train_time:87241ms step_avg:98.13ms | |
| step:890/1770 train_time:87342ms step_avg:98.14ms | |
| step:891/1770 train_time:87443ms step_avg:98.14ms | |
| step:892/1770 train_time:87544ms step_avg:98.14ms | |
| step:893/1770 train_time:87645ms step_avg:98.15ms | |
| step:894/1770 train_time:87747ms step_avg:98.15ms | |
| step:895/1770 train_time:87850ms step_avg:98.16ms | |
| step:896/1770 train_time:87951ms step_avg:98.16ms | |
| step:897/1770 train_time:88053ms step_avg:98.16ms | |
| step:898/1770 train_time:88155ms step_avg:98.17ms | |
| step:899/1770 train_time:88256ms step_avg:98.17ms | |
| step:900/1770 train_time:88357ms step_avg:98.17ms | |
| step:901/1770 train_time:88457ms step_avg:98.18ms | |
| step:902/1770 train_time:88558ms step_avg:98.18ms | |
| step:903/1770 train_time:88659ms step_avg:98.18ms | |
| step:904/1770 train_time:88759ms step_avg:98.19ms | |
| step:905/1770 train_time:88861ms step_avg:98.19ms | |
| step:906/1770 train_time:88962ms step_avg:98.19ms | |
| step:907/1770 train_time:89063ms step_avg:98.20ms | |
| step:908/1770 train_time:89165ms step_avg:98.20ms | |
| step:909/1770 train_time:89267ms step_avg:98.20ms | |
| step:910/1770 train_time:89368ms step_avg:98.21ms | |
| step:911/1770 train_time:89469ms step_avg:98.21ms | |
| step:912/1770 train_time:89572ms step_avg:98.21ms | |
| step:913/1770 train_time:89674ms step_avg:98.22ms | |
| step:914/1770 train_time:89776ms step_avg:98.22ms | |
| step:915/1770 train_time:89877ms step_avg:98.23ms | |
| step:916/1770 train_time:89977ms step_avg:98.23ms | |
| step:917/1770 train_time:90078ms step_avg:98.23ms | |
| step:918/1770 train_time:90179ms step_avg:98.23ms | |
| step:919/1770 train_time:90280ms step_avg:98.24ms | |
| step:920/1770 train_time:90383ms step_avg:98.24ms | |
| step:921/1770 train_time:90485ms step_avg:98.25ms | |
| step:922/1770 train_time:90588ms step_avg:98.25ms | |
| step:923/1770 train_time:90691ms step_avg:98.26ms | |
| step:924/1770 train_time:90796ms step_avg:98.26ms | |
| step:925/1770 train_time:90898ms step_avg:98.27ms | |
| step:926/1770 train_time:90999ms step_avg:98.27ms | |
| step:927/1770 train_time:91101ms step_avg:98.28ms | |
| step:928/1770 train_time:91204ms step_avg:98.28ms | |
| step:929/1770 train_time:91307ms step_avg:98.29ms | |
| step:930/1770 train_time:91410ms step_avg:98.29ms | |
| step:931/1770 train_time:91512ms step_avg:98.29ms | |
| step:932/1770 train_time:91614ms step_avg:98.30ms | |
| step:933/1770 train_time:91717ms step_avg:98.30ms | |
| step:934/1770 train_time:91819ms step_avg:98.31ms | |
| step:935/1770 train_time:91921ms step_avg:98.31ms | |
| step:936/1770 train_time:92023ms step_avg:98.32ms | |
| step:937/1770 train_time:92125ms step_avg:98.32ms | |
| step:938/1770 train_time:92228ms step_avg:98.32ms | |
| step:939/1770 train_time:92331ms step_avg:98.33ms | |
| step:940/1770 train_time:92434ms step_avg:98.33ms | |
| step:941/1770 train_time:92537ms step_avg:98.34ms | |
| step:942/1770 train_time:92638ms step_avg:98.34ms | |
| step:943/1770 train_time:92742ms step_avg:98.35ms | |
| step:944/1770 train_time:92844ms step_avg:98.35ms | |
| step:945/1770 train_time:92946ms step_avg:98.36ms | |
| step:946/1770 train_time:93050ms step_avg:98.36ms | |
| step:947/1770 train_time:93153ms step_avg:98.37ms | |
| step:948/1770 train_time:93256ms step_avg:98.37ms | |
| step:949/1770 train_time:93358ms step_avg:98.37ms | |
| step:950/1770 train_time:93460ms step_avg:98.38ms | |
| step:951/1770 train_time:93562ms step_avg:98.38ms | |
| step:952/1770 train_time:93665ms step_avg:98.39ms | |
| step:953/1770 train_time:93767ms step_avg:98.39ms | |
| step:954/1770 train_time:93871ms step_avg:98.40ms | |
| step:955/1770 train_time:93973ms step_avg:98.40ms | |
| step:956/1770 train_time:94075ms step_avg:98.41ms | |
| step:957/1770 train_time:94178ms step_avg:98.41ms | |
| step:958/1770 train_time:94280ms step_avg:98.41ms | |
| step:959/1770 train_time:94383ms step_avg:98.42ms | |
| step:960/1770 train_time:94485ms step_avg:98.42ms | |
| step:961/1770 train_time:94587ms step_avg:98.43ms | |
| step:962/1770 train_time:94690ms step_avg:98.43ms | |
| step:963/1770 train_time:94793ms step_avg:98.44ms | |
| step:964/1770 train_time:94897ms step_avg:98.44ms | |
| step:965/1770 train_time:94999ms step_avg:98.44ms | |
| step:966/1770 train_time:95101ms step_avg:98.45ms | |
| step:967/1770 train_time:95203ms step_avg:98.45ms | |
| step:968/1770 train_time:95306ms step_avg:98.46ms | |
| step:969/1770 train_time:95409ms step_avg:98.46ms | |
| step:970/1770 train_time:95512ms step_avg:98.47ms | |
| step:971/1770 train_time:95615ms step_avg:98.47ms | |
| step:972/1770 train_time:95716ms step_avg:98.47ms | |
| step:973/1770 train_time:95820ms step_avg:98.48ms | |
| step:974/1770 train_time:95922ms step_avg:98.48ms | |
| step:975/1770 train_time:96024ms step_avg:98.49ms | |
| step:976/1770 train_time:96127ms step_avg:98.49ms | |
| step:977/1770 train_time:96230ms step_avg:98.50ms | |
| step:978/1770 train_time:96333ms step_avg:98.50ms | |
| step:979/1770 train_time:96436ms step_avg:98.50ms | |
| step:980/1770 train_time:96538ms step_avg:98.51ms | |
| step:981/1770 train_time:96641ms step_avg:98.51ms | |
| step:982/1770 train_time:96743ms step_avg:98.52ms | |
| step:983/1770 train_time:96845ms step_avg:98.52ms | |
| step:984/1770 train_time:96948ms step_avg:98.52ms | |
| step:985/1770 train_time:97051ms step_avg:98.53ms | |
| step:986/1770 train_time:97154ms step_avg:98.53ms | |
| step:987/1770 train_time:97256ms step_avg:98.54ms | |
| step:988/1770 train_time:97359ms step_avg:98.54ms | |
| step:989/1770 train_time:97460ms step_avg:98.54ms | |
| step:990/1770 train_time:97562ms step_avg:98.55ms | |
| step:991/1770 train_time:97664ms step_avg:98.55ms | |
| step:992/1770 train_time:97767ms step_avg:98.56ms | |
| step:993/1770 train_time:97870ms step_avg:98.56ms | |
| step:994/1770 train_time:97973ms step_avg:98.56ms | |
| step:995/1770 train_time:98076ms step_avg:98.57ms | |
| step:996/1770 train_time:98179ms step_avg:98.57ms | |
| step:997/1770 train_time:98281ms step_avg:98.58ms | |
| step:998/1770 train_time:98385ms step_avg:98.58ms | |
| step:999/1770 train_time:98488ms step_avg:98.59ms | |
| step:1000/1770 train_time:98592ms step_avg:98.59ms | |
| step:1000/1770 val_loss:3.5141 train_time:98678ms step_avg:98.68ms | |
| step:1001/1770 train_time:98707ms step_avg:98.61ms | |
| step:1002/1770 train_time:98808ms step_avg:98.61ms | |
| step:1003/1770 train_time:98912ms step_avg:98.62ms | |
| step:1004/1770 train_time:99014ms step_avg:98.62ms | |
| step:1005/1770 train_time:99115ms step_avg:98.62ms | |
| step:1006/1770 train_time:99216ms step_avg:98.62ms | |
| step:1007/1770 train_time:99318ms step_avg:98.63ms | |
| step:1008/1770 train_time:99420ms step_avg:98.63ms | |
| step:1009/1770 train_time:99522ms step_avg:98.63ms | |
| step:1010/1770 train_time:99623ms step_avg:98.64ms | |
| step:1011/1770 train_time:99727ms step_avg:98.64ms | |
| step:1012/1770 train_time:99833ms step_avg:98.65ms | |
| step:1013/1770 train_time:99936ms step_avg:98.65ms | |
| step:1014/1770 train_time:100038ms step_avg:98.66ms | |
| step:1015/1770 train_time:100141ms step_avg:98.66ms | |
| step:1016/1770 train_time:100244ms step_avg:98.67ms | |
| step:1017/1770 train_time:100347ms step_avg:98.67ms | |
| step:1018/1770 train_time:100450ms step_avg:98.67ms | |
| step:1019/1770 train_time:100553ms step_avg:98.68ms | |
| step:1020/1770 train_time:100654ms step_avg:98.68ms | |
| step:1021/1770 train_time:100757ms step_avg:98.68ms | |
| step:1022/1770 train_time:100859ms step_avg:98.69ms | |
| step:1023/1770 train_time:100961ms step_avg:98.69ms | |
| step:1024/1770 train_time:101065ms step_avg:98.70ms | |
| step:1025/1770 train_time:101168ms step_avg:98.70ms | |
| step:1026/1770 train_time:101272ms step_avg:98.71ms | |
| step:1027/1770 train_time:101374ms step_avg:98.71ms | |
| step:1028/1770 train_time:101476ms step_avg:98.71ms | |
| step:1029/1770 train_time:101578ms step_avg:98.72ms | |
| step:1030/1770 train_time:101680ms step_avg:98.72ms | |
| step:1031/1770 train_time:101783ms step_avg:98.72ms | |
| step:1032/1770 train_time:101885ms step_avg:98.73ms | |
| step:1033/1770 train_time:101990ms step_avg:98.73ms | |
| step:1034/1770 train_time:102092ms step_avg:98.74ms | |
| step:1035/1770 train_time:102195ms step_avg:98.74ms | |
| step:1036/1770 train_time:102297ms step_avg:98.74ms | |
| step:1037/1770 train_time:102399ms step_avg:98.75ms | |
| step:1038/1770 train_time:102501ms step_avg:98.75ms | |
| step:1039/1770 train_time:102604ms step_avg:98.75ms | |
| step:1040/1770 train_time:102706ms step_avg:98.76ms | |
| step:1041/1770 train_time:102808ms step_avg:98.76ms | |
| step:1042/1770 train_time:102911ms step_avg:98.76ms | |
| step:1043/1770 train_time:103013ms step_avg:98.77ms | |
| step:1044/1770 train_time:103115ms step_avg:98.77ms | |
| step:1045/1770 train_time:103217ms step_avg:98.77ms | |
| step:1046/1770 train_time:103320ms step_avg:98.78ms | |
| step:1047/1770 train_time:103422ms step_avg:98.78ms | |
| step:1048/1770 train_time:103524ms step_avg:98.78ms | |
| step:1049/1770 train_time:103627ms step_avg:98.79ms | |
| step:1050/1770 train_time:103731ms step_avg:98.79ms | |
| step:1051/1770 train_time:103834ms step_avg:98.80ms | |
| step:1052/1770 train_time:103937ms step_avg:98.80ms | |
| step:1053/1770 train_time:104039ms step_avg:98.80ms | |
| step:1054/1770 train_time:104142ms step_avg:98.81ms | |
| step:1055/1770 train_time:104246ms step_avg:98.81ms | |
| step:1056/1770 train_time:104349ms step_avg:98.82ms | |
| step:1057/1770 train_time:104452ms step_avg:98.82ms | |
| step:1058/1770 train_time:104554ms step_avg:98.82ms | |
| step:1059/1770 train_time:104656ms step_avg:98.83ms | |
| step:1060/1770 train_time:104759ms step_avg:98.83ms | |
| step:1061/1770 train_time:104862ms step_avg:98.83ms | |
| step:1062/1770 train_time:104966ms step_avg:98.84ms | |
| step:1063/1770 train_time:105070ms step_avg:98.84ms | |
| step:1064/1770 train_time:105173ms step_avg:98.85ms | |
| step:1065/1770 train_time:105275ms step_avg:98.85ms | |
| step:1066/1770 train_time:105378ms step_avg:98.85ms | |
| step:1067/1770 train_time:105481ms step_avg:98.86ms | |
| step:1068/1770 train_time:105584ms step_avg:98.86ms | |
| step:1069/1770 train_time:105687ms step_avg:98.87ms | |
| step:1070/1770 train_time:105790ms step_avg:98.87ms | |
| step:1071/1770 train_time:105893ms step_avg:98.87ms | |
| step:1072/1770 train_time:105995ms step_avg:98.88ms | |
| step:1073/1770 train_time:106098ms step_avg:98.88ms | |
| step:1074/1770 train_time:106201ms step_avg:98.88ms | |
| step:1075/1770 train_time:106303ms step_avg:98.89ms | |
| step:1076/1770 train_time:106408ms step_avg:98.89ms | |
| step:1077/1770 train_time:106511ms step_avg:98.90ms | |
| step:1078/1770 train_time:106614ms step_avg:98.90ms | |
| step:1079/1770 train_time:106716ms step_avg:98.90ms | |
| step:1080/1770 train_time:106819ms step_avg:98.91ms | |
| step:1081/1770 train_time:106921ms step_avg:98.91ms | |
| step:1082/1770 train_time:107025ms step_avg:98.91ms | |
| step:1083/1770 train_time:107128ms step_avg:98.92ms | |
| step:1084/1770 train_time:107231ms step_avg:98.92ms | |
| step:1085/1770 train_time:107334ms step_avg:98.93ms | |
| step:1086/1770 train_time:107436ms step_avg:98.93ms | |
| step:1087/1770 train_time:107538ms step_avg:98.93ms | |
| step:1088/1770 train_time:107640ms step_avg:98.93ms | |
| step:1089/1770 train_time:107743ms step_avg:98.94ms | |
| step:1090/1770 train_time:107847ms step_avg:98.94ms | |
| step:1091/1770 train_time:107950ms step_avg:98.95ms | |
| step:1092/1770 train_time:108054ms step_avg:98.95ms | |
| step:1093/1770 train_time:108156ms step_avg:98.95ms | |
| step:1094/1770 train_time:108259ms step_avg:98.96ms | |
| step:1095/1770 train_time:108363ms step_avg:98.96ms | |
| step:1096/1770 train_time:108466ms step_avg:98.97ms | |
| step:1097/1770 train_time:108571ms step_avg:98.97ms | |
| step:1098/1770 train_time:108673ms step_avg:98.97ms | |
| step:1099/1770 train_time:108774ms step_avg:98.98ms | |
| step:1100/1770 train_time:108877ms step_avg:98.98ms | |
| step:1101/1770 train_time:108979ms step_avg:98.98ms | |
| step:1102/1770 train_time:109082ms step_avg:98.99ms | |
| step:1103/1770 train_time:109186ms step_avg:98.99ms | |
| step:1104/1770 train_time:109290ms step_avg:98.99ms | |
| step:1105/1770 train_time:109393ms step_avg:99.00ms | |
| step:1106/1770 train_time:109495ms step_avg:99.00ms | |
| step:1107/1770 train_time:109599ms step_avg:99.01ms | |
| step:1108/1770 train_time:109702ms step_avg:99.01ms | |
| step:1109/1770 train_time:109804ms step_avg:99.01ms | |
| step:1110/1770 train_time:109909ms step_avg:99.02ms | |
| step:1111/1770 train_time:110011ms step_avg:99.02ms | |
| step:1112/1770 train_time:110114ms step_avg:99.02ms | |
| step:1113/1770 train_time:110216ms step_avg:99.03ms | |
| step:1114/1770 train_time:110319ms step_avg:99.03ms | |
| step:1115/1770 train_time:110421ms step_avg:99.03ms | |
| step:1116/1770 train_time:110524ms step_avg:99.04ms | |
| step:1117/1770 train_time:110629ms step_avg:99.04ms | |
| step:1118/1770 train_time:110731ms step_avg:99.04ms | |
| step:1119/1770 train_time:110834ms step_avg:99.05ms | |
| step:1120/1770 train_time:110936ms step_avg:99.05ms | |
| step:1121/1770 train_time:111039ms step_avg:99.05ms | |
| step:1122/1770 train_time:111141ms step_avg:99.06ms | |
| step:1123/1770 train_time:111244ms step_avg:99.06ms | |
| step:1124/1770 train_time:111347ms step_avg:99.06ms | |
| step:1125/1770 train_time:111451ms step_avg:99.07ms | |
| step:1125/1770 val_loss:3.4733 train_time:111537ms step_avg:99.14ms | |
| step:1126/1770 train_time:111566ms step_avg:99.08ms | |
| step:1127/1770 train_time:111665ms step_avg:99.08ms | |
| step:1128/1770 train_time:111770ms step_avg:99.09ms | |
| step:1129/1770 train_time:111872ms step_avg:99.09ms | |
| step:1130/1770 train_time:111974ms step_avg:99.09ms | |
| step:1131/1770 train_time:112077ms step_avg:99.10ms | |
| step:1132/1770 train_time:112178ms step_avg:99.10ms | |
| step:1133/1770 train_time:112280ms step_avg:99.10ms | |
| step:1134/1770 train_time:112382ms step_avg:99.10ms | |
| step:1135/1770 train_time:112485ms step_avg:99.11ms | |
| step:1136/1770 train_time:112590ms step_avg:99.11ms | |
| step:1137/1770 train_time:112694ms step_avg:99.12ms | |
| step:1138/1770 train_time:112798ms step_avg:99.12ms | |
| step:1139/1770 train_time:112903ms step_avg:99.12ms | |
| step:1140/1770 train_time:113007ms step_avg:99.13ms | |
| step:1141/1770 train_time:113109ms step_avg:99.13ms | |
| step:1142/1770 train_time:113212ms step_avg:99.13ms | |
| step:1143/1770 train_time:113314ms step_avg:99.14ms | |
| step:1144/1770 train_time:113416ms step_avg:99.14ms | |
| step:1145/1770 train_time:113519ms step_avg:99.14ms | |
| step:1146/1770 train_time:113622ms step_avg:99.15ms | |
| step:1147/1770 train_time:113727ms step_avg:99.15ms | |
| step:1148/1770 train_time:113831ms step_avg:99.16ms | |
| step:1149/1770 train_time:113935ms step_avg:99.16ms | |
| step:1150/1770 train_time:114038ms step_avg:99.16ms | |
| step:1151/1770 train_time:114143ms step_avg:99.17ms | |
| step:1152/1770 train_time:114246ms step_avg:99.17ms | |
| step:1153/1770 train_time:114350ms step_avg:99.18ms | |
| step:1154/1770 train_time:114451ms step_avg:99.18ms | |
| step:1155/1770 train_time:114554ms step_avg:99.18ms | |
| step:1156/1770 train_time:114657ms step_avg:99.18ms | |
| step:1157/1770 train_time:114760ms step_avg:99.19ms | |
| step:1158/1770 train_time:114863ms step_avg:99.19ms | |
| step:1159/1770 train_time:114967ms step_avg:99.19ms | |
| step:1160/1770 train_time:115070ms step_avg:99.20ms | |
| step:1161/1770 train_time:115173ms step_avg:99.20ms | |
| step:1162/1770 train_time:115275ms step_avg:99.20ms | |
| step:1163/1770 train_time:115377ms step_avg:99.21ms | |
| step:1164/1770 train_time:115480ms step_avg:99.21ms | |
| step:1165/1770 train_time:115583ms step_avg:99.21ms | |
| step:1166/1770 train_time:115686ms step_avg:99.22ms | |
| step:1167/1770 train_time:115789ms step_avg:99.22ms | |
| step:1168/1770 train_time:115892ms step_avg:99.22ms | |
| step:1169/1770 train_time:115995ms step_avg:99.23ms | |
| step:1170/1770 train_time:116098ms step_avg:99.23ms | |
| step:1171/1770 train_time:116201ms step_avg:99.23ms | |
| step:1172/1770 train_time:116305ms step_avg:99.24ms | |
| step:1173/1770 train_time:116409ms step_avg:99.24ms | |
| step:1174/1770 train_time:116511ms step_avg:99.24ms | |
| step:1175/1770 train_time:116613ms step_avg:99.25ms | |
| step:1176/1770 train_time:116715ms step_avg:99.25ms | |
| step:1177/1770 train_time:116818ms step_avg:99.25ms | |
| step:1178/1770 train_time:116921ms step_avg:99.25ms | |
| step:1179/1770 train_time:117024ms step_avg:99.26ms | |
| step:1180/1770 train_time:117127ms step_avg:99.26ms | |
| step:1181/1770 train_time:117235ms step_avg:99.27ms | |
| step:1182/1770 train_time:117333ms step_avg:99.27ms | |
| step:1183/1770 train_time:117438ms step_avg:99.27ms | |
| step:1184/1770 train_time:117541ms step_avg:99.27ms | |
| step:1185/1770 train_time:117646ms step_avg:99.28ms | |
| step:1186/1770 train_time:117750ms step_avg:99.28ms | |
| step:1187/1770 train_time:117854ms step_avg:99.29ms | |
| step:1188/1770 train_time:117958ms step_avg:99.29ms | |
| step:1189/1770 train_time:118062ms step_avg:99.30ms | |
| step:1190/1770 train_time:118168ms step_avg:99.30ms | |
| step:1191/1770 train_time:118272ms step_avg:99.31ms | |
| step:1192/1770 train_time:118376ms step_avg:99.31ms | |
| step:1193/1770 train_time:118480ms step_avg:99.31ms | |
| step:1194/1770 train_time:118585ms step_avg:99.32ms | |
| step:1195/1770 train_time:118691ms step_avg:99.32ms | |
| step:1196/1770 train_time:118796ms step_avg:99.33ms | |
| step:1197/1770 train_time:118900ms step_avg:99.33ms | |
| step:1198/1770 train_time:119004ms step_avg:99.34ms | |
| step:1199/1770 train_time:119108ms step_avg:99.34ms | |
| step:1200/1770 train_time:119213ms step_avg:99.34ms | |
| step:1201/1770 train_time:119316ms step_avg:99.35ms | |
| step:1202/1770 train_time:119420ms step_avg:99.35ms | |
| step:1203/1770 train_time:119525ms step_avg:99.36ms | |
| step:1204/1770 train_time:119629ms step_avg:99.36ms | |
| step:1205/1770 train_time:119733ms step_avg:99.36ms | |
| step:1206/1770 train_time:119839ms step_avg:99.37ms | |
| step:1207/1770 train_time:119942ms step_avg:99.37ms | |
| step:1208/1770 train_time:120047ms step_avg:99.38ms | |
| step:1209/1770 train_time:120152ms step_avg:99.38ms | |
| step:1210/1770 train_time:120256ms step_avg:99.38ms | |
| step:1211/1770 train_time:120360ms step_avg:99.39ms | |
| step:1212/1770 train_time:120465ms step_avg:99.39ms | |
| step:1213/1770 train_time:120569ms step_avg:99.40ms | |
| step:1214/1770 train_time:120673ms step_avg:99.40ms | |
| step:1215/1770 train_time:120777ms step_avg:99.41ms | |
| step:1216/1770 train_time:120881ms step_avg:99.41ms | |
| step:1217/1770 train_time:120985ms step_avg:99.41ms | |
| step:1218/1770 train_time:121090ms step_avg:99.42ms | |
| step:1219/1770 train_time:121194ms step_avg:99.42ms | |
| step:1220/1770 train_time:121298ms step_avg:99.42ms | |
| step:1221/1770 train_time:121402ms step_avg:99.43ms | |
| step:1222/1770 train_time:121507ms step_avg:99.43ms | |
| step:1223/1770 train_time:121611ms step_avg:99.44ms | |
| step:1224/1770 train_time:121715ms step_avg:99.44ms | |
| step:1225/1770 train_time:121819ms step_avg:99.44ms | |
| step:1226/1770 train_time:121922ms step_avg:99.45ms | |
| step:1227/1770 train_time:122027ms step_avg:99.45ms | |
| step:1228/1770 train_time:122132ms step_avg:99.46ms | |
| step:1229/1770 train_time:122236ms step_avg:99.46ms | |
| step:1230/1770 train_time:122339ms step_avg:99.46ms | |
| step:1231/1770 train_time:122444ms step_avg:99.47ms | |
| step:1232/1770 train_time:122549ms step_avg:99.47ms | |
| step:1233/1770 train_time:122652ms step_avg:99.47ms | |
| step:1234/1770 train_time:122755ms step_avg:99.48ms | |
| step:1235/1770 train_time:122859ms step_avg:99.48ms | |
| step:1236/1770 train_time:122965ms step_avg:99.49ms | |
| step:1237/1770 train_time:123071ms step_avg:99.49ms | |
| step:1238/1770 train_time:123174ms step_avg:99.49ms | |
| step:1239/1770 train_time:123277ms step_avg:99.50ms | |
| step:1240/1770 train_time:123381ms step_avg:99.50ms | |
| step:1241/1770 train_time:123486ms step_avg:99.51ms | |
| step:1242/1770 train_time:123592ms step_avg:99.51ms | |
| step:1243/1770 train_time:123695ms step_avg:99.51ms | |
| step:1244/1770 train_time:123799ms step_avg:99.52ms | |
| step:1245/1770 train_time:123903ms step_avg:99.52ms | |
| step:1246/1770 train_time:124009ms step_avg:99.53ms | |
| step:1247/1770 train_time:124113ms step_avg:99.53ms | |
| step:1248/1770 train_time:124217ms step_avg:99.53ms | |
| step:1249/1770 train_time:124321ms step_avg:99.54ms | |
| step:1250/1770 train_time:124425ms step_avg:99.54ms | |
| step:1250/1770 val_loss:3.4258 train_time:124515ms step_avg:99.61ms | |
| step:1251/1770 train_time:124543ms step_avg:99.56ms | |
| step:1252/1770 train_time:124645ms step_avg:99.56ms | |
| step:1253/1770 train_time:124749ms step_avg:99.56ms | |
| step:1254/1770 train_time:124852ms step_avg:99.56ms | |
| step:1255/1770 train_time:124956ms step_avg:99.57ms | |
| step:1256/1770 train_time:125059ms step_avg:99.57ms | |
| step:1257/1770 train_time:125162ms step_avg:99.57ms | |
| step:1258/1770 train_time:125265ms step_avg:99.58ms | |
| step:1259/1770 train_time:125369ms step_avg:99.58ms | |
| step:1260/1770 train_time:125474ms step_avg:99.58ms | |
| step:1261/1770 train_time:125579ms step_avg:99.59ms | |
| step:1262/1770 train_time:125684ms step_avg:99.59ms | |
| step:1263/1770 train_time:125789ms step_avg:99.60ms | |
| step:1264/1770 train_time:125895ms step_avg:99.60ms | |
| step:1265/1770 train_time:125999ms step_avg:99.60ms | |
| step:1266/1770 train_time:126103ms step_avg:99.61ms | |
| step:1267/1770 train_time:126206ms step_avg:99.61ms | |
| step:1268/1770 train_time:126309ms step_avg:99.61ms | |
| step:1269/1770 train_time:126413ms step_avg:99.62ms | |
| step:1270/1770 train_time:126518ms step_avg:99.62ms | |
| step:1271/1770 train_time:126622ms step_avg:99.62ms | |
| step:1272/1770 train_time:126726ms step_avg:99.63ms | |
| step:1273/1770 train_time:126831ms step_avg:99.63ms | |
| step:1274/1770 train_time:126936ms step_avg:99.64ms | |
| step:1275/1770 train_time:127041ms step_avg:99.64ms | |
| step:1276/1770 train_time:127144ms step_avg:99.64ms | |
| step:1277/1770 train_time:127247ms step_avg:99.65ms | |
| step:1278/1770 train_time:127353ms step_avg:99.65ms | |
| step:1279/1770 train_time:127458ms step_avg:99.65ms | |
| step:1280/1770 train_time:127562ms step_avg:99.66ms | |
| step:1281/1770 train_time:127665ms step_avg:99.66ms | |
| step:1282/1770 train_time:127770ms step_avg:99.66ms | |
| step:1283/1770 train_time:127876ms step_avg:99.67ms | |
| step:1284/1770 train_time:127980ms step_avg:99.67ms | |
| step:1285/1770 train_time:128084ms step_avg:99.68ms | |
| step:1286/1770 train_time:128188ms step_avg:99.68ms | |
| step:1287/1770 train_time:128292ms step_avg:99.68ms | |
| step:1288/1770 train_time:128396ms step_avg:99.69ms | |
| step:1289/1770 train_time:128499ms step_avg:99.69ms | |
| step:1290/1770 train_time:128602ms step_avg:99.69ms | |
| step:1291/1770 train_time:128706ms step_avg:99.69ms | |
| step:1292/1770 train_time:128810ms step_avg:99.70ms | |
| step:1293/1770 train_time:128916ms step_avg:99.70ms | |
| step:1294/1770 train_time:129021ms step_avg:99.71ms | |
| step:1295/1770 train_time:129124ms step_avg:99.71ms | |
| step:1296/1770 train_time:129228ms step_avg:99.71ms | |
| step:1297/1770 train_time:129333ms step_avg:99.72ms | |
| step:1298/1770 train_time:129438ms step_avg:99.72ms | |
| step:1299/1770 train_time:129542ms step_avg:99.72ms | |
| step:1300/1770 train_time:129645ms step_avg:99.73ms | |
| step:1301/1770 train_time:129749ms step_avg:99.73ms | |
| step:1302/1770 train_time:129854ms step_avg:99.73ms | |
| step:1303/1770 train_time:129959ms step_avg:99.74ms | |
| step:1304/1770 train_time:130063ms step_avg:99.74ms | |
| step:1305/1770 train_time:130166ms step_avg:99.74ms | |
| step:1306/1770 train_time:130270ms step_avg:99.75ms | |
| step:1307/1770 train_time:130376ms step_avg:99.75ms | |
| step:1308/1770 train_time:130479ms step_avg:99.75ms | |
| step:1309/1770 train_time:130582ms step_avg:99.76ms | |
| step:1310/1770 train_time:130686ms step_avg:99.76ms | |
| step:1311/1770 train_time:130790ms step_avg:99.76ms | |
| step:1312/1770 train_time:130896ms step_avg:99.77ms | |
| step:1313/1770 train_time:131000ms step_avg:99.77ms | |
| step:1314/1770 train_time:131104ms step_avg:99.77ms | |
| step:1315/1770 train_time:131207ms step_avg:99.78ms | |
| step:1316/1770 train_time:131311ms step_avg:99.78ms | |
| step:1317/1770 train_time:131417ms step_avg:99.79ms | |
| step:1318/1770 train_time:131522ms step_avg:99.79ms | |
| step:1319/1770 train_time:131626ms step_avg:99.79ms | |
| step:1320/1770 train_time:131730ms step_avg:99.80ms | |
| step:1321/1770 train_time:131835ms step_avg:99.80ms | |
| step:1322/1770 train_time:131940ms step_avg:99.80ms | |
| step:1323/1770 train_time:132044ms step_avg:99.81ms | |
| step:1324/1770 train_time:132148ms step_avg:99.81ms | |
| step:1325/1770 train_time:132256ms step_avg:99.82ms | |
| step:1326/1770 train_time:132360ms step_avg:99.82ms | |
| step:1327/1770 train_time:132467ms step_avg:99.82ms | |
| step:1328/1770 train_time:132571ms step_avg:99.83ms | |
| step:1329/1770 train_time:132676ms step_avg:99.83ms | |
| step:1330/1770 train_time:132779ms step_avg:99.83ms | |
| step:1331/1770 train_time:132883ms step_avg:99.84ms | |
| step:1332/1770 train_time:132986ms step_avg:99.84ms | |
| step:1333/1770 train_time:133090ms step_avg:99.84ms | |
| step:1334/1770 train_time:133195ms step_avg:99.85ms | |
| step:1335/1770 train_time:133300ms step_avg:99.85ms | |
| step:1336/1770 train_time:133405ms step_avg:99.85ms | |
| step:1337/1770 train_time:133510ms step_avg:99.86ms | |
| step:1338/1770 train_time:133615ms step_avg:99.86ms | |
| step:1339/1770 train_time:133720ms step_avg:99.87ms | |
| step:1340/1770 train_time:133824ms step_avg:99.87ms | |
| step:1341/1770 train_time:133928ms step_avg:99.87ms | |
| step:1342/1770 train_time:134032ms step_avg:99.88ms | |
| step:1343/1770 train_time:134137ms step_avg:99.88ms | |
| step:1344/1770 train_time:134240ms step_avg:99.88ms | |
| step:1345/1770 train_time:134344ms step_avg:99.88ms | |
| step:1346/1770 train_time:134449ms step_avg:99.89ms | |
| step:1347/1770 train_time:134554ms step_avg:99.89ms | |
| step:1348/1770 train_time:134659ms step_avg:99.90ms | |
| step:1349/1770 train_time:134763ms step_avg:99.90ms | |
| step:1350/1770 train_time:134867ms step_avg:99.90ms | |
| step:1351/1770 train_time:134971ms step_avg:99.90ms | |
| step:1352/1770 train_time:135075ms step_avg:99.91ms | |
| step:1353/1770 train_time:135180ms step_avg:99.91ms | |
| step:1354/1770 train_time:135284ms step_avg:99.91ms | |
| step:1355/1770 train_time:135387ms step_avg:99.92ms | |
| step:1356/1770 train_time:135492ms step_avg:99.92ms | |
| step:1357/1770 train_time:135598ms step_avg:99.92ms | |
| step:1358/1770 train_time:135702ms step_avg:99.93ms | |
| step:1359/1770 train_time:135805ms step_avg:99.93ms | |
| step:1360/1770 train_time:135910ms step_avg:99.93ms | |
| step:1361/1770 train_time:136015ms step_avg:99.94ms | |
| step:1362/1770 train_time:136119ms step_avg:99.94ms | |
| step:1363/1770 train_time:136222ms step_avg:99.94ms | |
| step:1364/1770 train_time:136325ms step_avg:99.95ms | |
| step:1365/1770 train_time:136430ms step_avg:99.95ms | |
| step:1366/1770 train_time:136536ms step_avg:99.95ms | |
| step:1367/1770 train_time:136639ms step_avg:99.96ms | |
| step:1368/1770 train_time:136742ms step_avg:99.96ms | |
| step:1369/1770 train_time:136846ms step_avg:99.96ms | |
| step:1370/1770 train_time:136951ms step_avg:99.96ms | |
| step:1371/1770 train_time:137057ms step_avg:99.97ms | |
| step:1372/1770 train_time:137162ms step_avg:99.97ms | |
| step:1373/1770 train_time:137266ms step_avg:99.98ms | |
| step:1374/1770 train_time:137373ms step_avg:99.98ms | |
| step:1375/1770 train_time:137478ms step_avg:99.98ms | |
| step:1375/1770 val_loss:3.3816 train_time:137566ms step_avg:100.05ms | |
| step:1376/1770 train_time:137595ms step_avg:100.00ms | |
| step:1377/1770 train_time:137698ms step_avg:100.00ms | |
| step:1378/1770 train_time:137802ms step_avg:100.00ms | |
| step:1379/1770 train_time:137905ms step_avg:100.00ms | |
| step:1380/1770 train_time:138009ms step_avg:100.01ms | |
| step:1381/1770 train_time:138113ms step_avg:100.01ms | |
| step:1382/1770 train_time:138217ms step_avg:100.01ms | |
| step:1383/1770 train_time:138321ms step_avg:100.01ms | |
| step:1384/1770 train_time:138424ms step_avg:100.02ms | |
| step:1385/1770 train_time:138528ms step_avg:100.02ms | |
| step:1386/1770 train_time:138634ms step_avg:100.02ms | |
| step:1387/1770 train_time:138741ms step_avg:100.03ms | |
| step:1388/1770 train_time:138846ms step_avg:100.03ms | |
| step:1389/1770 train_time:138950ms step_avg:100.04ms | |
| step:1390/1770 train_time:139054ms step_avg:100.04ms | |
| step:1391/1770 train_time:139158ms step_avg:100.04ms | |
| step:1392/1770 train_time:139262ms step_avg:100.04ms | |
| step:1393/1770 train_time:139366ms step_avg:100.05ms | |
| step:1394/1770 train_time:139469ms step_avg:100.05ms | |
| step:1395/1770 train_time:139574ms step_avg:100.05ms | |
| step:1396/1770 train_time:139680ms step_avg:100.06ms | |
| step:1397/1770 train_time:139785ms step_avg:100.06ms | |
| step:1398/1770 train_time:139889ms step_avg:100.06ms | |
| step:1399/1770 train_time:139994ms step_avg:100.07ms | |
| step:1400/1770 train_time:140100ms step_avg:100.07ms | |
| step:1401/1770 train_time:140203ms step_avg:100.07ms | |
| step:1402/1770 train_time:140306ms step_avg:100.08ms | |
| step:1403/1770 train_time:140410ms step_avg:100.08ms | |
| step:1404/1770 train_time:140515ms step_avg:100.08ms | |
| step:1405/1770 train_time:140620ms step_avg:100.09ms | |
| step:1406/1770 train_time:140725ms step_avg:100.09ms | |
| step:1407/1770 train_time:140829ms step_avg:100.09ms | |
| step:1408/1770 train_time:140933ms step_avg:100.09ms | |
| step:1409/1770 train_time:141039ms step_avg:100.10ms | |
| step:1410/1770 train_time:141143ms step_avg:100.10ms | |
| step:1411/1770 train_time:141247ms step_avg:100.10ms | |
| step:1412/1770 train_time:141351ms step_avg:100.11ms | |
| step:1413/1770 train_time:141456ms step_avg:100.11ms | |
| step:1414/1770 train_time:141562ms step_avg:100.11ms | |
| step:1415/1770 train_time:141667ms step_avg:100.12ms | |
| step:1416/1770 train_time:141771ms step_avg:100.12ms | |
| step:1417/1770 train_time:141876ms step_avg:100.12ms | |
| step:1418/1770 train_time:141981ms step_avg:100.13ms | |
| step:1419/1770 train_time:142086ms step_avg:100.13ms | |
| step:1420/1770 train_time:142189ms step_avg:100.13ms | |
| step:1421/1770 train_time:142293ms step_avg:100.14ms | |
| step:1422/1770 train_time:142398ms step_avg:100.14ms | |
| step:1423/1770 train_time:142502ms step_avg:100.14ms | |
| step:1424/1770 train_time:142606ms step_avg:100.14ms | |
| step:1425/1770 train_time:142710ms step_avg:100.15ms | |
| step:1426/1770 train_time:142816ms step_avg:100.15ms | |
| step:1427/1770 train_time:142920ms step_avg:100.15ms | |
| step:1428/1770 train_time:143025ms step_avg:100.16ms | |
| step:1429/1770 train_time:143128ms step_avg:100.16ms | |
| step:1430/1770 train_time:143233ms step_avg:100.16ms | |
| step:1431/1770 train_time:143339ms step_avg:100.17ms | |
| step:1432/1770 train_time:143443ms step_avg:100.17ms | |
| step:1433/1770 train_time:143546ms step_avg:100.17ms | |
| step:1434/1770 train_time:143650ms step_avg:100.17ms | |
| step:1435/1770 train_time:143754ms step_avg:100.18ms | |
| step:1436/1770 train_time:143863ms step_avg:100.18ms | |
| step:1437/1770 train_time:143967ms step_avg:100.19ms | |
| step:1438/1770 train_time:144071ms step_avg:100.19ms | |
| step:1439/1770 train_time:144174ms step_avg:100.19ms | |
| step:1440/1770 train_time:144279ms step_avg:100.19ms | |
| step:1441/1770 train_time:144384ms step_avg:100.20ms | |
| step:1442/1770 train_time:144487ms step_avg:100.20ms | |
| step:1443/1770 train_time:144592ms step_avg:100.20ms | |
| step:1444/1770 train_time:144698ms step_avg:100.21ms | |
| step:1445/1770 train_time:144803ms step_avg:100.21ms | |
| step:1446/1770 train_time:144908ms step_avg:100.21ms | |
| step:1447/1770 train_time:145012ms step_avg:100.22ms | |
| step:1448/1770 train_time:145117ms step_avg:100.22ms | |
| step:1449/1770 train_time:145225ms step_avg:100.22ms | |
| step:1450/1770 train_time:145331ms step_avg:100.23ms | |
| step:1451/1770 train_time:145438ms step_avg:100.23ms | |
| step:1452/1770 train_time:145542ms step_avg:100.24ms | |
| step:1453/1770 train_time:145646ms step_avg:100.24ms | |
| step:1454/1770 train_time:145752ms step_avg:100.24ms | |
| step:1455/1770 train_time:145859ms step_avg:100.25ms | |
| step:1456/1770 train_time:145964ms step_avg:100.25ms | |
| step:1457/1770 train_time:146070ms step_avg:100.25ms | |
| step:1458/1770 train_time:146175ms step_avg:100.26ms | |
| step:1459/1770 train_time:146281ms step_avg:100.26ms | |
| step:1460/1770 train_time:146386ms step_avg:100.26ms | |
| step:1461/1770 train_time:146491ms step_avg:100.27ms | |
| step:1462/1770 train_time:146599ms step_avg:100.27ms | |
| step:1463/1770 train_time:146704ms step_avg:100.28ms | |
| step:1464/1770 train_time:146810ms step_avg:100.28ms | |
| step:1465/1770 train_time:146915ms step_avg:100.28ms | |
| step:1466/1770 train_time:147020ms step_avg:100.29ms | |
| step:1467/1770 train_time:147126ms step_avg:100.29ms | |
| step:1468/1770 train_time:147232ms step_avg:100.29ms | |
| step:1469/1770 train_time:147338ms step_avg:100.30ms | |
| step:1470/1770 train_time:147442ms step_avg:100.30ms | |
| step:1471/1770 train_time:147547ms step_avg:100.30ms | |
| step:1472/1770 train_time:147653ms step_avg:100.31ms | |
| step:1473/1770 train_time:147761ms step_avg:100.31ms | |
| step:1474/1770 train_time:147866ms step_avg:100.32ms | |
| step:1475/1770 train_time:147972ms step_avg:100.32ms | |
| step:1476/1770 train_time:148079ms step_avg:100.32ms | |
| step:1477/1770 train_time:148186ms step_avg:100.33ms | |
| step:1478/1770 train_time:148291ms step_avg:100.33ms | |
| step:1479/1770 train_time:148399ms step_avg:100.34ms | |
| step:1480/1770 train_time:148504ms step_avg:100.34ms | |
| step:1481/1770 train_time:148610ms step_avg:100.34ms | |
| step:1482/1770 train_time:148715ms step_avg:100.35ms | |
| step:1483/1770 train_time:148821ms step_avg:100.35ms | |
| step:1484/1770 train_time:148926ms step_avg:100.35ms | |
| step:1485/1770 train_time:149031ms step_avg:100.36ms | |
| step:1486/1770 train_time:149140ms step_avg:100.36ms | |
| step:1487/1770 train_time:149244ms step_avg:100.37ms | |
| step:1488/1770 train_time:149350ms step_avg:100.37ms | |
| step:1489/1770 train_time:149458ms step_avg:100.37ms | |
| step:1490/1770 train_time:149563ms step_avg:100.38ms | |
| step:1491/1770 train_time:149668ms step_avg:100.38ms | |
| step:1492/1770 train_time:149774ms step_avg:100.38ms | |
| step:1493/1770 train_time:149880ms step_avg:100.39ms | |
| step:1494/1770 train_time:149987ms step_avg:100.39ms | |
| step:1495/1770 train_time:150092ms step_avg:100.40ms | |
| step:1496/1770 train_time:150198ms step_avg:100.40ms | |
| step:1497/1770 train_time:150303ms step_avg:100.40ms | |
| step:1498/1770 train_time:150408ms step_avg:100.41ms | |
| step:1499/1770 train_time:150514ms step_avg:100.41ms | |
| step:1500/1770 train_time:150619ms step_avg:100.41ms | |
| step:1500/1770 val_loss:3.3445 train_time:150708ms step_avg:100.47ms | |
| step:1501/1770 train_time:150737ms step_avg:100.42ms | |
| step:1502/1770 train_time:150841ms step_avg:100.43ms | |
| step:1503/1770 train_time:150946ms step_avg:100.43ms | |
| step:1504/1770 train_time:151052ms step_avg:100.43ms | |
| step:1505/1770 train_time:151159ms step_avg:100.44ms | |
| step:1506/1770 train_time:151264ms step_avg:100.44ms | |
| step:1507/1770 train_time:151369ms step_avg:100.44ms | |
| step:1508/1770 train_time:151474ms step_avg:100.45ms | |
| step:1509/1770 train_time:151579ms step_avg:100.45ms | |
| step:1510/1770 train_time:151684ms step_avg:100.45ms | |
| step:1511/1770 train_time:151791ms step_avg:100.46ms | |
| step:1512/1770 train_time:151897ms step_avg:100.46ms | |
| step:1513/1770 train_time:152002ms step_avg:100.46ms | |
| step:1514/1770 train_time:152107ms step_avg:100.47ms | |
| step:1515/1770 train_time:152213ms step_avg:100.47ms | |
| step:1516/1770 train_time:152318ms step_avg:100.47ms | |
| step:1517/1770 train_time:152424ms step_avg:100.48ms | |
| step:1518/1770 train_time:152530ms step_avg:100.48ms | |
| step:1519/1770 train_time:152634ms step_avg:100.48ms | |
| step:1520/1770 train_time:152741ms step_avg:100.49ms | |
| step:1521/1770 train_time:152846ms step_avg:100.49ms | |
| step:1522/1770 train_time:152952ms step_avg:100.49ms | |
| step:1523/1770 train_time:153059ms step_avg:100.50ms | |
| step:1524/1770 train_time:153163ms step_avg:100.50ms | |
| step:1525/1770 train_time:153268ms step_avg:100.50ms | |
| step:1526/1770 train_time:153374ms step_avg:100.51ms | |
| step:1527/1770 train_time:153480ms step_avg:100.51ms | |
| step:1528/1770 train_time:153585ms step_avg:100.51ms | |
| step:1529/1770 train_time:153692ms step_avg:100.52ms | |
| step:1530/1770 train_time:153798ms step_avg:100.52ms | |
| step:1531/1770 train_time:153904ms step_avg:100.52ms | |
| step:1532/1770 train_time:154010ms step_avg:100.53ms | |
| step:1533/1770 train_time:154117ms step_avg:100.53ms | |
| step:1534/1770 train_time:154221ms step_avg:100.54ms | |
| step:1535/1770 train_time:154327ms step_avg:100.54ms | |
| step:1536/1770 train_time:154433ms step_avg:100.54ms | |
| step:1537/1770 train_time:154538ms step_avg:100.55ms | |
| step:1538/1770 train_time:154645ms step_avg:100.55ms | |
| step:1539/1770 train_time:154751ms step_avg:100.55ms | |
| step:1540/1770 train_time:154858ms step_avg:100.56ms | |
| step:1541/1770 train_time:154966ms step_avg:100.56ms | |
| step:1542/1770 train_time:155071ms step_avg:100.56ms | |
| step:1543/1770 train_time:155178ms step_avg:100.57ms | |
| step:1544/1770 train_time:155284ms step_avg:100.57ms | |
| step:1545/1770 train_time:155389ms step_avg:100.58ms | |
| step:1546/1770 train_time:155496ms step_avg:100.58ms | |
| step:1547/1770 train_time:155601ms step_avg:100.58ms | |
| step:1548/1770 train_time:155706ms step_avg:100.59ms | |
| step:1549/1770 train_time:155811ms step_avg:100.59ms | |
| step:1550/1770 train_time:155916ms step_avg:100.59ms | |
| step:1551/1770 train_time:156021ms step_avg:100.59ms | |
| step:1552/1770 train_time:156128ms step_avg:100.60ms | |
| step:1553/1770 train_time:156233ms step_avg:100.60ms | |
| step:1554/1770 train_time:156338ms step_avg:100.60ms | |
| step:1555/1770 train_time:156443ms step_avg:100.61ms | |
| step:1556/1770 train_time:156548ms step_avg:100.61ms | |
| step:1557/1770 train_time:156655ms step_avg:100.61ms | |
| step:1558/1770 train_time:156759ms step_avg:100.62ms | |
| step:1559/1770 train_time:156863ms step_avg:100.62ms | |
| step:1560/1770 train_time:156969ms step_avg:100.62ms | |
| step:1561/1770 train_time:157076ms step_avg:100.63ms | |
| step:1562/1770 train_time:157181ms step_avg:100.63ms | |
| step:1563/1770 train_time:157285ms step_avg:100.63ms | |
| step:1564/1770 train_time:157390ms step_avg:100.63ms | |
| step:1565/1770 train_time:157496ms step_avg:100.64ms | |
| step:1566/1770 train_time:157601ms step_avg:100.64ms | |
| step:1567/1770 train_time:157707ms step_avg:100.64ms | |
| step:1568/1770 train_time:157813ms step_avg:100.65ms | |
| step:1569/1770 train_time:157919ms step_avg:100.65ms | |
| step:1570/1770 train_time:158025ms step_avg:100.65ms | |
| step:1571/1770 train_time:158130ms step_avg:100.66ms | |
| step:1572/1770 train_time:158236ms step_avg:100.66ms | |
| step:1573/1770 train_time:158343ms step_avg:100.66ms | |
| step:1574/1770 train_time:158448ms step_avg:100.67ms | |
| step:1575/1770 train_time:158554ms step_avg:100.67ms | |
| step:1576/1770 train_time:158659ms step_avg:100.67ms | |
| step:1577/1770 train_time:158766ms step_avg:100.68ms | |
| step:1578/1770 train_time:158873ms step_avg:100.68ms | |
| step:1579/1770 train_time:158978ms step_avg:100.68ms | |
| step:1580/1770 train_time:159084ms step_avg:100.69ms | |
| step:1581/1770 train_time:159192ms step_avg:100.69ms | |
| step:1582/1770 train_time:159298ms step_avg:100.69ms | |
| step:1583/1770 train_time:159403ms step_avg:100.70ms | |
| step:1584/1770 train_time:159509ms step_avg:100.70ms | |
| step:1585/1770 train_time:159614ms step_avg:100.70ms | |
| step:1586/1770 train_time:159722ms step_avg:100.71ms | |
| step:1587/1770 train_time:159828ms step_avg:100.71ms | |
| step:1588/1770 train_time:159935ms step_avg:100.71ms | |
| step:1589/1770 train_time:160041ms step_avg:100.72ms | |
| step:1590/1770 train_time:160146ms step_avg:100.72ms | |
| step:1591/1770 train_time:160250ms step_avg:100.72ms | |
| step:1592/1770 train_time:160358ms step_avg:100.73ms | |
| step:1593/1770 train_time:160463ms step_avg:100.73ms | |
| step:1594/1770 train_time:160568ms step_avg:100.73ms | |
| step:1595/1770 train_time:160674ms step_avg:100.74ms | |
| step:1596/1770 train_time:160780ms step_avg:100.74ms | |
| step:1597/1770 train_time:160884ms step_avg:100.74ms | |
| step:1598/1770 train_time:160990ms step_avg:100.74ms | |
| step:1599/1770 train_time:161097ms step_avg:100.75ms | |
| step:1600/1770 train_time:161203ms step_avg:100.75ms | |
| step:1601/1770 train_time:161309ms step_avg:100.76ms | |
| step:1602/1770 train_time:161415ms step_avg:100.76ms | |
| step:1603/1770 train_time:161519ms step_avg:100.76ms | |
| step:1604/1770 train_time:161625ms step_avg:100.76ms | |
| step:1605/1770 train_time:161730ms step_avg:100.77ms | |
| step:1606/1770 train_time:161836ms step_avg:100.77ms | |
| step:1607/1770 train_time:161945ms step_avg:100.77ms | |
| step:1608/1770 train_time:162052ms step_avg:100.78ms | |
| step:1609/1770 train_time:162158ms step_avg:100.78ms | |
| step:1610/1770 train_time:162264ms step_avg:100.78ms | |
| step:1611/1770 train_time:162369ms step_avg:100.79ms | |
| step:1612/1770 train_time:162477ms step_avg:100.79ms | |
| step:1613/1770 train_time:162582ms step_avg:100.80ms | |
| step:1614/1770 train_time:162687ms step_avg:100.80ms | |
| step:1615/1770 train_time:162796ms step_avg:100.80ms | |
| step:1616/1770 train_time:162901ms step_avg:100.80ms | |
| step:1617/1770 train_time:163006ms step_avg:100.81ms | |
| step:1618/1770 train_time:163114ms step_avg:100.81ms | |
| step:1619/1770 train_time:163220ms step_avg:100.82ms | |
| step:1620/1770 train_time:163327ms step_avg:100.82ms | |
| step:1621/1770 train_time:163433ms step_avg:100.82ms | |
| step:1622/1770 train_time:163539ms step_avg:100.83ms | |
| step:1623/1770 train_time:163645ms step_avg:100.83ms | |
| step:1624/1770 train_time:163749ms step_avg:100.83ms | |
| step:1625/1770 train_time:163855ms step_avg:100.83ms | |
| step:1625/1770 val_loss:3.3099 train_time:163944ms step_avg:100.89ms | |
| step:1626/1770 train_time:163973ms step_avg:100.84ms | |
| step:1627/1770 train_time:164075ms step_avg:100.84ms | |
| step:1628/1770 train_time:164180ms step_avg:100.85ms | |
| step:1629/1770 train_time:164285ms step_avg:100.85ms | |
| step:1630/1770 train_time:164389ms step_avg:100.85ms | |
| step:1631/1770 train_time:164494ms step_avg:100.85ms | |
| step:1632/1770 train_time:164598ms step_avg:100.86ms | |
| step:1633/1770 train_time:164703ms step_avg:100.86ms | |
| step:1634/1770 train_time:164808ms step_avg:100.86ms | |
| step:1635/1770 train_time:164913ms step_avg:100.86ms | |
| step:1636/1770 train_time:165019ms step_avg:100.87ms | |
| step:1637/1770 train_time:165127ms step_avg:100.87ms | |
| step:1638/1770 train_time:165231ms step_avg:100.87ms | |
| step:1639/1770 train_time:165339ms step_avg:100.88ms | |
| step:1640/1770 train_time:165445ms step_avg:100.88ms | |
| step:1641/1770 train_time:165550ms step_avg:100.88ms | |
| step:1642/1770 train_time:165655ms step_avg:100.89ms | |
| step:1643/1770 train_time:165761ms step_avg:100.89ms | |
| step:1644/1770 train_time:165867ms step_avg:100.89ms | |
| step:1645/1770 train_time:165973ms step_avg:100.90ms | |
| step:1646/1770 train_time:166078ms step_avg:100.90ms | |
| step:1647/1770 train_time:166185ms step_avg:100.90ms | |
| step:1648/1770 train_time:166290ms step_avg:100.90ms | |
| step:1649/1770 train_time:166395ms step_avg:100.91ms | |
| step:1650/1770 train_time:166501ms step_avg:100.91ms | |
| step:1651/1770 train_time:166607ms step_avg:100.91ms | |
| step:1652/1770 train_time:166712ms step_avg:100.92ms | |
| step:1653/1770 train_time:166817ms step_avg:100.92ms | |
| step:1654/1770 train_time:166925ms step_avg:100.92ms | |
| step:1655/1770 train_time:167032ms step_avg:100.93ms | |
| step:1656/1770 train_time:167137ms step_avg:100.93ms | |
| step:1657/1770 train_time:167243ms step_avg:100.93ms | |
| step:1658/1770 train_time:167349ms step_avg:100.93ms | |
| step:1659/1770 train_time:167456ms step_avg:100.94ms | |
| step:1660/1770 train_time:167562ms step_avg:100.94ms | |
| step:1661/1770 train_time:167668ms step_avg:100.94ms | |
| step:1662/1770 train_time:167772ms step_avg:100.95ms | |
| step:1663/1770 train_time:167877ms step_avg:100.95ms | |
| step:1664/1770 train_time:167983ms step_avg:100.95ms | |
| step:1665/1770 train_time:168088ms step_avg:100.95ms | |
| step:1666/1770 train_time:168194ms step_avg:100.96ms | |
| step:1667/1770 train_time:168299ms step_avg:100.96ms | |
| step:1668/1770 train_time:168406ms step_avg:100.96ms | |
| step:1669/1770 train_time:168511ms step_avg:100.96ms | |
| step:1670/1770 train_time:168616ms step_avg:100.97ms | |
| step:1671/1770 train_time:168722ms step_avg:100.97ms | |
| step:1672/1770 train_time:168828ms step_avg:100.97ms | |
| step:1673/1770 train_time:168934ms step_avg:100.98ms | |
| step:1674/1770 train_time:169039ms step_avg:100.98ms | |
| step:1675/1770 train_time:169146ms step_avg:100.98ms | |
| step:1676/1770 train_time:169252ms step_avg:100.99ms | |
| step:1677/1770 train_time:169359ms step_avg:100.99ms | |
| step:1678/1770 train_time:169465ms step_avg:100.99ms | |
| step:1679/1770 train_time:169571ms step_avg:101.00ms | |
| step:1680/1770 train_time:169676ms step_avg:101.00ms | |
| step:1681/1770 train_time:169781ms step_avg:101.00ms | |
| step:1682/1770 train_time:169891ms step_avg:101.01ms | |
| step:1683/1770 train_time:169995ms step_avg:101.01ms | |
| step:1684/1770 train_time:170100ms step_avg:101.01ms | |
| step:1685/1770 train_time:170207ms step_avg:101.01ms | |
| step:1686/1770 train_time:170315ms step_avg:101.02ms | |
| step:1687/1770 train_time:170421ms step_avg:101.02ms | |
| step:1688/1770 train_time:170528ms step_avg:101.02ms | |
| step:1689/1770 train_time:170634ms step_avg:101.03ms | |
| step:1690/1770 train_time:170738ms step_avg:101.03ms | |
| step:1691/1770 train_time:170845ms step_avg:101.03ms | |
| step:1692/1770 train_time:170952ms step_avg:101.04ms | |
| step:1693/1770 train_time:171057ms step_avg:101.04ms | |
| step:1694/1770 train_time:171164ms step_avg:101.04ms | |
| step:1695/1770 train_time:171270ms step_avg:101.04ms | |
| step:1696/1770 train_time:171375ms step_avg:101.05ms | |
| step:1697/1770 train_time:171482ms step_avg:101.05ms | |
| step:1698/1770 train_time:171588ms step_avg:101.05ms | |
| step:1699/1770 train_time:171691ms step_avg:101.05ms | |
| step:1700/1770 train_time:171796ms step_avg:101.06ms | |
| step:1701/1770 train_time:171901ms step_avg:101.06ms | |
| step:1702/1770 train_time:172006ms step_avg:101.06ms | |
| step:1703/1770 train_time:172112ms step_avg:101.06ms | |
| step:1704/1770 train_time:172219ms step_avg:101.07ms | |
| step:1705/1770 train_time:172326ms step_avg:101.07ms | |
| step:1706/1770 train_time:172431ms step_avg:101.07ms | |
| step:1707/1770 train_time:172539ms step_avg:101.08ms | |
| step:1708/1770 train_time:172646ms step_avg:101.08ms | |
| step:1709/1770 train_time:172753ms step_avg:101.08ms | |
| step:1710/1770 train_time:172863ms step_avg:101.09ms | |
| step:1711/1770 train_time:172972ms step_avg:101.09ms | |
| step:1712/1770 train_time:173078ms step_avg:101.10ms | |
| step:1713/1770 train_time:173184ms step_avg:101.10ms | |
| step:1714/1770 train_time:173290ms step_avg:101.10ms | |
| step:1715/1770 train_time:173397ms step_avg:101.11ms | |
| step:1716/1770 train_time:173505ms step_avg:101.11ms | |
| step:1717/1770 train_time:173611ms step_avg:101.11ms | |
| step:1718/1770 train_time:173717ms step_avg:101.12ms | |
| step:1719/1770 train_time:173824ms step_avg:101.12ms | |
| step:1720/1770 train_time:173933ms step_avg:101.12ms | |
| step:1721/1770 train_time:174039ms step_avg:101.13ms | |
| step:1722/1770 train_time:174148ms step_avg:101.13ms | |
| step:1723/1770 train_time:174256ms step_avg:101.14ms | |
| step:1724/1770 train_time:174365ms step_avg:101.14ms | |
| step:1725/1770 train_time:174474ms step_avg:101.14ms | |
| step:1726/1770 train_time:174581ms step_avg:101.15ms | |
| step:1727/1770 train_time:174689ms step_avg:101.15ms | |
| step:1728/1770 train_time:174797ms step_avg:101.16ms | |
| step:1729/1770 train_time:174904ms step_avg:101.16ms | |
| step:1730/1770 train_time:175010ms step_avg:101.16ms | |
| step:1731/1770 train_time:175117ms step_avg:101.17ms | |
| step:1732/1770 train_time:175224ms step_avg:101.17ms | |
| step:1733/1770 train_time:175332ms step_avg:101.17ms | |
| step:1734/1770 train_time:175438ms step_avg:101.18ms | |
| step:1735/1770 train_time:175545ms step_avg:101.18ms | |
| step:1736/1770 train_time:175651ms step_avg:101.18ms | |
| step:1737/1770 train_time:175757ms step_avg:101.18ms | |
| step:1738/1770 train_time:175865ms step_avg:101.19ms | |
| step:1739/1770 train_time:175971ms step_avg:101.19ms | |
| step:1740/1770 train_time:176077ms step_avg:101.19ms | |
| step:1741/1770 train_time:176186ms step_avg:101.20ms | |
| step:1742/1770 train_time:176294ms step_avg:101.20ms | |
| step:1743/1770 train_time:176400ms step_avg:101.20ms | |
| step:1744/1770 train_time:176507ms step_avg:101.21ms | |
| step:1745/1770 train_time:176613ms step_avg:101.21ms | |
| step:1746/1770 train_time:176722ms step_avg:101.22ms | |
| step:1747/1770 train_time:176828ms step_avg:101.22ms | |
| step:1748/1770 train_time:176936ms step_avg:101.22ms | |
| step:1749/1770 train_time:177042ms step_avg:101.22ms | |
| step:1750/1770 train_time:177149ms step_avg:101.23ms | |
| step:1750/1770 val_loss:3.2828 train_time:177238ms step_avg:101.28ms | |
| step:1751/1770 train_time:177267ms step_avg:101.24ms | |
| step:1752/1770 train_time:177369ms step_avg:101.24ms | |
| step:1753/1770 train_time:177474ms step_avg:101.24ms | |
| step:1754/1770 train_time:177580ms step_avg:101.24ms | |
| step:1755/1770 train_time:177685ms step_avg:101.25ms | |
| step:1756/1770 train_time:177792ms step_avg:101.25ms | |
| step:1757/1770 train_time:177899ms step_avg:101.25ms | |
| step:1758/1770 train_time:178004ms step_avg:101.25ms | |
| step:1759/1770 train_time:178111ms step_avg:101.26ms | |
| step:1760/1770 train_time:178218ms step_avg:101.26ms | |
| step:1761/1770 train_time:178327ms step_avg:101.26ms | |
| step:1762/1770 train_time:178435ms step_avg:101.27ms | |
| step:1763/1770 train_time:178540ms step_avg:101.27ms | |
| step:1764/1770 train_time:178646ms step_avg:101.27ms | |
| step:1765/1770 train_time:178752ms step_avg:101.28ms | |
| step:1766/1770 train_time:178860ms step_avg:101.28ms | |
| step:1767/1770 train_time:178966ms step_avg:101.28ms | |
| step:1768/1770 train_time:179073ms step_avg:101.29ms | |
| step:1769/1770 train_time:179178ms step_avg:101.29ms | |
| step:1770/1770 train_time:179285ms step_avg:101.29ms | |
| step:1770/1770 val_loss:3.2798 train_time:179377ms step_avg:101.34ms | |
| peak memory allocated: 31481 MiB reserved: 46892 MiB | |